Commit b5bea16a authored by 于飞's avatar 于飞

根据知识库名称,获取资源

parent 63ad9bc1
...@@ -6,7 +6,8 @@ from dbgpt.app.apps.utils.file.file_manage import FileManage ...@@ -6,7 +6,8 @@ from dbgpt.app.apps.utils.file.file_manage import FileManage
from dbgpt.app.apps.utils.response import SuccessResponse, ErrorResponse from dbgpt.app.apps.utils.response import SuccessResponse, ErrorResponse
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from dbgpt.app.apps.vadmin.media import schemas from dbgpt.app.apps.vadmin.media import schemas
from dbgpt.app.apps.vadmin.media.crud import MediaDal,QuestionDal from dbgpt.app.apps.vadmin.media.crud import MediaDal,QuestionDal,CorrelationDal
from dbgpt.app.apps.vadmin.media.models import VadminCorrelation
from dbgpt.app.apps.vadmin.media.params.media_list import MediaListParams, GroupListParams, MediaEditParams, \ from dbgpt.app.apps.vadmin.media.params.media_list import MediaListParams, GroupListParams, MediaEditParams, \
QuestionListParams, \ QuestionListParams, \
QuestionEditParams, CorrelationListParams QuestionEditParams, CorrelationListParams
...@@ -30,6 +31,7 @@ from dbgpt.app.openapi.api_view_model import ( ...@@ -30,6 +31,7 @@ from dbgpt.app.openapi.api_view_model import (
ConversationVo, ConversationVo,
Result, Result,
) )
from sqlalchemy import BinaryExpression
router = APIRouter() router = APIRouter()
...@@ -73,6 +75,82 @@ def get_key_words_nlp(user_input: str) -> list: ...@@ -73,6 +75,82 @@ def get_key_words_nlp(user_input: str) -> list:
#print(words) #print(words)
return words return words
async def get_media_datas_by(conv_uid: str, words: str, db: AsyncSession, knownledge: str) -> list:
# 去拿出group_id
datas = []
if knownledge != None:
datas, count = await CorrelationDal(db).get_datas(name=knownledge, v_return_count=True)
if len(datas) > 0:
result = []
corrdata = datas[0]
image_datas = corrdata.get('image_group')
for image_groups in image_datas:
image_groupid = image_groups.get('group_id')
print(f"===========>image_groupid:{image_groupid}")
# 取出匹配到的关键词,获取数据库中的图片
images_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 1,
'group_id': image_groupid,
'key_word': ('like', words)}
images_datas, count = await MediaDal(db).get_datas(**images_dic, v_return_count=True)
print(f"-----查询到的图片为:---->:{images_datas}")
for data in images_datas:
json_image = {'type': 1, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'),
'local_path': data.get('local_path'), 'remote_path': data.get('remote_path')}
result.append(json_image)
video_datas = corrdata.get('video_group')
for video_groups in video_datas:
video_groupid = video_groups.get('group_id')
print(f"===========>video_groupid:{video_groupid}")
# 取出匹配到的关键词,获取数据库中的视频
video_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 2,
'group_id': video_groupid,
'key_word': ('like', words)}
video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
print(f"-----查询到的视频为:---->:{video_datas}")
for videodata in video_datas:
json_video = {'type': 2, 'file_name': videodata.get('file_name'),
'key_word': videodata.get('key_word'),
'local_path': videodata.get('local_path'),
'remote_path': videodata.get('remote_path')}
result.append(json_video)
question_datas = corrdata.get('question_group')
for question_groups in question_datas:
question_groupid = question_groups.get('group_id')
print(f"===========>question_groupid:{question_groupid}")
# 匹配到的问答对有
question_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None,
'group_id': question_groupid,
'key_word': ('like', words)}
question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
print(f"-----查询到的问答对为:---->:{question_datas}")
for questiondata in question_datas:
json_question = {'type': 4, 'title': questiondata.get('title'),
'key_word': questiondata.get('key_word'),
'answer': questiondata.get('answer')}
result.append(json_question)
# 保存到聊天历史资源数据库中
if len(result) > 0:
print(f"-----保存数据的时候打印会话ID:---->:{conv_uid}")
# 保存到聊天历史资料表中
json_string = json.dumps(result) # 转换为字符串
simi_data = ChatHistorySchemas()
simi_data.conv_uid = conv_uid
simi_data.message_medias = json_string
await ChatHistoryDal(db).create_data(data=simi_data)
return result
else:
#return await get_media_datas(conv_uid, words, db)
result = []
return result
async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list: async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list:
# 取出匹配到的关键词,获取数据库中的图片 # 取出匹配到的关键词,获取数据库中的图片
...@@ -153,12 +231,9 @@ async def get_media_datas_all(conv_uid: str, default_model: str, db: AsyncSessio ...@@ -153,12 +231,9 @@ async def get_media_datas_all(conv_uid: str, default_model: str, db: AsyncSessio
@router.post("/get_spacy_keywords", summary="资源列表(图片、视频)") @router.post("/get_spacy_keywords", summary="资源列表(图片、视频)")
async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Depends(OpenAuth())): async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Depends(OpenAuth())):
print(f"用户输入的问题:{dialogue.user_input} ") print(f"用户输入的问题:{dialogue.user_input} -- 选择的知识库为:{dialogue.select_param}")
print('----------------begin---------------->') print('----------------begin---------------->')
# 从数据库中加载 并且初始化敏感词-->到内存中
# await mydfafiter.parse_from_db(auth.db)
#先判断敏感词 #先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*") dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
print(dfa_result) print(dfa_result)
...@@ -173,7 +248,7 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep ...@@ -173,7 +248,7 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep
words = get_key_words_nlp(dialogue.user_input) #100%匹配算法 | 只取匹配到的第一个 words = get_key_words_nlp(dialogue.user_input) #100%匹配算法 | 只取匹配到的第一个
if len(words) > 0: if len(words) > 0:
print(f"---算法1-匹配到的关键词--->:{words[0]}") print(f"---算法1-匹配到的关键词--->:{words[0]}")
result = await get_media_datas(dialogue.conv_uid, words[0], auth.db) result = await get_media_datas_by(dialogue.conv_uid, words[0], auth.db, dialogue.select_param)
return SuccessResponse(result) return SuccessResponse(result)
else: else:
print(f"---算法2-begin--->") print(f"---算法2-begin--->")
...@@ -181,7 +256,7 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep ...@@ -181,7 +256,7 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep
words2 = get_key_words(dialogue.user_input) words2 = get_key_words(dialogue.user_input)
if len(words2) > 0: if len(words2) > 0:
print(f"---算法2-匹配到的关键词--->:{words[0]}") print(f"---算法2-匹配到的关键词--->:{words[0]}")
result = await get_media_datas(dialogue.conv_uid, words2[0], auth.db) result = await get_media_datas_by(dialogue.conv_uid, words2[0], auth.db, dialogue.select_param)
return SuccessResponse(result) return SuccessResponse(result)
else: else:
print(f"-----没有找到需要查询的内容:---->") print(f"-----没有找到需要查询的内容:---->")
......
...@@ -359,6 +359,7 @@ async def chat_completions( ...@@ -359,6 +359,7 @@ async def chat_completions(
flow_service: FlowService = Depends(get_chat_flow), flow_service: FlowService = Depends(get_chat_flow),
auth: Auth = Depends(OpenAuth()), auth: Auth = Depends(OpenAuth()),
): ):
#dialogue.select_param 这个参数为 : 知识库名称
logger.info( logger.info(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}" f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
) )
......
#!/bin/bash #!/bin/bash
echo Starting... echo Starting...
nohup ./BufferServer > /dev/null 2>&1 & nohup dbgpt start webserver --port 6006 > /dev/null 2>&1 &
sleep 1 sleep 10
#加载一下敏感词接口 #加载一下敏感词接口
#http://192.168.11.46:5670/api/v2/vadmin/load_parse_from_db curl http://127.0.0.1:6006/api/v2/vadmin/load_parse_from_db
echo Done! echo Done!
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment