Commit f793c556 authored by 于飞's avatar 于飞

添加关键词提出代码

parent 563c60e8
...@@ -4,7 +4,10 @@ import re ...@@ -4,7 +4,10 @@ import re
from fastapi import Request, Depends from fastapi import Request, Depends
from dbgpt.app.apps.core.database import db_getter from dbgpt.app.apps.core.database import db_getter
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.vadmin.word.crud import SensitiveDal from dbgpt.app.apps.vadmin.word.crud import SensitiveDal
from dbgpt.app.apps.vadmin.media.crud import MediaDal,QuestionDal
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
...@@ -13,9 +16,9 @@ class DFAFilter(): ...@@ -13,9 +16,9 @@ class DFAFilter():
Use DFA to keep algorithm perform constantly Use DFA to keep algorithm perform constantly
敏感词过滤 敏感词过滤
>>> f = DFAFilter() #>>> f = DFAFilter()
>>> f.add("sexy") #>>> f.add("sexy")
>>> f.filter("hello sexy baby") #>>> f.filter("hello sexy baby")
hello **** baby hello **** baby
''' '''
...@@ -52,10 +55,43 @@ class DFAFilter(): ...@@ -52,10 +55,43 @@ class DFAFilter():
for keyword in f: for keyword in f:
self.add(keyword.strip()) self.add(keyword.strip())
#从数据库中加载 敏感词 #从数据库中加载 图片资源
async def parse_from_db(self, db: AsyncSession): async def parse_picture_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------parse_picture_from_db--load-------------')
# 获取数据库中所有的图片--->添加到词库内存中
images_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 1, 'group_id': None}
images_datas, count = await MediaDal(db).get_datas(**images_dic, v_return_count=True)
print(f"-----图片库列表为:---->:{images_datas}")
for imagedata in images_datas:
self.add(imagedata.get('key_word'))
#从数据库中加载 问答对
async def parse_question_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------parse_question_from_db--load-------------')
# 获取数据库中所有的问答对--->添加到词库内存中
question_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None}
question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
print(f"-----问答对库列表为:---->:{question_datas}")
for questiondata in question_datas:
self.add(questiondata.get('key_word'))
#从数据库中加载 视频资源
async def parse_video_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter) #db: AsyncSession = Depends(db_getter)
print('---------sensitive-load-------------') print('---------parse_video_from_db--load-------------')
# 获取数据库中所有的视频--->添加到词库内存中
video_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 2, 'group_id': None}
video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
print(f"-----视频库列表为:---->:{video_datas}")
for videodata in video_datas:
self.add(videodata.get('key_word'))
# 从数据库中加载 敏感词
async def parse_from_db(self, db: AsyncSession):
# db: AsyncSession = Depends(db_getter)
print('---------parse_from_db-load-------------')
sdl = SensitiveDal(db) sdl = SensitiveDal(db)
datas = await sdl.get_sensitives() datas = await sdl.get_sensitives()
for keyword in datas: for keyword in datas:
...@@ -94,8 +130,19 @@ class DFAFilter(): ...@@ -94,8 +130,19 @@ class DFAFilter():
# return 返回三个参数 # return 返回三个参数
return ''.join(ret), is_sensitive, matched_sensitives return ''.join(ret), is_sensitive, matched_sensitives
#初始化全局对象
#初始化-敏感词
mydfafiter = DFAFilter() mydfafiter = DFAFilter()
#初始化-图片关键词
mydfafiter_picture = DFAFilter()
#初始化-视频关键词
mydfafiter_video = DFAFilter()
#初始化-问答对关键词
mydfafiter_question = DFAFilter()
""" """
if __name__ == "__main__": if __name__ == "__main__":
gfw = DFAFilter() gfw = DFAFilter()
......
...@@ -3,50 +3,43 @@ import json ...@@ -3,50 +3,43 @@ import json
from fastapi import APIRouter, Depends, Body, UploadFile, Form, Request from fastapi import APIRouter, Depends, Body, UploadFile, Form, Request
from dbgpt.app.apps.utils.file.file_manage import FileManage from dbgpt.app.apps.utils.file.file_manage import FileManage
from dbgpt.app.apps.utils.response import SuccessResponse,ErrorResponse from dbgpt.app.apps.utils.response import SuccessResponse, ErrorResponse
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from dbgpt.app.apps.vadmin.media import schemas, crud from dbgpt.app.apps.vadmin.media import schemas
from dbgpt.app.apps.vadmin.media.params.media_list import MediaListParams, GroupListParams, MediaEditParams, QuestionListParams, \ from dbgpt.app.apps.vadmin.media.crud import MediaDal,QuestionDal
from dbgpt.app.apps.vadmin.media.params.media_list import MediaListParams, GroupListParams, MediaEditParams, \
QuestionListParams, \
QuestionEditParams, CorrelationListParams QuestionEditParams, CorrelationListParams
from dbgpt.app.apps.core.dependencies import IdList from dbgpt.app.apps.core.dependencies import IdList
from dbgpt.app.apps.vadmin.chathistory.crud import ChatHistoryDal
from dbgpt.app.apps.vadmin.chathistory.schemas.chathistory import ChatHistorySchemas
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Any, Dict, Generic, Optional, TypeVar from typing import Any, Dict, Generic, Optional, TypeVar
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from dbgpt.app.apps.utils.spach_keywords import my_spacy_nlp from dbgpt.app.apps.utils.spach_keywords import my_spacy_nlp
from dbgpt.app.apps.utils.filter import mydfafiter from dbgpt.app.apps.utils.filter import mydfafiter, mydfafiter_picture, mydfafiter_question, mydfafiter_video
from dbgpt.serve.conversation.api.schemas import MessageVo, ServeRequest, ServerResponse
from dbgpt.app.openapi.api_view_model import (
ChatSceneVo,
ConversationVo,
Result,
)
router = APIRouter() router = APIRouter()
class ConversationVo(BaseModel):
model_config = ConfigDict(protected_namespaces=())
user_input: str = ""
@router.post("/get_spacy_keywords", summary="资源列表(图片、视频)")
async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Depends(OpenAuth())):
print(f"用户输入的问题:{dialogue.user_input} ")
print('----------------begin---------------->')
# 从数据库中加载 并且初始化敏感词-->到内存中
# await mydfafiter.parse_from_db(auth.db)
#先判断敏感词
dfa_result, is_sensitive,matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
print(dfa_result)
if is_sensitive:
print('用户输入有敏感词')
result = {'code': 200, 'message': 'success', 'data': [{'type': 3, 'word_name': matched_sensitives, 'is_sensitive': 1, 'user_input':dfa_result}]}
return SuccessResponse(result) #返回type=3
#没有敏感词的时候,查找是否有相关图片 或者 视频
doc = my_spacy_nlp.nlp(dialogue.user_input)
def get_key_words(user_input: str) -> list:
"""
接受一个字符串输入,提取关键词,并返回其中的单词列表。
"""
words = [] words = []
doc = my_spacy_nlp.nlp(user_input)
# examine the top-ranked phrases in the document # examine the top-ranked phrases in the document
for phrase in doc._.phrases: for phrase in doc._.phrases:
# logger.info(f"----1--->:{phrase.rank}--->:{phrase.count}") # logger.info(f"----1--->:{phrase.rank}--->:{phrase.count}")
...@@ -54,16 +47,40 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep ...@@ -54,16 +47,40 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep
words.append(phrase.chunks[0]) words.append(phrase.chunks[0])
print(words) print(words)
return words
if len(words) > 0: def get_key_words_nlp(user_input: str) -> list:
print(words[0]) """
另外一种算法提取关键词,比上面的算法更加准确
"""
words = []
dfa_result, is_sensitive, matched_medias = mydfafiter_picture.filter(user_input, "*")
#print(matched_medias)
for phrase in matched_medias:
words.append(phrase)
dfa_result2, is_sensitive2, matched_medias2 = mydfafiter_question.filter(user_input, "*")
# print(matched_medias2)
for phrase2 in matched_medias2:
words.append(phrase2)
#取出匹配到的关键词,获取数据库中的图片 dfa_result3, is_sensitive3, matched_medias3 = mydfafiter_video.filter(user_input, "*")
images_dic = {'page': 1, 'limit': 10, 'v_order': None, 'v_order_field': None, 'type': 1, 'group_id': None, # print(matched_medias3)
'file_name': ('like', words[0])} for phrase3 in matched_medias3:
words.append(phrase3)
images_datas, count = await crud.MediaDal(auth.db).get_datas(**images_dic, v_return_count=True) #print(words)
#print(f"-----查询到的图片为:---->:{images_datas}") return words
async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list:
# 取出匹配到的关键词,获取数据库中的图片
images_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 1, 'group_id': None,
'key_word': ('like', words)}
images_datas, count = await MediaDal(db).get_datas(**images_dic, v_return_count=True)
print(f"-----查询到的图片为:---->:{images_datas}")
result = [] result = []
for data in images_datas: for data in images_datas:
json_image = {'type': 1, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'), json_image = {'type': 1, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'),
...@@ -71,35 +88,105 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep ...@@ -71,35 +88,105 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep
result.append(json_image) result.append(json_image)
# 取出匹配到的关键词,获取数据库中的视频 # 取出匹配到的关键词,获取数据库中的视频
video_dic = {'page': 1, 'limit': 10, 'v_order': None, 'v_order_field': None, 'type': 2, 'group_id': None, video_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 2, 'group_id': None,
'file_name': ('like', words[0])} 'key_word': ('like', words)}
video_datas, count = await crud.MediaDal(auth.db).get_datas(**video_dic, v_return_count=True) video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
#print(f"-----查询到的视频为:---->:{video_datas}") print(f"-----查询到的视频为:---->:{video_datas}")
for data in video_datas: for videodata in video_datas:
json_video = {'type': 2, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'), json_video = {'type': 2, 'file_name': videodata.get('file_name'), 'key_word': videodata.get('key_word'),
'local_path': data.get('local_path'), 'remote_path': data.get('remote_path')} 'local_path': videodata.get('local_path'), 'remote_path': videodata.get('remote_path')}
result.append(json_video) result.append(json_video)
return SuccessResponse(result) #匹配到的问答对有
else: question_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None,'key_word': ('like', words)}
print(f"-----没有找到需要查询的内容:---->") question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
return ErrorResponse("没有找到需要查询的内容") print(f"-----查询到的问答对为:---->:{question_datas}")
for questiondata in question_datas:
@router.get("/load_parse_from_db", summary="加载敏感词") json_question = {'type': 4, 'title': questiondata.get('title'), 'key_word': questiondata.get('key_word'),
async def load_parse_from_db(auth: Auth = Depends(OpenAuth())): 'answer': questiondata.get('answer')}
# 从数据库中加载 并且初始化敏感词-->到内存中 result.append(json_question)
await mydfafiter.parse_from_db(auth.db)
return SuccessResponse("sensitive load OK") #保存到聊天历史资源数据库中
if len(result) > 0:
print(f"-----保存数据的时候打印会话ID:---->:{conv_uid}")
#保存到聊天历史资料表中
json_string = json.dumps(result) #转换为字符串
simi_data = ChatHistorySchemas()
simi_data.conv_uid = conv_uid
simi_data.message_medias = json_string
await ChatHistoryDal(db).create_data(data=simi_data)
return result
async def get_media_datas_all(conv_uid: str, default_model: str, db: AsyncSession, messages:list) -> list:
"""
根据会话ID-->获取聊天历史的图片或视频资源
"""
ret_media_datas = []
#获取聊天历史资源数据
history_dic = {'page': 1, 'limit': 10, 'v_order': None, 'v_order_field': None, 'conv_uid': conv_uid,
'message_medias': ('like', None)}
datas, count = await ChatHistoryDal(db).get_datas(**history_dic, v_return_count=True)
json_data = [{"type": 1, "file_name": "723629348SJfHgjzD.png"}] # 传入列表
if count > 0:
history_datas = datas[0].get('message_medias')
json_data = json.loads(history_datas)
print(f"----将字符串转换为json格式---->:{json_data}")
for msg in messages:
# 根据历史聊天记录msg
ret_media_datas.append(
MessageVo(
role=msg.type,
context=msg.content,
order=msg.round_index,
model_name=default_model,
extra=json_data,
)
)
return ret_media_datas
@router.post("/get_spacy_keywords", summary="资源列表(图片、视频)")
async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Depends(OpenAuth())):
print(f"用户输入的问题:{dialogue.user_input} ")
print('----------------begin---------------->')
# 从数据库中加载 并且初始化敏感词-->到内存中
# await mydfafiter.parse_from_db(auth.db)
#先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
print(dfa_result)
if is_sensitive:
print('用户输入有敏感词')
result = {'code': 200, 'message': 'success',
'data': [{'type': 3, 'word_name': matched_sensitives, 'is_sensitive': 1, 'user_input': dfa_result}]}
return SuccessResponse(result) #返回type=3
#没有敏感词的时候,查找是否有相关图片 或者 视频
#words = get_key_words(dialogue.user_input)
words = get_key_words_nlp(dialogue.user_input) #100%匹配算法 | 只取匹配到的第一个
if len(words) > 0:
print(f"----匹配到的关键词--->:{words[0]}")
result = await get_media_datas(dialogue.conv_uid, words[0], auth.db)
return SuccessResponse(result)
else:
print(f"-----没有找到需要查询的内容:---->")
return ErrorResponse("没有找到需要查询的内容")
@router.get("/load_parse_from_db", summary="加载敏感词和资源关键词")
async def load_parse_from_db(auth: Auth = Depends(OpenAuth())):
# 从数据库中加载 并且初始化敏感词,图片,视频,问答对-->到内存中
await mydfafiter.parse_from_db(auth.db)
await mydfafiter_picture.parse_picture_from_db(auth.db)
await mydfafiter_video.parse_video_from_db(auth.db)
await mydfafiter_question.parse_question_from_db(auth.db)
return SuccessResponse("media and sensitive all load OK")
...@@ -45,7 +45,7 @@ from dbgpt.util.utils import ( ...@@ -45,7 +45,7 @@ from dbgpt.util.utils import (
from dbgpt.app.apps.utils.tools import import_modules from dbgpt.app.apps.utils.tools import import_modules
from dbgpt.app.apps.utils.spach_keywords import my_spacy_nlp from dbgpt.app.apps.utils.spach_keywords import my_spacy_nlp
from dbgpt.app.apps.utils.filter import mydfafiter from dbgpt.app.apps.utils.filter import mydfafiter, mydfafiter_picture, mydfafiter_question, mydfafiter_video
REQUEST_LOG_RECORD = False REQUEST_LOG_RECORD = False
MIDDLEWARES = [ MIDDLEWARES = [
...@@ -104,6 +104,7 @@ def mount_routers(app: FastAPI): ...@@ -104,6 +104,7 @@ def mount_routers(app: FastAPI):
from dbgpt.app.apps.vadmin.media.views import router as media_views from dbgpt.app.apps.vadmin.media.views import router as media_views
from dbgpt.app.apps.vadmin.word.views import router as word_views from dbgpt.app.apps.vadmin.word.views import router as word_views
from dbgpt.app.apps.vadmin.keywordsviews import router as keywords_views from dbgpt.app.apps.vadmin.keywordsviews import router as keywords_views
from dbgpt.app.apps.vadmin.chathistory.views import router as chathistory_views
app.include_router(api_v1, prefix="/api", tags=["Chat"]) app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_v2, prefix="/api", tags=["ChatV2"]) app.include_router(api_v2, prefix="/api", tags=["ChatV2"])
...@@ -119,6 +120,7 @@ def mount_routers(app: FastAPI): ...@@ -119,6 +120,7 @@ def mount_routers(app: FastAPI):
app.include_router(media_views, prefix="/api/v2", tags=["System"]) app.include_router(media_views, prefix="/api/v2", tags=["System"])
app.include_router(word_views, prefix="/api/v2/vadmin/word", tags=["Word"]) app.include_router(word_views, prefix="/api/v2/vadmin/word", tags=["Word"])
app.include_router(keywords_views, prefix="/api/v2/vadmin", tags=["vadmin"]) app.include_router(keywords_views, prefix="/api/v2/vadmin", tags=["vadmin"])
app.include_router(chathistory_views, prefix="/api/v2/vadmin", tags=["chathistory"])
def mount_static_files(app: FastAPI): def mount_static_files(app: FastAPI):
......
...@@ -46,6 +46,7 @@ from dbgpt.util.tracer import SpanType, root_tracer ...@@ -46,6 +46,7 @@ from dbgpt.util.tracer import SpanType, root_tracer
from dbgpt.app.apps.utils.filter import mydfafiter from dbgpt.app.apps.utils.filter import mydfafiter
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter() router = APIRouter()
...@@ -263,6 +264,7 @@ async def params_load( ...@@ -263,6 +264,7 @@ async def params_load(
user_name: Optional[str] = None, user_name: Optional[str] = None,
sys_code: Optional[str] = None, sys_code: Optional[str] = None,
doc_file: UploadFile = File(...), doc_file: UploadFile = File(...),
auth: Auth = Depends(OpenAuth()),
): ):
logger.info(f"params_load: {conv_uid},{chat_mode},{model_name}") logger.info(f"params_load: {conv_uid},{chat_mode},{model_name}")
try: try:
...@@ -287,11 +289,19 @@ async def params_load( ...@@ -287,11 +289,19 @@ async def params_load(
resp = await chat.prepare() resp = await chat.prepare()
# Refresh messages # Refresh messages
return Result.succ(get_hist_messages(conv_uid)) #ret_his_msg = get_hist_messages(conv_uid)
ret_his_msg = await get_hist_messages_datas(conv_uid, auth.db)
return Result.succ(ret_his_msg)
except Exception as e: except Exception as e:
logger.error("excel load error!", e) logger.error("excel load error!", e)
return Result.failed(code="E000X", msg=f"File Load Error {str(e)}") return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
async def get_hist_messages_datas(conv_uid: str, db: AsyncSession):
from dbgpt.serve.conversation.serve import Service as ConversationService
instance: ConversationService = ConversationService.get_instance(CFG.SYSTEM_APP)
result = await instance.get_history_messages2(db, {"conv_uid": conv_uid})
return result
def get_hist_messages(conv_uid: str): def get_hist_messages(conv_uid: str):
from dbgpt.serve.conversation.serve import Service as ConversationService from dbgpt.serve.conversation.serve import Service as ConversationService
...@@ -360,7 +370,7 @@ async def chat_completions( ...@@ -360,7 +370,7 @@ async def chat_completions(
} }
# 从数据库中加载 并且初始化敏感词-->到内存中 # 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db) # await mydfafiter.parse_from_db(auth.db)
# 先判断敏感词 # 先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*") dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
......
...@@ -99,7 +99,7 @@ async def chat_completions( ...@@ -99,7 +99,7 @@ async def chat_completions(
} }
# 从数据库中加载 并且初始化敏感词-->到内存中 # 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db) # await mydfafiter.parse_from_db(auth.db)
# 先判断敏感词 # 先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(request.user_input, "*") dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(request.user_input, "*")
......
...@@ -12,6 +12,9 @@ from dbgpt.util import PaginationResult ...@@ -12,6 +12,9 @@ from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service from ..service.service import Service
from .schemas import MessageVo, ServeRequest, ServerResponse from .schemas import MessageVo, ServeRequest, ServerResponse
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter() router = APIRouter()
...@@ -207,9 +210,13 @@ async def list_latest_conv( ...@@ -207,9 +210,13 @@ async def list_latest_conv(
response_model=Result[List[MessageVo]], response_model=Result[List[MessageVo]],
dependencies=[Depends(check_api_key)], dependencies=[Depends(check_api_key)],
) )
async def get_history_messages(con_uid: str, service: Service = Depends(get_service)): async def get_history_messages(con_uid: str,
service: Service = Depends(get_service),
auth: Auth = Depends(OpenAuth())):
"""Get the history messages of a conversation""" """Get the history messages of a conversation"""
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid))) #result = service.get_history_messages(ServeRequest(conv_uid=con_uid))
result = await service.get_history_messages2(auth.db, ServeRequest(conv_uid=con_uid))
return Result.succ(result)
def init_endpoints(system_app: SystemApp) -> None: def init_endpoints(system_app: SystemApp) -> None:
......
# Define your Pydantic schemas here # Define your Pydantic schemas here
from typing import Any, Dict, Optional from typing import Optional, List, Dict, Any
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
...@@ -149,6 +149,14 @@ class MessageVo(BaseModel): ...@@ -149,6 +149,14 @@ class MessageVo(BaseModel):
], ],
) )
extra: Optional[List[Dict[str, Any]]] = Field( # 修改为列表类型
default=None,
description="A field to store additional JSON data.",
examples=[
[{"type": 1, "file_name": "111.png"}],
],
)
def to_dict(self, **kwargs) -> Dict[str, Any]: def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary""" """Convert the model to a dictionary"""
return model_to_dict(self, **kwargs) return model_to_dict(self, **kwargs)
...@@ -17,6 +17,8 @@ from dbgpt.util.pagination_utils import PaginationResult ...@@ -17,6 +17,8 @@ from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import MessageVo, ServeRequest, ServerResponse from ..api.schemas import MessageVo, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity from ..models.models import ServeDao, ServeEntity
from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.vadmin.keywordsviews import get_media_datas_all
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
...@@ -178,6 +180,19 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): ...@@ -178,6 +180,19 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
""" """
return self.dao.get_conv_by_page(request, page, page_size) return self.dao.get_conv_by_page(request, page, page_size)
async def get_history_messages2(
self, db: AsyncSession, request: Union[ServeRequest, Dict[str, Any]]
) -> List[MessageVo]:
print(f"------------历史记录request-------------->:{request}")
conv: StorageConversation = self.create_storage_conv(request)
messages = _append_view_messages(conv.messages)
# 查找是否有相关图片 或者 视频
result = await get_media_datas_all(request.conv_uid, self.config.default_model, db, messages)
print(f"------------历史记录result-------------->:{result}")
return result
def get_history_messages( def get_history_messages(
self, request: Union[ServeRequest, Dict[str, Any]] self, request: Union[ServeRequest, Dict[str, Any]]
) -> List[MessageVo]: ) -> List[MessageVo]:
...@@ -199,6 +214,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): ...@@ -199,6 +214,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
context=msg.content, context=msg.content,
order=msg.round_index, order=msg.round_index,
model_name=self.config.default_model, model_name=self.config.default_model,
extra="----====---->",
) )
) )
return result return result
#!/bin/bash #!/bin/bash
echo Starting... echo Starting...
if [ ! -d datamining ] ; then
mkdir datamining
fi
if [ ! -d jfxt ] ; then
mkdir jfxt
fi
export LC_ALL=C
echo "`pwd`" >> ./pwd.txt
nohup ./BufferServer > /dev/null 2>&1 & nohup ./BufferServer > /dev/null 2>&1 &
sleep 1 sleep 1
nohup ./LogicServer > /dev/null 2>&1 &
sleep 1 #加载一下敏感词接口
nohup ./GatewayServer > /dev/null 2>&1 & #http://192.168.11.46:5670/api/v2/vadmin/load_parse_from_db
echo Done! echo Done!
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment