Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
D
db_gpt
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
linyangyang
db_gpt
Commits
cc2841b3
Commit
cc2841b3
authored
Aug 13, 2024
by
林洋洋
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
修改token 存到cookies
parent
8b0030b8
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
136 additions
and
128 deletions
+136
-128
middleware.py
dbgpt/app/apps/core/middleware.py
+103
-111
views.py
dbgpt/app/apps/system/views.py
+10
-10
crud.py
dbgpt/app/apps/vadmin/auth/crud.py
+1
-1
user.py
dbgpt/app/apps/vadmin/auth/models/user.py
+0
-1
user.py
dbgpt/app/apps/vadmin/auth/params/user.py
+2
-0
current.py
dbgpt/app/apps/vadmin/auth/utils/current.py
+2
-2
login.py
dbgpt/app/apps/vadmin/auth/utils/login.py
+9
-2
views.py
dbgpt/app/apps/vadmin/auth/views.py
+3
-0
embeddings.py
dbgpt/rag/embedding/embeddings.py
+6
-1
No files found.
dbgpt/app/apps/core/middleware.py
View file @
cc2841b3
...
...
@@ -13,42 +13,34 @@ import datetime
import
json
import
time
from
fastapi
import
Request
,
Response
from
core.logger
import
logger
from
fastapi
import
FastAPI
from
fastapi.routing
import
APIRoute
from
user_agents
import
parse
from
application.settings
import
OPERATION_RECORD_METHOD
,
MONGO_DB_ENABLE
,
IGNORE_OPERATION_FUNCTION
,
\
DEMO_WHITE_LIST_PATH
,
DEMO
,
DEMO_BLACK_LIST_PATH
from
utils.response
import
ErrorResponse
from
apps.vadmin.record.crud
import
OperationRecordDal
from
core.database
import
mongo_getter
from
utils
import
status
def
write_request_log
(
request
:
Request
,
response
:
Response
):
http_version
=
f
"http/{request.scope['http_version']}"
content_length
=
response
.
raw_headers
[
0
][
1
]
process_time
=
response
.
headers
[
"X-Process-Time"
]
content
=
f
"basehttp.log_message: '{request.method} {request.url} {http_version}' {response.status_code}"
\
f
"{response.charset} {content_length} {process_time}"
logger
.
info
(
content
)
from
fastapi
import
FastAPI
def
register_request_log_middleware
(
app
:
FastAPI
):
"""
记录请求日志中间件
:param app:
:return:
"""
# def write_request_log(request: Request, response: Response):
# http_version = f"http/{request.scope['http_version']}"
# content_length = response.raw_headers[0][1]
# process_time = response.headers["X-Process-Time"]
# content = f"basehttp.log_message: '{request.method} {request.url} {http_version}' {response.status_code}" \
# f"{response.charset} {content_length} {process_time}"
# logger.info(content)
@
app
.
middleware
(
"http"
)
async
def
request_log_middleware
(
request
:
Request
,
call_next
):
start_time
=
time
.
time
()
response
=
await
call_next
(
request
)
process_time
=
time
.
time
()
-
start_time
response
.
headers
[
"X-Process-Time"
]
=
str
(
process_time
)
write_request_log
(
request
,
response
)
return
response
#
# def register_request_log_middleware(app: FastAPI):
# """
# 记录请求日志中间件
# :param app:
# :return:
# """
#
# @app.middleware("http")
# async def request_log_middleware(request: Request, call_next):
# start_time = time.time()
# response = await call_next(request)
# process_time = time.time() - start_time
# response.headers["X-Process-Time"] = str(process_time)
# write_request_log(request, response)
# return response
def
register_operation_record_middleware
(
app
:
FastAPI
):
...
...
@@ -59,87 +51,87 @@ def register_operation_record_middleware(app: FastAPI):
:return:
"""
@
app
.
middleware
(
"http"
)
async
def
operation_record_middleware
(
request
:
Request
,
call_next
):
start_time
=
time
.
time
()
response
=
await
call_next
(
request
)
if
not
MONGO_DB_ENABLE
:
return
response
telephone
=
request
.
scope
.
get
(
'telephone'
,
None
)
user_id
=
request
.
scope
.
get
(
'user_id'
,
None
)
user_name
=
request
.
scope
.
get
(
'user_name'
,
None
)
route
=
request
.
scope
.
get
(
'route'
)
if
not
telephone
:
return
response
elif
request
.
method
not
in
OPERATION_RECORD_METHOD
:
return
response
elif
route
.
name
in
IGNORE_OPERATION_FUNCTION
:
return
response
process_time
=
time
.
time
()
-
start_time
user_agent
=
parse
(
request
.
headers
.
get
(
"user-agent"
))
system
=
f
"{user_agent.os.family} {user_agent.os.version_string}"
browser
=
f
"{user_agent.browser.family} {user_agent.browser.version_string}"
query_params
=
dict
(
request
.
query_params
.
multi_items
())
path_params
=
request
.
path_params
if
isinstance
(
request
.
scope
.
get
(
'body'
),
str
):
body
=
request
.
scope
.
get
(
'body'
)
else
:
body
=
request
.
scope
.
get
(
'body'
)
.
decode
()
if
body
:
body
=
json
.
loads
(
body
)
params
=
{
"body"
:
body
,
"query_params"
:
query_params
if
query_params
else
None
,
"path_params"
:
path_params
if
path_params
else
None
,
}
content_length
=
response
.
raw_headers
[
0
][
1
]
assert
isinstance
(
route
,
APIRoute
)
document
=
{
"process_time"
:
process_time
,
"telephone"
:
telephone
,
"user_id"
:
user_id
,
"user_name"
:
user_name
,
"request_api"
:
request
.
url
.
__str__
(),
"client_ip"
:
request
.
client
.
host
,
"system"
:
system
,
"browser"
:
browser
,
"request_method"
:
request
.
method
,
"api_path"
:
route
.
path
,
"summary"
:
route
.
summary
,
"description"
:
route
.
description
,
"tags"
:
route
.
tags
,
"route_name"
:
route
.
name
,
"status_code"
:
response
.
status_code
,
"content_length"
:
content_length
,
"create_datetime"
:
datetime
.
datetime
.
now
()
.
strftime
(
"
%
Y-
%
m-
%
d
%
H:
%
M:
%
S"
),
"params"
:
json
.
dumps
(
params
)
}
await
OperationRecordDal
(
mongo_getter
(
request
))
.
create_data
(
document
)
return
response
#
@app.middleware("http")
#
async def operation_record_middleware(request: Request, call_next):
#
start_time = time.time()
#
response = await call_next(request)
#
if not MONGO_DB_ENABLE:
#
return response
#
telephone = request.scope.get('telephone', None)
#
user_id = request.scope.get('user_id', None)
#
user_name = request.scope.get('user_name', None)
#
route = request.scope.get('route')
#
if not telephone:
#
return response
#
elif request.method not in OPERATION_RECORD_METHOD:
#
return response
#
elif route.name in IGNORE_OPERATION_FUNCTION:
#
return response
#
process_time = time.time() - start_time
#
user_agent = parse(request.headers.get("user-agent"))
#
system = f"{user_agent.os.family} {user_agent.os.version_string}"
#
browser = f"{user_agent.browser.family} {user_agent.browser.version_string}"
#
query_params = dict(request.query_params.multi_items())
#
path_params = request.path_params
#
if isinstance(request.scope.get('body'), str):
#
body = request.scope.get('body')
#
else:
#
body = request.scope.get('body').decode()
#
if body:
#
body = json.loads(body)
#
params = {
#
"body": body,
#
"query_params": query_params if query_params else None,
#
"path_params": path_params if path_params else None,
#
}
#
content_length = response.raw_headers[0][1]
#
assert isinstance(route, APIRoute)
#
document = {
#
"process_time": process_time,
#
"telephone": telephone,
#
"user_id": user_id,
#
"user_name": user_name,
#
"request_api": request.url.__str__(),
#
"client_ip": request.client.host,
#
"system": system,
#
"browser": browser,
#
"request_method": request.method,
#
"api_path": route.path,
#
"summary": route.summary,
#
"description": route.description,
#
"tags": route.tags,
#
"route_name": route.name,
#
"status_code": response.status_code,
#
"content_length": content_length,
#
"create_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
#
"params": json.dumps(params)
#
}
#
await OperationRecordDal(mongo_getter(request)).create_data(document)
#
return response
def
register_demo_env_middleware
(
app
:
FastAPI
):
"""
演示环境中间件
:param app:
:return:
"""
@
app
.
middleware
(
"http"
)
async
def
demo_env_middleware
(
request
:
Request
,
call_next
):
path
=
request
.
scope
.
get
(
"path"
)
if
request
.
method
!=
"GET"
:
print
(
"路由:"
,
path
,
request
.
method
)
if
DEMO
and
request
.
method
!=
"GET"
:
if
path
in
DEMO_BLACK_LIST_PATH
:
return
ErrorResponse
(
status
=
status
.
HTTP_403_FORBIDDEN
,
code
=
status
.
HTTP_403_FORBIDDEN
,
msg
=
"演示环境,禁止操作"
)
elif
path
not
in
DEMO_WHITE_LIST_PATH
:
return
ErrorResponse
(
msg
=
"演示环境,禁止操作"
)
return
await
call_next
(
request
)
#
def register_demo_env_middleware(app: FastAPI):
#
"""
#
演示环境中间件
#
:param app:
#
:return:
#
"""
#
#
@app.middleware("http")
#
async def demo_env_middleware(request: Request, call_next):
#
path = request.scope.get("path")
#
if request.method != "GET":
#
print("路由:", path, request.method)
#
if DEMO and request.method != "GET":
#
if path in DEMO_BLACK_LIST_PATH:
#
return ErrorResponse(
#
status=status.HTTP_403_FORBIDDEN,
#
code=status.HTTP_403_FORBIDDEN,
#
msg="演示环境,禁止操作"
#
)
#
elif path not in DEMO_WHITE_LIST_PATH:
#
return ErrorResponse(msg="演示环境,禁止操作")
#
return await call_next(request)
def
register_jwt_refresh_middleware
(
app
:
FastAPI
):
...
...
dbgpt/app/apps/system/views.py
View file @
cc2841b3
...
...
@@ -111,16 +111,16 @@ async def get_dict_detail(data_id: int, auth: Auth = Depends(AllUserAuth())):
# return SuccessResponse(result)
@
router
.
post
(
"/upload/video/to/oss"
,
summary
=
"上传视频到阿里云OSS"
)
async
def
upload_video_to_oss
(
file
:
UploadFile
,
path
:
str
=
Form
(
...
)):
result
=
await
AliyunOSS
(
BucketConf
(
**
ALIYUN_OSS
))
.
upload_video
(
path
,
file
)
return
SuccessResponse
(
result
)
@
router
.
post
(
"/upload/file/to/oss"
,
summary
=
"上传文件到阿里云OSS"
)
async
def
upload_file_to_oss
(
file
:
UploadFile
,
path
:
str
=
Form
(
...
)):
result
=
await
AliyunOSS
(
BucketConf
(
**
ALIYUN_OSS
))
.
upload_file
(
path
,
file
)
return
SuccessResponse
(
result
)
#
@router.post("/upload/video/to/oss", summary="上传视频到阿里云OSS")
#
async def upload_video_to_oss(file: UploadFile, path: str = Form(...)):
#
result = await AliyunOSS(BucketConf(**ALIYUN_OSS)).upload_video(path, file)
#
return SuccessResponse(result)
#
#
#
@router.post("/upload/file/to/oss", summary="上传文件到阿里云OSS")
#
async def upload_file_to_oss(file: UploadFile, path: str = Form(...)):
#
result = await AliyunOSS(BucketConf(**ALIYUN_OSS)).upload_file(path, file)
#
return SuccessResponse(result)
@
router
.
post
(
"/upload/image/to/local"
,
summary
=
"上传图片到本地"
)
...
...
dbgpt/app/apps/vadmin/auth/crud.py
View file @
cc2841b3
...
...
@@ -372,7 +372,7 @@ class UserDal(DalBase):
user
[
"send_sms_msg"
]
=
"重置密码失败"
continue
password
:
str
=
user
.
pop
(
"password"
)
email
:
str
=
user
.
get
(
"email"
,
None
)
#
email: str = user.get("email", None)
# if email:
# subject = "密码已重置"
# body = f"您好,您的密码已经重置为{password},请及时登录并修改密码。"
...
...
dbgpt/app/apps/vadmin/auth/models/user.py
View file @
cc2841b3
...
...
@@ -43,7 +43,6 @@ class VadminUser(BaseModel):
roles
:
Mapped
[
set
[
VadminRole
]]
=
relationship
(
secondary
=
vadmin_auth_user_roles
)
depts
:
Mapped
[
set
[
VadminDept
]]
=
relationship
(
secondary
=
vadmin_auth_user_depts
)
@
staticmethod
def
get_password_hash
(
password
:
str
)
->
str
:
"""
...
...
dbgpt/app/apps/vadmin/auth/params/user.py
View file @
cc2841b3
...
...
@@ -25,6 +25,7 @@ class UserParams(QueryParams):
email
:
str
|
None
=
Query
(
None
,
title
=
"邮箱"
),
is_active
:
bool
|
None
=
Query
(
None
,
title
=
"是否可用"
),
is_staff
:
bool
|
None
=
Query
(
None
,
title
=
"是否为工作人员"
),
dept_id
:
int
|
None
=
Query
(
None
,
title
=
"部门信息"
),
params
:
Paging
=
Depends
()
):
super
()
.
__init__
(
params
)
...
...
@@ -33,5 +34,6 @@ class UserParams(QueryParams):
self
.
email
=
(
"like"
,
email
)
self
.
is_active
=
is_active
self
.
is_staff
=
is_staff
self
.
dept_id
=
dept_id
dbgpt/app/apps/vadmin/auth/utils/current.py
View file @
cc2841b3
...
...
@@ -18,7 +18,7 @@ from dbgpt.app.apps.config import settings
from
dbgpt.app.apps.core.database
import
db_getter
from
.validation.auth
import
Auth
from
fastapi
import
Cookie
,
HTTPException
class
OpenAuth
(
AuthValidation
):
"""
...
...
@@ -85,7 +85,7 @@ class FullAdminAuth(AuthValidation):
async
def
__call__
(
self
,
request
:
Request
,
token
:
str
=
Depends
(
settings
.
oauth2_schem
e
),
token
:
str
=
Cookie
(
Non
e
),
db
:
AsyncSession
=
Depends
(
db_getter
)
)
->
Auth
:
"""
...
...
dbgpt/app/apps/vadmin/auth/utils/login.py
View file @
cc2841b3
...
...
@@ -35,6 +35,8 @@ from dbgpt.app.apps.vadmin.auth.crud import MenuDal, UserDal
from
dbgpt.app.apps.vadmin.auth.models
import
VadminUser
from
.current
import
FullAdminAuth
from
.validation.auth
import
Auth
from
fastapi.responses
import
JSONResponse
from
fastapi
import
Response
import
jwt
router
=
APIRouter
()
...
...
@@ -89,6 +91,7 @@ async def login_for_access_token(
payload
=
{
"sub"
:
result
.
user
.
telephone
,
"is_refresh"
:
True
,
"password"
:
result
.
user
.
password
},
expires
=
expires
)
resp
=
{
"access_token"
:
access_token
,
"refresh_token"
:
refresh_token
,
...
...
@@ -96,8 +99,10 @@ async def login_for_access_token(
"is_reset_password"
:
result
.
user
.
is_reset_password
,
"is_wx_server_openid"
:
result
.
user
.
is_wx_server_openid
}
response
=
JSONResponse
(
resp
)
response
.
set_cookie
(
key
=
"token"
,
value
=
access_token
,
domain
=
""
)
# await VadminLoginRecord.create_login_record(db, data, True, request, resp)
return
SuccessResponse
(
resp
)
return
response
except
ValueError
as
e
:
# await VadminLoginRecord.create_login_record(db, data, False, request, {"message": str(e)})
return
ErrorResponse
(
msg
=
str
(
e
))
...
...
@@ -134,4 +139,6 @@ async def token_refresh(refresh: str = Body(..., title="刷新Token")):
"refresh_token"
:
refresh_token
,
"token_type"
:
"bearer"
}
return
SuccessResponse
(
resp
)
response
=
JSONResponse
(
resp
)
response
.
set_cookie
(
key
=
"jwt"
,
value
=
access_token
,
httponly
=
True
)
return
response
dbgpt/app/apps/vadmin/auth/views.py
View file @
cc2841b3
...
...
@@ -43,6 +43,9 @@ async def get_users(
**
params
.
dict
(),
v_options
=
options
,
v_schema
=
schema
,
v_outer_join
=
[
[
models
.
vadmin_auth_user_depts
,
params
.
dept_id
==
models
.
vadmin_auth_user_depts
.
c
.
dept_id
],
],
v_return_count
=
True
)
return
SuccessResponse
(
datas
,
count
=
count
)
...
...
dbgpt/rag/embedding/embeddings.py
View file @
cc2841b3
...
...
@@ -115,8 +115,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
sentence_transformers
.
SentenceTransformer
.
stop_multi_process_pool
(
pool
)
else
:
embeddings
=
self
.
client
.
encode
(
texts
,
**
self
.
encode_kwargs
)
if
len
(
embeddings
):
return
embeddings
.
tolist
()
else
:
return
[]
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
"""Compute query embeddings using a HuggingFace transformer model.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment