Commit cc2841b3 authored by 林洋洋's avatar 林洋洋

修改token 存到cookies

parent 8b0030b8
...@@ -13,42 +13,34 @@ import datetime ...@@ -13,42 +13,34 @@ import datetime
import json import json
import time import time
from fastapi import Request, Response from fastapi import Request, Response
from core.logger import logger
from fastapi import FastAPI
from fastapi.routing import APIRoute
from user_agents import parse
from application.settings import OPERATION_RECORD_METHOD, MONGO_DB_ENABLE, IGNORE_OPERATION_FUNCTION, \
DEMO_WHITE_LIST_PATH, DEMO, DEMO_BLACK_LIST_PATH
from utils.response import ErrorResponse
from apps.vadmin.record.crud import OperationRecordDal
from core.database import mongo_getter
from utils import status
def write_request_log(request: Request, response: Response): from fastapi import FastAPI
http_version = f"http/{request.scope['http_version']}"
content_length = response.raw_headers[0][1]
process_time = response.headers["X-Process-Time"]
content = f"basehttp.log_message: '{request.method} {request.url} {http_version}' {response.status_code}" \
f"{response.charset} {content_length} {process_time}"
logger.info(content)
def register_request_log_middleware(app: FastAPI): # def write_request_log(request: Request, response: Response):
""" # http_version = f"http/{request.scope['http_version']}"
记录请求日志中间件 # content_length = response.raw_headers[0][1]
:param app: # process_time = response.headers["X-Process-Time"]
:return: # content = f"basehttp.log_message: '{request.method} {request.url} {http_version}' {response.status_code}" \
""" # f"{response.charset} {content_length} {process_time}"
# logger.info(content)
@app.middleware("http") #
async def request_log_middleware(request: Request, call_next): # def register_request_log_middleware(app: FastAPI):
start_time = time.time() # """
response = await call_next(request) # 记录请求日志中间件
process_time = time.time() - start_time # :param app:
response.headers["X-Process-Time"] = str(process_time) # :return:
write_request_log(request, response) # """
return response #
# @app.middleware("http")
# async def request_log_middleware(request: Request, call_next):
# start_time = time.time()
# response = await call_next(request)
# process_time = time.time() - start_time
# response.headers["X-Process-Time"] = str(process_time)
# write_request_log(request, response)
# return response
def register_operation_record_middleware(app: FastAPI): def register_operation_record_middleware(app: FastAPI):
...@@ -59,87 +51,87 @@ def register_operation_record_middleware(app: FastAPI): ...@@ -59,87 +51,87 @@ def register_operation_record_middleware(app: FastAPI):
:return: :return:
""" """
@app.middleware("http") # @app.middleware("http")
async def operation_record_middleware(request: Request, call_next): # async def operation_record_middleware(request: Request, call_next):
start_time = time.time() # start_time = time.time()
response = await call_next(request) # response = await call_next(request)
if not MONGO_DB_ENABLE: # if not MONGO_DB_ENABLE:
return response # return response
telephone = request.scope.get('telephone', None) # telephone = request.scope.get('telephone', None)
user_id = request.scope.get('user_id', None) # user_id = request.scope.get('user_id', None)
user_name = request.scope.get('user_name', None) # user_name = request.scope.get('user_name', None)
route = request.scope.get('route') # route = request.scope.get('route')
if not telephone: # if not telephone:
return response # return response
elif request.method not in OPERATION_RECORD_METHOD: # elif request.method not in OPERATION_RECORD_METHOD:
return response # return response
elif route.name in IGNORE_OPERATION_FUNCTION: # elif route.name in IGNORE_OPERATION_FUNCTION:
return response # return response
process_time = time.time() - start_time # process_time = time.time() - start_time
user_agent = parse(request.headers.get("user-agent")) # user_agent = parse(request.headers.get("user-agent"))
system = f"{user_agent.os.family} {user_agent.os.version_string}" # system = f"{user_agent.os.family} {user_agent.os.version_string}"
browser = f"{user_agent.browser.family} {user_agent.browser.version_string}" # browser = f"{user_agent.browser.family} {user_agent.browser.version_string}"
query_params = dict(request.query_params.multi_items()) # query_params = dict(request.query_params.multi_items())
path_params = request.path_params # path_params = request.path_params
if isinstance(request.scope.get('body'), str): # if isinstance(request.scope.get('body'), str):
body = request.scope.get('body') # body = request.scope.get('body')
else: # else:
body = request.scope.get('body').decode() # body = request.scope.get('body').decode()
if body: # if body:
body = json.loads(body) # body = json.loads(body)
params = { # params = {
"body": body, # "body": body,
"query_params": query_params if query_params else None, # "query_params": query_params if query_params else None,
"path_params": path_params if path_params else None, # "path_params": path_params if path_params else None,
} # }
content_length = response.raw_headers[0][1] # content_length = response.raw_headers[0][1]
assert isinstance(route, APIRoute) # assert isinstance(route, APIRoute)
document = { # document = {
"process_time": process_time, # "process_time": process_time,
"telephone": telephone, # "telephone": telephone,
"user_id": user_id, # "user_id": user_id,
"user_name": user_name, # "user_name": user_name,
"request_api": request.url.__str__(), # "request_api": request.url.__str__(),
"client_ip": request.client.host, # "client_ip": request.client.host,
"system": system, # "system": system,
"browser": browser, # "browser": browser,
"request_method": request.method, # "request_method": request.method,
"api_path": route.path, # "api_path": route.path,
"summary": route.summary, # "summary": route.summary,
"description": route.description, # "description": route.description,
"tags": route.tags, # "tags": route.tags,
"route_name": route.name, # "route_name": route.name,
"status_code": response.status_code, # "status_code": response.status_code,
"content_length": content_length, # "content_length": content_length,
"create_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), # "create_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"params": json.dumps(params) # "params": json.dumps(params)
} # }
await OperationRecordDal(mongo_getter(request)).create_data(document) # await OperationRecordDal(mongo_getter(request)).create_data(document)
return response # return response
def register_demo_env_middleware(app: FastAPI): # def register_demo_env_middleware(app: FastAPI):
""" # """
演示环境中间件 # 演示环境中间件
:param app: # :param app:
:return: # :return:
""" # """
#
@app.middleware("http") # @app.middleware("http")
async def demo_env_middleware(request: Request, call_next): # async def demo_env_middleware(request: Request, call_next):
path = request.scope.get("path") # path = request.scope.get("path")
if request.method != "GET": # if request.method != "GET":
print("路由:", path, request.method) # print("路由:", path, request.method)
if DEMO and request.method != "GET": # if DEMO and request.method != "GET":
if path in DEMO_BLACK_LIST_PATH: # if path in DEMO_BLACK_LIST_PATH:
return ErrorResponse( # return ErrorResponse(
status=status.HTTP_403_FORBIDDEN, # status=status.HTTP_403_FORBIDDEN,
code=status.HTTP_403_FORBIDDEN, # code=status.HTTP_403_FORBIDDEN,
msg="演示环境,禁止操作" # msg="演示环境,禁止操作"
) # )
elif path not in DEMO_WHITE_LIST_PATH: # elif path not in DEMO_WHITE_LIST_PATH:
return ErrorResponse(msg="演示环境,禁止操作") # return ErrorResponse(msg="演示环境,禁止操作")
return await call_next(request) # return await call_next(request)
def register_jwt_refresh_middleware(app: FastAPI): def register_jwt_refresh_middleware(app: FastAPI):
......
...@@ -111,16 +111,16 @@ async def get_dict_detail(data_id: int, auth: Auth = Depends(AllUserAuth())): ...@@ -111,16 +111,16 @@ async def get_dict_detail(data_id: int, auth: Auth = Depends(AllUserAuth())):
# return SuccessResponse(result) # return SuccessResponse(result)
@router.post("/upload/video/to/oss", summary="上传视频到阿里云OSS") # @router.post("/upload/video/to/oss", summary="上传视频到阿里云OSS")
async def upload_video_to_oss(file: UploadFile, path: str = Form(...)): # async def upload_video_to_oss(file: UploadFile, path: str = Form(...)):
result = await AliyunOSS(BucketConf(**ALIYUN_OSS)).upload_video(path, file) # result = await AliyunOSS(BucketConf(**ALIYUN_OSS)).upload_video(path, file)
return SuccessResponse(result) # return SuccessResponse(result)
#
#
@router.post("/upload/file/to/oss", summary="上传文件到阿里云OSS") # @router.post("/upload/file/to/oss", summary="上传文件到阿里云OSS")
async def upload_file_to_oss(file: UploadFile, path: str = Form(...)): # async def upload_file_to_oss(file: UploadFile, path: str = Form(...)):
result = await AliyunOSS(BucketConf(**ALIYUN_OSS)).upload_file(path, file) # result = await AliyunOSS(BucketConf(**ALIYUN_OSS)).upload_file(path, file)
return SuccessResponse(result) # return SuccessResponse(result)
@router.post("/upload/image/to/local", summary="上传图片到本地") @router.post("/upload/image/to/local", summary="上传图片到本地")
......
...@@ -372,7 +372,7 @@ class UserDal(DalBase): ...@@ -372,7 +372,7 @@ class UserDal(DalBase):
user["send_sms_msg"] = "重置密码失败" user["send_sms_msg"] = "重置密码失败"
continue continue
password: str = user.pop("password") password: str = user.pop("password")
email: str = user.get("email", None) # email: str = user.get("email", None)
# if email: # if email:
# subject = "密码已重置" # subject = "密码已重置"
# body = f"您好,您的密码已经重置为{password},请及时登录并修改密码。" # body = f"您好,您的密码已经重置为{password},请及时登录并修改密码。"
......
...@@ -43,7 +43,6 @@ class VadminUser(BaseModel): ...@@ -43,7 +43,6 @@ class VadminUser(BaseModel):
roles: Mapped[set[VadminRole]] = relationship(secondary=vadmin_auth_user_roles) roles: Mapped[set[VadminRole]] = relationship(secondary=vadmin_auth_user_roles)
depts: Mapped[set[VadminDept]] = relationship(secondary=vadmin_auth_user_depts) depts: Mapped[set[VadminDept]] = relationship(secondary=vadmin_auth_user_depts)
@staticmethod @staticmethod
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
""" """
......
...@@ -25,6 +25,7 @@ class UserParams(QueryParams): ...@@ -25,6 +25,7 @@ class UserParams(QueryParams):
email: str | None = Query(None, title="邮箱"), email: str | None = Query(None, title="邮箱"),
is_active: bool | None = Query(None, title="是否可用"), is_active: bool | None = Query(None, title="是否可用"),
is_staff: bool | None = Query(None, title="是否为工作人员"), is_staff: bool | None = Query(None, title="是否为工作人员"),
dept_id: int | None = Query(None, title="部门信息"),
params: Paging = Depends() params: Paging = Depends()
): ):
super().__init__(params) super().__init__(params)
...@@ -33,5 +34,6 @@ class UserParams(QueryParams): ...@@ -33,5 +34,6 @@ class UserParams(QueryParams):
self.email = ("like", email) self.email = ("like", email)
self.is_active = is_active self.is_active = is_active
self.is_staff = is_staff self.is_staff = is_staff
self.dept_id = dept_id
...@@ -18,7 +18,7 @@ from dbgpt.app.apps.config import settings ...@@ -18,7 +18,7 @@ from dbgpt.app.apps.config import settings
from dbgpt.app.apps.core.database import db_getter from dbgpt.app.apps.core.database import db_getter
from .validation.auth import Auth from .validation.auth import Auth
from fastapi import Cookie, HTTPException
class OpenAuth(AuthValidation): class OpenAuth(AuthValidation):
""" """
...@@ -85,7 +85,7 @@ class FullAdminAuth(AuthValidation): ...@@ -85,7 +85,7 @@ class FullAdminAuth(AuthValidation):
async def __call__( async def __call__(
self, self,
request: Request, request: Request,
token: str = Depends(settings.oauth2_scheme), token: str = Cookie(None),
db: AsyncSession = Depends(db_getter) db: AsyncSession = Depends(db_getter)
) -> Auth: ) -> Auth:
""" """
......
...@@ -35,6 +35,8 @@ from dbgpt.app.apps.vadmin.auth.crud import MenuDal, UserDal ...@@ -35,6 +35,8 @@ from dbgpt.app.apps.vadmin.auth.crud import MenuDal, UserDal
from dbgpt.app.apps.vadmin.auth.models import VadminUser from dbgpt.app.apps.vadmin.auth.models import VadminUser
from .current import FullAdminAuth from .current import FullAdminAuth
from .validation.auth import Auth from .validation.auth import Auth
from fastapi.responses import JSONResponse
from fastapi import Response
import jwt import jwt
router = APIRouter() router = APIRouter()
...@@ -89,6 +91,7 @@ async def login_for_access_token( ...@@ -89,6 +91,7 @@ async def login_for_access_token(
payload={"sub": result.user.telephone, "is_refresh": True, "password": result.user.password}, payload={"sub": result.user.telephone, "is_refresh": True, "password": result.user.password},
expires=expires expires=expires
) )
resp = { resp = {
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
...@@ -96,8 +99,10 @@ async def login_for_access_token( ...@@ -96,8 +99,10 @@ async def login_for_access_token(
"is_reset_password": result.user.is_reset_password, "is_reset_password": result.user.is_reset_password,
"is_wx_server_openid": result.user.is_wx_server_openid "is_wx_server_openid": result.user.is_wx_server_openid
} }
response = JSONResponse(resp)
response.set_cookie(key="token", value=access_token ,domain="")
# await VadminLoginRecord.create_login_record(db, data, True, request, resp) # await VadminLoginRecord.create_login_record(db, data, True, request, resp)
return SuccessResponse(resp) return response
except ValueError as e: except ValueError as e:
# await VadminLoginRecord.create_login_record(db, data, False, request, {"message": str(e)}) # await VadminLoginRecord.create_login_record(db, data, False, request, {"message": str(e)})
return ErrorResponse(msg=str(e)) return ErrorResponse(msg=str(e))
...@@ -134,4 +139,6 @@ async def token_refresh(refresh: str = Body(..., title="刷新Token")): ...@@ -134,4 +139,6 @@ async def token_refresh(refresh: str = Body(..., title="刷新Token")):
"refresh_token": refresh_token, "refresh_token": refresh_token,
"token_type": "bearer" "token_type": "bearer"
} }
return SuccessResponse(resp) response = JSONResponse(resp)
response.set_cookie(key="jwt", value=access_token, httponly=True)
return response
...@@ -43,6 +43,9 @@ async def get_users( ...@@ -43,6 +43,9 @@ async def get_users(
**params.dict(), **params.dict(),
v_options=options, v_options=options,
v_schema=schema, v_schema=schema,
v_outer_join=[
[models.vadmin_auth_user_depts, params.dept_id == models.vadmin_auth_user_depts.c.dept_id],
],
v_return_count=True v_return_count=True
) )
return SuccessResponse(datas, count=count) return SuccessResponse(datas, count=count)
......
...@@ -115,8 +115,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): ...@@ -115,8 +115,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool) sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
else: else:
embeddings = self.client.encode(texts, **self.encode_kwargs) embeddings = self.client.encode(texts, **self.encode_kwargs)
if len(embeddings):
return embeddings.tolist()
else:
return []
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model. """Compute query embeddings using a HuggingFace transformer model.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment