Commit 7d2c55f2 authored by 林洋洋's avatar 林洋洋

权限代码提交

parent 47e17549
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#*******************************************************************# #*******************************************************************#
#** Webserver Port **# #** Webserver Port **#
#*******************************************************************# #*******************************************************************#
# DBGPT_WEBSERVER_PORT=5670 DBGPT_WEBSERVER_PORT=5670
#*******************************************************************# #*******************************************************************#
#*** LLM PROVIDER ***# #*** LLM PROVIDER ***#
...@@ -137,7 +137,7 @@ LOCAL_DB_PORT=33333 ...@@ -137,7 +137,7 @@ LOCAL_DB_PORT=33333
LOCAL_DB_NAME=dbgpt LOCAL_DB_NAME=dbgpt
### This option determines the storage location of conversation records. The default is not configured to the old version of duckdb. It can be optionally db or file (if the value is db, the database configured by LOCAL_DB will be used) ### This option determines the storage location of conversation records. The default is not configured to the old version of duckdb. It can be optionally db or file (if the value is db, the database configured by LOCAL_DB will be used)
#CHAT_HISTORY_STORE_TYPE=db #CHAT_HISTORY_STORE_TYPE=db
REDIS_DB_URL = "redis://:ningzaichun@39.105.143.94:16379/1"
#*******************************************************************# #*******************************************************************#
#** COMMANDS **# #** COMMANDS **#
#*******************************************************************# #*******************************************************************#
......
...@@ -185,3 +185,4 @@ thirdparty ...@@ -185,3 +185,4 @@ thirdparty
/examples/**/*.gv.pdf /examples/**/*.gv.pdf
/i18n/locales/**/**/*_ai_translated.po /i18n/locales/**/**/*_ai_translated.po
/i18n/locales/**/**/*~ /i18n/locales/**/**/*~
/env_name
\ No newline at end of file
...@@ -190,7 +190,7 @@ class Config(metaclass=Singleton): ...@@ -190,7 +190,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10)) self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
self.LOCAL_DB_POOL_OVERFLOW = int(os.getenv("LOCAL_DB_POOL_OVERFLOW", 20)) self.LOCAL_DB_POOL_OVERFLOW = int(os.getenv("LOCAL_DB_POOL_OVERFLOW", 20))
self.SQLALCHEMY_DATABASE_URL = f"mysql+asyncmy://{self.LOCAL_DB_USER}:{self.LOCAL_DB_PASSWORD}@{self.LOCAL_DB_HOST}:{self.LOCAL_DB_PORT}/{self.LOCAL_DB_NAME}"
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db") self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db")
# LLM Model Service Configuration # LLM Model Service Configuration
...@@ -304,6 +304,7 @@ class Config(metaclass=Singleton): ...@@ -304,6 +304,7 @@ class Config(metaclass=Singleton):
# global dbgpt api key # global dbgpt api key
self.API_KEYS = os.getenv("API_KEYS", None) self.API_KEYS = os.getenv("API_KEYS", None)
# Non-streaming scene retries # Non-streaming scene retries
self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int( self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int(
os.getenv("DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE", 1) os.getenv("DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE", 1)
...@@ -314,6 +315,8 @@ class Config(metaclass=Singleton): ...@@ -314,6 +315,8 @@ class Config(metaclass=Singleton):
) )
# experimental financial report model configuration # experimental financial report model configuration
self.FIN_REPORT_MODEL = os.getenv("FIN_REPORT_MODEL", None) self.FIN_REPORT_MODEL = os.getenv("FIN_REPORT_MODEL", None)
# Redis
self.REDIS_DB_URL = os.getenv("REDIS_DB_URL", "redis://:fI0#aE2*aH@39.105.143.94:6379/1")
@property @property
def local_db_manager(self) -> "ConnectorManager": def local_db_manager(self) -> "ConnectorManager":
......
import os
from fastapi.security import OAuth2PasswordBearer
SECRET_KEY = 'vgb0tnl9d58+6n-6h-ea&u^1#s0ccp!794=kbvqacjq75vzps$'
"""用于设定 JWT 令牌签名算法"""
ALGORITHM = "HS256"
DEBUG = True
DEMO = False
"""access_token 过期时间,一天"""
ACCESS_TOKEN_EXPIRE_MINUTES = 1440
"""refresh_token 过期时间,用于刷新token使用,两天"""
REFRESH_TOKEN_EXPIRE_MINUTES = 1440 * 2
"""access_token 缓存时间,用于刷新token使用,30分钟"""
ACCESS_TOKEN_CACHE_MINUTES = 30
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
"""
挂载静态目录,并添加路由访问,此路由不会在接口文档中显示
STATIC_ENABLE:是否启用静态目录访问
STATIC_URL:路由访问
STATIC_ROOT:静态文件目录绝对路径
官方文档:https://fastapi.tiangolo.com/tutorial/static-files/
"""
STATIC_ENABLE = True
STATIC_URL = "/media"
STATIC_DIR = "static"
STATIC_ROOT = os.path.join(BASE_DIR, STATIC_DIR)
"""
挂载临时文件目录,并添加路由访问,此路由不会在接口文档中显示
TEMP_DIR:临时文件目录绝对路径
官方文档:https://fastapi.tiangolo.com/tutorial/static-files/
"""
TEMP_DIR = os.path.join(BASE_DIR, "temp")
# 默认密码,"0" 默认为手机号后六位
DEFAULT_PASSWORD = "0"
# 默认头像
DEFAULT_AVATAR = "https://vv-reserve.oss-cn-hangzhou.aliyuncs.com/avatar/2023-01-27/1674820804e81e7631.png"
# 默认登陆时最大输入密码或验证码错误次数
DEFAULT_AUTH_ERROR_MAX_NUMBER = 5
# 发布/订阅通道,与定时任务程序相互关联,请勿随意更改
SUBSCRIBE = 'kinit_queue'
OAUTH_ENABLE = True
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/api/login", auto_error=False) if OAUTH_ENABLE else lambda: ""
This diff is collapsed.
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/7/16 12:42
# @File : data_types.py
# @IDE : PyCharm
# @desc : 自定义数据类型
"""
自定义数据类型 - 官方文档:https://docs.pydantic.dev/dev-v2/usage/types/custom/#adding-validation-and-serialization
"""
import datetime
from typing import Annotated, Any
from bson import ObjectId
from pydantic import AfterValidator, PlainSerializer, WithJsonSchema
from .validator import *
def datetime_str_vali(value: str | datetime.datetime | int | float | dict):
"""
日期时间字符串验证
如果我传入的是字符串,那么直接返回,如果我传入的是一个日期类型,那么会转为字符串格式后返回
因为在 pydantic 2.0 中是支持 int 或 float 自动转换类型的,所以我这里添加进去,但是在处理时会使这两种类型报错
官方文档:https://docs.pydantic.dev/dev-v2/usage/types/datetime/
"""
if isinstance(value, str):
pattern = "%Y-%m-%d %H:%M:%S"
try:
datetime.datetime.strptime(value, pattern)
return value
except ValueError:
pass
elif isinstance(value, datetime.datetime):
return value.strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(value, dict):
# 用于处理 mongodb 日期时间数据类型
date_str = value.get("$date")
date_format = '%Y-%m-%dT%H:%M:%S.%fZ'
# 将字符串转换为datetime.datetime类型
datetime_obj = datetime.datetime.strptime(date_str, date_format)
# 将datetime.datetime对象转换为指定的字符串格式
return datetime_obj.strftime('%Y-%m-%d %H:%M:%S')
raise ValueError("无效的日期时间或字符串数据")
# 实现自定义一个日期时间字符串的数据类型
DatetimeStr = Annotated[
str | datetime.datetime | int | float | dict,
AfterValidator(datetime_str_vali),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
# 实现自定义一个手机号类型
Telephone = Annotated[
str,
AfterValidator(lambda x: vali_telephone(x)),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
# 实现自定义一个邮箱类型
Email = Annotated[
str,
AfterValidator(lambda x: vali_email(x)),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
def date_str_vali(value: str | datetime.date | int | float):
"""
日期字符串验证
如果我传入的是字符串,那么直接返回,如果我传入的是一个日期类型,那么会转为字符串格式后返回
因为在 pydantic 2.0 中是支持 int 或 float 自动转换类型的,所以我这里添加进去,但是在处理时会使这两种类型报错
官方文档:https://docs.pydantic.dev/dev-v2/usage/types/datetime/
"""
if isinstance(value, str):
pattern = "%Y-%m-%d"
try:
datetime.datetime.strptime(value, pattern)
return value
except ValueError:
pass
elif isinstance(value, datetime.date):
return value.strftime("%Y-%m-%d")
raise ValueError("无效的日期时间或字符串数据")
# 实现自定义一个日期字符串的数据类型
DateStr = Annotated[
str | datetime.date | int | float,
AfterValidator(date_str_vali),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
def object_id_str_vali(value: str | dict | ObjectId):
"""
官方文档:https://docs.pydantic.dev/dev-v2/usage/types/datetime/
"""
if isinstance(value, str):
return value
elif isinstance(value, dict):
return value.get("$oid")
elif isinstance(value, ObjectId):
return str(value)
raise ValueError("无效的 ObjectId 数据类型")
ObjectIdStr = Annotated[
Any, # 这里不能直接使用 any,需要使用 typing.Any
AfterValidator(object_id_str_vali),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
# -*- coding: utf-8 -*-
# @version : 1.0
# @Update Time : 2023/8/18 9:00
# @File : database.py
# @IDE : PyCharm
# @desc : SQLAlchemy 部分
"""
导入 SQLAlchemy 部分
安装: pip install sqlalchemy[asyncio]
官方文档:https://docs.sqlalchemy.org/en/20/intro.html#installation
"""
from typing import AsyncGenerator
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker, AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, declared_attr
from fastapi import Request
from dbgpt._private.config import Config
from dbgpt.app.apps.core.exception import CustomException
from motor.motor_asyncio import AsyncIOMotorDatabase
# 官方文档:https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.create_async_engine
# database_url dialect+driver://username:password@host:port/database
# echo:如果为True,引擎将记录所有语句以及它们的参数列表的repr()到默认的日志处理程序,该处理程序默认为sys.stdout。如果设置为字符串"debug",
# 结果行也将打印到标准输出。Engine的echo属性可以随时修改以打开和关闭日志记录;也可以使用标准的Python logging模块来直接控制日志记录。
# echo_pool=False:如果为True,连接池将记录信息性输出,如何时使连接失效以及何时将连接回收到默认的日志处理程序,该处理程序默认为sys.stdout。
# 如果设置为字符串"debug",记录将包括池的检出和检入。也可以使用标准的Python logging模块来直接控制日志记录。
# pool_pre_ping:布尔值,如果为True,将启用连接池的"pre-ping"功能,该功能在每次检出时测试连接的活动性。
# pool_recycle=-1:此设置导致池在给定的秒数后重新使用连接。默认为-1,即没有超时。例如,将其设置为3600意味着在一小时后重新使用连接。
# 请注意,特别是MySQL会在检测到连接8小时内没有活动时自动断开连接(尽管可以通过MySQLDB连接自身和服务器配置进行配置)。
# pool_size=5:在连接池内保持打开的连接数。与QueuePool以及SingletonThreadPool一起使用。
# 对于QueuePool,pool_size设置为0表示没有限制;要禁用连接池,请将poolclass设置为NullPool。
# pool_timeout=30:在从池中获取连接之前等待的秒数。仅在QueuePool中使用。这可以是一个浮点数,但受Python时间函数的限制,可能在几十毫秒内不可靠
# max_overflow 参数用于配置连接池中允许的连接 "溢出" 数量。这个参数用于在高负载情况下处理连接请求的峰值。
# 当连接池的所有连接都在使用中时,如果有新的连接请求到达,连接池可以创建额外的连接来满足这些请求,最多创建的数量由 max_overflow 参数决定。
# 创建数据库连接
async_engine = create_async_engine(
Config().SQLALCHEMY_DATABASE_URL,
echo=False,
echo_pool=False,
pool_pre_ping=True,
pool_recycle=3600,
pool_size=5,
max_overflow=5,
connect_args={}
)
# 创建数据库会话
session_factory = async_sessionmaker(
autocommit=False,
autoflush=False,
bind=async_engine,
expire_on_commit=True,
class_=AsyncSession
)
class Base(AsyncAttrs, DeclarativeBase):
"""
创建基本映射类
稍后,我们将继承该类,创建每个 ORM 模型
"""
@declared_attr.directive
def __tablename__(cls) -> str:
"""
将表名改为小写
如果有自定义表名就取自定义,没有就取小写类名
"""
table_name = cls.__tablename__
if not table_name:
model_name = cls.__name__
ls = []
for index, char in enumerate(model_name):
if char.isupper() and index != 0:
ls.append("_")
ls.append(char)
table_name = "".join(ls).lower()
return table_name
async def db_getter() -> AsyncGenerator[AsyncSession, None]:
"""
获取主数据库会话
数据库依赖项,它将在单个请求中使用,然后在请求完成后将其关闭。
函数的返回类型被注解为 AsyncGenerator[int, None],其中 AsyncSession 是生成的值的类型,而 None 表示异步生成器没有终止条件。
"""
async with session_factory() as session:
# 创建一个新的事务,半自动 commit
async with session.begin():
yield session
def redis_getter(request: Request) -> Redis:
"""
获取 redis 数据库对象
全局挂载,使用一个数据库对象
"""
return request.app.state.redis
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/8/8 14:18
# @File : dependencies.py
# @IDE : PyCharm
# @desc : 常用依赖项
"""
类依赖项-官方文档:https://fastapi.tiangolo.com/zh/tutorial/dependencies/classes-as-dependencies/
"""
from fastapi import Body
import copy
class QueryParams:
def __init__(self, params=None):
if params:
self.page = params.page
self.limit = params.limit
self.v_order = params.v_order
self.v_order_field = params.v_order_field
def dict(self, exclude: list[str] = None) -> dict:
result = copy.deepcopy(self.__dict__)
if exclude:
for item in exclude:
try:
del result[item]
except KeyError:
pass
return result
def to_count(self, exclude: list[str] = None) -> dict:
params = self.dict(exclude=exclude)
del params["page"]
del params["limit"]
del params["v_order"]
del params["v_order_field"]
return params
class Paging(QueryParams):
"""
列表分页
"""
def __init__(self, page: int = 1, limit: int = 10, v_order_field: str = None, v_order: str = None):
super().__init__()
self.page = page
self.limit = limit
self.v_order = v_order
self.v_order_field = v_order_field
class IdList:
"""
id 列表
"""
def __init__(self, ids: list[int] = Body(..., title="ID 列表")):
self.ids = ids
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/11/16 16:44
# @File : views.py
# @IDE : PyCharm
# @desc : 项目文档
# 自定义接口文档静态文件:https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/
from fastapi import FastAPI
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
def custom_api_docs(app: FastAPI):
"""
自定义配置接口本地静态文档
"""
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title + " - Swagger UI",
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
swagger_js_url="/media/swagger_ui/swagger-ui-bundle.js",
swagger_css_url="/media/swagger_ui/swagger-ui.css",
)
@app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False)
async def swagger_ui_redirect():
return get_swagger_ui_oauth2_redirect_html()
@app.get("/redoc", include_in_schema=False)
async def custom_redoc_html():
return get_redoc_html(
openapi_url=app.openapi_url,
title=app.title + " - ReDoc",
redoc_js_url="/media/redoc_ui/redoc.standalone.js",
)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/02/12 22:18
# @File : enum.py
# @IDE : PyCharm
# @desc : 增加枚举类方法
from enum import Enum
class SuperEnum(Enum):
@classmethod
def to_dict(cls):
"""Returns a dictionary representation of the enum."""
return {e.name: e.value for e in cls}
@classmethod
def keys(cls):
"""Returns a list of all the enum keys."""
return cls._member_names_
@classmethod
def values(cls):
"""Returns a list of all the enum values."""
return list(cls._value2member_map_.keys())
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/3/21 11:03
# @File : event.py
# @IDE : PyCharm
# @desc : 全局事件
from dbgpt._private.config import Config
from fastapi import FastAPI
# from dbgpt.app.apps.utils.cache import Cache
from redis import asyncio as aioredis
from redis.exceptions import AuthenticationError, TimeoutError, RedisError
from contextlib import asynccontextmanager
from dbgpt.app.apps.utils.tools import import_modules_async
from sqlalchemy.exc import ProgrammingError
from dbgpt.util.logger import logger
EVENTS = ["dbgpt.app.apps.core.event.connect_redis"]
@asynccontextmanager
async def lifespan(app: FastAPI):
await import_modules_async(EVENTS, "全局事件", app=app, status=True)
yield
await import_modules_async(EVENTS, "全局事件", app=app, status=False)
async def connect_redis(app: FastAPI, status: bool):
"""
把 redis 挂载到 app 对象上面
博客:https://blog.csdn.net/wgPython/article/details/107668521
博客:https://www.cnblogs.com/emunshe/p/15761597.html
官网:https://aioredis.readthedocs.io/en/latest/getting-started/
Github: https://github.com/aio-libs/aioredis-py
aioredis.from_url(url, *, encoding=None, parser=None, decode_responses=False, db=None, password=None, ssl=None,
connection_cls=None, loop=None, **kwargs) 方法是 aioredis 库中用于从 Redis 连接 URL 创建 Redis 连接对象的方法。
以下是该方法的参数说明:
url:Redis 连接 URL。例如 redis://localhost:6379/0。
encoding:可选参数,Redis 编码格式。默认为 utf-8。
parser:可选参数,Redis 数据解析器。默认为 None,表示使用默认解析器。
decode_responses:可选参数,是否将 Redis 响应解码为 Python 字符串。默认为 False。
db:可选参数,Redis 数据库编号。默认为 None。
password:可选参数,Redis 认证密码。默认为 None,表示无需认证。
ssl:可选参数,是否使用 SSL/TLS 加密连接。默认为 None。
connection_cls:可选参数,Redis 连接类。默认为 None,表示使用默认连接类。
loop:可选参数,用于创建连接对象的事件循环。默认为 None,表示使用默认事件循环。
**kwargs:可选参数,其他连接参数,用于传递给 Redis 连接类的构造函数。
aioredis.from_url() 方法的主要作用是将 Redis 连接 URL 转换为 Redis 连接对象。
除了 URL 参数外,其他参数用于指定 Redis 连接的各种选项,例如 Redis 数据库编号、密码、SSL/TLS 加密等等。可以根据需要选择使用这些选项。
health_check_interval 是 aioredis.from_url() 方法中的一个可选参数,用于设置 Redis 连接的健康检查间隔时间。
健康检查是指在 Redis 连接池中使用的连接对象会定期向 Redis 服务器发送 PING 命令来检查连接是否仍然有效。
该参数的默认值是 0,表示不进行健康检查。如果需要启用健康检查,则可以将该参数设置为一个正整数,表示检查间隔的秒数。
例如,如果需要每隔 5 秒对 Redis 连接进行一次健康检查,则可以将 health_check_interval 设置为 5
:param app:
:param status:
:return:
"""
if status:
CFG = Config()
print(CFG.REDIS_DB_URL)
rd = aioredis.from_url(CFG.REDIS_DB_URL, decode_responses=True, health_check_interval=1)
app.state.redis = rd
try:
response = await rd.ping()
if response:
print("Redis 连接成功")
else:
print("Redis 连接失败")
except AuthenticationError as e:
raise AuthenticationError(f"Redis 连接认证失败,用户名或密码错误: {e}")
except TimeoutError as e:
raise TimeoutError(f"Redis 连接超时,地址或者端口错误: {e}")
except RedisError as e:
raise RedisError(f"Redis 连接失败: {e}")
# try:
# await Cache(app.state.redis).cache_tab_names()
# except ProgrammingError as e:
# logger.error(f"sqlalchemy.exc.ProgrammingError: {e}")
# print(f"sqlalchemy.exc.ProgrammingError: {e}")
else:
print("Redis 连接关闭")
await app.state.redis.close()
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/19 15:47
# @File : exception.py
# @IDE : PyCharm
# @desc : 全局异常处理
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from fastapi.exceptions import RequestValidationError
from starlette import status
from fastapi import Request
from fastapi.encoders import jsonable_encoder
from fastapi import FastAPI
from dbgpt.util.logger import logger
from dbgpt.app.apps.config.settings import SECRET_KEY,DEBUG
class CustomException(Exception):
def __init__(
self,
msg: str,
code: int = status.HTTP_400_BAD_REQUEST,
status_code: int = status.HTTP_200_OK,
desc: str = None
):
self.msg = msg
self.code = code
self.status_code = status_code
self.desc = desc
def register_exception(app: FastAPI):
"""
异常捕捉
"""
@app.exception_handler(CustomException)
async def custom_exception_handler(request: Request, exc: CustomException):
"""
自定义异常
"""
if DEBUG:
print("请求地址", request.url.__str__())
print("捕捉到重写CustomException异常异常:custom_exception_handler")
print(exc.desc)
print(exc.msg)
# 打印栈信息,方便追踪排查异常
logger.exception(exc)
return JSONResponse(
status_code=exc.status_code,
content={"message": exc.msg, "code": exc.code},
)
@app.exception_handler(StarletteHTTPException)
async def unicorn_exception_handler(request: Request, exc: StarletteHTTPException):
"""
重写HTTPException异常处理器
"""
if DEBUG:
print("请求地址", request.url.__str__())
print("捕捉到重写HTTPException异常异常:unicorn_exception_handler")
print(exc.detail)
# 打印栈信息,方便追踪排查异常
logger.exception(exc)
return JSONResponse(
status_code=exc.status_code,
content={
"code": exc.status_code,
"message": exc.detail,
}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""
重写请求验证异常处理器
"""
if DEBUG:
print("请求地址", request.url.__str__())
print("捕捉到重写请求验证异常异常:validation_exception_handler")
print(exc.errors())
# 打印栈信息,方便追踪排查异常
logger.exception(exc)
msg = exc.errors()[0].get("msg")
if msg == "field required":
msg = "请求失败,缺少必填项!"
elif msg == "value is not a valid list":
print(exc.errors())
msg = f"类型错误,提交参数应该为列表!"
elif msg == "value is not a valid int":
msg = f"类型错误,提交参数应该为整数!"
elif msg == "value could not be parsed to a boolean":
msg = f"类型错误,提交参数应该为布尔值!"
elif msg == "Input should be a valid list":
msg = f"类型错误,输入应该是一个有效的列表!"
return JSONResponse(
status_code=200,
content=jsonable_encoder(
{
"message": msg,
"body": exc.body,
"code": status.HTTP_400_BAD_REQUEST
}
),
)
@app.exception_handler(ValueError)
async def value_exception_handler(request: Request, exc: ValueError):
"""
捕获值异常
"""
if DEBUG:
print("请求地址", request.url.__str__())
print("捕捉到值异常:value_exception_handler")
print(exc.__str__())
# 打印栈信息,方便追踪排查异常
logger.exception(exc)
return JSONResponse(
status_code=200,
content=jsonable_encoder(
{
"message": exc.__str__(),
"code": status.HTTP_400_BAD_REQUEST
}
),
)
@app.exception_handler(Exception)
async def all_exception_handler(request: Request, exc: Exception):
"""
捕获全部异常
"""
if DEBUG:
print("请求地址", request.url.__str__())
print("捕捉到全局异常:all_exception_handler")
print(exc.__str__())
# 打印栈信息,方便追踪排查异常
logger.exception(exc)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=jsonable_encoder(
{
"message": "接口异常!",
"code": status.HTTP_500_INTERNAL_SERVER_ERROR
}
),
)
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/19 15:47
# @File : middleware.py
# @IDE : PyCharm
# @desc : 中间件
"""
官方文档——中间件:https://fastapi.tiangolo.com/tutorial/middleware/
官方文档——高级中间件:https://fastapi.tiangolo.com/advanced/middleware/
"""
import datetime
import json
import time
from fastapi import Request, Response
from core.logger import logger
from fastapi import FastAPI
from fastapi.routing import APIRoute
from user_agents import parse
from application.settings import OPERATION_RECORD_METHOD, MONGO_DB_ENABLE, IGNORE_OPERATION_FUNCTION, \
DEMO_WHITE_LIST_PATH, DEMO, DEMO_BLACK_LIST_PATH
from utils.response import ErrorResponse
from apps.vadmin.record.crud import OperationRecordDal
from core.database import mongo_getter
from utils import status
def write_request_log(request: Request, response: Response):
http_version = f"http/{request.scope['http_version']}"
content_length = response.raw_headers[0][1]
process_time = response.headers["X-Process-Time"]
content = f"basehttp.log_message: '{request.method} {request.url} {http_version}' {response.status_code}" \
f"{response.charset} {content_length} {process_time}"
logger.info(content)
def register_request_log_middleware(app: FastAPI):
"""
记录请求日志中间件
:param app:
:return:
"""
@app.middleware("http")
async def request_log_middleware(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
write_request_log(request, response)
return response
def register_operation_record_middleware(app: FastAPI):
"""
操作记录中间件
用于将使用认证的操作全部记录到 mongodb 数据库中
:param app:
:return:
"""
@app.middleware("http")
async def operation_record_middleware(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
if not MONGO_DB_ENABLE:
return response
telephone = request.scope.get('telephone', None)
user_id = request.scope.get('user_id', None)
user_name = request.scope.get('user_name', None)
route = request.scope.get('route')
if not telephone:
return response
elif request.method not in OPERATION_RECORD_METHOD:
return response
elif route.name in IGNORE_OPERATION_FUNCTION:
return response
process_time = time.time() - start_time
user_agent = parse(request.headers.get("user-agent"))
system = f"{user_agent.os.family} {user_agent.os.version_string}"
browser = f"{user_agent.browser.family} {user_agent.browser.version_string}"
query_params = dict(request.query_params.multi_items())
path_params = request.path_params
if isinstance(request.scope.get('body'), str):
body = request.scope.get('body')
else:
body = request.scope.get('body').decode()
if body:
body = json.loads(body)
params = {
"body": body,
"query_params": query_params if query_params else None,
"path_params": path_params if path_params else None,
}
content_length = response.raw_headers[0][1]
assert isinstance(route, APIRoute)
document = {
"process_time": process_time,
"telephone": telephone,
"user_id": user_id,
"user_name": user_name,
"request_api": request.url.__str__(),
"client_ip": request.client.host,
"system": system,
"browser": browser,
"request_method": request.method,
"api_path": route.path,
"summary": route.summary,
"description": route.description,
"tags": route.tags,
"route_name": route.name,
"status_code": response.status_code,
"content_length": content_length,
"create_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"params": json.dumps(params)
}
await OperationRecordDal(mongo_getter(request)).create_data(document)
return response
def register_demo_env_middleware(app: FastAPI):
"""
演示环境中间件
:param app:
:return:
"""
@app.middleware("http")
async def demo_env_middleware(request: Request, call_next):
path = request.scope.get("path")
if request.method != "GET":
print("路由:", path, request.method)
if DEMO and request.method != "GET":
if path in DEMO_BLACK_LIST_PATH:
return ErrorResponse(
status=status.HTTP_403_FORBIDDEN,
code=status.HTTP_403_FORBIDDEN,
msg="演示环境,禁止操作"
)
elif path not in DEMO_WHITE_LIST_PATH:
return ErrorResponse(msg="演示环境,禁止操作")
return await call_next(request)
def register_jwt_refresh_middleware(app: FastAPI):
"""
JWT刷新中间件
:param app:
:return:
"""
@app.middleware("http")
async def jwt_refresh_middleware(request: Request, call_next):
response = await call_next(request)
refresh = request.scope.get('if-refresh', 0)
response.headers["if-refresh"] = str(refresh)
return response
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : validator.py
# @IDE : PyCharm
# @desc : pydantic 模型重用验证器
"""
官方文档:https://pydantic-docs.helpmanual.io/usage/validators/#reuse-validators
"""
import re
def vali_telephone(value: str) -> str:
"""
手机号验证器
:param value: 手机号
:return: 手机号
"""
if not value or len(value) != 11 or not value.isdigit():
raise ValueError("请输入正确手机号")
regex = r'^1(3\d|4[4-9]|5[0-35-9]|6[67]|7[013-8]|8[0-9]|9[0-9])\d{8}$'
if not re.match(regex, value):
raise ValueError("请输入正确手机号")
return value
def vali_email(value: str) -> str:
"""
邮箱地址验证器
:param value: 邮箱
:return: 邮箱
"""
if not value:
raise ValueError("请输入邮箱地址")
regex = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(regex, value):
raise ValueError("请输入正确邮箱地址")
return value
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : db_base.py
# @IDE : PyCharm
# @desc : 数据库公共 ORM 模型
from datetime import datetime
from sqlalchemy.orm import Mapped, mapped_column
from dbgpt.app.apps.core.database import Base
from sqlalchemy import DateTime, Integer, func, Boolean, inspect
# 使用命令:alembic init alembic 初始化迁移数据库环境
# 这时会生成alembic文件夹 和 alembic.ini文件
class BaseModel(Base):
"""
公共 ORM 模型,基表
"""
__abstract__ = True
id: Mapped[int] = mapped_column(Integer, primary_key=True, comment='主键ID')
create_datetime: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), comment='创建时间')
update_datetime: Mapped[datetime] = mapped_column(
DateTime,
server_default=func.now(),
onupdate=func.now(),
comment='更新时间'
)
delete_datetime: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, comment='删除时间')
is_delete: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否软删除")
@classmethod
def get_column_attrs(cls) -> list:
"""
获取模型中除 relationships 外的所有字段名称
:return:
"""
mapper = inspect(cls)
# for attr_name, column_property in mapper.column_attrs.items():
# # 假设它是单列属性
# column = column_property.columns[0]
# # 访问各种属性
# print(f"属性: {attr_name}")
# print(f"类型: {column.type}")
# print(f"默认值: {column.default}")
# print(f"服务器默认值: {column.server_default}")
return mapper.column_attrs.keys()
@classmethod
def get_attrs(cls) -> list:
"""
获取模型所有字段名称
:return:
"""
mapper = inspect(cls)
return mapper.attrs.keys()
@classmethod
def get_relationships_attrs(cls) -> list:
"""
获取模型中 relationships 所有字段名称
:return:
"""
mapper = inspect(cls)
return mapper.relationships.keys()
This diff is collapsed.
from .dict import VadminDictType, VadminDictDetails
from .settings import VadminSystemSettings, VadminSystemSettingsTab
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/7/7 13:41
# @File : user.py
# @IDE : PyCharm
# @desc : 系统字典模型
from sqlalchemy.orm import relationship, Mapped, mapped_column
from dbgpt.app.apps.db.db_base import BaseModel
from sqlalchemy import Column, String, Boolean, ForeignKey, Integer
class VadminDictType(BaseModel):
__tablename__ = "vadmin_system_dict_type"
__table_args__ = ({'comment': '字典类型表'})
dict_name: Mapped[str] = mapped_column(String(50), index=True, nullable=False, comment="字典名称")
dict_type: Mapped[str] = mapped_column(String(50), index=True, nullable=False, comment="字典类型")
disabled: Mapped[bool] = mapped_column(Boolean, default=False, comment="字典状态,是否禁用")
remark: Mapped[str | None] = mapped_column(String(255), comment="备注")
details: Mapped[list["VadminDictDetails"]] = relationship(back_populates="dict_type")
class VadminDictDetails(BaseModel):
__tablename__ = "vadmin_system_dict_details"
__table_args__ = ({'comment': '字典详情表'})
label: Mapped[str] = mapped_column(String(50), index=True, nullable=False, comment="字典标签")
value: Mapped[str] = mapped_column(String(50), index=True, nullable=False, comment="字典键值")
disabled: Mapped[bool] = mapped_column(Boolean, default=False, comment="字典状态,是否禁用")
is_default: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否默认")
order: Mapped[int] = mapped_column(Integer, comment="字典排序")
dict_type_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("vadmin_system_dict_type.id", ondelete='CASCADE'),
comment="关联字典类型"
)
dict_type: Mapped[VadminDictType] = relationship(foreign_keys=dict_type_id, back_populates="details")
remark: Mapped[str | None] = mapped_column(String(255), comment="备注")
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/7/7 13:41
# @File : settings.py
# @IDE : PyCharm
# @desc : 系统字典模型
from sqlalchemy.orm import relationship, Mapped, mapped_column
from dbgpt.app.apps.db.db_base import BaseModel
from sqlalchemy import String, Integer, ForeignKey, Boolean, Text
class VadminSystemSettingsTab(BaseModel):
__tablename__ = "vadmin_system_settings_tab"
__table_args__ = ({'comment': '系统配置分类表'})
title: Mapped[str] = mapped_column(String(255), comment="标题")
classify: Mapped[str] = mapped_column(String(255), index=True, nullable=False, comment="分类键")
tab_label: Mapped[str] = mapped_column(String(255), comment="tab标题")
tab_name: Mapped[str] = mapped_column(String(255), index=True, nullable=False, unique=True, comment="tab标识符")
hidden: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否隐藏")
disabled: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否禁用")
settings: Mapped[list["VadminSystemSettings"]] = relationship(back_populates="tab")
class VadminSystemSettings(BaseModel):
__tablename__ = "vadmin_system_settings"
__table_args__ = ({'comment': '系统配置表'})
config_label: Mapped[str] = mapped_column(String(255), comment="配置表标签")
config_key: Mapped[str] = mapped_column(String(255), index=True, nullable=False, unique=True, comment="配置表键")
config_value: Mapped[str | None] = mapped_column(Text, comment="配置表内容")
remark: Mapped[str | None] = mapped_column(String(255), comment="备注信息")
disabled: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否禁用")
tab_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("vadmin_system_settings_tab.id", ondelete='CASCADE'),
comment="关联tab标签"
)
tab: Mapped[VadminSystemSettingsTab] = relationship(foreign_keys=tab_id, back_populates="settings")
from .dict_type import DictTypeParams
from .dict_detail import DictDetailParams
from .task import TaskParams
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : dict_type.py
# @IDE : PyCharm
# @desc : 查询参数-类依赖项
"""
类依赖项-官方文档:https://fastapi.tiangolo.com/zh/tutorial/dependencies/classes-as-dependencies/
"""
from fastapi import Depends
from dbgpt.app.apps.core.dependencies import Paging, QueryParams
class DictDetailParams(QueryParams):
"""
列表分页
"""
def __init__(self, dict_type_id: int = None, label: str = None, params: Paging = Depends()):
super().__init__(params)
self.dict_type_id = dict_type_id
self.label = ("like", label)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : dict_type.py
# @IDE : PyCharm
# @desc : 查询参数-类依赖项
"""
类依赖项-官方文档:https://fastapi.tiangolo.com/zh/tutorial/dependencies/classes-as-dependencies/
"""
from fastapi import Depends
from dbgpt.app.apps.core.dependencies import Paging, QueryParams
class DictTypeParams(QueryParams):
"""
列表分页
"""
def __init__(self, dict_name: str = None, dict_type: str = None, params: Paging = Depends()):
super().__init__(params)
self.dict_name = ("like", dict_name)
self.dict_type = ("like", dict_type)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/6/25 14:50
# @File : task.py
# @IDE : PyCharm
# @desc : 简要说明
from fastapi import Depends
from dbgpt.app.apps.core.dependencies import Paging, QueryParams
class TaskParams(QueryParams):
"""
列表分页
"""
def __init__(self, name: str = None, _id: str = None, group: str = None, params: Paging = Depends()):
super().__init__(params)
self.name = ("like", name)
self.group = group
self._id = ("ObjectId", _id)
self.v_order = "desc"
class TaskRecordParams(QueryParams):
"""
列表分页
"""
def __init__(self, job_id: str = None, name: str = None, params: Paging = Depends()):
super().__init__(params)
self.job_id = ("like", job_id)
self.name = ("like", name)
self.v_order = "desc"
from .dict import DictType, DictDetails, DictTypeSimpleOut, DictDetailsSimpleOut, DictTypeOptionsOut
from .settings_tab import SettingsTab, SettingsTabSimpleOut
from .settings import Settings, SettingsSimpleOut
from .task import Task, TaskSimpleOut
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : dict.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, ConfigDict, Field
from dbgpt.app.apps.core.data_types import DatetimeStr
class DictType(BaseModel):
dict_name: str
dict_type: str
disabled: bool | None = False
remark: str | None = None
class DictTypeSimpleOut(DictType):
model_config = ConfigDict(from_attributes=True)
id: int
create_datetime: DatetimeStr
update_datetime: DatetimeStr
class DictTypeOptionsOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
label: str = Field(alias='dict_name')
value: int = Field(alias='id')
disabled: bool
class DictDetails(BaseModel):
label: str
value: str
disabled: bool | None = False
is_default: bool | None = False
remark: str | None = None
order: int | None = None
dict_type_id: int
class DictDetailsSimpleOut(DictDetails):
model_config = ConfigDict(from_attributes=True)
id: int
create_datetime: DatetimeStr
update_datetime: DatetimeStr
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : settings.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, ConfigDict
from dbgpt.app.apps.core.data_types import DatetimeStr
class Settings(BaseModel):
config_label: str | None = None
config_key: str
config_value: str | None = None
remark: str | None = None
disabled: bool | None = None
tab_id: int
class SettingsSimpleOut(Settings):
model_config = ConfigDict(from_attributes=True)
id: int
create_datetime: DatetimeStr
update_datetime: DatetimeStr
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : settings_tab.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, ConfigDict
from dbgpt.app.apps.core.data_types import DatetimeStr
class SettingsTab(BaseModel):
title: str
classify: str
tab_label: str
tab_name: str
hidden: bool
class SettingsTabSimpleOut(SettingsTab):
model_config = ConfigDict(from_attributes=True)
id: int
create_datetime: DatetimeStr
update_datetime: DatetimeStr
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/6/25 15:08
# @File : task.py
# @IDE : PyCharm
# @desc : 简要说明
from pydantic import BaseModel, Field, ConfigDict
from dbgpt.app.apps.core.data_types import DatetimeStr, ObjectIdStr
class Task(BaseModel):
name: str
group: str | None = None
job_class: str
exec_strategy: str
expression: str
is_active: bool | None = True # 临时字段,不在表中创建
remark: str | None = None
start_date: DatetimeStr | None = None
end_date: DatetimeStr | None = None
class TaskSimpleOut(Task):
model_config = ConfigDict(from_attributes=True)
id: ObjectIdStr = Field(..., alias='_id')
create_datetime: DatetimeStr
update_datetime: DatetimeStr
last_run_datetime: DatetimeStr | None = None # 临时字段,不在表中创建
This diff is collapsed.
import base64
from Crypto.Cipher import AES # 安装:pip install pycryptodome
# 密钥(key), 密斯偏移量(iv) CBC模式加密
# base64 详解:https://cloud.tencent.com/developer/article/1099008
_key = '0CoJUm6Qywm6ts68' # 自己密钥
def aes_encrypt(data: str):
"""
加密
"""
vi = '0102030405060708'
pad = lambda s: s + (16 - len(s) % 16) * chr(16 - len(s) % 16)
data = pad(data)
# 字符串补位
cipher = AES.new(_key.encode('utf8'), AES.MODE_CBC, vi.encode('utf8'))
encrypted_bytes = cipher.encrypt(data.encode('utf8'))
# 加密后得到的是bytes类型的数据
encode_strs = base64.urlsafe_b64encode(encrypted_bytes)
# 使用Base64进行编码,返回byte字符串
# 对byte字符串按utf-8进行解码
return encode_strs.decode('utf8')
def aes_decrypt(data):
"""
解密
"""
vi = '0102030405060708'
data = data.encode('utf8')
encode_bytes = base64.urlsafe_b64decode(data)
# 将加密数据转换位bytes类型数据
cipher = AES.new(_key.encode('utf8'), AES.MODE_CBC, vi.encode('utf8'))
text_decrypted = cipher.decrypt(encode_bytes)
unpad = lambda s: s[0:-s[-1]]
text_decrypted = unpad(text_decrypted)
# 补位
text_decrypted = text_decrypted.decode('utf8')
return text_decrypted
if __name__ == '__main__':
_data = '16658273438153332588-95YEUPJR' # 需要加密的内容
enctext = aes_encrypt(_data)
print(enctext)
# enctext = "Wzll1oiVs9UKAySY1-xSy_CbrZmelVwyqu8P0CZTrrc="
# _text_decrypted = aes_decrypt(_key, enctext)
# print(_text_decrypted)
# #!/usr/bin/python
# # -*- coding: utf-8 -*-
# # @version : 1.0
# # @Create Time : 2022/3/21 11:03
# # @File : cache.py
# # @IDE : PyCharm
# # @desc : 缓存
#
# from typing import List
#
# from sqlalchemy import false
# from sqlalchemy.future import select
# from sqlalchemy.orm import joinedload
# from dbgpt.util.logger import logger # 注意:报错就在这里,如果只写 core.logger 会写入日志报错,很难排查
# from dbgpt.app.apps.core.database import db_getter
# from dbgpt.app.apps.vadmin.system.models import VadminSystemSettingsTab
# import json
# from redis.asyncio.client import Redis
# from dbgpt.app.apps.core.exception import CustomException
# from dbgpt.app.apps.utils import status
#
#
# class Cache:
#
# DEFAULT_TAB_NAMES = ["wx_server", "aliyun_sms", "aliyun_oss", "web_email"]
#
# def __init__(self, rd: Redis):
# self.rd = rd
#
# async def __get_tab_name_values(self, tab_names: List[str]):
# """
# 获取系统配置标签下的标签信息
# """
# async_session = db_getter()
# session = await async_session.__anext__()
# model = VadminSystemSettingsTab
# v_options = [joinedload(model.settings)]
# sql = select(model).where(
# model.is_delete == false(),
# model.tab_name.in_(tab_names),
# model.disabled == false()
# ).options(*[load for load in v_options])
# queryset = await session.execute(sql)
# datas = queryset.scalars().unique().all()
# return self.__generate_values(datas)
#
# @classmethod
# def __generate_values(cls, datas: List[VadminSystemSettingsTab]):
# """
# 生成字典值
# """
# return {
# tab.tab_name: {
# item.config_key: item.config_value
# for item in tab.settings
# if not item.disabled
# }
# for tab in datas
# }
#
# async def cache_tab_names(self, tab_names: List[str] = None):
# """
# 缓存系统配置
# 如果手动修改了mysql数据库中的配置
# 那么需要在redis中将对应的tab_name删除
# """
#
# if not tab_names:
# tab_names = self.DEFAULT_TAB_NAMES
# datas = await self.__get_tab_name_values(tab_names)
#
# for k, v in datas.items():
# await self.rd.client().set(k, json.dumps(v))
#
# async def get_tab_name(self, tab_name: str, retry: int = 3):
# """
# 获取系统配置
# :param tab_name: 配置表标签名称
# :param retry: 重试次数
# """
# result = await self.rd.get(tab_name)
# if not result and retry > 0:
# logger.error(f"未从Redis中获取到{tab_name}配置信息,正在重新更新配置信息,重试次数:{retry}。")
# await self.cache_tab_names([tab_name])
# return await self.get_tab_name(tab_name, retry - 1)
# elif not result and retry == 0:
# raise CustomException(f"获取{tab_name}配置信息失败,请联系管理员!", code=status.HTTP_ERROR)
# else:
# return json.loads(result)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/11/3 17:23
# @File : count.py
# @IDE : PyCharm
# @desc : 计数
from redis.asyncio.client import Redis
class Count:
"""
计数
"""
def __init__(self, rd: Redis, key):
self.rd = rd
self.key = key
async def add(self, ex: int = None) -> int:
await self.rd.set(self.key, await self.get_count() + 1, ex=ex)
return await self.get_count()
async def subtract(self, ex: int = None) -> int:
await self.rd.set(self.key, await self.get_count() - 1, ex=ex)
return await self.get_count()
async def get_count(self) -> int:
number = await self.rd.get(self.key)
if number:
return int(number)
return 0
async def reset(self) -> None:
await self.rd.set(self.key, 0)
async def delete(self) -> None:
await self.rd.delete(self.key)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/5/6 9:29
# @File : task.py
# @IDE : PyCharm
# @desc : 任务基础类
import re
import pymysql
from dbgpt.util.logger import logger
from dbgpt._private.config import Config
class DBGetter:
def __init__(self):
self.mysql_cursor = None
self.mysql_conn = None
def conn_mysql(self) -> None:
"""
连接系统中配置的 mysql 数据库
"""
CFG = Config()
try:
username = CFG.LOCAL_DB_USER
password = CFG.LOCAL_DB_PASSWORD
host = CFG.LOCAL_DB_HOST
port = CFG.LOCAL_DB_PORT
database = CFG.LOCAL_DB_NAME
self.mysql_conn = pymysql.connect(
host=host,
port=port,
user=username,
password=password,
database=database
)
self.mysql_cursor = self.mysql_conn.cursor()
except pymysql.err.OperationalError as e:
logger.error(f"数据库连接失败,{e}")
raise ValueError("数据库连接失败!")
except AttributeError as e:
logger.error(f"数据库链接解析失败,{e}")
raise ValueError("数据库链接解析失败!")
def close_mysql(self) -> None:
"""
关闭 mysql 链接
"""
try:
self.mysql_cursor.close()
self.mysql_conn.close()
except AttributeError as e:
logger.error(f"未连接数据库,无需关闭!,{e}")
raise ValueError("未连接数据库,无需关闭!")
if __name__ == '__main__':
t = DBGetter()
t.conn_mysql()
t.close_mysql()
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/11/14 9:56
# @File : __init__.py.py
# @IDE : PyCharm
# @desc : 简要说明
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/5/6 17:25
# @File : excel_manage.py
# @IDE : PyCharm
# @desc : EXCEL 文件操作
import datetime
import os
import re
from pathlib import Path
from openpyxl.utils import get_column_letter
from openpyxl import load_workbook, Workbook
from dbgpt.app.apps.config.settings import STATIC_ROOT, STATIC_URL
from openpyxl.styles import Alignment, Font, PatternFill, Border, Side
from dbgpt.app.apps.utils.file.file_base import FileBase
from .excel_schema import AlignmentModel, FontModel, PatternFillModel
class ExcelManage:
"""
excel 文件序列化
"""
# 列名,A-Z
EXCEL_COLUMNS = [chr(a) for a in range(ord('A'), ord('Z') + 1)]
def __init__(self):
self.sheet = None
self.wb = None
def open_workbook(self, file: str, read_only: bool = False, data_only: bool = False) -> None:
"""
初始化 excel 文件
:param file: 文件名称或者对象
:param read_only: 是否只读,优化读取速度
:param data_only: 是否加载文件对象
:return:
"""
# 加载excel文件,获取表单
self.wb = load_workbook(file, read_only=read_only, data_only=data_only)
def open_sheet(
self,
sheet_name: str = None,
file: str = None,
read_only: bool = False,
data_only: bool = False
) -> None:
"""
初始化 excel 文件
:param sheet_name: 表单名称,为空则默认第一个
:param file:
:param read_only:
:param data_only:
:return:
"""
# 加载excel文件,获取表单
if not self.wb:
self.open_workbook(file, read_only, data_only)
if sheet_name:
if sheet_name in self.get_sheets():
self.sheet = self.wb[sheet_name]
else:
self.sheet = self.wb.create_sheet(sheet_name)
else:
self.sheet = self.wb.active
def get_sheets(self) -> list:
"""
读取所有工作区名称
:return: 一维数组
"""
return self.wb.sheetnames
def create_excel(self, sheet_name: str = None) -> None:
"""
创建 excel 文件
:param sheet_name: 表单名称,为空则默认第一个
:return:
"""
# 加载excel文件,获取表单
self.wb = Workbook()
self.sheet = self.wb.active
if sheet_name:
self.sheet.title = sheet_name
def readlines(self, min_row: int = 1, min_col: int = 1, max_row: int = None, max_col: int = None) -> list:
"""
读取指定表单所有数据
:param min_row: 最小行
:param min_col: 最小列
:param max_row: 最大行
:param max_col: 最大列
:return: 二维数组
"""
rows = self.sheet.iter_rows(min_row=min_row, min_col=min_col, max_row=max_row, max_col=max_col)
result = []
for row in rows:
_row = []
for cell in row:
_row.append(cell.value)
if any(_row):
result.append(_row)
return result
def get_header(self, row: int = 1, col: int = None, asterisk: bool = False) -> list:
"""
读取指定表单的表头(第一行数据)
:param row: 指定行
:param col: 最大列
:param asterisk: 是否去除 * 号
:return: 一维数组
"""
rows = self.sheet.iter_rows(min_row=row, max_col=col)
result = []
for row in rows:
for cell in row:
value = cell.value.replace("*", "").strip() if asterisk else cell.value.strip()
result.append(value)
break
return result
def write_list(self, rows: list, header: list = None) -> None:
"""
写入 excel文件
:param rows: 行数据集
:param header: 表头
:return:
"""
if header:
self.sheet.append(header)
pattern_fill_style = PatternFillModel(start_color='D9D9D9', end_color='D9D9D9', fill_type='solid')
font_style = FontModel(bold=True)
self.__set_row_style(1, len(header), pattern_fill_style=pattern_fill_style, font_style=font_style)
for index, data in enumerate(rows):
format_columns = {
"date_columns": []
}
for i in range(0, len(data)):
if isinstance(data[i], datetime.datetime):
data[i] = data[i].strftime("%Y-%m-%d %H:%M:%S")
format_columns["date_columns"].append(i + 1)
elif isinstance(data[i], bool):
data[i] = 1 if data[i] else 0
self.sheet.append(data)
self.__set_row_style(index + 2 if header else index + 1, len(data))
self.__set_row_format(index + 2 if header else index + 1, format_columns)
self.__auto_width()
self.__set_row_border()
def save_excel(self, path: str = "excel_manage"):
"""
保存 excel 文件到本地 static 目录
:param path: static 指定目录类别
:return:
"""
file_path = FileBase.generate_static_file_path(path=path, suffix="xlsx")
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
self.wb.save(file_path)
return {
"local_path": file_path,
"remote_path": file_path.replace(STATIC_ROOT, STATIC_URL)
}
def __set_row_style(
self,
row: int,
max_column: int,
alignment_style: AlignmentModel = AlignmentModel(),
font_style: FontModel = FontModel(),
pattern_fill_style: PatternFillModel = PatternFillModel()
):
"""
设置行样式
:param row: 行
:param max_column: 最大列
:param alignment_style: 单元格内容的对齐设置
:param font_style: 单元格内容的字体样式设置
:param pattern_fill_style: 单元格的填充模式设置
:return:
"""
for index in range(0, max_column):
alignment = Alignment(**alignment_style.model_dump())
font = Font(**font_style.model_dump())
pattern_fill = PatternFill(**pattern_fill_style.model_dump())
self.sheet.cell(row=row, column=index + 1).alignment = alignment
self.sheet.cell(row=row, column=index + 1).font = font
self.sheet.cell(row=row, column=index + 1).fill = pattern_fill
def __set_row_format(self, row: int, columns: dict):
"""
格式化行数据类型
:param row: 行
:param columns: 列数据
:return:
"""
for index in columns.get("date_columns", []):
self.sheet.cell(row=row, column=index).number_format = "yyyy-mm-dd h:mm:ss"
def __set_row_border(self):
"""
设置行边框
:return:
"""
# 创建 Border 对象并设置边框样式
border = Border(
left=Side(border_style="thin", color="000000"),
right=Side(border_style="thin", color="000000"),
top=Side(border_style="thin", color="000000"),
bottom=Side(border_style="thin", color="000000")
)
# 设置整个表格的边框
for row in self.sheet.iter_rows():
for cell in row:
cell.border = border
def __auto_width(self):
"""
设置自适应列宽
:return:
"""
# 设置一个字典用于保存列宽数据
dims = {}
# 遍历表格数据,获取自适应列宽数据
for row in self.sheet.rows:
for cell in row:
if cell.value:
# 遍历整个表格,把该列所有的单元格文本进行长度对比,找出最长的单元格
# 在对比单元格文本时需要将中文字符识别为1.7个长度,英文字符识别为1个,这里只需要将文本长度直接加上中文字符数量即可
# re.findall('([\u4e00-\u9fa5])', cell.value)能够识别大部分中文字符
cell_len = 0.7 * len(re.findall('([\u4e00-\u9fa5])', str(cell.value))) + len(str(cell.value))
dims[cell.column] = max((dims.get(cell.column, 0), cell_len))
for col, value in dims.items():
# 设置列宽,get_column_letter用于获取数字列号对应的字母列号,最后值+2是用来调整最终效果的
self.sheet.column_dimensions[get_column_letter(col)].width = value + 10
def close(self):
"""
关闭文件
:return:
"""
self.wb.close()
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/08/24 22:19
# @File : excel_schema.py
# @IDE : PyCharm
# @desc :
from pydantic import BaseModel, Field
class AlignmentModel(BaseModel):
horizontal: str = Field('center', description="水平对齐方式。可选值:'left'、'center'、'right'、'justify'、'distributed'")
vertical: str = Field('center', description="垂直对齐方式。可选值:'top'、'center'、'bottom'、'justify'、'distributed'")
textRotation: int = Field(0, description="文本旋转角度(以度为单位)。默认为 0。")
wrapText: bool = Field(None, description="自动换行文本。设置为 True 时,如果文本内容超出单元格宽度,会自动换行显示。")
shrinkToFit: bool = Field(
None,
description="缩小字体以适应单元格。设置为 True 时,如果文本内容超出单元格宽度,会自动缩小字体大小以适应。"
)
indent: int = Field(0, description="文本缩进级别。默认为 0。")
relativeIndent: int = Field(0, description="相对缩进级别。默认为 0。")
justifyLastLine: bool = Field(
None,
description="对齐换行文本的末尾行。设置为 True 时,如果设置了文本换行,末尾的行也会被对齐。"
)
readingOrder: int = Field(0, description="阅读顺序。默认为 0。")
class Config:
title = "对齐设置模型"
description = "用于设置单元格内容的对齐样式。"
class FontModel(BaseModel):
name: str = Field(None, description="字体名称")
size: float = Field(None, description="字体大小(磅为单位)")
bold: bool = Field(None, description="是否加粗")
italic: bool = Field(None, description="是否斜体")
underline: str = Field(None, description="下划线样式。可选值:'single'、'double'、'singleAccounting'、'doubleAccounting'")
strikethrough: bool = Field(None, description="是否有删除线")
outline: bool = Field(None, description="是否轮廓字体")
shadow: bool = Field(None, description="是否阴影字体")
condense: bool = Field(None, description="是否压缩字体")
extend: bool = Field(None, description="是否扩展字体")
vertAlign: str = Field(None, description="垂直对齐方式。可选值:'superscript'、'subscript'、'baseline'")
color: dict = Field(None, description="字体颜色")
scheme: str = Field(None, description="字体方案。可选值:'major'、'minor'")
charset: int = Field(None, description="字符集编号")
family: int = Field(None, description="字体族编号")
class Config:
title = "字体设置模型"
description = "用于设置单元格内容的字体样式"
class PatternFillModel(BaseModel):
start_color: str = Field("FFFFFF", description="起始颜色(RGB 值或颜色名称)")
end_color: str = Field("FFFFFF", description="结束颜色(RGB 值或颜色名称)")
fill_type: str = Field("solid", description="填充类型('none'、'solid'、'darkDown' 等)")
class Config:
title = "填充模式设置模型"
description = "单元格的填充模式设置"
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/12/5 8:45
# @File : import_manage.py
# @IDE : PyCharm
# @desc : 数据导入管理
from typing import List
from fastapi import UploadFile
from dbgpt.app.apps.core.exception import CustomException
from dbgpt.app.apps.utils import status
from .excel_manage import ExcelManage
from dbgpt.app.apps.utils.file.file_manage import FileManage
from .write_xlsx import WriteXlsx
from ..tools import list_dict_find
from enum import Enum
class FieldType(Enum):
list = "list"
str = "str"
class ImportManage(ExcelManage):
"""
数据导入管理
只支持 XLSX 类型文件:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet
1. 判断文件类型
2. 保存文件为临时文件
3. 获取文件中的数据
4. 逐行检查数据,通过则创建数据
5. 不通过则添加到错误列表
6. 统计数量并返回
"""
file_type = ["application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"]
def __init__(self, file: UploadFile, headers: List[dict]):
super().__init__()
self.__table_data = None
self.__table_header = None
self.errors = []
self.success = []
self.success_number = 0
self.error_number = 0
self.check_file_type(file)
self.file = file
self.headers = headers
@classmethod
def check_file_type(cls, file: UploadFile) -> None:
"""
验证文件类型
:param file: 上传文件
:return:
"""
if file.content_type not in cls.file_type:
raise CustomException(msg="文件类型必须为xlsx类型", code=status.HTTP_ERROR)
async def get_table_data(
self,
file_path: str = None,
sheet_name: str = None,
header_row: int = 1,
data_row: int = 2
) -> None:
"""
获取表格数据与表头
:param file_path:
:param sheet_name:
:param header_row: 表头在第几行
:param data_row: 数据开始行
:return:
"""
if file_path:
__filename = file_path
else:
__filename = await FileManage.async_save_temp_file(self.file)
self.open_sheet(sheet_name=sheet_name, file=__filename, read_only=True)
self.__table_header = self.get_header(header_row, len(self.headers), asterisk=True)
self.__table_data = self.readlines(min_row=data_row, max_col=len(self.headers))
self.close()
def check_table_data(self) -> None:
"""
检查表格数据
:return:
"""
for row in self.__table_data:
result = self.__check_row(row)
if not result[0]:
row.append(result[1])
self.errors.append(row)
self.error_number += 1
else:
self.success_number += 1
self.success.append(result[1])
def __check_row(self, row: list) -> tuple:
"""
检查行数据
检查条件:
1. 检查是否为必填项
2. 检查是否为选项列表
3. 检查是否符合规则
:param row: 数据行
:return:
"""
data = {}
for index, cell in enumerate(row):
value = cell
field = self.headers[index]
label = self.__table_header[index]
if not cell and field.get("required", False):
return False, f"{label}不能为空!"
elif field.get("options", []) and cell:
item = list_dict_find(field.get("options", []), "label", cell)
if item:
value = item.get("value")
else:
return False, f"请选择正确的{label}"
elif field.get("rules", []) and cell:
rules = field.get("rules")
for validator in rules:
try:
validator(str(cell))
except ValueError as e:
return False, f"{label}:{e.__str__()}"
if value:
field_type = field.get("type", FieldType.str)
if field_type == FieldType.list:
data[field.get("field")] = [value]
elif field_type == FieldType.str:
data[field.get("field")] = str(value)
else:
data[field.get("field")] = value
data["old_data_list"] = row
return True, data
def generate_error_url(self) -> str:
"""
成功错误数据的文件链接
:return:
"""
if self.error_number <= 0:
return ""
em = WriteXlsx()
em.create_excel(sheet_name="用户导入失败数据", save_static=True)
em.generate_template(self.headers, max_row=self.error_number)
em.write_list(self.errors)
em.close()
return em.get_file_url()
def add_error_data(self, row: dict) -> None:
"""
增加错误数据
:param row: 错误的数据行
:return:
"""
self.errors.append(row)
self.error_number += 1
self.success_number -= 1
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/11/11 12:01
# @File : write_xlsx.py
# @IDE : PyCharm
# @desc : 简要说明
"""
XlsxWriter:https://github.com/jmcnamara/XlsxWriter
博客教程:https://blog.csdn.net/lemonbit/article/details/113855768
"""
import os.path
import xlsxwriter
from typing import List
from dbgpt.app.apps.config.settings import STATIC_ROOT, STATIC_URL
from dbgpt.app.apps.utils.file.file_base import FileBase
from pathlib import Path
class WriteXlsx:
"""
写入xlsx文件
"""
def __init__(self):
self.file_path = None
self.sheet_name = None
self.wb = None
self.sheet = None
def create_excel(self, file_path: str = None, sheet_name: str = "sheet1", save_static: bool = False) -> None:
"""
创建 excel 文件
:param file_path: 文件绝对路径或相对路径
:param sheet_name: sheet 名称
:param save_static: 保存方式 static 静态资源或者临时文件
:return:
"""
if not file_path:
if save_static:
self.file_path = FileBase.generate_static_file_path(path="write_xlsx", suffix="xlsx")
else:
self.file_path = FileBase.generate_temp_file_path(suffix="xlsx")
elif not os.path.isabs(file_path):
if save_static:
self.file_path = FileBase.generate_static_file_path(path="write_xlsx", filename=file_path)
else:
self.file_path = FileBase.generate_temp_file_path(filename=file_path)
else:
self.file_path = file_path
Path(self.file_path).parent.mkdir(parents=True, exist_ok=True)
self.sheet_name = sheet_name
self.wb = xlsxwriter.Workbook(self.file_path)
self.sheet = self.wb.add_worksheet(sheet_name)
def generate_template(self, headers: List[dict] = None, max_row: int = 101) -> None:
"""
生成模板
:param headers: 表头
:param max_row: 设置下拉列表至最大行
:return: 文件链接地址
"""
max_row = max_row + 100
for index, field in enumerate(headers):
font_format = {
'bold': False, # 字体加粗
'align': 'center', # 水平位置设置:居中
'valign': 'vcenter', # 垂直位置设置,居中
'font_size': 11, # '字体大小设置'
}
if field.get("required", False):
# 设置单元格必填样式
field["label"] = "* " + field["label"]
font_format["font_color"] = "red"
if field.get("options", False):
# 添加数据验证,下拉列表
validate = {'validate': 'list', 'source': [item.get("label") for item in field.get("options", [])]}
self.sheet.data_validation(1, index, max_row, index, validate)
cell_format = self.wb.add_format(font_format)
self.sheet.write(0, index, field.get("label"), cell_format)
# 设置列宽
self.sheet.set_column(0, len(headers) - 1, 22)
# 设置行高
self.sheet.set_row(0, 25)
def write_list(self, rows: list, start_row: int = 1) -> None:
"""
写入 excel文件
:param rows: 行数据集
:param start_row: 开始行
"""
font_format = {
'bold': False, # 字体加粗
'align': 'center', # 水平位置设置:居中
'valign': 'vcenter', # 垂直位置设置,居中
'font_size': 11, # '字体大小设置'
}
cell_format = self.wb.add_format(font_format)
for index, row in enumerate(rows):
row_number = index+start_row
self.sheet.write_row(row_number, 0, row, cell_format)
# 设置列宽
self.sheet.set_column(0, len(rows[0]) - 1, 22)
# 设置行高
self.sheet.set_default_row(25)
def close(self) -> None:
"""
关闭文件
"""
self.wb.close()
def get_file_url(self) -> str:
"""
获取访问文件的 url
:return:
"""
if not self.file_path:
raise ValueError("还未创建文件,请先创建 excel 文件!")
assert isinstance(self.file_path, str)
if self.file_path.startswith(STATIC_ROOT):
return self.file_path.replace(STATIC_ROOT, STATIC_URL)
else:
print("write_xlsx 生成文件:", self.file_path)
raise ValueError("生成文件为临时文件,无法访问!")
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/12/12 14:30
# @File : __init__.py.py
# @IDE : PyCharm
# @desc : 简要说明
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/12/12 14:31
# @File : file_base.py
# @IDE : PyCharm
# @desc : 简要说明
import datetime
import os
from pathlib import Path
from aiopathlib import AsyncPath
from fastapi import UploadFile
from dbgpt.app.apps.config.settings import TEMP_DIR, STATIC_ROOT
from dbgpt.app.apps.core.exception import CustomException
from dbgpt.app.apps.utils import status
from dbgpt.app.apps.utils.tools import generate_string
class FileBase:
IMAGE_ACCEPT = ["image/png", "image/jpeg", "image/gif", "image/x-icon"]
VIDEO_ACCEPT = ["video/mp4", "video/mpeg"]
AUDIO_ACCEPT = ["audio/wav", "audio/mp3", "audio/m4a", "audio/wma", "audio/ogg", "audio/mpeg", "audio/x-wav"]
ALL_ACCEPT = [*IMAGE_ACCEPT, *VIDEO_ACCEPT, *AUDIO_ACCEPT]
@classmethod
def get_random_filename(cls, suffix: str) -> str:
"""
生成随机文件名称,生成规则:当前时间戳 + 8位随机字符串拼接
:param suffix: 文件后缀
:return:
"""
if not suffix.startswith("."):
suffix = "." + suffix
return f"{str(int(datetime.datetime.now().timestamp())) + str(generate_string(8))}{suffix}"
@classmethod
def get_today_timestamp(cls) -> str:
"""
获取当天时间戳
:return:
"""
return str(int((datetime.datetime.now().replace(hour=0, minute=0, second=0)).timestamp()))
@classmethod
def generate_relative_path(cls, path: str, filename: str = None, suffix: str = None) -> str:
"""
生成相对路径,生成规则:自定义目录/当天日期时间戳/随机文件名称
1. filename 参数或者 suffix 参数必须填写一个
2. filename 参数和 suffix 参数都存在则优先取 suffix 参数为后缀
:param path: static 指定目录类别
:param filename: 文件名称,只用户获取后缀,不做真实文件名称,避免文件重复问题
:param suffix: 文件后缀
"""
if not filename and not suffix:
raise ValueError("filename 参数或者 suffix 参数必须填写一个")
elif not suffix and filename:
suffix = os.path.splitext(filename)[-1]
path = path.replace("\\", "/")
if path[0] == "/":
path = path[1:]
if path[-1] == "/":
path = path[:-1]
today = datetime.datetime.strftime(datetime.datetime.now(), "%Y%m%d")
return f"{path}/{today}/{cls.get_random_filename(suffix)}"
@classmethod
def generate_static_file_path(cls, path: str, filename: str = None, suffix: str = None) -> str:
"""
生成 static 静态文件路径,生成规则:自定义目录/当天日期时间戳/随机文件名称
1. filename 参数或者 suffix 参数必须填写一个
2. filename 参数和 suffix 参数都存在则优先取 suffix 参数为后缀
:param path: static 指定目录类别
:param filename: 文件名称,只用户获取后缀,不做真实文件名称,避免文件重复问题
:param suffix: 文件后缀
:return:
"""
return f"{STATIC_ROOT}/{cls.generate_relative_path(path, filename, suffix)}"
@classmethod
def generate_temp_file_path(cls, filename: str = None, suffix: str = None) -> str:
"""
生成临时文件路径,生成规则:
1. filename 参数或者 suffix 参数必须填写一个
2. filename 参数和 suffix 参数都存在则优先取 suffix 参数为后缀
:param filename: 文件名称
:param suffix: 文件后缀
:return:
"""
if not filename and not suffix:
raise ValueError("filename 参数或者 suffix 参数必须填写一个")
elif not suffix and filename:
suffix = os.path.splitext(filename)[-1]
return f"{cls.generate_temp_dir_path()}/{cls.get_random_filename(suffix)}"
@classmethod
def generate_temp_dir_path(cls) -> str:
"""
生成临时目录路径
:return:
"""
date = datetime.datetime.strftime(datetime.datetime.now(), "%Y%m%d")
file_dir = Path(TEMP_DIR) / date
if not file_dir.exists():
file_dir.mkdir(parents=True, exist_ok=True)
return str(file_dir).replace("\\", "/")
@classmethod
async def async_generate_temp_file_path(cls, filename: str) -> str:
"""
生成临时文件路径
:param filename: 文件名称
:return:
"""
return f"{await cls.async_generate_temp_dir_path()}/{filename}"
@classmethod
async def async_generate_temp_dir_path(cls) -> str:
"""
生成临时目录路径
:return:
"""
date = datetime.datetime.strftime(datetime.datetime.now(), "%Y%m%d")
file_dir = AsyncPath(TEMP_DIR) / date
path = file_dir / (generate_string(4) + str(int(datetime.datetime.now().timestamp())))
if not await path.exists():
await path.mkdir(parents=True, exist_ok=True)
return str(path).replace("\\", "/")
@classmethod
async def validate_file(cls, file: UploadFile, max_size: int = None, mime_types: list = None) -> bool:
"""
验证文件是否符合格式
:param file: 文件
:param max_size: 文件最大值,单位 MB
:param mime_types: 支持的文件类型
"""
if max_size:
size = len(await file.read()) / 1024 / 1024
if size > max_size:
raise CustomException(f"上传文件过大,不能超过{max_size}MB", status.HTTP_ERROR)
await file.seek(0)
if mime_types:
if file.content_type not in mime_types:
raise CustomException(f"上传文件格式错误,只支持 {'/'.join(mime_types)} 格式!", status.HTTP_ERROR)
return True
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/12/5 8:45
# @File : file_manage.py
# @IDE : PyCharm
# @desc : 保存图片到本地
import asyncio
import io
import os
import zipfile
from dbgpt.app.apps.config.settings import STATIC_ROOT, BASE_DIR, STATIC_URL
from fastapi import UploadFile
import sys
from dbgpt.app.apps.core.exception import CustomException
from dbgpt.app.apps.utils.file.file_base import FileBase
from aiopathlib import AsyncPath
import aioshutil
class FileManage(FileBase):
"""
上传文件管理
"""
def __init__(self, file: UploadFile, path: str):
self.path = self.generate_static_file_path(path, file.filename)
self.file = file
async def save_image_local(self, accept: list = None) -> dict:
"""
保存图片文件到本地
:param accept:
:return:
"""
if accept is None:
accept = self.IMAGE_ACCEPT
await self.validate_file(self.file, max_size=5, mime_types=accept)
return await self.async_save_local()
async def save_audio_local(self, accept: list = None) -> dict:
"""
保存音频文件到本地
:param accept:
:return:
"""
if accept is None:
accept = self.AUDIO_ACCEPT
await self.validate_file(self.file, max_size=50, mime_types=accept)
return await self.async_save_local()
async def save_video_local(self, accept: list = None) -> dict:
"""
保存视频文件到本地
:param accept:
:return:
"""
if accept is None:
accept = self.VIDEO_ACCEPT
await self.validate_file(self.file, max_size=100, mime_types=accept)
return await self.async_save_local()
async def async_save_local(self) -> dict:
"""
保存文件到本地
:return: 示例:
{
'local_path': 'D:\\project\\kinit_dev\\kinit-api\\static\\system\\20240301\\1709303205HuYB3mrC.png',
'remote_path': '/media/system/20240301/1709303205HuYB3mrC.png'
}
"""
path = AsyncPath(self.path)
if sys.platform == "win32":
path = AsyncPath(self.path.replace("/", "\\"))
if not await path.parent.exists():
await path.parent.mkdir(parents=True, exist_ok=True)
await path.write_bytes(await self.file.read())
return {
"local_path": str(path),
"remote_path": STATIC_URL + str(path).replace(STATIC_ROOT, '').replace("\\", '/')
}
@classmethod
async def async_save_temp_file(cls, file: UploadFile) -> str:
"""
保存临时文件
:param file:
:return:
"""
temp_file_path = await cls.async_generate_temp_file_path(file.filename)
await AsyncPath(temp_file_path).write_bytes(await file.read())
return temp_file_path
@classmethod
async def unzip(cls, file: UploadFile, dir_path: str) -> str:
"""
解压 zip 压缩包
:param file:
:param dir_path: 解压路径
:return:
"""
if file.content_type != "application/x-zip-compressed":
raise CustomException("上传文件类型错误,必须是 zip 压缩包格式!")
# 读取上传的文件内容
contents = await file.read()
# 将文件内容转换为字节流
zip_stream = io.BytesIO(contents)
# 使用zipfile库解压字节流
with zipfile.ZipFile(zip_stream, "r") as zip_ref:
zip_ref.extractall(dir_path)
return dir_path
@staticmethod
async def async_copy_file(src: str, dst: str) -> None:
"""
异步复制文件
根目录为项目根目录,传过来的文件路径均为相对路径
:param src: 原始文件
:param dst: 目标路径。绝对路径
"""
if src[0] == "/":
src = src.lstrip("/")
src = AsyncPath(BASE_DIR) / src
if not await src.exists():
raise CustomException(f"{src} 源文件不存在!")
dst = AsyncPath(dst)
if not await dst.parent.exists():
await dst.parent.mkdir(parents=True, exist_ok=True)
await aioshutil.copyfile(src, dst)
@staticmethod
async def async_copy_dir(src: str, dst: str, dirs_exist_ok: bool = True) -> None:
"""
复制目录
:param src: 源目录
:param dst: 目标目录
:param dirs_exist_ok: 是否覆盖
"""
if not os.path.exists(dst):
raise CustomException("目标目录不存在!")
await aioshutil.copytree(src, dst, dirs_exist_ok=dirs_exist_ok)
# 晚上星月争辉,美梦陪你入睡
import random
from math import sin, cos, pi, log
from tkinter import *
CANVAS_WIDTH = 640 # 画布的宽
CANVAS_HEIGHT = 480 # 画布的高
CANVAS_CENTER_X = CANVAS_WIDTH / 2 # 画布中心的X轴坐标
CANVAS_CENTER_Y = CANVAS_HEIGHT / 2 # 画布中心的Y轴坐标
IMAGE_ENLARGE = 11 # 放大比例
HEART_COLOR = "#ff2121" # 心的颜色,这个是中国红
def heart_function(t, shrink_ratio: float = IMAGE_ENLARGE):
"""
“爱心函数生成器”
:param shrink_ratio: 放大比例
:param t: 参数
:return: 坐标
"""
# 基础函数
x = 16 * (sin(t) ** 3)
y = -(13 * cos(t) - 5 * cos(2 * t) - 2 * cos(3 * t) - cos(4 * t))
# 放大
x *= shrink_ratio
y *= shrink_ratio
# 移到画布中央
x += CANVAS_CENTER_X
y += CANVAS_CENTER_Y
return int(x), int(y)
def scatter_inside(x, y, beta=0.15):
"""
随机内部扩散
:param x: 原x
:param y: 原y
:param beta: 强度
:return: 新坐标
"""
ratio_x = - beta * log(random.random())
ratio_y = - beta * log(random.random())
dx = ratio_x * (x - CANVAS_CENTER_X)
dy = ratio_y * (y - CANVAS_CENTER_Y)
return x - dx, y - dy
def shrink(x, y, ratio):
"""
抖动
:param x: 原x
:param y: 原y
:param ratio: 比例
:return: 新坐标
"""
force = -1 / (((x - CANVAS_CENTER_X) ** 2 + (y - CANVAS_CENTER_Y) ** 2) ** 0.6) # 这个参数...
dx = ratio * force * (x - CANVAS_CENTER_X)
dy = ratio * force * (y - CANVAS_CENTER_Y)
return x - dx, y - dy
def curve(p):
"""
自定义曲线函数,调整跳动周期
:param p: 参数
:return: 正弦
"""
# 可以尝试换其他的动态函数,达到更有力量的效果(贝塞尔?)
return 2 * (2 * sin(4 * p)) / (2 * pi)
class Heart:
"""
爱心类
"""
def __init__(self, generate_frame=20):
self._points = set() # 原始爱心坐标集合
self._edge_diffusion_points = set() # 边缘扩散效果点坐标集合
self._center_diffusion_points = set() # 中心扩散效果点坐标集合
self.all_points = {} # 每帧动态点坐标
self.build(2000)
self.random_halo = 1000
self.generate_frame = generate_frame
for frame in range(generate_frame):
self.calc(frame)
def build(self, number):
# 爱心
for _ in range(number):
t = random.uniform(0, 2 * pi) # 随机不到的地方造成爱心有缺口
x, y = heart_function(t)
self._points.add((x, y))
# 爱心内扩散
for _x, _y in list(self._points):
for _ in range(3):
x, y = scatter_inside(_x, _y, 0.05)
self._edge_diffusion_points.add((x, y))
# 爱心内再次扩散
point_list = list(self._points)
for _ in range(4000):
x, y = random.choice(point_list)
x, y = scatter_inside(x, y, 0.17)
self._center_diffusion_points.add((x, y))
@staticmethod
def calc_position(x, y, ratio):
# 调整缩放比例
force = 1 / (((x - CANVAS_CENTER_X) ** 2 + (y - CANVAS_CENTER_Y) ** 2) ** 0.520) # 魔法参数
dx = ratio * force * (x - CANVAS_CENTER_X) + random.randint(-1, 1)
dy = ratio * force * (y - CANVAS_CENTER_Y) + random.randint(-1, 1)
return x - dx, y - dy
def calc(self, generate_frame):
ratio = 10 * curve(generate_frame / 10 * pi) # 圆滑的周期的缩放比例
halo_radius = int(4 + 6 * (1 + curve(generate_frame / 10 * pi)))
halo_number = int(3000 + 4000 * abs(curve(generate_frame / 10 * pi) ** 2))
all_points = []
# 光环
heart_halo_point = set() # 光环的点坐标集合
for _ in range(halo_number):
t = random.uniform(0, 2 * pi) # 随机不到的地方造成爱心有缺口
x, y = heart_function(t, shrink_ratio=11.6) # 魔法参数
x, y = shrink(x, y, halo_radius)
if (x, y) not in heart_halo_point:
# 处理新的点
heart_halo_point.add((x, y))
x += random.randint(-14, 14)
y += random.randint(-14, 14)
size = random.choice((1, 2, 2))
all_points.append((x, y, size))
# 轮廓
for x, y in self._points:
x, y = self.calc_position(x, y, ratio)
size = random.randint(1, 3)
all_points.append((x, y, size))
# 内容
for x, y in self._edge_diffusion_points:
x, y = self.calc_position(x, y, ratio)
size = random.randint(1, 2)
all_points.append((x, y, size))
for x, y in self._center_diffusion_points:
x, y = self.calc_position(x, y, ratio)
size = random.randint(1, 2)
all_points.append((x, y, size))
self.all_points[generate_frame] = all_points
def render(self, render_canvas, render_frame):
for x, y, size in self.all_points[render_frame % self.generate_frame]:
render_canvas.create_rectangle(x, y, x + size, y + size, width=0, fill=HEART_COLOR)
def draw(main: Tk, render_canvas: Canvas, render_heart: Heart, render_frame=0):
render_canvas.delete('all')
render_heart.render(render_canvas, render_frame)
main.after(160, draw, main, render_canvas, render_heart, render_frame + 1)
if __name__ == '__main__':
root = Tk() # 一个Tk
canvas = Canvas(root, bg='black', height=CANVAS_HEIGHT, width=CANVAS_WIDTH)
canvas.pack()
heart = Heart() # 心
draw(root, canvas, heart) # 开始画画~
root.mainloop()
# 依赖安装:pip install orjson
from fastapi.responses import ORJSONResponse as Response
from fastapi import status as http_status
from dbgpt.app.apps.utils import status as http
class SuccessResponse(Response):
"""
成功响应
"""
def __init__(self, data=None, msg="success", code=http.HTTP_SUCCESS, status=http_status.HTTP_200_OK, **kwargs):
self.data = {
"code": code,
"message": msg,
"data": data
}
self.data.update(kwargs)
super().__init__(content=self.data, status_code=status)
class ErrorResponse(Response):
"""
失败响应
"""
def __init__(self, msg=None, code=http.HTTP_ERROR, status=http_status.HTTP_200_OK, **kwargs):
self.data = {
"code": code,
"message": msg,
"data": []
}
self.data.update(kwargs)
super().__init__(content=self.data, status_code=status)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/8/10 22:20
# @File : status.py
# @IDE : PyCharm
# @desc : 简要说明
HTTP_SUCCESS = 200
HTTP_ERROR = 400
HTTP_401_UNAUTHORIZED = 401
HTTP_403_FORBIDDEN = 403
HTTP_404_NOT_FOUND = 404
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/10/9 17:09
# @File : tools.py
# @IDE : PyCharm
# @desc : 工具类
import datetime
import random
import re
import string
from typing import List, Union
import importlib
from dbgpt.util.logger import logger
def test_password(password: str) -> Union[str, bool]:
"""
检测密码强度
"""
if len(password) < 8 or len(password) > 16:
return '长度需为8-16个字符,请重新输入。'
else:
for i in password:
if 0x4e00 <= ord(i) <= 0x9fa5 or ord(i) == 0x20: # Ox4e00等十六进制数分别为中文字符和空格的Unicode编码
return '不能使用空格、中文,请重新输入。'
else:
key = 0
key += 1 if bool(re.search(r'\d', password)) else 0
key += 1 if bool(re.search(r'[A-Za-z]', password)) else 0
key += 1 if bool(re.search(r"\W", password)) else 0
if key >= 2:
return True
else:
return '至少含数字/字母/字符2种组合,请重新输入。'
def list_dict_find(options: List[dict], key: str, value: any) -> Union[dict, None]:
"""
字典列表查找
"""
return next((item for item in options if item.get(key) == value), None)
def get_time_interval(start_time: str, end_time: str, interval: int, time_format: str = "%H:%M:%S") -> List:
"""
获取时间间隔
:param end_time: 结束时间
:param start_time: 开始时间
:param interval: 间隔时间(分)
:param time_format: 字符串格式化,默认:%H:%M:%S
"""
if start_time.count(":") == 1:
start_time = f"{start_time}:00"
if end_time.count(":") == 1:
end_time = f"{end_time}:00"
start_time = datetime.datetime.strptime(start_time, "%H:%M:%S")
end_time = datetime.datetime.strptime(end_time, "%H:%M:%S")
time_range = []
while end_time > start_time:
time_range.append(start_time.strftime(time_format))
start_time = start_time + datetime.timedelta(minutes=interval)
return time_range
def generate_string(length: int = 8) -> str:
"""
生成随机字符串
:param length: 字符串长度
"""
return ''.join(random.sample(string.ascii_letters + string.digits, length))
def import_modules(modules: list, desc: str, **kwargs):
for module in modules:
if not module:
continue
try:
# 动态导入模块
module_pag = importlib.import_module(module[0:module.rindex(".")])
getattr(module_pag, module[module.rindex(".") + 1:])(**kwargs)
except ModuleNotFoundError:
logger.error(f"AttributeError:导入{desc}失败,未找到该模块:{module}")
except AttributeError:
logger.error(f"ModuleNotFoundError:导入{desc}失败,未找到该模块下的方法:{module}")
async def import_modules_async(modules: list, desc: str, **kwargs):
for module in modules:
if not module:
continue
try:
# 动态导入模块
module_pag = importlib.import_module(module[0:module.rindex(".")])
await getattr(module_pag, module[module.rindex(".") + 1:])(**kwargs)
# except TimeoutError:
# logger.error(f"asyncio.exceptions.TimeoutError:连接Mysql数据库超时")
# print(f"asyncio.exceptions.TimeoutError:连接Mysql数据库超时")
except ModuleNotFoundError:
logger.error(f"AttributeError:导入{desc}失败,未找到该模块:{module}")
except AttributeError:
logger.error(f"ModuleNotFoundError:导入{desc}失败,未找到该模块下的方法:{module}")
This diff is collapsed.
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/7/7 13:41
# @File : __init__.py
# @IDE : PyCharm
# @desc : 简要说明
from .m2m import vadmin_auth_user_roles, vadmin_auth_role_menus, vadmin_auth_user_depts, vadmin_auth_role_depts
from .menu import VadminMenu
from .role import VadminRole
from .user import VadminUser
from .dept import VadminDept
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/10/23 13:41
# @File : dept.py
# @IDE : PyCharm
# @desc : 部门模型
from sqlalchemy.orm import Mapped, mapped_column
from dbgpt.app.apps.db.db_base import BaseModel
from sqlalchemy import String, Boolean, Integer, ForeignKey
class VadminDept(BaseModel):
__tablename__ = "vadmin_auth_dept"
__table_args__ = ({'comment': '部门表'})
name: Mapped[str] = mapped_column(String(50), index=True, nullable=False, comment="部门名称")
dept_key: Mapped[str] = mapped_column(String(50), index=True, nullable=False, comment="部门标识")
disabled: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否禁用")
order: Mapped[int | None] = mapped_column(Integer, comment="显示排序")
desc: Mapped[str | None] = mapped_column(String(255), comment="描述")
owner: Mapped[str | None] = mapped_column(String(255), comment="负责人")
phone: Mapped[str | None] = mapped_column(String(255), comment="联系电话")
email: Mapped[str | None] = mapped_column(String(255), comment="邮箱")
parent_id: Mapped[int | None] = mapped_column(
Integer,
ForeignKey("vadmin_auth_dept.id", ondelete='CASCADE'),
comment="上级部门"
)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/7/7 13:41
# @File : m2m.py
# @IDE : PyCharm
# @desc : 关联中间表
from dbgpt.app.apps.db.db_base import Base
from sqlalchemy import ForeignKey, Column, Table, Integer
vadmin_auth_user_roles = Table(
"vadmin_auth_user_roles",
Base.metadata,
Column("user_id", Integer, ForeignKey("vadmin_auth_user.id", ondelete="CASCADE")),
Column("role_id", Integer, ForeignKey("vadmin_auth_role.id", ondelete="CASCADE")),
)
vadmin_auth_role_menus = Table(
"vadmin_auth_role_menus",
Base.metadata,
Column("role_id", Integer, ForeignKey("vadmin_auth_role.id", ondelete="CASCADE")),
Column("menu_id", Integer, ForeignKey("vadmin_auth_menu.id", ondelete="CASCADE")),
)
vadmin_auth_user_depts = Table(
"vadmin_auth_user_depts",
Base.metadata,
Column("user_id", Integer, ForeignKey("vadmin_auth_user.id", ondelete="CASCADE")),
Column("dept_id", Integer, ForeignKey("vadmin_auth_dept.id", ondelete="CASCADE")),
)
vadmin_auth_role_depts = Table(
"vadmin_auth_role_depts",
Base.metadata,
Column("role_id", Integer, ForeignKey("vadmin_auth_role.id", ondelete="CASCADE")),
Column("dept_id", Integer, ForeignKey("vadmin_auth_dept.id", ondelete="CASCADE")),
)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/7/7 13:41
# @File : menu.py
# @IDE : PyCharm
# @desc : 菜单模型
from dbgpt.app.apps.db.db_base import BaseModel
from sqlalchemy import String, Boolean, Integer, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
class VadminMenu(BaseModel):
__tablename__ = "vadmin_auth_menu"
__table_args__ = ({'comment': '菜单表'})
title: Mapped[str] = mapped_column(String(50), comment="名称")
icon: Mapped[str | None] = mapped_column(String(50), comment="菜单图标")
redirect: Mapped[str | None] = mapped_column(String(100), comment="重定向地址")
component: Mapped[str | None] = mapped_column(String(255), comment="前端组件地址")
path: Mapped[str | None] = mapped_column(String(50), comment="前端路由地址")
disabled: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否禁用")
hidden: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否隐藏")
order: Mapped[int] = mapped_column(Integer, comment="排序")
menu_type: Mapped[str] = mapped_column(String(8), comment="菜单类型")
parent_id: Mapped[int | None] = mapped_column(
Integer,
ForeignKey("vadmin_auth_menu.id", ondelete='CASCADE'),
comment="父菜单"
)
perms: Mapped[str | None] = mapped_column(String(50), comment="权限标识", unique=False, index=True)
"""以下属性主要用于补全前端路由属性,"""
noCache: Mapped[bool] = mapped_column(
Boolean,
comment="如果设置为true,则不会被 <keep-alive> 缓存(默认 false)",
default=False
)
breadcrumb: Mapped[bool] = mapped_column(
Boolean,
comment="如果设置为false,则不会在breadcrumb面包屑中显示(默认 true)",
default=True
)
affix: Mapped[bool] = mapped_column(
Boolean,
comment="如果设置为true,则会一直固定在tag项中(默认 false)",
default=False
)
noTagsView: Mapped[bool] = mapped_column(
Boolean,
comment="如果设置为true,则不会出现在tag中(默认 false)",
default=False
)
canTo: Mapped[bool] = mapped_column(
Boolean,
comment="设置为true即使hidden为true,也依然可以进行路由跳转(默认 false)",
default=False
)
alwaysShow: Mapped[bool] = mapped_column(
Boolean,
comment="""当你一个路由下面的 children 声明的路由大于1个时,自动会变成嵌套的模式,
只有一个时,会将那个子路由当做根路由显示在侧边栏,若你想不管路由下面的 children 声明的个数都显示你的根路由,
你可以设置 alwaysShow: true,这样它就会忽略之前定义的规则,一直显示根路由(默认 true)""",
default=True
)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/7/7 13:41
# @File : role.py
# @IDE : PyCharm
# @desc : 角色模型
from sqlalchemy.orm import relationship, Mapped, mapped_column
from dbgpt.app.apps.db.db_base import BaseModel
from sqlalchemy import String, Boolean, Integer
from .menu import VadminMenu
from .dept import VadminDept
from .m2m import vadmin_auth_role_menus, vadmin_auth_role_depts
class VadminRole(BaseModel):
__tablename__ = "vadmin_auth_role"
__table_args__ = ({'comment': '角色表'})
name: Mapped[str] = mapped_column(String(50), index=True, comment="名称")
role_key: Mapped[str] = mapped_column(String(50), index=True, comment="权限字符")
data_range: Mapped[int] = mapped_column(Integer, default=4, comment="数据权限范围")
disabled: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否禁用")
order: Mapped[int | None] = mapped_column(Integer, comment="排序")
desc: Mapped[str | None] = mapped_column(String(255), comment="描述")
is_admin: Mapped[bool] = mapped_column(Boolean, comment="是否为超级角色", default=False)
menus: Mapped[set[VadminMenu]] = relationship(secondary=vadmin_auth_role_menus)
depts: Mapped[set[VadminDept]] = relationship(secondary=vadmin_auth_role_depts)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/7/7 13:41
# @File : user.py
# @IDE : PyCharm
# @desc : 用户模型
from datetime import datetime
from sqlalchemy.orm import relationship, Mapped, mapped_column
from dbgpt.app.apps.db.db_base import BaseModel
from sqlalchemy import String, Boolean, DateTime
from passlib.context import CryptContext
from .role import VadminRole
from .dept import VadminDept
from .m2m import vadmin_auth_user_roles, vadmin_auth_user_depts
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
class VadminUser(BaseModel):
__tablename__ = "vadmin_auth_user"
__table_args__ = ({'comment': '用户表'})
avatar: Mapped[str | None] = mapped_column(String(500), comment='头像')
telephone: Mapped[str] = mapped_column(String(11), nullable=False, index=True, comment="手机号", unique=False)
email: Mapped[str | None] = mapped_column(String(50), comment="邮箱地址")
name: Mapped[str] = mapped_column(String(50), index=True, nullable=False, comment="姓名")
nickname: Mapped[str | None] = mapped_column(String(50), nullable=True, comment="昵称")
password: Mapped[str] = mapped_column(String(255), nullable=True, comment="密码")
gender: Mapped[str | None] = mapped_column(String(8), nullable=True, comment="性别")
is_active: Mapped[bool] = mapped_column(Boolean, default=True, comment="是否可用")
is_reset_password: Mapped[bool] = mapped_column(
Boolean,
default=False,
comment="是否已经重置密码,没有重置的,登陆系统后必须重置密码"
)
last_ip: Mapped[str | None] = mapped_column(String(50), comment="最后一次登录IP")
last_login: Mapped[datetime | None] = mapped_column(DateTime, comment="最近一次登录时间")
is_staff: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否为工作人员")
wx_server_openid: Mapped[str | None] = mapped_column(String(255), comment="服务端微信平台openid")
is_wx_server_openid: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否已有服务端微信平台openid")
roles: Mapped[set[VadminRole]] = relationship(secondary=vadmin_auth_user_roles)
depts: Mapped[set[VadminDept]] = relationship(secondary=vadmin_auth_user_depts)
@staticmethod
def get_password_hash(password: str) -> str:
"""
生成哈希密码
:param password: 原始密码
:return: 哈希密码
"""
return pwd_context.hash(password)
@staticmethod
def verify_password(password: str, hashed_password: str) -> bool:
"""
验证原始密码是否与哈希密码一致
:param password: 原始密码
:param hashed_password: 哈希密码
:return:
"""
return pwd_context.verify(password, hashed_password)
def is_admin(self) -> bool:
"""
获取该用户是否拥有最高权限
以最高权限为准
:return:
"""
return any([i.is_admin for i in self.roles])
from .user import UserParams
from .role import RoleParams
from .dept import DeptParams
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/12/18 10:19
# @File : dept.py
# @IDE : PyCharm
# @desc : 查询参数-类依赖项
"""
类依赖项-官方文档:https://fastapi.tiangolo.com/zh/tutorial/dependencies/classes-as-dependencies/
"""
from fastapi import Depends, Query
from dbgpt.app.apps.core.dependencies import Paging, QueryParams
class DeptParams(QueryParams):
"""
列表分页
"""
def __init__(
self,
name: str | None = Query(None, title="部门名称"),
dept_key: str | None = Query(None, title="部门标识"),
disabled: bool | None = Query(None, title="是否禁用"),
params: Paging = Depends()
):
super().__init__(params)
self.name = ("like", name)
self.dept_key = ("like", dept_key)
self.disabled = disabled
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : role.py
# @IDE : PyCharm
# @desc : 查询参数-类依赖项
"""
类依赖项-官方文档:https://fastapi.tiangolo.com/zh/tutorial/dependencies/classes-as-dependencies/
"""
from fastapi import Depends, Query
from dbgpt.app.apps.core.dependencies import Paging, QueryParams
class RoleParams(QueryParams):
"""
列表分页
"""
def __init__(
self,
name: str | None = Query(None, title="角色名称"),
role_key: str | None = Query(None, title="权限字符"),
disabled: bool | None = Query(None, title="是否禁用"),
params: Paging = Depends()
):
super().__init__(params)
self.name = ("like", name)
self.role_key = ("like", role_key)
self.disabled = disabled
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : user.py
# @IDE : PyCharm
# @desc : 查询参数-类依赖项
"""
类依赖项-官方文档:https://fastapi.tiangolo.com/zh/tutorial/dependencies/classes-as-dependencies/
"""
from fastapi import Depends, Query
from dbgpt.app.apps.core.dependencies import Paging, QueryParams
class UserParams(QueryParams):
"""
列表分页
"""
def __init__(
self,
name: str | None = Query(None, title="用户名称"),
telephone: str | None = Query(None, title="手机号"),
email: str | None = Query(None, title="邮箱"),
is_active: bool | None = Query(None, title="是否可用"),
is_staff: bool | None = Query(None, title="是否为工作人员"),
params: Paging = Depends()
):
super().__init__(params)
self.name = ("like", name)
self.telephone = ("like", telephone)
self.email = ("like", email)
self.is_active = is_active
self.is_staff = is_staff
from .user import UserOut, UserUpdate, User, UserIn, UserSimpleOut, ResetPwd, UserUpdateBaseInfo, UserPasswordOut
from .role import Role, RoleOut, RoleIn, RoleOptionsOut, RoleSimpleOut
from .menu import Menu, MenuSimpleOut, RouterOut, Meta, MenuTreeListOut
from .dept import Dept, DeptSimpleOut, DeptTreeListOut
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2023/10/25 12:19
# @File : dept.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, ConfigDict
from dbgpt.app.apps.core.data_types import DatetimeStr
class Dept(BaseModel):
name: str
dept_key: str
disabled: bool = False
order: int | None = None
desc: str | None = None
owner: str | None = None
phone: str | None = None
email: str | None = None
parent_id: int | None = None
class DeptSimpleOut(Dept):
model_config = ConfigDict(from_attributes=True)
id: int
create_datetime: DatetimeStr
update_datetime: DatetimeStr
class DeptTreeListOut(DeptSimpleOut):
model_config = ConfigDict(from_attributes=True)
children: list[dict] = []
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : role.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, ConfigDict
from dbgpt.app.apps.core.data_types import DatetimeStr
class Menu(BaseModel):
title: str
icon: str | None = None
component: str | None = None
redirect: str | None = None
path: str | None = None
disabled: bool = False
hidden: bool = False
order: int | None = None
perms: str | None = None
parent_id: int | None = None
menu_type: str
alwaysShow: bool | None = True
noCache: bool | None = False
class MenuSimpleOut(Menu):
model_config = ConfigDict(from_attributes=True)
id: int
create_datetime: DatetimeStr
update_datetime: DatetimeStr
class Meta(BaseModel):
title: str
icon: str | None = None
hidden: bool = False
noCache: bool | None = False
breadcrumb: bool | None = True
affix: bool | None = False
noTagsView: bool | None = False
canTo: bool | None = False
alwaysShow: bool | None = True
# 路由展示
class RouterOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
name: str | None = None
component: str | None = None
path: str
redirect: str | None = None
meta: Meta | None = None
order: int | None = None
children: list[dict] = []
class MenuTreeListOut(MenuSimpleOut):
model_config = ConfigDict(from_attributes=True)
children: list[dict] = []
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : role.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, ConfigDict, Field
from dbgpt.app.apps.core.data_types import DatetimeStr
from .menu import MenuSimpleOut
from .dept import DeptSimpleOut
class Role(BaseModel):
name: str
disabled: bool = False
order: int | None = None
desc: str | None = None
data_range: int = 4
role_key: str
is_admin: bool = False
class RoleSimpleOut(Role):
model_config = ConfigDict(from_attributes=True)
id: int
create_datetime: DatetimeStr
update_datetime: DatetimeStr
class RoleOut(RoleSimpleOut):
model_config = ConfigDict(from_attributes=True)
menus: list[MenuSimpleOut] = []
depts: list[DeptSimpleOut] = []
class RoleIn(Role):
menu_ids: list[int] = []
dept_ids: list[int] = []
class RoleOptionsOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
label: str = Field(alias='name')
value: int = Field(alias='id')
disabled: bool
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/18 22:19
# @File : user.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core.core_schema import FieldValidationInfo
from dbgpt.app.apps.core.data_types import Telephone, DatetimeStr, Email
from .role import RoleSimpleOut
from .dept import DeptSimpleOut
class User(BaseModel):
name: str
telephone: Telephone
email: Email | None = None
nickname: str | None = None
avatar: str | None = None
is_active: bool | None = True
is_staff: bool | None = True
gender: str | None = "0"
is_wx_server_openid: bool | None = False
class UserIn(User):
"""
创建用户
"""
role_ids: list[int] = []
dept_ids: list[int] = []
password: str | None = ""
class UserUpdateBaseInfo(BaseModel):
"""
更新用户基本信息
"""
name: str
telephone: Telephone
email: Email | None = None
nickname: str | None = None
gender: str | None = "0"
class UserUpdate(User):
"""
更新用户详细信息
"""
name: str | None = None
telephone: Telephone
email: Email | None = None
nickname: str | None = None
avatar: str | None = None
is_active: bool | None = True
is_staff: bool | None = False
gender: str | None = "0"
role_ids: list[int] = []
dept_ids: list[int] = []
class UserSimpleOut(User):
model_config = ConfigDict(from_attributes=True)
id: int
update_datetime: DatetimeStr
create_datetime: DatetimeStr
is_reset_password: bool | None = None
last_login: DatetimeStr | None = None
last_ip: str | None = None
class UserPasswordOut(UserSimpleOut):
model_config = ConfigDict(from_attributes=True)
password: str
class UserOut(UserSimpleOut):
model_config = ConfigDict(from_attributes=True)
roles: list[RoleSimpleOut] = []
depts: list[DeptSimpleOut] = []
class ResetPwd(BaseModel):
password: str
password_two: str
@field_validator('password_two')
def check_passwords_match(cls, v, info: FieldValidationInfo):
if 'password' in info.data and v != info.data['password']:
raise ValueError('两次密码不一致!')
return v
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/8/8 11:02
# @File : __init__.py
# @IDE : PyCharm
# @desc : 简要说明
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/24 16:44
# @File : current.py
# @IDE : PyCharm
# @desc : 获取认证后的信息工具
from typing import Annotated
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from dbgpt.app.apps.vadmin.auth.crud import UserDal
from dbgpt.app.apps.vadmin.auth.models import VadminUser, VadminRole
from dbgpt.app.apps.core.exception import CustomException
from dbgpt.app.apps.utils import status
from .validation import AuthValidation
from fastapi import Request, Depends
from dbgpt.app.apps.config import settings
from dbgpt.app.apps.core.database import db_getter
from .validation.auth import Auth
class OpenAuth(AuthValidation):
"""
开放认证,无认证也可以访问
认证了以后可以获取到用户信息,无认证则获取不到
"""
async def __call__(
self,
request: Request,
token: Annotated[str, Depends(settings.oauth2_scheme)],
db: AsyncSession = Depends(db_getter)
):
"""
每次调用依赖此类的接口会执行该方法
"""
if not settings.OAUTH_ENABLE:
return Auth(db=db)
try:
telephone, password = self.validate_token(request, token)
user = await UserDal(db).get_data(telephone=telephone, password=password, v_return_none=True)
return await self.validate_user(request, user, db, is_all=True)
except CustomException:
return Auth(db=db)
class AllUserAuth(AuthValidation):
"""
支持所有用户认证
获取用户基本信息
"""
async def __call__(
self,
request: Request,
token: str = Depends(settings.oauth2_scheme),
db: AsyncSession = Depends(db_getter)
):
"""
每次调用依赖此类的接口会执行该方法
"""
if not settings.OAUTH_ENABLE:
return Auth(db=db)
telephone, password = self.validate_token(request, token)
user = await UserDal(db).get_data(telephone=telephone, password=password, v_return_none=True)
return await self.validate_user(request, user, db, is_all=True)
class FullAdminAuth(AuthValidation):
"""
只支持员工用户认证
获取员工用户完整信息
如果有权限,那么会验证该用户是否包括权限列表中的其中一个权限
"""
def __init__(self, permissions: list[str] | None = None):
if permissions:
self.permissions = set(permissions)
else:
self.permissions = None
async def __call__(
self,
request: Request,
token: str = Depends(settings.oauth2_scheme),
db: AsyncSession = Depends(db_getter)
) -> Auth:
"""
每次调用依赖此类的接口会执行该方法
"""
if not settings.OAUTH_ENABLE:
return Auth(db=db)
telephone, password = self.validate_token(request, token)
options = [
joinedload(VadminUser.roles).subqueryload(VadminRole.menus),
joinedload(VadminUser.roles).subqueryload(VadminRole.depts),
joinedload(VadminUser.depts)
]
user = await UserDal(db).get_data(
telephone=telephone,
password=password,
v_return_none=True,
v_options=options,
is_staff=True
)
result = await self.validate_user(request, user, db, is_all=False)
permissions = self.get_user_permissions(user)
if permissions != {'*.*.*'} and self.permissions:
if not (self.permissions & permissions):
raise CustomException(msg="无权限操作", code=status.HTTP_403_FORBIDDEN)
return result
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2021/10/24 16:44
# @File : views.py
# @IDE : PyCharm
# @desc : 安全认证视图
"""
JWT 表示 「JSON Web Tokens」。https://jwt.io/
它是一个将 JSON 对象编码为密集且没有空格的长字符串的标准。
通过这种方式,你可以创建一个有效期为 1 周的令牌。然后当用户第二天使用令牌重新访问时,你知道该用户仍然处于登入状态。
一周后令牌将会过期,用户将不会通过认证,必须再次登录才能获得一个新令牌。
我们需要安装 python-jose 以在 Python 中生成和校验 JWT 令牌:pip install python-jose[cryptography]
PassLib 是一个用于处理哈希密码的很棒的 Python 包。它支持许多安全哈希算法以及配合算法使用的实用程序。
推荐的算法是 「Bcrypt」:pip install passlib[bcrypt]
"""
from datetime import timedelta
from redis.asyncio import Redis
from fastapi import APIRouter, Depends, Request, Body
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.core.database import db_getter, redis_getter
from dbgpt.app.apps.core.exception import CustomException
from dbgpt.app.apps.utils import status
from dbgpt.app.apps.utils.response import SuccessResponse, ErrorResponse
from dbgpt.app.apps.config import settings
from .login_manage import LoginManage
from .validation import LoginForm, WXLoginForm
from dbgpt.app.apps.vadmin.auth.crud import MenuDal, UserDal
from dbgpt.app.apps.vadmin.auth.models import VadminUser
from .current import FullAdminAuth
from .validation.auth import Auth
import jwt
router = APIRouter()
@router.post("/v2/login", summary="API 手机号密码登录", description="Swagger API 文档登录认证")
async def api_login_for_access_token(
request: Request,
data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(db_getter)
):
user = await UserDal(db).get_data(telephone=data.username, v_return_none=True)
error_code = status.HTTP_401_UNAUTHORIZED
if not user:
raise CustomException(status_code=error_code, code=error_code, msg="该手机号不存在")
result = VadminUser.verify_password(data.password, user.password)
if not result:
raise CustomException(status_code=error_code, code=error_code, msg="手机号或密码错误")
if not user.is_active:
raise CustomException(status_code=error_code, code=error_code, msg="此手机号已被冻结")
elif not user.is_staff:
raise CustomException(status_code=error_code, code=error_code, msg="此手机号无权限")
access_token = LoginManage.create_token({"sub": user.telephone, "password": user.password})
record = LoginForm(platform='2', method='0', telephone=data.username, password=data.password)
resp = {"access_token": access_token, "token_type": "bearer"}
# await VadminLoginRecord.create_login_record(db, record, True, request, resp)
return resp
@router.get("/v2/getMenuList", summary="获取当前用户菜单树")
async def get_menu_list(auth: Auth = Depends(FullAdminAuth())):
return SuccessResponse(await MenuDal(auth.db).get_routers(auth.user))
@router.post("/v2/token/refresh", summary="刷新Token")
async def token_refresh(refresh: str = Body(..., title="刷新Token")):
error_code = status.HTTP_401_UNAUTHORIZED
try:
payload = jwt.decode(refresh, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
telephone: str = payload.get("sub")
is_refresh: bool = payload.get("is_refresh")
password: str = payload.get("password")
if not telephone or not is_refresh or not password:
return ErrorResponse("未认证,请您重新登录", code=error_code, status=error_code)
except jwt.exceptions.InvalidSignatureError:
return ErrorResponse("无效认证,请您重新登录", code=error_code, status=error_code)
except jwt.exceptions.ExpiredSignatureError:
return ErrorResponse("登录已超时,请您重新登录", code=error_code, status=error_code)
access_token = LoginManage.create_token({"sub": telephone, "is_refresh": False, "password": password})
expires = timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)
refresh_token = LoginManage.create_token(
payload={"sub": telephone, "is_refresh": True, "password": password},
expires=expires
)
resp = {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer"
}
return SuccessResponse(resp)
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/8/8 11:02
# @File : auth_util.py
# @IDE : PyCharm
# @desc : 简要说明
from datetime import datetime, timedelta
from fastapi import Request
from dbgpt.app.apps.config import settings
import jwt
from dbgpt.app.apps.vadmin.auth import models
from dbgpt.app.apps.core.database import redis_getter
from .validation import LoginValidation, LoginForm, LoginResult
class LoginManage:
"""
登录认证工具
"""
@LoginValidation
async def password_login(self, data: LoginForm, user: models.VadminUser, **kwargs) -> LoginResult:
"""
验证用户密码
"""
result = models.VadminUser.verify_password(data.password, user.password)
if result:
return LoginResult(status=True, msg="验证成功")
return LoginResult(status=False, msg="手机号或密码错误")
@staticmethod
def create_token(payload: dict, expires: timedelta = None):
"""
创建一个生成新的访问令牌的工具函数。
pyjwt:https://github.com/jpadilla/pyjwt/blob/master/docs/usage.rst
jwt 博客:https://geek-docs.com/python/python-tutorial/j_python-jwt.html
#TODO 传入的时间为UTC时间datetime.datetime类型,但是在解码时获取到的是本机时间的时间戳
"""
if expires:
expire = datetime.utcnow() + expires
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
payload.update({"exp": expire})
encoded_jwt = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/11/9 10:14
# @File : __init__.py.py
# @IDE : PyCharm
# @desc : 简要说明
from .auth import Auth, AuthValidation
from .login import LoginValidation, LoginForm, LoginResult, WXLoginForm
This diff is collapsed.
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2022/11/9 10:15
# @File : login.py
# @IDE : PyCharm
# @desc : 登录验证装饰器
from fastapi import Request
from pydantic import BaseModel, field_validator
from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.config.settings import DEFAULT_AUTH_ERROR_MAX_NUMBER, DEMO
from dbgpt.app.apps.vadmin.auth import crud, schemas
from dbgpt.app.apps.core.database import redis_getter
from dbgpt.app.apps.core.validator import vali_telephone
from dbgpt.app.apps.utils.count import Count
class LoginForm(BaseModel):
telephone: str
password: str
method: str = '0' # 认证方式,0:密码登录,1:短信登录,2:微信一键登录
platform: str = '0' # 登录平台,0:PC端管理系统,1:移动端管理系统
# 重用验证器:https://docs.pydantic.dev/dev-v2/usage/validators/#reuse-validators
normalize_telephone = field_validator('telephone')(vali_telephone)
class WXLoginForm(BaseModel):
telephone: str | None = None
code: str
method: str = '2' # 认证方式,0:密码登录,1:短信登录,2:微信一键登录
platform: str = '1' # 登录平台,0:PC端管理系统,1:移动端管理系统
class LoginResult(BaseModel):
status: bool | None = False
user: schemas.UserPasswordOut | None = None
msg: str | None = None
class Config:
arbitrary_types_allowed = True
class LoginValidation:
"""
验证用户登录时提交的数据是否有效
"""
def __init__(self, func):
self.func = func
async def __call__(self, data: LoginForm, db: AsyncSession, request: Request) -> LoginResult:
self.result = LoginResult()
if data.platform not in ["0", "1"] or data.method not in ["0", "1"]:
self.result.msg = "无效参数"
return self.result
user = await crud.UserDal(db).get_data(telephone=data.telephone, v_return_none=True)
if not user:
self.result.msg = "该手机号不存在!"
return self.result
result = await self.func(self, data=data, user=user, request=request)
count_key = f"{data.telephone}_password_auth" if data.method == '0' else f"{data.telephone}_sms_auth"
count = Count(redis_getter(request), count_key)
if not result.status:
self.result.msg = result.msg
if not DEMO and count:
number = await count.add(ex=86400)
if number >= DEFAULT_AUTH_ERROR_MAX_NUMBER:
await count.reset()
# 如果等于最大次数,那么就将用户 is_active=False
user.is_active = False
await db.flush()
elif not user.is_active:
self.result.msg = "此手机号已被冻结!"
elif data.platform in ["0", "1"] and not user.is_staff:
self.result.msg = "此手机号无权限!"
else:
if not DEMO and count:
await count.delete()
self.result.msg = "OK"
self.result.status = True
self.result.user = schemas.UserPasswordOut.model_validate(user)
await crud.UserDal(db).update_login_info(user, request.client.host)
return self.result
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment