Commit f793c556 authored by 于飞's avatar 于飞

添加关键词提出代码

parent 563c60e8
......@@ -4,7 +4,10 @@ import re
from fastapi import Request, Depends
from dbgpt.app.apps.core.database import db_getter
from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.vadmin.word.crud import SensitiveDal
from dbgpt.app.apps.vadmin.media.crud import MediaDal,QuestionDal
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
......@@ -13,9 +16,9 @@ class DFAFilter():
Use DFA to keep algorithm perform constantly
敏感词过滤
>>> f = DFAFilter()
>>> f.add("sexy")
>>> f.filter("hello sexy baby")
#>>> f = DFAFilter()
#>>> f.add("sexy")
#>>> f.filter("hello sexy baby")
hello **** baby
'''
......@@ -52,10 +55,43 @@ class DFAFilter():
for keyword in f:
self.add(keyword.strip())
#从数据库中加载 敏感词
async def parse_from_db(self, db: AsyncSession):
#从数据库中加载 图片资源
async def parse_picture_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------parse_picture_from_db--load-------------')
# 获取数据库中所有的图片--->添加到词库内存中
images_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 1, 'group_id': None}
images_datas, count = await MediaDal(db).get_datas(**images_dic, v_return_count=True)
print(f"-----图片库列表为:---->:{images_datas}")
for imagedata in images_datas:
self.add(imagedata.get('key_word'))
#从数据库中加载 问答对
async def parse_question_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------parse_question_from_db--load-------------')
# 获取数据库中所有的问答对--->添加到词库内存中
question_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None}
question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
print(f"-----问答对库列表为:---->:{question_datas}")
for questiondata in question_datas:
self.add(questiondata.get('key_word'))
#从数据库中加载 视频资源
async def parse_video_from_db(self, db: AsyncSession):
#db: AsyncSession = Depends(db_getter)
print('---------sensitive-load-------------')
print('---------parse_video_from_db--load-------------')
# 获取数据库中所有的视频--->添加到词库内存中
video_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 2, 'group_id': None}
video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
print(f"-----视频库列表为:---->:{video_datas}")
for videodata in video_datas:
self.add(videodata.get('key_word'))
# 从数据库中加载 敏感词
async def parse_from_db(self, db: AsyncSession):
# db: AsyncSession = Depends(db_getter)
print('---------parse_from_db-load-------------')
sdl = SensitiveDal(db)
datas = await sdl.get_sensitives()
for keyword in datas:
......@@ -94,8 +130,19 @@ class DFAFilter():
# return 返回三个参数
return ''.join(ret), is_sensitive, matched_sensitives
#初始化全局对象
#初始化-敏感词
mydfafiter = DFAFilter()
#初始化-图片关键词
mydfafiter_picture = DFAFilter()
#初始化-视频关键词
mydfafiter_video = DFAFilter()
#初始化-问答对关键词
mydfafiter_question = DFAFilter()
"""
if __name__ == "__main__":
gfw = DFAFilter()
......
......@@ -3,50 +3,43 @@ import json
from fastapi import APIRouter, Depends, Body, UploadFile, Form, Request
from dbgpt.app.apps.utils.file.file_manage import FileManage
from dbgpt.app.apps.utils.response import SuccessResponse,ErrorResponse
from dbgpt.app.apps.utils.response import SuccessResponse, ErrorResponse
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from dbgpt.app.apps.vadmin.media import schemas, crud
from dbgpt.app.apps.vadmin.media.params.media_list import MediaListParams, GroupListParams, MediaEditParams, QuestionListParams, \
from dbgpt.app.apps.vadmin.media import schemas
from dbgpt.app.apps.vadmin.media.crud import MediaDal,QuestionDal
from dbgpt.app.apps.vadmin.media.params.media_list import MediaListParams, GroupListParams, MediaEditParams, \
QuestionListParams, \
QuestionEditParams, CorrelationListParams
from dbgpt.app.apps.core.dependencies import IdList
from dbgpt.app.apps.vadmin.chathistory.crud import ChatHistoryDal
from dbgpt.app.apps.vadmin.chathistory.schemas.chathistory import ChatHistorySchemas
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Any, Dict, Generic, Optional, TypeVar
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from dbgpt.app.apps.utils.spach_keywords import my_spacy_nlp
from dbgpt.app.apps.utils.filter import mydfafiter
from dbgpt.app.apps.utils.filter import mydfafiter, mydfafiter_picture, mydfafiter_question, mydfafiter_video
from dbgpt.serve.conversation.api.schemas import MessageVo, ServeRequest, ServerResponse
from dbgpt.app.openapi.api_view_model import (
ChatSceneVo,
ConversationVo,
Result,
)
router = APIRouter()
class ConversationVo(BaseModel):
model_config = ConfigDict(protected_namespaces=())
user_input: str = ""
@router.post("/get_spacy_keywords", summary="资源列表(图片、视频)")
async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Depends(OpenAuth())):
print(f"用户输入的问题:{dialogue.user_input} ")
print('----------------begin---------------->')
# 从数据库中加载 并且初始化敏感词-->到内存中
# await mydfafiter.parse_from_db(auth.db)
#先判断敏感词
dfa_result, is_sensitive,matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
print(dfa_result)
if is_sensitive:
print('用户输入有敏感词')
result = {'code': 200, 'message': 'success', 'data': [{'type': 3, 'word_name': matched_sensitives, 'is_sensitive': 1, 'user_input':dfa_result}]}
return SuccessResponse(result) #返回type=3
#没有敏感词的时候,查找是否有相关图片 或者 视频
doc = my_spacy_nlp.nlp(dialogue.user_input)
def get_key_words(user_input: str) -> list:
"""
接受一个字符串输入,提取关键词,并返回其中的单词列表。
"""
words = []
doc = my_spacy_nlp.nlp(user_input)
# examine the top-ranked phrases in the document
for phrase in doc._.phrases:
# logger.info(f"----1--->:{phrase.rank}--->:{phrase.count}")
......@@ -54,16 +47,40 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep
words.append(phrase.chunks[0])
print(words)
return words
if len(words) > 0:
print(words[0])
def get_key_words_nlp(user_input: str) -> list:
"""
另外一种算法提取关键词,比上面的算法更加准确
"""
words = []
dfa_result, is_sensitive, matched_medias = mydfafiter_picture.filter(user_input, "*")
#print(matched_medias)
for phrase in matched_medias:
words.append(phrase)
dfa_result2, is_sensitive2, matched_medias2 = mydfafiter_question.filter(user_input, "*")
# print(matched_medias2)
for phrase2 in matched_medias2:
words.append(phrase2)
#取出匹配到的关键词,获取数据库中的图片
images_dic = {'page': 1, 'limit': 10, 'v_order': None, 'v_order_field': None, 'type': 1, 'group_id': None,
'file_name': ('like', words[0])}
dfa_result3, is_sensitive3, matched_medias3 = mydfafiter_video.filter(user_input, "*")
# print(matched_medias3)
for phrase3 in matched_medias3:
words.append(phrase3)
images_datas, count = await crud.MediaDal(auth.db).get_datas(**images_dic, v_return_count=True)
#print(f"-----查询到的图片为:---->:{images_datas}")
#print(words)
return words
async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list:
# 取出匹配到的关键词,获取数据库中的图片
images_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 1, 'group_id': None,
'key_word': ('like', words)}
images_datas, count = await MediaDal(db).get_datas(**images_dic, v_return_count=True)
print(f"-----查询到的图片为:---->:{images_datas}")
result = []
for data in images_datas:
json_image = {'type': 1, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'),
......@@ -71,35 +88,105 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep
result.append(json_image)
# 取出匹配到的关键词,获取数据库中的视频
video_dic = {'page': 1, 'limit': 10, 'v_order': None, 'v_order_field': None, 'type': 2, 'group_id': None,
'file_name': ('like', words[0])}
video_datas, count = await crud.MediaDal(auth.db).get_datas(**video_dic, v_return_count=True)
#print(f"-----查询到的视频为:---->:{video_datas}")
for data in video_datas:
json_video = {'type': 2, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'),
'local_path': data.get('local_path'), 'remote_path': data.get('remote_path')}
video_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 2, 'group_id': None,
'key_word': ('like', words)}
video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
print(f"-----查询到的视频为:---->:{video_datas}")
for videodata in video_datas:
json_video = {'type': 2, 'file_name': videodata.get('file_name'), 'key_word': videodata.get('key_word'),
'local_path': videodata.get('local_path'), 'remote_path': videodata.get('remote_path')}
result.append(json_video)
return SuccessResponse(result)
else:
print(f"-----没有找到需要查询的内容:---->")
return ErrorResponse("没有找到需要查询的内容")
@router.get("/load_parse_from_db", summary="加载敏感词")
async def load_parse_from_db(auth: Auth = Depends(OpenAuth())):
# 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db)
return SuccessResponse("sensitive load OK")
#匹配到的问答对有
question_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None,'key_word': ('like', words)}
question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
print(f"-----查询到的问答对为:---->:{question_datas}")
for questiondata in question_datas:
json_question = {'type': 4, 'title': questiondata.get('title'), 'key_word': questiondata.get('key_word'),
'answer': questiondata.get('answer')}
result.append(json_question)
#保存到聊天历史资源数据库中
if len(result) > 0:
print(f"-----保存数据的时候打印会话ID:---->:{conv_uid}")
#保存到聊天历史资料表中
json_string = json.dumps(result) #转换为字符串
simi_data = ChatHistorySchemas()
simi_data.conv_uid = conv_uid
simi_data.message_medias = json_string
await ChatHistoryDal(db).create_data(data=simi_data)
return result
async def get_media_datas_all(conv_uid: str, default_model: str, db: AsyncSession, messages:list) -> list:
"""
根据会话ID-->获取聊天历史的图片或视频资源
"""
ret_media_datas = []
#获取聊天历史资源数据
history_dic = {'page': 1, 'limit': 10, 'v_order': None, 'v_order_field': None, 'conv_uid': conv_uid,
'message_medias': ('like', None)}
datas, count = await ChatHistoryDal(db).get_datas(**history_dic, v_return_count=True)
json_data = [{"type": 1, "file_name": "723629348SJfHgjzD.png"}] # 传入列表
if count > 0:
history_datas = datas[0].get('message_medias')
json_data = json.loads(history_datas)
print(f"----将字符串转换为json格式---->:{json_data}")
for msg in messages:
# 根据历史聊天记录msg
ret_media_datas.append(
MessageVo(
role=msg.type,
context=msg.content,
order=msg.round_index,
model_name=default_model,
extra=json_data,
)
)
return ret_media_datas
@router.post("/get_spacy_keywords", summary="资源列表(图片、视频)")
async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Depends(OpenAuth())):
print(f"用户输入的问题:{dialogue.user_input} ")
print('----------------begin---------------->')
# 从数据库中加载 并且初始化敏感词-->到内存中
# await mydfafiter.parse_from_db(auth.db)
#先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
print(dfa_result)
if is_sensitive:
print('用户输入有敏感词')
result = {'code': 200, 'message': 'success',
'data': [{'type': 3, 'word_name': matched_sensitives, 'is_sensitive': 1, 'user_input': dfa_result}]}
return SuccessResponse(result) #返回type=3
#没有敏感词的时候,查找是否有相关图片 或者 视频
#words = get_key_words(dialogue.user_input)
words = get_key_words_nlp(dialogue.user_input) #100%匹配算法 | 只取匹配到的第一个
if len(words) > 0:
print(f"----匹配到的关键词--->:{words[0]}")
result = await get_media_datas(dialogue.conv_uid, words[0], auth.db)
return SuccessResponse(result)
else:
print(f"-----没有找到需要查询的内容:---->")
return ErrorResponse("没有找到需要查询的内容")
@router.get("/load_parse_from_db", summary="加载敏感词和资源关键词")
async def load_parse_from_db(auth: Auth = Depends(OpenAuth())):
# 从数据库中加载 并且初始化敏感词,图片,视频,问答对-->到内存中
await mydfafiter.parse_from_db(auth.db)
await mydfafiter_picture.parse_picture_from_db(auth.db)
await mydfafiter_video.parse_video_from_db(auth.db)
await mydfafiter_question.parse_question_from_db(auth.db)
return SuccessResponse("media and sensitive all load OK")
......@@ -45,7 +45,7 @@ from dbgpt.util.utils import (
from dbgpt.app.apps.utils.tools import import_modules
from dbgpt.app.apps.utils.spach_keywords import my_spacy_nlp
from dbgpt.app.apps.utils.filter import mydfafiter
from dbgpt.app.apps.utils.filter import mydfafiter, mydfafiter_picture, mydfafiter_question, mydfafiter_video
REQUEST_LOG_RECORD = False
MIDDLEWARES = [
......@@ -104,6 +104,7 @@ def mount_routers(app: FastAPI):
from dbgpt.app.apps.vadmin.media.views import router as media_views
from dbgpt.app.apps.vadmin.word.views import router as word_views
from dbgpt.app.apps.vadmin.keywordsviews import router as keywords_views
from dbgpt.app.apps.vadmin.chathistory.views import router as chathistory_views
app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_v2, prefix="/api", tags=["ChatV2"])
......@@ -119,6 +120,7 @@ def mount_routers(app: FastAPI):
app.include_router(media_views, prefix="/api/v2", tags=["System"])
app.include_router(word_views, prefix="/api/v2/vadmin/word", tags=["Word"])
app.include_router(keywords_views, prefix="/api/v2/vadmin", tags=["vadmin"])
app.include_router(chathistory_views, prefix="/api/v2/vadmin", tags=["chathistory"])
def mount_static_files(app: FastAPI):
......
......@@ -46,6 +46,7 @@ from dbgpt.util.tracer import SpanType, root_tracer
from dbgpt.app.apps.utils.filter import mydfafiter
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
......@@ -263,6 +264,7 @@ async def params_load(
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
doc_file: UploadFile = File(...),
auth: Auth = Depends(OpenAuth()),
):
logger.info(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
......@@ -287,11 +289,19 @@ async def params_load(
resp = await chat.prepare()
# Refresh messages
return Result.succ(get_hist_messages(conv_uid))
#ret_his_msg = get_hist_messages(conv_uid)
ret_his_msg = await get_hist_messages_datas(conv_uid, auth.db)
return Result.succ(ret_his_msg)
except Exception as e:
logger.error("excel load error!", e)
return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
async def get_hist_messages_datas(conv_uid: str, db: AsyncSession):
from dbgpt.serve.conversation.serve import Service as ConversationService
instance: ConversationService = ConversationService.get_instance(CFG.SYSTEM_APP)
result = await instance.get_history_messages2(db, {"conv_uid": conv_uid})
return result
def get_hist_messages(conv_uid: str):
from dbgpt.serve.conversation.serve import Service as ConversationService
......@@ -360,7 +370,7 @@ async def chat_completions(
}
# 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db)
# await mydfafiter.parse_from_db(auth.db)
# 先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
......
......@@ -99,7 +99,7 @@ async def chat_completions(
}
# 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db)
# await mydfafiter.parse_from_db(auth.db)
# 先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(request.user_input, "*")
......
......@@ -12,6 +12,9 @@ from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import MessageVo, ServeRequest, ServerResponse
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
......@@ -207,9 +210,13 @@ async def list_latest_conv(
response_model=Result[List[MessageVo]],
dependencies=[Depends(check_api_key)],
)
async def get_history_messages(con_uid: str, service: Service = Depends(get_service)):
async def get_history_messages(con_uid: str,
service: Service = Depends(get_service),
auth: Auth = Depends(OpenAuth())):
"""Get the history messages of a conversation"""
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid)))
#result = service.get_history_messages(ServeRequest(conv_uid=con_uid))
result = await service.get_history_messages2(auth.db, ServeRequest(conv_uid=con_uid))
return Result.succ(result)
def init_endpoints(system_app: SystemApp) -> None:
......
# Define your Pydantic schemas here
from typing import Any, Dict, Optional
from typing import Optional, List, Dict, Any
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
......@@ -149,6 +149,14 @@ class MessageVo(BaseModel):
],
)
extra: Optional[List[Dict[str, Any]]] = Field( # 修改为列表类型
default=None,
description="A field to store additional JSON data.",
examples=[
[{"type": 1, "file_name": "111.png"}],
],
)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""
return model_to_dict(self, **kwargs)
......@@ -17,6 +17,8 @@ from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import MessageVo, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
from sqlalchemy.ext.asyncio import AsyncSession
from dbgpt.app.apps.vadmin.keywordsviews import get_media_datas_all
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
......@@ -178,6 +180,19 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""
return self.dao.get_conv_by_page(request, page, page_size)
async def get_history_messages2(
self, db: AsyncSession, request: Union[ServeRequest, Dict[str, Any]]
) -> List[MessageVo]:
print(f"------------历史记录request-------------->:{request}")
conv: StorageConversation = self.create_storage_conv(request)
messages = _append_view_messages(conv.messages)
# 查找是否有相关图片 或者 视频
result = await get_media_datas_all(request.conv_uid, self.config.default_model, db, messages)
print(f"------------历史记录result-------------->:{result}")
return result
def get_history_messages(
self, request: Union[ServeRequest, Dict[str, Any]]
) -> List[MessageVo]:
......@@ -199,6 +214,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
context=msg.content,
order=msg.round_index,
model_name=self.config.default_model,
extra="----====---->",
)
)
return result
#!/bin/bash
echo Starting...
if [ ! -d datamining ] ; then
mkdir datamining
fi
if [ ! -d jfxt ] ; then
mkdir jfxt
fi
export LC_ALL=C
echo "`pwd`" >> ./pwd.txt
nohup ./BufferServer > /dev/null 2>&1 &
sleep 1
nohup ./LogicServer > /dev/null 2>&1 &
sleep 1
nohup ./GatewayServer > /dev/null 2>&1 &
#加载一下敏感词接口
#http://192.168.11.46:5670/api/v2/vadmin/load_parse_from_db
echo Done!
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment