Commit 0d55c620 authored by 张会鑫's avatar 张会鑫

代码提交

parent be3e97d4
import os import os
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from dbgpt.configs import ROOT_PATH
SECRET_KEY = 'vgb0tnl9d58+6n-6h-ea&u^1#s0ccp!794=kbvqacjq75vzps$' SECRET_KEY = 'vgb0tnl9d58+6n-6h-ea&u^1#s0ccp!794=kbvqacjq75vzps$'
"""用于设定 JWT 令牌签名算法""" """用于设定 JWT 令牌签名算法"""
ALGORITHM = "HS256" ALGORITHM = "HS256"
...@@ -30,6 +32,7 @@ STATIC_ENABLE = True ...@@ -30,6 +32,7 @@ STATIC_ENABLE = True
STATIC_URL = "/media" STATIC_URL = "/media"
STATIC_DIR = "static" STATIC_DIR = "static"
STATIC_ROOT = os.path.join(BASE_DIR, STATIC_DIR) STATIC_ROOT = os.path.join(BASE_DIR, STATIC_DIR)
STATIC_PATH = os.path.join(ROOT_PATH, "file")
""" """
挂载临时文件目录,并添加路由访问,此路由不会在接口文档中显示 挂载临时文件目录,并添加路由访问,此路由不会在接口文档中显示
TEMP_DIR:临时文件目录绝对路径 TEMP_DIR:临时文件目录绝对路径
......
...@@ -11,7 +11,7 @@ import os ...@@ -11,7 +11,7 @@ import os
from pathlib import Path from pathlib import Path
from aiopathlib import AsyncPath from aiopathlib import AsyncPath
from fastapi import UploadFile from fastapi import UploadFile
from dbgpt.app.apps.config.settings import TEMP_DIR, STATIC_ROOT from dbgpt.app.apps.config.settings import TEMP_DIR, STATIC_ROOT, STATIC_PATH
from dbgpt.app.apps.core.exception import CustomException from dbgpt.app.apps.core.exception import CustomException
from dbgpt.app.apps.utils import status from dbgpt.app.apps.utils import status
from dbgpt.app.apps.utils.tools import generate_string from dbgpt.app.apps.utils.tools import generate_string
...@@ -76,7 +76,7 @@ class FileBase: ...@@ -76,7 +76,7 @@ class FileBase:
:param suffix: 文件后缀 :param suffix: 文件后缀
:return: :return:
""" """
return f"{STATIC_ROOT}/{cls.generate_relative_path(path, filename, suffix)}" return f"{STATIC_PATH}/{cls.generate_relative_path(path, filename, suffix)}"
@classmethod @classmethod
def generate_temp_file_path(cls, filename: str = None, suffix: str = None) -> str: def generate_temp_file_path(cls, filename: str = None, suffix: str = None) -> str:
......
...@@ -10,7 +10,7 @@ import io ...@@ -10,7 +10,7 @@ import io
import os import os
import zipfile import zipfile
from dbgpt.app.apps.config.settings import STATIC_ROOT, BASE_DIR, STATIC_URL from dbgpt.app.apps.config.settings import STATIC_ROOT, BASE_DIR, STATIC_URL, STATIC_PATH
from fastapi import UploadFile from fastapi import UploadFile
import sys import sys
from dbgpt.app.apps.core.exception import CustomException from dbgpt.app.apps.core.exception import CustomException
...@@ -78,7 +78,7 @@ class FileManage(FileBase): ...@@ -78,7 +78,7 @@ class FileManage(FileBase):
await path.write_bytes(await self.file.read()) await path.write_bytes(await self.file.read())
return { return {
"local_path": str(path), "local_path": str(path),
"remote_path": STATIC_URL + str(path).replace(STATIC_ROOT, '').replace("\\", '/') "remote_path": STATIC_URL + str(path).replace(STATIC_PATH, '').replace("\\", '/')
} }
@classmethod @classmethod
......
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
# @desc : 增删改查 # @desc : 增删改查
from typing import Any from typing import Any
from requests import Session
from sqlalchemy.orm.strategy_options import _AbstractLoad, contains_eager from sqlalchemy.orm.strategy_options import _AbstractLoad, contains_eager
from dbgpt.app.apps.core.exception import CustomException from dbgpt.app.apps.core.exception import CustomException
from sqlalchemy import select, false, and_ from sqlalchemy import select, false, and_, func
from dbgpt.app.apps.core.crud import DalBase from dbgpt.app.apps.core.crud import DalBase
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update from sqlalchemy import select, update
...@@ -40,6 +42,13 @@ class MediaDal(DalBase): ...@@ -40,6 +42,13 @@ class MediaDal(DalBase):
sql = update(self.model).where(self.model.id.in_(params.ids)).values(group_id=params.group_id) sql = update(self.model).where(self.model.id.in_(params.ids)).values(group_id=params.group_id)
await self.db.execute(sql) await self.db.execute(sql)
async def get_count_by_groupIds(self, ids: list) -> int:
async with Session() as session:
stmt = select(func.count(self.model.id)).where(self.model.group_id.in_(ids))
result = await session.execute(stmt)
count = result.scalar()
return count
class GroupDal(DalBase): class GroupDal(DalBase):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from fastapi import APIRouter, Depends, Body, UploadFile, Form, Request from fastapi import APIRouter, Depends, Body, UploadFile, Form, Request
from dbgpt.app.apps.utils.file.file_manage import FileManage from dbgpt.app.apps.utils.file.file_manage import FileManage
from dbgpt.app.apps.utils.response import SuccessResponse from dbgpt.app.apps.utils.response import SuccessResponse,ErrorResponse
from dbgpt.app.apps.vadmin.auth.utils.current import FullAdminAuth from dbgpt.app.apps.vadmin.auth.utils.current import FullAdminAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from . import schemas, crud from . import schemas, crud
...@@ -30,7 +30,7 @@ async def upload_video_to_local(file: UploadFile, path: str = Form(...)): ...@@ -30,7 +30,7 @@ async def upload_video_to_local(file: UploadFile, path: str = Form(...)):
async def upload_save(data: schemas.Media, auth: Auth = Depends(FullAdminAuth())): async def upload_save(data: schemas.Media, auth: Auth = Depends(FullAdminAuth())):
return SuccessResponse(await crud.MediaDal(auth.db).create_data(data=data)) return SuccessResponse(await crud.MediaDal(auth.db).create_data(data=data))
@router.post("/media/list", summary="资源列表(图片、视频)") @router.get("/media/list", summary="资源列表(图片、视频)")
async def media_image_list(params: MediaListParams = Depends(), auth: Auth = Depends(FullAdminAuth())): async def media_image_list(params: MediaListParams = Depends(), auth: Auth = Depends(FullAdminAuth())):
datas, count = await crud.MediaDal(auth.db).get_datas(**params.dict(), v_return_count=True) datas, count = await crud.MediaDal(auth.db).get_datas(**params.dict(), v_return_count=True)
return SuccessResponse(datas, count=count) return SuccessResponse(datas, count=count)
...@@ -52,7 +52,11 @@ async def image_list(params: GroupListParams = Depends(), auth: Auth = Depends(F ...@@ -52,7 +52,11 @@ async def image_list(params: GroupListParams = Depends(), auth: Auth = Depends(F
@router.post("/group/del", summary="删除分组") @router.post("/group/del", summary="删除分组")
async def group_del(ids: IdList = Depends(), auth: Auth = Depends(FullAdminAuth())): async def group_del(ids: IdList = Depends(), auth: Auth = Depends(FullAdminAuth())):
await crud.GroupDal(auth.db).delete_datas(ids.ids, v_soft=True) # 校验分组是否为空
media_counts = await crud.MediaDal(auth.db).get_count_by_groupIds(ids.ids)
if media_counts > 0:
return ErrorResponse("分组内不为空,无法删除")
# await crud.GroupDal(auth.db).delete_datas(ids.ids, v_soft=True)
return SuccessResponse("删除成功") return SuccessResponse("删除成功")
...@@ -66,7 +70,7 @@ async def question_add(data: schemas.Question, auth: Auth = Depends(FullAdminAut ...@@ -66,7 +70,7 @@ async def question_add(data: schemas.Question, auth: Auth = Depends(FullAdminAut
return SuccessResponse(await crud.QuestionDal(auth.db).create_data(data=data)) return SuccessResponse(await crud.QuestionDal(auth.db).create_data(data=data))
@router.post("/question/list", summary="问答对列表") @router.get("/question/list", summary="问答对列表")
async def question_list(params: QuestionListParams = Depends(), auth: Auth = Depends(FullAdminAuth())): async def question_list(params: QuestionListParams = Depends(), auth: Auth = Depends(FullAdminAuth())):
datas, count = await crud.QuestionDal(auth.db).get_datas(**params.dict(), v_return_count=True) datas, count = await crud.QuestionDal(auth.db).get_datas(**params.dict(), v_return_count=True)
return SuccessResponse(datas, count=count) return SuccessResponse(datas, count=count)
...@@ -89,7 +93,7 @@ async def correlation_add(data: schemas.Correlation, auth: Auth = Depends(FullAd ...@@ -89,7 +93,7 @@ async def correlation_add(data: schemas.Correlation, auth: Auth = Depends(FullAd
return SuccessResponse(await crud.CorrelationDal(auth.db).put_data(data_id=data.id, data=data)) return SuccessResponse(await crud.CorrelationDal(auth.db).put_data(data_id=data.id, data=data))
@router.post("/correlation/list", summary="问答对列表") @router.get("/correlation/list", summary="资源关联列表")
async def correlation_list(params: CorrelationListParams = Depends(), auth: Auth = Depends(FullAdminAuth())): async def correlation_list(params: CorrelationListParams = Depends(), auth: Auth = Depends(FullAdminAuth())):
datas, count = await crud.CorrelationDal(auth.db).get_datas(**params.dict(), v_return_count=True) datas, count = await crud.CorrelationDal(auth.db).get_datas(**params.dict(), v_return_count=True)
return SuccessResponse(datas, count=count) return SuccessResponse(datas, count=count)
......
...@@ -7,7 +7,7 @@ from functools import cache ...@@ -7,7 +7,7 @@ from functools import cache
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models") MODEL_PATH = os.path.join(ROOT_PATH, "models")
PILOT_PATH = os.path.join(ROOT_PATH, "pilot") PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
STATIC_RESOURCE_PATH = os.path.join(ROOT_PATH, "dbgpt/app/apps/static") STATIC_RESOURCE_PATH = os.path.join(ROOT_PATH, "file")
LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs")) LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs"))
STATIC_MESSAGE_IMG_PATH = os.path.join(PILOT_PATH, "message/img") STATIC_MESSAGE_IMG_PATH = os.path.join(PILOT_PATH, "message/img")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment