Commit 4325e913 authored by 于飞's avatar 于飞

修改铭感词匹配规则

parent 8e040ad5
...@@ -123,6 +123,8 @@ class DFAFilter(): ...@@ -123,6 +123,8 @@ class DFAFilter():
for keyword in datas: for keyword in datas:
self.add(keyword.word_name) self.add(keyword.word_name)
#最长匹配模式,确保敏感词过滤器优先匹配和替换较长的敏感词
def filter(self, message, repl="*"): def filter(self, message, repl="*"):
is_sensitive = False is_sensitive = False
if not isinstance(message, str): if not isinstance(message, str):
...@@ -130,30 +132,37 @@ class DFAFilter(): ...@@ -130,30 +132,37 @@ class DFAFilter():
message = message.lower() message = message.lower()
ret = [] ret = []
start = 0 start = 0
matched_sensitives = [] # List to store matched sensitive words matched_sensitives = [] # 用来存储匹配的敏感词
while start < len(message): while start < len(message):
level = self.keyword_chains level = self.keyword_chains
longest_match_len = 0 # 记录最长匹配长度
longest_match_word = None # 记录最长匹配的敏感词
step_ins = 0 step_ins = 0
for char in message[start:]:
for i, char in enumerate(message[start:], start=1):
if char in level: if char in level:
step_ins += 1
if self.delimit not in level[char]:
level = level[char] level = level[char]
step_ins += 1
if self.delimit in level:
# 找到一个完整的敏感词
longest_match_len = step_ins
longest_match_word = message[start:start + step_ins]
else: else:
# Store the matched sensitive word
matched_sensitives.append(message[start:start + step_ins])
ret.append(repl * step_ins)
start += step_ins - 1
is_sensitive = True
break
else:
ret.append(message[start])
break break
if longest_match_len > 0:
# 进行最长匹配替换
matched_sensitives.append(longest_match_word)
ret.append(repl * longest_match_len)
start += longest_match_len
is_sensitive = True
else: else:
# 无匹配,直接保留原字符
ret.append(message[start]) ret.append(message[start])
start += 1 start += 1
# return 返回三个参数 # 返回三个参数
return ''.join(ret), is_sensitive, matched_sensitives return ''.join(ret), is_sensitive, matched_sensitives
#初始化全局对象 #初始化全局对象
......
...@@ -35,6 +35,15 @@ from sqlalchemy import BinaryExpression ...@@ -35,6 +35,15 @@ from sqlalchemy import BinaryExpression
router = APIRouter() router = APIRouter()
#提问次数
question_count = 0
#关键词匹配到的类型
MEDIA_TYPE1 = 1 #图片
MEDIA_TYPE2 = 2 #视频
MEDIA_TYPE3 = 3 #铭感词
MEDIA_TYPE4 = 4 #问答对
MEDIA_TYPE5 = 5 #统计次数3->留下联系电话
def get_key_words(user_input: str) -> list: def get_key_words(user_input: str) -> list:
""" """
...@@ -97,7 +106,7 @@ async def get_media_datas_by(conv_uid: str, words: str, db: AsyncSession, knownl ...@@ -97,7 +106,7 @@ async def get_media_datas_by(conv_uid: str, words: str, db: AsyncSession, knownl
print(f"-----查询到的图片为:---->:{images_datas}") print(f"-----查询到的图片为:---->:{images_datas}")
for data in images_datas: for data in images_datas:
json_image = {'type': 1, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'), json_image = {'type': MEDIA_TYPE1, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'),
'local_path': data.get('local_path'), 'remote_path': data.get('remote_path')} 'local_path': data.get('local_path'), 'remote_path': data.get('remote_path')}
result.append(json_image) result.append(json_image)
...@@ -113,7 +122,7 @@ async def get_media_datas_by(conv_uid: str, words: str, db: AsyncSession, knownl ...@@ -113,7 +122,7 @@ async def get_media_datas_by(conv_uid: str, words: str, db: AsyncSession, knownl
video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True) video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
print(f"-----查询到的视频为:---->:{video_datas}") print(f"-----查询到的视频为:---->:{video_datas}")
for videodata in video_datas: for videodata in video_datas:
json_video = {'type': 2, 'file_name': videodata.get('file_name'), json_video = {'type': MEDIA_TYPE2, 'file_name': videodata.get('file_name'),
'key_word': videodata.get('key_word'), 'key_word': videodata.get('key_word'),
'local_path': videodata.get('local_path'), 'local_path': videodata.get('local_path'),
'remote_path': videodata.get('remote_path')} 'remote_path': videodata.get('remote_path')}
...@@ -130,7 +139,7 @@ async def get_media_datas_by(conv_uid: str, words: str, db: AsyncSession, knownl ...@@ -130,7 +139,7 @@ async def get_media_datas_by(conv_uid: str, words: str, db: AsyncSession, knownl
question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True) question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
print(f"-----查询到的问答对为:---->:{question_datas}") print(f"-----查询到的问答对为:---->:{question_datas}")
for questiondata in question_datas: for questiondata in question_datas:
json_question = {'type': 4, 'title': questiondata.get('title'), json_question = {'type': MEDIA_TYPE4, 'title': questiondata.get('title'),
'key_word': questiondata.get('key_word'), 'key_word': questiondata.get('key_word'),
'answer': questiondata.get('answer')} 'answer': questiondata.get('answer')}
result.append(json_question) result.append(json_question)
...@@ -161,7 +170,7 @@ async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list: ...@@ -161,7 +170,7 @@ async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list:
print(f"-----查询到的图片为:---->:{images_datas}") print(f"-----查询到的图片为:---->:{images_datas}")
result = [] result = []
for data in images_datas: for data in images_datas:
json_image = {'type': 1, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'), json_image = {'type': MEDIA_TYPE1, 'file_name': data.get('file_name'), 'key_word': data.get('key_word'),
'local_path': data.get('local_path'), 'remote_path': data.get('remote_path')} 'local_path': data.get('local_path'), 'remote_path': data.get('remote_path')}
result.append(json_image) result.append(json_image)
...@@ -172,7 +181,7 @@ async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list: ...@@ -172,7 +181,7 @@ async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list:
video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True) video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
print(f"-----查询到的视频为:---->:{video_datas}") print(f"-----查询到的视频为:---->:{video_datas}")
for videodata in video_datas: for videodata in video_datas:
json_video = {'type': 2, 'file_name': videodata.get('file_name'), 'key_word': videodata.get('key_word'), json_video = {'type': MEDIA_TYPE2, 'file_name': videodata.get('file_name'), 'key_word': videodata.get('key_word'),
'local_path': videodata.get('local_path'), 'remote_path': videodata.get('remote_path')} 'local_path': videodata.get('local_path'), 'remote_path': videodata.get('remote_path')}
result.append(json_video) result.append(json_video)
...@@ -181,7 +190,7 @@ async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list: ...@@ -181,7 +190,7 @@ async def get_media_datas(conv_uid: str, words: str, db: AsyncSession) -> list:
question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True) question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
print(f"-----查询到的问答对为:---->:{question_datas}") print(f"-----查询到的问答对为:---->:{question_datas}")
for questiondata in question_datas: for questiondata in question_datas:
json_question = {'type': 4, 'title': questiondata.get('title'), 'key_word': questiondata.get('key_word'), json_question = {'type': MEDIA_TYPE4, 'title': questiondata.get('title'), 'key_word': questiondata.get('key_word'),
'answer': questiondata.get('answer')} 'answer': questiondata.get('answer')}
result.append(json_question) result.append(json_question)
...@@ -208,7 +217,7 @@ async def get_media_datas_all(conv_uid: str, default_model: str, db: AsyncSessio ...@@ -208,7 +217,7 @@ async def get_media_datas_all(conv_uid: str, default_model: str, db: AsyncSessio
'message_medias': ('like', None)} 'message_medias': ('like', None)}
datas, count = await ChatHistoryDal(db).get_datas(**history_dic, v_return_count=True) datas, count = await ChatHistoryDal(db).get_datas(**history_dic, v_return_count=True)
json_data = [{"type": 1, "file_name": "723629348SJfHgjzD.png"}] # 传入列表 json_data = [{"type": MEDIA_TYPE1, "file_name": "723629348SJfHgjzD.png"}] # 传入列表
if count > 0: if count > 0:
history_datas = datas[0].get('message_medias') history_datas = datas[0].get('message_medias')
json_data = json.loads(history_datas) json_data = json.loads(history_datas)
...@@ -234,6 +243,19 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep ...@@ -234,6 +243,19 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep
print(f"用户输入的问题:{dialogue.user_input} -- 选择的知识库为:{dialogue.select_param}") print(f"用户输入的问题:{dialogue.user_input} -- 选择的知识库为:{dialogue.select_param}")
print('----------------begin---------------->') print('----------------begin---------------->')
#统计提问次数
"""
global question_count
question_count += 1
print(f"=====>question_count:{question_count}")
if question_count >= 3:
print('=====触发留下联系方式=====')
result = {'code': 200, 'message': 'success',
'data': [{'type': MEDIA_TYPE5, 'answer': "请留下您的联系方式,后续给您安排技术人员给您详细讲解一下:"}]}
question_count = 0
return SuccessResponse(result) # 返回type=5
"""
#先判断敏感词 #先判断敏感词
dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*") dfa_result, is_sensitive, matched_sensitives = mydfafiter.filter(dialogue.user_input, "*")
print(dfa_result) print(dfa_result)
...@@ -241,7 +263,7 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep ...@@ -241,7 +263,7 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep
if is_sensitive: if is_sensitive:
print('用户输入有敏感词') print('用户输入有敏感词')
result = {'code': 200, 'message': 'success', result = {'code': 200, 'message': 'success',
'data': [{'type': 3, 'word_name': matched_sensitives, 'is_sensitive': 1, 'user_input': dfa_result}]} 'data': [{'type': MEDIA_TYPE3, 'word_name': matched_sensitives, 'is_sensitive': 1, 'user_input': dfa_result}]}
return SuccessResponse(result) #返回type=3 return SuccessResponse(result) #返回type=3
#没有敏感词的时候,查找是否有相关图片 或者 视频 #没有敏感词的时候,查找是否有相关图片 或者 视频
......
...@@ -544,6 +544,7 @@ async def stream_generator(chat, incremental: bool, model_name: str): ...@@ -544,6 +544,7 @@ async def stream_generator(chat, incremental: bool, model_name: str):
json_chunk = model_to_json( json_chunk = model_to_json(
chunk, exclude_unset=True, ensure_ascii=False chunk, exclude_unset=True, ensure_ascii=False
) )
#把\\n --> 转换为 \n
json_chunk = json_chunk.replace("\\n", "n") json_chunk = json_chunk.replace("\\n", "n")
#print(f"===>:{json_chunk}") #print(f"===>:{json_chunk}")
yield f"data: {json_chunk}\n\n" yield f"data: {json_chunk}\n\n"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment