Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
D
db_gpt
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
linyangyang
db_gpt
Commits
cc2841b3
Commit
cc2841b3
authored
Aug 13, 2024
by
林洋洋
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
修改token 存到cookies
parent
8b0030b8
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
136 additions
and
128 deletions
+136
-128
middleware.py
dbgpt/app/apps/core/middleware.py
+103
-111
views.py
dbgpt/app/apps/system/views.py
+10
-10
crud.py
dbgpt/app/apps/vadmin/auth/crud.py
+1
-1
user.py
dbgpt/app/apps/vadmin/auth/models/user.py
+0
-1
user.py
dbgpt/app/apps/vadmin/auth/params/user.py
+2
-0
current.py
dbgpt/app/apps/vadmin/auth/utils/current.py
+2
-2
login.py
dbgpt/app/apps/vadmin/auth/utils/login.py
+9
-2
views.py
dbgpt/app/apps/vadmin/auth/views.py
+3
-0
embeddings.py
dbgpt/rag/embedding/embeddings.py
+6
-1
No files found.
dbgpt/app/apps/core/middleware.py
View file @
cc2841b3
...
@@ -13,42 +13,34 @@ import datetime
...
@@ -13,42 +13,34 @@ import datetime
import
json
import
json
import
time
import
time
from
fastapi
import
Request
,
Response
from
fastapi
import
Request
,
Response
from
core.logger
import
logger
from
fastapi
import
FastAPI
from
fastapi.routing
import
APIRoute
from
user_agents
import
parse
from
application.settings
import
OPERATION_RECORD_METHOD
,
MONGO_DB_ENABLE
,
IGNORE_OPERATION_FUNCTION
,
\
DEMO_WHITE_LIST_PATH
,
DEMO
,
DEMO_BLACK_LIST_PATH
from
utils.response
import
ErrorResponse
from
apps.vadmin.record.crud
import
OperationRecordDal
from
core.database
import
mongo_getter
from
utils
import
status
def
write_request_log
(
request
:
Request
,
response
:
Response
):
from
fastapi
import
FastAPI
http_version
=
f
"http/{request.scope['http_version']}"
content_length
=
response
.
raw_headers
[
0
][
1
]
process_time
=
response
.
headers
[
"X-Process-Time"
]
content
=
f
"basehttp.log_message: '{request.method} {request.url} {http_version}' {response.status_code}"
\
f
"{response.charset} {content_length} {process_time}"
logger
.
info
(
content
)
def
register_request_log_middleware
(
app
:
FastAPI
):
# def write_request_log(request: Request, response: Response):
"""
# http_version = f"http/{request.scope['http_version']}"
记录请求日志中间件
# content_length = response.raw_headers[0][1]
:param app:
# process_time = response.headers["X-Process-Time"]
:return:
# content = f"basehttp.log_message: '{request.method} {request.url} {http_version}' {response.status_code}" \
"""
# f"{response.charset} {content_length} {process_time}"
# logger.info(content)
@
app
.
middleware
(
"http"
)
#
async
def
request_log_middleware
(
request
:
Request
,
call_next
):
# def register_request_log_middleware(app: FastAPI):
start_time
=
time
.
time
()
# """
response
=
await
call_next
(
request
)
# 记录请求日志中间件
process_time
=
time
.
time
()
-
start_time
# :param app:
response
.
headers
[
"X-Process-Time"
]
=
str
(
process_time
)
# :return:
write_request_log
(
request
,
response
)
# """
return
response
#
# @app.middleware("http")
# async def request_log_middleware(request: Request, call_next):
# start_time = time.time()
# response = await call_next(request)
# process_time = time.time() - start_time
# response.headers["X-Process-Time"] = str(process_time)
# write_request_log(request, response)
# return response
def
register_operation_record_middleware
(
app
:
FastAPI
):
def
register_operation_record_middleware
(
app
:
FastAPI
):
...
@@ -59,87 +51,87 @@ def register_operation_record_middleware(app: FastAPI):
...
@@ -59,87 +51,87 @@ def register_operation_record_middleware(app: FastAPI):
:return:
:return:
"""
"""
@
app
.
middleware
(
"http"
)
#
@app.middleware("http")
async
def
operation_record_middleware
(
request
:
Request
,
call_next
):
#
async def operation_record_middleware(request: Request, call_next):
start_time
=
time
.
time
()
#
start_time = time.time()
response
=
await
call_next
(
request
)
#
response = await call_next(request)
if
not
MONGO_DB_ENABLE
:
#
if not MONGO_DB_ENABLE:
return
response
#
return response
telephone
=
request
.
scope
.
get
(
'telephone'
,
None
)
#
telephone = request.scope.get('telephone', None)
user_id
=
request
.
scope
.
get
(
'user_id'
,
None
)
#
user_id = request.scope.get('user_id', None)
user_name
=
request
.
scope
.
get
(
'user_name'
,
None
)
#
user_name = request.scope.get('user_name', None)
route
=
request
.
scope
.
get
(
'route'
)
#
route = request.scope.get('route')
if
not
telephone
:
#
if not telephone:
return
response
#
return response
elif
request
.
method
not
in
OPERATION_RECORD_METHOD
:
#
elif request.method not in OPERATION_RECORD_METHOD:
return
response
#
return response
elif
route
.
name
in
IGNORE_OPERATION_FUNCTION
:
#
elif route.name in IGNORE_OPERATION_FUNCTION:
return
response
#
return response
process_time
=
time
.
time
()
-
start_time
#
process_time = time.time() - start_time
user_agent
=
parse
(
request
.
headers
.
get
(
"user-agent"
))
#
user_agent = parse(request.headers.get("user-agent"))
system
=
f
"{user_agent.os.family} {user_agent.os.version_string}"
#
system = f"{user_agent.os.family} {user_agent.os.version_string}"
browser
=
f
"{user_agent.browser.family} {user_agent.browser.version_string}"
#
browser = f"{user_agent.browser.family} {user_agent.browser.version_string}"
query_params
=
dict
(
request
.
query_params
.
multi_items
())
#
query_params = dict(request.query_params.multi_items())
path_params
=
request
.
path_params
#
path_params = request.path_params
if
isinstance
(
request
.
scope
.
get
(
'body'
),
str
):
#
if isinstance(request.scope.get('body'), str):
body
=
request
.
scope
.
get
(
'body'
)
#
body = request.scope.get('body')
else
:
#
else:
body
=
request
.
scope
.
get
(
'body'
)
.
decode
()
#
body = request.scope.get('body').decode()
if
body
:
#
if body:
body
=
json
.
loads
(
body
)
#
body = json.loads(body)
params
=
{
#
params = {
"body"
:
body
,
#
"body": body,
"query_params"
:
query_params
if
query_params
else
None
,
#
"query_params": query_params if query_params else None,
"path_params"
:
path_params
if
path_params
else
None
,
#
"path_params": path_params if path_params else None,
}
#
}
content_length
=
response
.
raw_headers
[
0
][
1
]
#
content_length = response.raw_headers[0][1]
assert
isinstance
(
route
,
APIRoute
)
#
assert isinstance(route, APIRoute)
document
=
{
#
document = {
"process_time"
:
process_time
,
#
"process_time": process_time,
"telephone"
:
telephone
,
#
"telephone": telephone,
"user_id"
:
user_id
,
#
"user_id": user_id,
"user_name"
:
user_name
,
#
"user_name": user_name,
"request_api"
:
request
.
url
.
__str__
(),
#
"request_api": request.url.__str__(),
"client_ip"
:
request
.
client
.
host
,
#
"client_ip": request.client.host,
"system"
:
system
,
#
"system": system,
"browser"
:
browser
,
#
"browser": browser,
"request_method"
:
request
.
method
,
#
"request_method": request.method,
"api_path"
:
route
.
path
,
#
"api_path": route.path,
"summary"
:
route
.
summary
,
#
"summary": route.summary,
"description"
:
route
.
description
,
#
"description": route.description,
"tags"
:
route
.
tags
,
#
"tags": route.tags,
"route_name"
:
route
.
name
,
#
"route_name": route.name,
"status_code"
:
response
.
status_code
,
#
"status_code": response.status_code,
"content_length"
:
content_length
,
#
"content_length": content_length,
"create_datetime"
:
datetime
.
datetime
.
now
()
.
strftime
(
"
%
Y-
%
m-
%
d
%
H:
%
M:
%
S"
),
#
"create_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"params"
:
json
.
dumps
(
params
)
#
"params": json.dumps(params)
}
#
}
await
OperationRecordDal
(
mongo_getter
(
request
))
.
create_data
(
document
)
#
await OperationRecordDal(mongo_getter(request)).create_data(document)
return
response
#
return response
def
register_demo_env_middleware
(
app
:
FastAPI
):
#
def register_demo_env_middleware(app: FastAPI):
"""
#
"""
演示环境中间件
#
演示环境中间件
:param app:
#
:param app:
:return:
#
:return:
"""
#
"""
#
@
app
.
middleware
(
"http"
)
#
@app.middleware("http")
async
def
demo_env_middleware
(
request
:
Request
,
call_next
):
#
async def demo_env_middleware(request: Request, call_next):
path
=
request
.
scope
.
get
(
"path"
)
#
path = request.scope.get("path")
if
request
.
method
!=
"GET"
:
#
if request.method != "GET":
print
(
"路由:"
,
path
,
request
.
method
)
#
print("路由:", path, request.method)
if
DEMO
and
request
.
method
!=
"GET"
:
#
if DEMO and request.method != "GET":
if
path
in
DEMO_BLACK_LIST_PATH
:
#
if path in DEMO_BLACK_LIST_PATH:
return
ErrorResponse
(
#
return ErrorResponse(
status
=
status
.
HTTP_403_FORBIDDEN
,
#
status=status.HTTP_403_FORBIDDEN,
code
=
status
.
HTTP_403_FORBIDDEN
,
#
code=status.HTTP_403_FORBIDDEN,
msg
=
"演示环境,禁止操作"
#
msg="演示环境,禁止操作"
)
#
)
elif
path
not
in
DEMO_WHITE_LIST_PATH
:
#
elif path not in DEMO_WHITE_LIST_PATH:
return
ErrorResponse
(
msg
=
"演示环境,禁止操作"
)
#
return ErrorResponse(msg="演示环境,禁止操作")
return
await
call_next
(
request
)
#
return await call_next(request)
def
register_jwt_refresh_middleware
(
app
:
FastAPI
):
def
register_jwt_refresh_middleware
(
app
:
FastAPI
):
...
...
dbgpt/app/apps/system/views.py
View file @
cc2841b3
...
@@ -111,16 +111,16 @@ async def get_dict_detail(data_id: int, auth: Auth = Depends(AllUserAuth())):
...
@@ -111,16 +111,16 @@ async def get_dict_detail(data_id: int, auth: Auth = Depends(AllUserAuth())):
# return SuccessResponse(result)
# return SuccessResponse(result)
@
router
.
post
(
"/upload/video/to/oss"
,
summary
=
"上传视频到阿里云OSS"
)
#
@router.post("/upload/video/to/oss", summary="上传视频到阿里云OSS")
async
def
upload_video_to_oss
(
file
:
UploadFile
,
path
:
str
=
Form
(
...
)):
#
async def upload_video_to_oss(file: UploadFile, path: str = Form(...)):
result
=
await
AliyunOSS
(
BucketConf
(
**
ALIYUN_OSS
))
.
upload_video
(
path
,
file
)
#
result = await AliyunOSS(BucketConf(**ALIYUN_OSS)).upload_video(path, file)
return
SuccessResponse
(
result
)
#
return SuccessResponse(result)
#
#
@
router
.
post
(
"/upload/file/to/oss"
,
summary
=
"上传文件到阿里云OSS"
)
#
@router.post("/upload/file/to/oss", summary="上传文件到阿里云OSS")
async
def
upload_file_to_oss
(
file
:
UploadFile
,
path
:
str
=
Form
(
...
)):
#
async def upload_file_to_oss(file: UploadFile, path: str = Form(...)):
result
=
await
AliyunOSS
(
BucketConf
(
**
ALIYUN_OSS
))
.
upload_file
(
path
,
file
)
#
result = await AliyunOSS(BucketConf(**ALIYUN_OSS)).upload_file(path, file)
return
SuccessResponse
(
result
)
#
return SuccessResponse(result)
@
router
.
post
(
"/upload/image/to/local"
,
summary
=
"上传图片到本地"
)
@
router
.
post
(
"/upload/image/to/local"
,
summary
=
"上传图片到本地"
)
...
...
dbgpt/app/apps/vadmin/auth/crud.py
View file @
cc2841b3
...
@@ -372,7 +372,7 @@ class UserDal(DalBase):
...
@@ -372,7 +372,7 @@ class UserDal(DalBase):
user
[
"send_sms_msg"
]
=
"重置密码失败"
user
[
"send_sms_msg"
]
=
"重置密码失败"
continue
continue
password
:
str
=
user
.
pop
(
"password"
)
password
:
str
=
user
.
pop
(
"password"
)
email
:
str
=
user
.
get
(
"email"
,
None
)
#
email: str = user.get("email", None)
# if email:
# if email:
# subject = "密码已重置"
# subject = "密码已重置"
# body = f"您好,您的密码已经重置为{password},请及时登录并修改密码。"
# body = f"您好,您的密码已经重置为{password},请及时登录并修改密码。"
...
...
dbgpt/app/apps/vadmin/auth/models/user.py
View file @
cc2841b3
...
@@ -43,7 +43,6 @@ class VadminUser(BaseModel):
...
@@ -43,7 +43,6 @@ class VadminUser(BaseModel):
roles
:
Mapped
[
set
[
VadminRole
]]
=
relationship
(
secondary
=
vadmin_auth_user_roles
)
roles
:
Mapped
[
set
[
VadminRole
]]
=
relationship
(
secondary
=
vadmin_auth_user_roles
)
depts
:
Mapped
[
set
[
VadminDept
]]
=
relationship
(
secondary
=
vadmin_auth_user_depts
)
depts
:
Mapped
[
set
[
VadminDept
]]
=
relationship
(
secondary
=
vadmin_auth_user_depts
)
@
staticmethod
@
staticmethod
def
get_password_hash
(
password
:
str
)
->
str
:
def
get_password_hash
(
password
:
str
)
->
str
:
"""
"""
...
...
dbgpt/app/apps/vadmin/auth/params/user.py
View file @
cc2841b3
...
@@ -25,6 +25,7 @@ class UserParams(QueryParams):
...
@@ -25,6 +25,7 @@ class UserParams(QueryParams):
email
:
str
|
None
=
Query
(
None
,
title
=
"邮箱"
),
email
:
str
|
None
=
Query
(
None
,
title
=
"邮箱"
),
is_active
:
bool
|
None
=
Query
(
None
,
title
=
"是否可用"
),
is_active
:
bool
|
None
=
Query
(
None
,
title
=
"是否可用"
),
is_staff
:
bool
|
None
=
Query
(
None
,
title
=
"是否为工作人员"
),
is_staff
:
bool
|
None
=
Query
(
None
,
title
=
"是否为工作人员"
),
dept_id
:
int
|
None
=
Query
(
None
,
title
=
"部门信息"
),
params
:
Paging
=
Depends
()
params
:
Paging
=
Depends
()
):
):
super
()
.
__init__
(
params
)
super
()
.
__init__
(
params
)
...
@@ -33,5 +34,6 @@ class UserParams(QueryParams):
...
@@ -33,5 +34,6 @@ class UserParams(QueryParams):
self
.
email
=
(
"like"
,
email
)
self
.
email
=
(
"like"
,
email
)
self
.
is_active
=
is_active
self
.
is_active
=
is_active
self
.
is_staff
=
is_staff
self
.
is_staff
=
is_staff
self
.
dept_id
=
dept_id
dbgpt/app/apps/vadmin/auth/utils/current.py
View file @
cc2841b3
...
@@ -18,7 +18,7 @@ from dbgpt.app.apps.config import settings
...
@@ -18,7 +18,7 @@ from dbgpt.app.apps.config import settings
from
dbgpt.app.apps.core.database
import
db_getter
from
dbgpt.app.apps.core.database
import
db_getter
from
.validation.auth
import
Auth
from
.validation.auth
import
Auth
from
fastapi
import
Cookie
,
HTTPException
class
OpenAuth
(
AuthValidation
):
class
OpenAuth
(
AuthValidation
):
"""
"""
...
@@ -85,7 +85,7 @@ class FullAdminAuth(AuthValidation):
...
@@ -85,7 +85,7 @@ class FullAdminAuth(AuthValidation):
async
def
__call__
(
async
def
__call__
(
self
,
self
,
request
:
Request
,
request
:
Request
,
token
:
str
=
Depends
(
settings
.
oauth2_schem
e
),
token
:
str
=
Cookie
(
Non
e
),
db
:
AsyncSession
=
Depends
(
db_getter
)
db
:
AsyncSession
=
Depends
(
db_getter
)
)
->
Auth
:
)
->
Auth
:
"""
"""
...
...
dbgpt/app/apps/vadmin/auth/utils/login.py
View file @
cc2841b3
...
@@ -35,6 +35,8 @@ from dbgpt.app.apps.vadmin.auth.crud import MenuDal, UserDal
...
@@ -35,6 +35,8 @@ from dbgpt.app.apps.vadmin.auth.crud import MenuDal, UserDal
from
dbgpt.app.apps.vadmin.auth.models
import
VadminUser
from
dbgpt.app.apps.vadmin.auth.models
import
VadminUser
from
.current
import
FullAdminAuth
from
.current
import
FullAdminAuth
from
.validation.auth
import
Auth
from
.validation.auth
import
Auth
from
fastapi.responses
import
JSONResponse
from
fastapi
import
Response
import
jwt
import
jwt
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -89,6 +91,7 @@ async def login_for_access_token(
...
@@ -89,6 +91,7 @@ async def login_for_access_token(
payload
=
{
"sub"
:
result
.
user
.
telephone
,
"is_refresh"
:
True
,
"password"
:
result
.
user
.
password
},
payload
=
{
"sub"
:
result
.
user
.
telephone
,
"is_refresh"
:
True
,
"password"
:
result
.
user
.
password
},
expires
=
expires
expires
=
expires
)
)
resp
=
{
resp
=
{
"access_token"
:
access_token
,
"access_token"
:
access_token
,
"refresh_token"
:
refresh_token
,
"refresh_token"
:
refresh_token
,
...
@@ -96,8 +99,10 @@ async def login_for_access_token(
...
@@ -96,8 +99,10 @@ async def login_for_access_token(
"is_reset_password"
:
result
.
user
.
is_reset_password
,
"is_reset_password"
:
result
.
user
.
is_reset_password
,
"is_wx_server_openid"
:
result
.
user
.
is_wx_server_openid
"is_wx_server_openid"
:
result
.
user
.
is_wx_server_openid
}
}
response
=
JSONResponse
(
resp
)
response
.
set_cookie
(
key
=
"token"
,
value
=
access_token
,
domain
=
""
)
# await VadminLoginRecord.create_login_record(db, data, True, request, resp)
# await VadminLoginRecord.create_login_record(db, data, True, request, resp)
return
SuccessResponse
(
resp
)
return
response
except
ValueError
as
e
:
except
ValueError
as
e
:
# await VadminLoginRecord.create_login_record(db, data, False, request, {"message": str(e)})
# await VadminLoginRecord.create_login_record(db, data, False, request, {"message": str(e)})
return
ErrorResponse
(
msg
=
str
(
e
))
return
ErrorResponse
(
msg
=
str
(
e
))
...
@@ -134,4 +139,6 @@ async def token_refresh(refresh: str = Body(..., title="刷新Token")):
...
@@ -134,4 +139,6 @@ async def token_refresh(refresh: str = Body(..., title="刷新Token")):
"refresh_token"
:
refresh_token
,
"refresh_token"
:
refresh_token
,
"token_type"
:
"bearer"
"token_type"
:
"bearer"
}
}
return
SuccessResponse
(
resp
)
response
=
JSONResponse
(
resp
)
response
.
set_cookie
(
key
=
"jwt"
,
value
=
access_token
,
httponly
=
True
)
return
response
dbgpt/app/apps/vadmin/auth/views.py
View file @
cc2841b3
...
@@ -43,6 +43,9 @@ async def get_users(
...
@@ -43,6 +43,9 @@ async def get_users(
**
params
.
dict
(),
**
params
.
dict
(),
v_options
=
options
,
v_options
=
options
,
v_schema
=
schema
,
v_schema
=
schema
,
v_outer_join
=
[
[
models
.
vadmin_auth_user_depts
,
params
.
dept_id
==
models
.
vadmin_auth_user_depts
.
c
.
dept_id
],
],
v_return_count
=
True
v_return_count
=
True
)
)
return
SuccessResponse
(
datas
,
count
=
count
)
return
SuccessResponse
(
datas
,
count
=
count
)
...
...
dbgpt/rag/embedding/embeddings.py
View file @
cc2841b3
...
@@ -115,8 +115,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
...
@@ -115,8 +115,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
sentence_transformers
.
SentenceTransformer
.
stop_multi_process_pool
(
pool
)
sentence_transformers
.
SentenceTransformer
.
stop_multi_process_pool
(
pool
)
else
:
else
:
embeddings
=
self
.
client
.
encode
(
texts
,
**
self
.
encode_kwargs
)
embeddings
=
self
.
client
.
encode
(
texts
,
**
self
.
encode_kwargs
)
if
len
(
embeddings
):
return
embeddings
.
tolist
()
return
embeddings
.
tolist
()
else
:
return
[]
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
"""Compute query embeddings using a HuggingFace transformer model.
"""Compute query embeddings using a HuggingFace transformer model.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment