Commit f793c556 authored by 于飞's avatar 于飞

添加关键词提出代码

parent 563c60e8
......@@ -4,7 +4,10 @@ import re
from fastapi import Request, Depends
from dbgpt.app.apps.core.database import db_getter
from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.vadmin.word.crud import SensitiveDal
from dbgpt.app.apps.vadmin.media.crud import MediaDal,QuestionDal
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
......@@ -13,9 +16,9 @@ class DFAFilter():
Use DFA to keep algorithm perform constantly
敏感词过滤
>>> f = DFAFilter()
>>> f.add("sexy")
>>> f.filter("hello sexy baby")
#>>> f = DFAFilter()
#>>> f.add("sexy")
#>>> f.filter("hello sexy baby")
hello **** baby
'''
......@@ -52,10 +55,43 @@ class DFAFilter():
for keyword in f:
self.add(keyword.strip())
#从数据库中加载 敏感词
async def parse_from_db(self, db: AsyncSession):
#从数据库中加载 图片资源
async def parse_picture_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------parse_picture_from_db--load-------------')
# 获取数据库中所有的图片--->添加到词库内存中
images_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 1, 'group_id': None}
images_datas, count = await MediaDal(db).get_datas(**images_dic, v_return_count=True)
print(f"-----图片库列表为:---->:{images_datas}")
for imagedata in images_datas:
self.add(imagedata.get('key_word'))
#从数据库中加载 问答对
async def parse_question_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------parse_question_from_db--load-------------')
# 获取数据库中所有的问答对--->添加到词库内存中
question_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None}
question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
print(f"-----问答对库列表为:---->:{question_datas}")
for questiondata in question_datas:
self.add(questiondata.get('key_word'))
#从数据库中加载 视频资源
async def parse_video_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------sensitive-load-------------')
print('---------parse_video_from_db--load-------------')
# 获取数据库中所有的视频--->添加到词库内存中
video_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 2, 'group_id': None}
video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
print(f"-----视频库列表为:---->:{video_datas}")
for videodata in video_datas:
self.add(videodata.get('key_word'))
# 从数据库中加载 敏感词
async def parse_from_db(self, db: AsyncSession):
# db: AsyncSession = Depends(db_getter)
print('---------parse_from_db-load-------------')
sdl = SensitiveDal(db)
datas = await sdl.get_sensitives()
for keyword in datas:
......@@ -94,8 +130,19 @@ class DFAFilter():
# return 返回三个参数
return ''.join(ret), is_sensitive, matched_sensitives
#初始化全局对象
#初始化-敏感词
mydfafiter = DFAFilter()
#初始化-图片关键词
mydfafiter_picture = DFAFilter()
#初始化-视频关键词
mydfafiter_video = DFAFilter()
#初始化-问答对关键词
mydfafiter_question = DFAFilter()
"""
if __name__ == "__main__":
gfw = DFAFilter()
......
This diff is collapsed.
......@@ -45,7 +45,7 @@ from dbgpt.util.utils import (
from dbgpt.app.apps.utils.tools import import_modules
from dbgpt.app.apps.utils.spach_keywords import my_spacy_nlp
from dbgpt.app.apps.utils.filter import mydfafiter
from dbgpt.app.apps.utils.filter import mydfafiter, mydfafiter_picture, mydfafiter_question, mydfafiter_video
REQUEST_LOG_RECORD = False
MIDDLEWARES = [
......@@ -104,6 +104,7 @@ def mount_routers(app: FastAPI):
from dbgpt.app.apps.vadmin.media.views import router as media_views
from dbgpt.app.apps.vadmin.word.views import router as word_views
from dbgpt.app.apps.vadmin.keywordsviews import router as keywords_views
from dbgpt.app.apps.vadmin.chathistory.views import router as chathistory_views
app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_v2, prefix="/api", tags=["ChatV2"])
......@@ -119,6 +120,7 @@ def mount_routers(app: FastAPI):
app.include_router(media_views, prefix="/api/v2", tags=["System"])
app.include_router(word_views, prefix="/api/v2/vadmin/word", tags=["Word"])
app.include_router(keywords_views, prefix="/api/v2/vadmin", tags=["vadmin"])
app.include_router(chathistory_views, prefix="/api/v2/vadmin", tags=["chathistory"])
def mount_static_files(app: FastAPI):
......
......@@ -46,6 +46,7 @@ from dbgpt.util.tracer import SpanType, root_tracer
from dbgpt.app.apps.utils.filter import mydfafiter
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
......@@ -263,6 +264,7 @@ async def params_load(
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
doc_file: UploadFile = File(...),
auth: Auth = Depends(OpenAuth()),
):
logger.info(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
......@@ -287,11 +289,19 @@ async def params_load(
resp = await chat.prepare()
# Refresh messages
return Result.succ(get_hist_messages(conv_uid))
#ret_his_msg = get_hist_messages(conv_uid)
ret_his_msg = await get_hist_messages_datas(conv_uid, auth.db)
return Result.succ(ret_his_msg)
except Exception as e:
logger.error("excel load error!", e)
return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
async def get_hist_messages_datas(conv_uid: str, db: AsyncSession):
from dbgpt.serve.conversation.serve import Service as ConversationService
instance: ConversationService = ConversationService.get_instance(CFG.SYSTEM_APP)
result = await instance.get_history_messages2(db, {"conv_uid": conv_uid})
return result
def get_hist_messages(conv_uid: str):
from dbgpt.serve.conversation.serve import Service as ConversationService
......@@ -360,7 +370,7 @@ async def chat_completions(
}
# 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db)
# await mydfafiter.parse_from_db(auth.db)
# 先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
......
......@@ -99,7 +99,7 @@ async def chat_completions(
}
# 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db)
# await mydfafiter.parse_from_db(auth.db)
# 先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(request.user_input, "*")
......
......@@ -12,6 +12,9 @@ from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import MessageVo, ServeRequest, ServerResponse
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
......@@ -207,9 +210,13 @@ async def list_latest_conv(
response_model=Result[List[MessageVo]],
dependencies=[Depends(check_api_key)],
)
async def get_history_messages(con_uid: str, service: Service = Depends(get_service)):
async def get_history_messages(con_uid: str,
service: Service = Depends(get_service),
auth: Auth = Depends(OpenAuth())):
"""Get the history messages of a conversation"""
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid)))
#result = service.get_history_messages(ServeRequest(conv_uid=con_uid))
result = await service.get_history_messages2(auth.db, ServeRequest(conv_uid=con_uid))
return Result.succ(result)
def init_endpoints(system_app: SystemApp) -> None:
......
# Define your Pydantic schemas here
from typing import Any, Dict, Optional
from typing import Optional, List, Dict, Any
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
......@@ -149,6 +149,14 @@ class MessageVo(BaseModel):
],
)
extra: Optional[List[Dict[str, Any]]] = Field( # 修改为列表类型
default=None,
description="A field to store additional JSON data.",
examples=[
[{"type": 1, "file_name": "111.png"}],
],
)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""
return model_to_dict(self, **kwargs)
......@@ -17,6 +17,8 @@ from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import MessageVo, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.vadmin.keywordsviews import get_media_datas_all
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
......@@ -178,6 +180,19 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""
return self.dao.get_conv_by_page(request, page, page_size)
async def get_history_messages2(
self, db: AsyncSession, request: Union[ServeRequest, Dict[str, Any]]
) -> List[MessageVo]:
print(f"------------历史记录request-------------->:{request}")
conv: StorageConversation = self.create_storage_conv(request)
messages = _append_view_messages(conv.messages)
# 查找是否有相关图片 或者 视频
result = await get_media_datas_all(request.conv_uid, self.config.default_model, db, messages)
print(f"------------历史记录result-------------->:{result}")
return result
def get_history_messages(
self, request: Union[ServeRequest, Dict[str, Any]]
) -> List[MessageVo]:
......@@ -199,6 +214,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
context=msg.content,
order=msg.round_index,
model_name=self.config.default_model,
extra="----====---->",
)
)
return result
#!/bin/bash
echo Starting...
if [ ! -d datamining ] ; then
mkdir datamining
fi
if [ ! -d jfxt ] ; then
mkdir jfxt
fi
export LC_ALL=C
echo "`pwd`" >> ./pwd.txt
nohup ./BufferServer > /dev/null 2>&1 &
sleep 1
nohup ./LogicServer > /dev/null 2>&1 &
sleep 1
nohup ./GatewayServer > /dev/null 2>&1 &
#加载一下敏感词接口
#http://192.168.11.46:5670/api/v2/vadmin/load_parse_from_db
echo Done!
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment