Commit f793c556 authored by 于飞's avatar 于飞

添加关键词提出代码

parent 563c60e8
...@@ -4,7 +4,10 @@ import re ...@@ -4,7 +4,10 @@ import re
from fastapi import Request, Depends from fastapi import Request, Depends
from dbgpt.app.apps.core.database import db_getter from dbgpt.app.apps.core.database import db_getter
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.vadmin.word.crud import SensitiveDal from dbgpt.app.apps.vadmin.word.crud import SensitiveDal
from dbgpt.app.apps.vadmin.media.crud import MediaDal,QuestionDal
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
...@@ -13,9 +16,9 @@ class DFAFilter(): ...@@ -13,9 +16,9 @@ class DFAFilter():
Use DFA to keep algorithm perform constantly Use DFA to keep algorithm perform constantly
敏感词过滤 敏感词过滤
>>> f = DFAFilter() #>>> f = DFAFilter()
>>> f.add("sexy") #>>> f.add("sexy")
>>> f.filter("hello sexy baby") #>>> f.filter("hello sexy baby")
hello **** baby hello **** baby
''' '''
...@@ -52,10 +55,43 @@ class DFAFilter(): ...@@ -52,10 +55,43 @@ class DFAFilter():
for keyword in f: for keyword in f:
self.add(keyword.strip()) self.add(keyword.strip())
#从数据库中加载 敏感词 #从数据库中加载 图片资源
async def parse_from_db(self, db: AsyncSession): async def parse_picture_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------parse_picture_from_db--load-------------')
# 获取数据库中所有的图片--->添加到词库内存中
images_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 1, 'group_id': None}
images_datas, count = await MediaDal(db).get_datas(**images_dic, v_return_count=True)
print(f"-----图片库列表为:---->:{images_datas}")
for imagedata in images_datas:
self.add(imagedata.get('key_word'))
#从数据库中加载 问答对
async def parse_question_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------parse_question_from_db--load-------------')
# 获取数据库中所有的问答对--->添加到词库内存中
question_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None}
question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
print(f"-----问答对库列表为:---->:{question_datas}")
for questiondata in question_datas:
self.add(questiondata.get('key_word'))
#从数据库中加载 视频资源
async def parse_video_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter) #db: AsyncSession = Depends(db_getter)
print('---------sensitive-load-------------') print('---------parse_video_from_db--load-------------')
# 获取数据库中所有的视频--->添加到词库内存中
video_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 2, 'group_id': None}
video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
print(f"-----视频库列表为:---->:{video_datas}")
for videodata in video_datas:
self.add(videodata.get('key_word'))
# 从数据库中加载 敏感词
async def parse_from_db(self, db: AsyncSession):
# db: AsyncSession = Depends(db_getter)
print('---------parse_from_db-load-------------')
sdl = SensitiveDal(db) sdl = SensitiveDal(db)
datas = await sdl.get_sensitives() datas = await sdl.get_sensitives()
for keyword in datas: for keyword in datas:
...@@ -94,8 +130,19 @@ class DFAFilter(): ...@@ -94,8 +130,19 @@ class DFAFilter():
# return 返回三个参数 # return 返回三个参数
return ''.join(ret), is_sensitive, matched_sensitives return ''.join(ret), is_sensitive, matched_sensitives
#初始化全局对象
#初始化-敏感词
mydfafiter = DFAFilter() mydfafiter = DFAFilter()
#初始化-图片关键词
mydfafiter_picture = DFAFilter()
#初始化-视频关键词
mydfafiter_video = DFAFilter()
#初始化-问答对关键词
mydfafiter_question = DFAFilter()
""" """
if __name__ == "__main__": if __name__ == "__main__":
gfw = DFAFilter() gfw = DFAFilter()
......
This diff is collapsed.
...@@ -45,7 +45,7 @@ from dbgpt.util.utils import ( ...@@ -45,7 +45,7 @@ from dbgpt.util.utils import (
from dbgpt.app.apps.utils.tools import import_modules from dbgpt.app.apps.utils.tools import import_modules
from dbgpt.app.apps.utils.spach_keywords import my_spacy_nlp from dbgpt.app.apps.utils.spach_keywords import my_spacy_nlp
from dbgpt.app.apps.utils.filter import mydfafiter from dbgpt.app.apps.utils.filter import mydfafiter, mydfafiter_picture, mydfafiter_question, mydfafiter_video
REQUEST_LOG_RECORD = False REQUEST_LOG_RECORD = False
MIDDLEWARES = [ MIDDLEWARES = [
...@@ -104,6 +104,7 @@ def mount_routers(app: FastAPI): ...@@ -104,6 +104,7 @@ def mount_routers(app: FastAPI):
from dbgpt.app.apps.vadmin.media.views import router as media_views from dbgpt.app.apps.vadmin.media.views import router as media_views
from dbgpt.app.apps.vadmin.word.views import router as word_views from dbgpt.app.apps.vadmin.word.views import router as word_views
from dbgpt.app.apps.vadmin.keywordsviews import router as keywords_views from dbgpt.app.apps.vadmin.keywordsviews import router as keywords_views
from dbgpt.app.apps.vadmin.chathistory.views import router as chathistory_views
app.include_router(api_v1, prefix="/api", tags=["Chat"]) app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_v2, prefix="/api", tags=["ChatV2"]) app.include_router(api_v2, prefix="/api", tags=["ChatV2"])
...@@ -119,6 +120,7 @@ def mount_routers(app: FastAPI): ...@@ -119,6 +120,7 @@ def mount_routers(app: FastAPI):
app.include_router(media_views, prefix="/api/v2", tags=["System"]) app.include_router(media_views, prefix="/api/v2", tags=["System"])
app.include_router(word_views, prefix="/api/v2/vadmin/word", tags=["Word"]) app.include_router(word_views, prefix="/api/v2/vadmin/word", tags=["Word"])
app.include_router(keywords_views, prefix="/api/v2/vadmin", tags=["vadmin"]) app.include_router(keywords_views, prefix="/api/v2/vadmin", tags=["vadmin"])
app.include_router(chathistory_views, prefix="/api/v2/vadmin", tags=["chathistory"])
def mount_static_files(app: FastAPI): def mount_static_files(app: FastAPI):
......
...@@ -46,6 +46,7 @@ from dbgpt.util.tracer import SpanType, root_tracer ...@@ -46,6 +46,7 @@ from dbgpt.util.tracer import SpanType, root_tracer
from dbgpt.app.apps.utils.filter import mydfafiter from dbgpt.app.apps.utils.filter import mydfafiter
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter() router = APIRouter()
...@@ -263,6 +264,7 @@ async def params_load( ...@@ -263,6 +264,7 @@ async def params_load(
user_name: Optional[str] = None, user_name: Optional[str] = None,
sys_code: Optional[str] = None, sys_code: Optional[str] = None,
doc_file: UploadFile = File(...), doc_file: UploadFile = File(...),
auth: Auth = Depends(OpenAuth()),
): ):
logger.info(f"params_load: {conv_uid},{chat_mode},{model_name}") logger.info(f"params_load: {conv_uid},{chat_mode},{model_name}")
try: try:
...@@ -287,11 +289,19 @@ async def params_load( ...@@ -287,11 +289,19 @@ async def params_load(
resp = await chat.prepare() resp = await chat.prepare()
# Refresh messages # Refresh messages
return Result.succ(get_hist_messages(conv_uid)) #ret_his_msg = get_hist_messages(conv_uid)
ret_his_msg = await get_hist_messages_datas(conv_uid, auth.db)
return Result.succ(ret_his_msg)
except Exception as e: except Exception as e:
logger.error("excel load error!", e) logger.error("excel load error!", e)
return Result.failed(code="E000X", msg=f"File Load Error {str(e)}") return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
async def get_hist_messages_datas(conv_uid: str, db: AsyncSession):
from dbgpt.serve.conversation.serve import Service as ConversationService
instance: ConversationService = ConversationService.get_instance(CFG.SYSTEM_APP)
result = await instance.get_history_messages2(db, {"conv_uid": conv_uid})
return result
def get_hist_messages(conv_uid: str): def get_hist_messages(conv_uid: str):
from dbgpt.serve.conversation.serve import Service as ConversationService from dbgpt.serve.conversation.serve import Service as ConversationService
...@@ -360,7 +370,7 @@ async def chat_completions( ...@@ -360,7 +370,7 @@ async def chat_completions(
} }
# 从数据库中加载 并且初始化敏感词-->到内存中 # 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db) # await mydfafiter.parse_from_db(auth.db)
# 先判断敏感词 # 先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*") dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
......
...@@ -99,7 +99,7 @@ async def chat_completions( ...@@ -99,7 +99,7 @@ async def chat_completions(
} }
# 从数据库中加载 并且初始化敏感词-->到内存中 # 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db) # await mydfafiter.parse_from_db(auth.db)
# 先判断敏感词 # 先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(request.user_input, "*") dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(request.user_input, "*")
......
...@@ -12,6 +12,9 @@ from dbgpt.util import PaginationResult ...@@ -12,6 +12,9 @@ from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service from ..service.service import Service
from .schemas import MessageVo, ServeRequest, ServerResponse from .schemas import MessageVo, ServeRequest, ServerResponse
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter() router = APIRouter()
...@@ -207,9 +210,13 @@ async def list_latest_conv( ...@@ -207,9 +210,13 @@ async def list_latest_conv(
response_model=Result[List[MessageVo]], response_model=Result[List[MessageVo]],
dependencies=[Depends(check_api_key)], dependencies=[Depends(check_api_key)],
) )
async def get_history_messages(con_uid: str, service: Service = Depends(get_service)): async def get_history_messages(con_uid: str,
service: Service = Depends(get_service),
auth: Auth = Depends(OpenAuth())):
"""Get the history messages of a conversation""" """Get the history messages of a conversation"""
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid))) #result = service.get_history_messages(ServeRequest(conv_uid=con_uid))
result = await service.get_history_messages2(auth.db, ServeRequest(conv_uid=con_uid))
return Result.succ(result)
def init_endpoints(system_app: SystemApp) -> None: def init_endpoints(system_app: SystemApp) -> None:
......
# Define your Pydantic schemas here # Define your Pydantic schemas here
from typing import Any, Dict, Optional from typing import Optional, List, Dict, Any
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
...@@ -149,6 +149,14 @@ class MessageVo(BaseModel): ...@@ -149,6 +149,14 @@ class MessageVo(BaseModel):
], ],
) )
extra: Optional[List[Dict[str, Any]]] = Field( # 修改为列表类型
default=None,
description="A field to store additional JSON data.",
examples=[
[{"type": 1, "file_name": "111.png"}],
],
)
def to_dict(self, **kwargs) -> Dict[str, Any]: def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary""" """Convert the model to a dictionary"""
return model_to_dict(self, **kwargs) return model_to_dict(self, **kwargs)
...@@ -17,6 +17,8 @@ from dbgpt.util.pagination_utils import PaginationResult ...@@ -17,6 +17,8 @@ from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import MessageVo, ServeRequest, ServerResponse from ..api.schemas import MessageVo, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity from ..models.models import ServeDao, ServeEntity
from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.vadmin.keywordsviews import get_media_datas_all
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
...@@ -178,6 +180,19 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): ...@@ -178,6 +180,19 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
""" """
return self.dao.get_conv_by_page(request, page, page_size) return self.dao.get_conv_by_page(request, page, page_size)
async def get_history_messages2(
self, db: AsyncSession, request: Union[ServeRequest, Dict[str, Any]]
) -> List[MessageVo]:
print(f"------------历史记录request-------------->:{request}")
conv: StorageConversation = self.create_storage_conv(request)
messages = _append_view_messages(conv.messages)
# 查找是否有相关图片 或者 视频
result = await get_media_datas_all(request.conv_uid, self.config.default_model, db, messages)
print(f"------------历史记录result-------------->:{result}")
return result
def get_history_messages( def get_history_messages(
self, request: Union[ServeRequest, Dict[str, Any]] self, request: Union[ServeRequest, Dict[str, Any]]
) -> List[MessageVo]: ) -> List[MessageVo]:
...@@ -199,6 +214,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): ...@@ -199,6 +214,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
context=msg.content, context=msg.content,
order=msg.round_index, order=msg.round_index,
model_name=self.config.default_model, model_name=self.config.default_model,
extra="----====---->",
) )
) )
return result return result
#!/bin/bash #!/bin/bash
echo Starting... echo Starting...
if [ ! -d datamining ] ; then
mkdir datamining
fi
if [ ! -d jfxt ] ; then
mkdir jfxt
fi
export LC_ALL=C
echo "`pwd`" >> ./pwd.txt
nohup ./BufferServer > /dev/null 2>&1 & nohup ./BufferServer > /dev/null 2>&1 &
sleep 1 sleep 1
nohup ./LogicServer > /dev/null 2>&1 &
sleep 1 #加载一下敏感词接口
nohup ./GatewayServer > /dev/null 2>&1 & #http://192.168.11.46:5670/api/v2/vadmin/load_parse_from_db
echo Done! echo Done!
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment