Commit 12844067 authored by 于飞's avatar 于飞

role权限绑定知识库

parent 29d47f96
...@@ -31,6 +31,9 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep ...@@ -31,6 +31,9 @@ async def get_spacy_keywords(dialogue: ConversationVo = Body(), auth: Auth = Dep
print(f"用户输入的问题:{dialogue.user_input} ") print(f"用户输入的问题:{dialogue.user_input} ")
print('----------------begin---------------->') print('----------------begin---------------->')
# 从数据库中加载 并且初始化敏感词-->到内存中
await mydfafiter.parse_from_db(auth.db)
#先判断敏感词 #先判断敏感词
dfa_result, is_sensitive = mydfafiter.filter(dialogue.user_input, "*") dfa_result, is_sensitive = mydfafiter.filter(dialogue.user_input, "*")
print(dfa_result) print(dfa_result)
......
...@@ -46,6 +46,11 @@ from dbgpt.storage.vector_store.base import VectorStoreConfig ...@@ -46,6 +46,11 @@ from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import SpanType, root_tracer from dbgpt.util.tracer import SpanType, root_tracer
from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth
from dbgpt.app.apps.vadmin.auth import schemas, crud, models
from sqlalchemy.orm import joinedload, aliased
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
...@@ -76,10 +81,30 @@ def space_add(request: KnowledgeSpaceRequest): ...@@ -76,10 +81,30 @@ def space_add(request: KnowledgeSpaceRequest):
@router.post("/knowledge/space/list") @router.post("/knowledge/space/list")
def space_list(request: KnowledgeSpaceRequest): async def space_list(request: KnowledgeSpaceRequest, auth: Auth = Depends(FullAdminAuth())):
print(f"/space/list params:") print(f"/space/list params:{request}")
try: try:
return Result.succ(knowledge_space_service.get_knowledge_space(request)) #超级管理员返回全部知识库
responses = knowledge_space_service.get_knowledge_space(request)
role_id = 0
# 角色的id
print(f"------>user_id:{auth.user.id}")
for vRole in auth.user.roles:
print(f"----role->:{vRole.id}:{vRole.name}")
#只获取第一个角色
role_id = vRole.id
break
# 非超级管理员,过滤掉部分知识库
if role_id != 1:
# 根据角色绑定的知识库id, 过滤掉其他知识库
role_knowledge_list = await crud.RoleDal(auth.db).get_role_knowledge_list(role_id)
for know_id in role_knowledge_list:
print(f"----->knowledge space know_id:{know_id}:")
responses = [element for element in responses if element.id == know_id]
return Result.succ(responses)
except Exception as e: except Exception as e:
return Result.failed(code="E000X", msg=f"space list error {e}") return Result.failed(code="E000X", msg=f"space list error {e}")
......
...@@ -121,7 +121,7 @@ class KnowledgeService: ...@@ -121,7 +121,7 @@ class KnowledgeService:
- request: KnowledgeSpaceRequest - request: KnowledgeSpaceRequest
""" """
query = KnowledgeSpaceEntity( query = KnowledgeSpaceEntity(
name=request.name, vector_type=request.vector_type, owner=request.owner id=request.id, name=request.name, vector_type=request.vector_type, owner=request.owner
) )
spaces = knowledge_space_dao.get_knowledge_space(query) spaces = knowledge_space_dao.get_knowledge_space(query)
space_names = [space.name for space in spaces] space_names = [space.name for space in spaces]
......
...@@ -156,6 +156,11 @@ class DefaultModelWorker(ModelWorker): ...@@ -156,6 +156,11 @@ class DefaultModelWorker(ModelWorker):
last_metrics = ModelInferenceMetrics.create_metrics() last_metrics = ModelInferenceMetrics.create_metrics()
is_first_generate = True is_first_generate = True
#----------------------------------------------------
#拼接图片和视频等资源
#datas, count = await crud.MediaDal(auth.db).get_datas(**mydic, v_return_count=True)
#yield ModelOutput(text='----AAAA---->', error_code=0)
context_len = params.get("context_len") or self.context_len context_len = params.get("context_len") or self.context_len
for output in generate_stream_func( for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), context_len self.model, self.tokenizer, params, get_device(), context_len
...@@ -178,8 +183,10 @@ class DefaultModelWorker(ModelWorker): ...@@ -178,8 +183,10 @@ class DefaultModelWorker(ModelWorker):
last_metrics = current_metrics last_metrics = current_metrics
yield model_output yield model_output
print( print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}" f"\n\ngenerate_stream full stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
) )
print(f"----generate_stream-------params:{params}")
model_span.end(metadata={"output": previous_response}) model_span.end(metadata={"output": previous_response})
span.end() span.end()
except Exception as e: except Exception as e:
...@@ -245,6 +252,8 @@ class DefaultModelWorker(ModelWorker): ...@@ -245,6 +252,8 @@ class DefaultModelWorker(ModelWorker):
last_metrics = ModelInferenceMetrics.create_metrics() last_metrics = ModelInferenceMetrics.create_metrics()
is_first_generate = True is_first_generate = True
async for output in generate_stream_func( async for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), context_len self.model, self.tokenizer, params, get_device(), context_len
): ):
......
...@@ -97,6 +97,11 @@ class TongyiLLMClient(ProxyLLMClient): ...@@ -97,6 +97,11 @@ class TongyiLLMClient(ProxyLLMClient):
stream=True, stream=True,
result_format="message", result_format="message",
) )
#-------------------------------------------------
#流式的 在dbgpt/model/worker/default_worker.py中统一调用
#yield ModelOutput(text='----AAAA---->', error_code=0)
for r in res: for r in res:
if r: if r:
if r["status_code"] == 200: if r["status_code"] == 200:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment