Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
A
ask_data_ai_admin
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
linyangyang
ask_data_ai_admin
Commits
1c71562d
Commit
1c71562d
authored
Jul 18, 2025
by
林洋洋
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
代码优化
parent
717ac795
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
394 additions
and
124 deletions
+394
-124
AsyncConfig.java
...data-ai-biz/src/main/java/com/ask/config/AsyncConfig.java
+42
-30
CommonConfiguration.java
...biz/src/main/java/com/ask/config/CommonConfiguration.java
+1
-2
PostgresChatMemoryDialect.java
...c/main/java/com/ask/config/PostgresChatMemoryDialect.java
+1
-1
ChatController.java
...-biz/src/main/java/com/ask/controller/ChatController.java
+89
-64
AsyncVectorizationService.java
.../main/java/com/ask/service/AsyncVectorizationService.java
+2
-7
AsyncVectorizationServiceImpl.java
...a/com/ask/service/impl/AsyncVectorizationServiceImpl.java
+38
-20
ChatService.java
...i-biz/src/main/java/com/ask/service/impl/ChatService.java
+178
-0
RagPromptService.java
.../src/main/java/com/ask/service/impl/RagPromptService.java
+43
-0
No files found.
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/config/AsyncConfig.java
View file @
1c71562d
package
com
.
ask
.
config
;
import
lombok.extern.slf4j.Slf4j
;
import
org.springframework.beans.factory.annotation.Value
;
import
org.springframework.boot.autoconfigure.AutoConfiguration
;
import
org.springframework.context.annotation.Bean
;
import
org.springframework.context.annotation.Configuration
;
import
org.springframework.core.task.AsyncTaskExecutor
;
import
org.springframework.scheduling.annotation.AsyncConfigurer
;
import
org.springframework.scheduling.annotation.EnableAsync
;
import
org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor
;
import
org.springframework.web.servlet.config.annotation.AsyncSupportConfigurer
;
import
org.springframework.web.servlet.config.annotation.WebMvcConfigurer
;
import
java.util.Optional
;
import
java.util.concurrent.Executor
;
import
java.util.concurrent.ThreadPoolExecutor
;
...
...
@@ -16,10 +23,42 @@ import java.util.concurrent.ThreadPoolExecutor;
* @date 2024/12/20
*/
@Slf4j
@Configuration
@EnableAsync
public
class
AsyncConfig
{
@AutoConfiguration
public
class
AsyncConfig
implements
AsyncConfigurer
{
/**
* 获取当前机器的核数, 不一定准确 请根据实际场景 CPU密集 || IO 密集
*/
public
static
final
int
cpuNum
=
Runtime
.
getRuntime
().
availableProcessors
();
@Value
(
"${thread.pool.corePoolSize:}"
)
private
Optional
<
Integer
>
corePoolSize
;
@Value
(
"${thread.pool.maxPoolSize:}"
)
private
Optional
<
Integer
>
maxPoolSize
;
@Value
(
"${thread.pool.queueCapacity:}"
)
private
Optional
<
Integer
>
queueCapacity
;
@Value
(
"${thread.pool.awaitTerminationSeconds:}"
)
private
Optional
<
Integer
>
awaitTerminationSeconds
;
@Override
@Bean
public
Executor
getAsyncExecutor
()
{
ThreadPoolTaskExecutor
taskExecutor
=
new
ThreadPoolTaskExecutor
();
// 核心线程大小 默认区 CPU 数量
taskExecutor
.
setCorePoolSize
(
corePoolSize
.
orElse
(
cpuNum
));
// 最大线程大小 默认区 CPU * 2 数量
taskExecutor
.
setMaxPoolSize
(
maxPoolSize
.
orElse
(
cpuNum
*
2
));
// 队列最大容量
taskExecutor
.
setQueueCapacity
(
queueCapacity
.
orElse
(
500
));
taskExecutor
.
setRejectedExecutionHandler
(
new
ThreadPoolExecutor
.
CallerRunsPolicy
());
taskExecutor
.
setWaitForTasksToCompleteOnShutdown
(
true
);
taskExecutor
.
setAwaitTerminationSeconds
(
awaitTerminationSeconds
.
orElse
(
60
));
taskExecutor
.
setThreadNamePrefix
(
"ASK-Thread-"
);
taskExecutor
.
initialize
();
return
taskExecutor
;
}
/**
* 向量化专用线程池
*
...
...
@@ -61,31 +100,4 @@ public class AsyncConfig {
return
executor
;
}
// /**
// * 通用异步线程池
// *
// * @return 通用任务执行器
// */
// @Bean("taskExecutor")
// public Executor taskExecutor() {
// ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
//
// // 异步线程配置
// executor.setCorePoolSize(5);
// executor.setMaxPoolSize(5);
// executor.setQueueCapacity(99999);
// executor.setThreadNamePrefix("async-service-");
//
// executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
// executor.setKeepAliveSeconds(60);
// executor.setWaitForTasksToCompleteOnShutdown(true);
// executor.setAwaitTerminationSeconds(60);
//
// executor.initialize();
//
// log.info("通用线程池初始化完成:核心线程={}, 最大线程={}, 队列容量={}",
// executor.getCorePoolSize(), executor.getMaxPoolSize(), executor.getQueueCapacity());
//
// return executor;
// }
}
\ No newline at end of file
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/config/CommonConfiguration.java
View file @
1c71562d
...
...
@@ -31,7 +31,6 @@ public class CommonConfiguration {
.
jdbcTemplate
(
jdbcTemplate
)
.
dialect
(
postgresChatMemoryDialect
)
.
build
();
return
MessageWindowChatMemory
.
builder
()
.
chatMemoryRepository
(
chatMemoryRepository
)
.
maxMessages
(
5
)
...
...
@@ -39,7 +38,7 @@ public class CommonConfiguration {
}
@Bean
public
ChatClient
chatClient
(
OpenAiChatModel
model
,
ChatMemory
chatMemory
)
{
public
ChatClient
chatClient
(
OpenAiChatModel
model
)
{
return
ChatClient
.
builder
(
model
)
.
defaultAdvisors
()
.
build
();
...
...
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/config/PostgresChatMemoryDialect.java
View file @
1c71562d
...
...
@@ -7,7 +7,7 @@ public class PostgresChatMemoryDialect implements JdbcChatMemoryRepositoryDialec
@Override
public
String
getSelectMessagesSql
()
{
return
"SELECT content, type FROM ask_chat_conversation_detail WHERE conversation_id = ? ORDER BY \"timestamp\""
;
return
"SELECT content, type FROM ask_chat_conversation_detail WHERE conversation_id = ?
and del_flag = '0'
ORDER BY \"timestamp\""
;
}
@Override
public
String
getInsertMessageSql
()
{
...
...
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/controller/ChatController.java
View file @
1c71562d
package
com
.
ask
.
controller
;
import
cn.hutool.json.JSONArray
;
import
cn.hutool.json.JSONObject
;
import
com.ask.api.entity.ChatConversation
;
import
com.ask.common.core.R
;
import
com.ask.service.ChatConversationService
;
import
com.ask.service.impl.ChatService
;
import
com.ask.service.impl.RagPromptService
;
import
io.swagger.v3.oas.annotations.Operation
;
import
io.swagger.v3.oas.annotations.Parameter
;
import
io.swagger.v3.oas.annotations.tags.Tag
;
...
...
@@ -13,21 +17,20 @@ import org.springframework.ai.chat.client.ChatClient;
import
org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor
;
import
org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor
;
import
org.springframework.ai.chat.memory.ChatMemory
;
import
org.springframework.ai.chat.messages.Message
;
import
org.springframework.ai.chat.messages.SystemMessage
;
import
org.springframework.ai.chat.messages.UserMessage
;
import
org.springframework.ai.chat.model.ChatResponse
;
import
org.springframework.ai.chat.messages.*
;
import
org.springframework.ai.chat.prompt.Prompt
;
import
org.springframework.ai.document.Document
;
import
org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor
;
import
org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever
;
import
org.springframework.ai.vectorstore.SearchRequest
;
import
org.springframework.ai.vectorstore.VectorStore
;
import
org.springframework.ai.vectorstore.filter.Filter
;
import
org.springframework.ai.vectorstore.filter.FilterExpressionBuilder
;
import
org.springframework.web.bind.annotation.*
;
import
reactor.core.publisher.Flux
;
import
reactor.core.publisher.Mono
;
import
java.util.Arrays
;
import
java.util.Objects
;
import
java.util.UUID
;
import
java.util.*
;
@Slf4j
@RestController
...
...
@@ -38,55 +41,56 @@ public class ChatController {
private
final
ChatClient
chatClient
;
private
final
ChatConversationService
chatConversationService
;
private
final
ChatConversationService
chatConversationService
;
private
final
RetrievalAugmentationAdvisor
retrievalAugmentationAdvisor
;
private
final
RetrievalAugmentationAdvisor
retrievalAugmentationAdvisor
;
private
final
MessageChatMemoryAdvisor
messageChatMemoryAdvisor
;
private
final
MessageChatMemoryAdvisor
messageChatMemoryAdvisor
;
private
final
PromptChatMemoryAdvisor
promptChatMemoryAdvisor
;
private
final
VectorStore
vectorStore
;
private
final
ChatService
chatService
;
private
final
RagPromptService
ragPromptService
;
private
final
ChatMemory
chatMemory
;
private
final
PromptChatMemoryAdvisor
promptChatMemoryAdvisor
;
/**
* 获取会话ID
*
* @return 新的会话ID
*/
@Operation
(
summary
=
"创建对话"
,
description
=
"创建对话"
)
@Operation
(
summary
=
"创建对话"
,
description
=
"创建对话"
)
@GetMapping
(
"/create/client"
)
public
R
<
ChatConversation
>
getConversationId
(
@Parameter
(
description
=
"智能体ID"
)
@RequestParam
Integer
agentId
,
@Parameter
(
description
=
"用户ID"
)
@RequestParam
Long
userId
)
{
if
(
Objects
.
isNull
(
agentId
))
{
throw
new
RuntimeException
(
"userID不能为NULL!"
);
}
ChatConversation
chatConversation
=
new
ChatConversation
();
String
conversationId
=
UUID
.
randomUUID
().
toString
().
replaceAll
(
"-"
,
""
);
chatConversation
.
setConversationId
(
conversationId
);
chatConversation
.
setAgentId
(
agentId
);
chatConversation
.
setUserId
(
userId
);
chatConversationService
.
save
(
chatConversation
);
public
R
<
ChatConversation
>
getConversationId
(
@Parameter
(
description
=
"智能体ID"
)
@RequestParam
Integer
agentId
,
@Parameter
(
description
=
"用户ID"
)
@RequestParam
Long
userId
)
{
if
(
Objects
.
isNull
(
agentId
))
{
throw
new
RuntimeException
(
"userID不能为NULL!"
);
}
ChatConversation
chatConversation
=
new
ChatConversation
();
String
conversationId
=
UUID
.
randomUUID
().
toString
().
replaceAll
(
"-"
,
""
);
chatConversation
.
setConversationId
(
conversationId
);
chatConversation
.
setAgentId
(
agentId
);
chatConversation
.
setUserId
(
userId
);
chatConversationService
.
save
(
chatConversation
);
return
R
.
ok
(
chatConversation
);
}
/**
* 最基本的AI流式输出对话
* <p>
* * @param message
*
* * @param message
* @return
*/
@Operation
(
summary
=
"普通对话"
,
description
=
"普通对话"
)
@Operation
(
summary
=
"普通对话"
,
description
=
"普通对话"
)
@GetMapping
(
value
=
"/chat"
,
produces
=
"text/html;charset=utf-8"
)
public
Flux
<
String
>
chat
(
@Parameter
(
description
=
"对话内容"
)
@RequestParam
String
message
,
@Parameter
(
description
=
"会话ID"
)
@RequestParam
String
conversationId
)
{
public
Flux
<
String
>
chat
(
@Parameter
(
description
=
"对话内容"
)
@RequestParam
String
message
,
@Parameter
(
description
=
"会话ID"
)
@RequestParam
String
conversationId
)
{
// 创建系统消息,告诉大模型只返回工具名和参数
Message
systemMessage
=
new
SystemMessage
(
"你是一个AI客服助手。请严格按照以下格式回答每个问题:"
);
Message
systemMessage
=
new
SystemMessage
(
"你是一个AI客服助手。请严格按照以下格式回答每个问题:"
);
// 用户消息
String
question
=
"请严格按以下格式回答:\n"
+
"<think>\n"
+
"[你的逐步推理过程]\n"
+
"</think>\n"
+
"<answer>\n"
+
"[最终答案]\n"
+
"</answer>\n"
+
"推理过程不要设计`<think>` 和 `<answer>` \n"
+
"问题:"
+
message
+
"\n"
;
String
question
=
"请严格按以下格式回答:\n"
+
"<think>\n"
+
"[你的逐步推理过程]\n"
+
"</think>\n"
+
"<answer>\n"
+
"[最终答案]\n"
+
"</answer>\n"
+
"推理过程不要设计`<think>` 和 `<answer>` \n"
+
"问题:"
+
message
+
"\n"
;
Message
userMessage
=
new
UserMessage
(
question
);
// 创建提示,包含系统消息和用户消息
...
...
@@ -95,31 +99,52 @@ public class ChatController {
return
chatClient
.
prompt
(
prompt
).
advisors
(
messageChatMemoryAdvisor
).
advisors
(
a
->
a
.
param
(
ChatMemory
.
CONVERSATION_ID
,
conversationId
)).
stream
().
content
();
}
/**
* 知识库对话
*
* * @param message
* @return
*/
@Operation
(
summary
=
"知识库对话"
,
description
=
"知识库对话"
)
@GetMapping
(
value
=
"/rag/chat"
,
produces
=
"text/html;charset=utf-8"
)
public
Flux
<
String
>
ragChat
(
String
message
,
String
conversationId
)
{
// 创建系统消息,告诉大模型只返回工具名和参数
// Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:");
//
// Message userMessage = new UserMessage(message);
// // 创建提示,包含系统消息和用户消息
// Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// 使用修改后的提示获取响应
FilterExpressionBuilder
builder
=
new
FilterExpressionBuilder
();
Filter
.
Expression
filter
=
builder
.
eq
(
"isEnabled"
,
1
).
build
();
// List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder().query(message).filterExpression(filter).similarityThreshold(0.75).topK(5).build());
return
chatClient
.
prompt
(
message
)
.
advisors
(
a
->
a
.
param
(
ChatMemory
.
CONVERSATION_ID
,
conversationId
))
.
advisors
(
promptChatMemoryAdvisor
)
//会话记忆
.
advisors
(
a
->
a
.
param
(
VectorStoreDocumentRetriever
.
FILTER_EXPRESSION
,
filter
))
.
advisors
(
retrievalAugmentationAdvisor
)
//知识库召回
.
stream
().
content
();
}
/**
* 知识库对话
* <p>
* * @param message
*
* @return
*/
@Operation
(
summary
=
"知识库对话"
,
description
=
"知识库对话"
)
@GetMapping
(
value
=
"/rag/chat"
,
produces
=
"text/html;charset=utf-8"
)
public
Flux
<
String
>
ragChat
(
@RequestParam
@Parameter
(
description
=
"对话内容"
)
String
message
,
@RequestParam
@Parameter
(
description
=
"会话ID"
)
String
conversationId
)
{
//获取对话历史
String
historyMemory
=
chatService
.
getHistoryMemoryAsString
(
conversationId
);
//新增问题到对话记录
UserMessage
userMessage
=
new
UserMessage
(
message
);
chatService
.
saveHistoryMemory
(
conversationId
,
userMessage
);
//向量数据召回
FilterExpressionBuilder
builder
=
new
FilterExpressionBuilder
();
Filter
.
Expression
filter
=
builder
.
eq
(
"isEnabled"
,
1
).
build
();
List
<
Document
>
documents
=
chatService
.
retrieveDocuments
(
message
,
0.75
,
5
,
filter
);
//获取文件引用
String
reference
=
chatService
.
getReference
(
documents
);
//拼装知识库上下文内容
String
context
=
chatService
.
convertDocumentsToString
(
documents
);
//创建提示词
String
userPrompt
=
ragPromptService
.
createRagPrompt
(
message
,
context
,
historyMemory
);
StringBuilder
contentBuilder
=
new
StringBuilder
();
return
chatClient
.
prompt
()
.
user
(
userPrompt
)
.
system
(
"你是一个智能助手,基于以下上下文和历史对话回答问题,请用简洁的语言回答问题,并确保答案准确,要求"
+
"1.以 Markdown 格式输出"
+
"2.请务必将你的思考过程放在 <think></think> 标签内"
+
"3.请务必将生成最终答案放在 <answer></answer> 标签内"
)
.
stream
()
.
content
()
.
concatWith
(
Mono
.
just
(
reference
))
.
doOnNext
(
chunk
->
{
// 实时收集每个流片段
contentBuilder
.
append
(
chunk
);
})
.
doOnComplete
(()
->
{
// 流结束时获取完整内容
String
fullResponse
=
contentBuilder
.
toString
();
// 异步保存到数据库(添加错误处理)
chatService
.
saveHistoryMemory
(
conversationId
,
new
AssistantMessage
(
fullResponse
));
});
}
}
\ No newline at end of file
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/service/AsyncVectorizationService.java
View file @
1c71562d
package
com
.
ask
.
service
;
import
com.ask.api.entity.AskVectorStore
;
import
com.ask.api.entity.KnowledgeDocument
;
import
java.util.List
;
import
java.util.concurrent.CompletableFuture
;
/**
* 异步向量化服务接口
*
...
...
@@ -16,9 +12,8 @@ public interface AsyncVectorizationService {
/**
* 根据文档异步向量化处理
*
*
* @param document 文档对象
* @return 异步任务结果
*/
CompletableFuture
<
Integer
>
vectorizeByDocumentIdAsync
(
KnowledgeDocument
document
);
void
vectorizeByDocumentIdAsync
(
KnowledgeDocument
document
);
}
\ No newline at end of file
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/service/impl/AsyncVectorizationServiceImpl.java
View file @
1c71562d
...
...
@@ -5,17 +5,16 @@ import com.ask.api.entity.KnowledgeDocument;
import
com.ask.mapper.KnowledgeDocumentMapper
;
import
com.ask.service.AskVectorStoreService
;
import
com.ask.service.AsyncVectorizationService
;
import
com.ask.service.KnowledgeDocumentService
;
import
com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper
;
import
com.baomidou.mybatisplus.core.toolkit.Wrappers
;
import
lombok.RequiredArgsConstructor
;
import
lombok.extern.slf4j.Slf4j
;
import
org.springframework.beans.factory.annotation.Autowired
;
import
org.springframework.scheduling.annotation.Async
;
import
org.springframework.stereotype.Service
;
import
org.springframework.util.CollectionUtils
;
import
java.util.List
;
import
java.util.concurrent.CompletableFuture
;
import
java.util.concurrent.*
;
import
java.util.concurrent.atomic.AtomicInteger
;
/**
* 异步向量化服务实现类
...
...
@@ -25,26 +24,47 @@ import java.util.concurrent.CompletableFuture;
*/
@Slf4j
@Service
@RequiredArgsConstructor
public
class
AsyncVectorizationServiceImpl
implements
AsyncVectorizationService
{
@Autowired
private
AskVectorStoreService
askVectorStoreService
;
private
final
AskVectorStoreService
askVectorStoreService
;
@Autowired
private
KnowledgeDocumentMapper
knowledgeDocumentMapper
;
private
final
KnowledgeDocumentMapper
knowledgeDocumentMapper
;
private
final
ThreadPoolExecutor
processThreadPool
=
new
ThreadPoolExecutor
(
3
,
// 核心线程数
3
,
// 最大线程数
60L
,
// 空闲线程存活时间
TimeUnit
.
SECONDS
,
// 时间单位
new
LinkedBlockingQueue
<>(
100
),
// 任务队列
new
ThreadFactory
()
{
private
final
AtomicInteger
threadNumber
=
new
AtomicInteger
(
1
);
@Override
public
Thread
newThread
(
Runnable
r
)
{
Thread
t
=
new
Thread
(
r
,
"data-process-thread-"
+
threadNumber
.
getAndIncrement
());
t
.
setDaemon
(
true
);
return
t
;
}
},
new
ThreadPoolExecutor
.
CallerRunsPolicy
()
// 拒绝策略:由调用线程执行
);
/**
* 根据文档异步向量化处理
*
* @param document 文档对象
* @return 异步任务结果
*/
@Override
@Async
(
"vectorizationExecutor"
)
public
CompletableFuture
<
Integer
>
vectorizeByDocumentIdAsync
(
KnowledgeDocument
document
)
{
public
void
vectorizeByDocumentIdAsync
(
KnowledgeDocument
document
)
{
processThreadPool
.
submit
(()
->
processData
(
document
));
}
public
void
processData
(
KnowledgeDocument
document
)
{
if
(
document
==
null
)
{
log
.
warn
(
"异步向量化:文档对象为空"
);
return
CompletableFuture
.
completedFuture
(
0
)
;
return
;
}
try
{
...
...
@@ -52,16 +72,16 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
Thread
.
currentThread
().
getName
(),
document
.
getId
());
long
startTime
=
System
.
currentTimeMillis
();
// 更新文档状态为处理中
knowledgeDocumentMapper
.
update
(
Wrappers
.<
KnowledgeDocument
>
lambdaUpdate
()
.
eq
(
KnowledgeDocument:
:
getId
,
document
.
getId
())
.
set
(
KnowledgeDocument:
:
getStatus
,
1
));
// 查询该文档下所有未向量化的数据
LambdaQueryWrapper
<
AskVectorStore
>
wrapper
=
new
LambdaQueryWrapper
<
AskVectorStore
>()
.
apply
(
"metadata::jsonb ->> 'documentId' = {0}"
,
String
.
valueOf
(
document
.
getId
()));
;
// 假设embedding为null表示未向量化
;
// 假设embedding为null表示未向量化
List
<
AskVectorStore
>
vectorStores
=
askVectorStoreService
.
list
(
wrapper
);
...
...
@@ -71,7 +91,7 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
knowledgeDocumentMapper
.
update
(
Wrappers
.<
KnowledgeDocument
>
lambdaUpdate
()
.
eq
(
KnowledgeDocument:
:
getId
,
document
.
getId
())
.
set
(
KnowledgeDocument:
:
getStatus
,
3
));
return
CompletableFuture
.
completedFuture
(
0
)
;
return
;
}
// 批量处理向量化
...
...
@@ -81,13 +101,11 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
log
.
info
(
"根据文档异步向量化完成,线程:{},文档ID:{},耗时:{}ms,成功:{}/{}"
,
Thread
.
currentThread
().
getName
(),
document
.
getId
(),
(
endTime
-
startTime
),
successCount
,
vectorStores
.
size
());
// 更新文档状态为处理完成
knowledgeDocumentMapper
.
update
(
Wrappers
.<
KnowledgeDocument
>
lambdaUpdate
()
.
eq
(
KnowledgeDocument:
:
getId
,
document
.
getId
())
.
set
(
KnowledgeDocument:
:
getStatus
,
2
));
return
CompletableFuture
.
completedFuture
(
successCount
);
}
catch
(
Exception
e
)
{
// 更新文档状态为处理失败
...
...
@@ -96,7 +114,7 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
.
set
(
KnowledgeDocument:
:
getStatus
,
3
));
log
.
error
(
"根据文档异步向量化失败,线程:{},文档ID:{},错误:{}"
,
Thread
.
currentThread
().
getName
(),
document
.
getId
(),
e
.
getMessage
(),
e
);
return
CompletableFuture
.
completedFuture
(
0
);
}
}
}
\ No newline at end of file
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/service/impl/ChatService.java
0 → 100644
View file @
1c71562d
package
com
.
ask
.
service
.
impl
;
import
cn.hutool.json.JSONArray
;
import
cn.hutool.json.JSONObject
;
import
lombok.RequiredArgsConstructor
;
import
lombok.extern.slf4j.Slf4j
;
import
org.apache.commons.lang3.StringUtils
;
import
org.springframework.ai.chat.memory.ChatMemory
;
import
org.springframework.ai.chat.messages.AssistantMessage
;
import
org.springframework.ai.chat.messages.Message
;
import
org.springframework.ai.document.Document
;
import
org.springframework.ai.vectorstore.SearchRequest
;
import
org.springframework.ai.vectorstore.VectorStore
;
import
org.springframework.ai.vectorstore.filter.Filter
;
import
org.springframework.stereotype.Service
;
import
org.springframework.util.CollectionUtils
;
import
java.util.Arrays
;
import
java.util.List
;
import
java.util.Set
;
import
java.util.concurrent.ConcurrentHashMap
;
import
java.util.function.Function
;
import
java.util.function.Predicate
;
import
java.util.regex.Matcher
;
import
java.util.regex.Pattern
;
import
java.util.stream.Collectors
;
import
static
org
.
springframework
.
ai
.
chat
.
messages
.
MessageType
.
ASSISTANT
;
@Slf4j
@RequiredArgsConstructor
@Service
public
class
ChatService
{
private
final
VectorStore
vectorStore
;
private
final
ChatMemory
chatMemory
;
/**
* rag召回
*
* @param query 问题
* @param threshold 相似度
* @param topK 召回数量
* @param filter 过滤
* @return 召回文档
*/
public
List
<
Document
>
retrieveDocuments
(
String
query
,
double
threshold
,
int
topK
,
Filter
.
Expression
filter
)
{
List
<
Document
>
documents
=
vectorStore
.
similaritySearch
(
SearchRequest
.
builder
()
.
query
(
query
)
.
filterExpression
(
filter
)
.
similarityThreshold
(
threshold
)
.
topK
(
topK
)
.
build
());
if
(
CollectionUtils
.
isEmpty
(
documents
))
{
return
List
.
of
();
}
// 根据ID去重
return
documents
.
stream
()
.
filter
(
distinctByKey
(
Document:
:
getId
))
// 根据ID去重
.
collect
(
Collectors
.
toList
());
}
/**
* 根据指定字段去重
*
* @param keyExtractor 提取字段的函数
* @param <T> 元素类型
* @return 去重的Predicate
*/
private
static
<
T
>
Predicate
<
T
>
distinctByKey
(
Function
<?
super
T
,
?>
keyExtractor
)
{
Set
<
Object
>
seen
=
ConcurrentHashMap
.
newKeySet
();
return
t
->
seen
.
add
(
keyExtractor
.
apply
(
t
));
}
/**
* 将召回的数据转换为String
*
* @param documents 召回的数据
* @return 拼接后的字符串
*/
public
String
convertDocumentsToString
(
List
<
Document
>
documents
)
{
return
documents
.
stream
()
.
map
(
Document:
:
getText
)
// 提取每个Document的内容
.
collect
(
Collectors
.
joining
(
"\n"
));
// 用换行符拼接
}
/**
* 保存会话记忆
*
* @param conversationId
* @param message
*/
public
void
saveHistoryMemory
(
String
conversationId
,
Message
message
)
{
chatMemory
.
add
(
conversationId
,
message
);
}
/**
* 获取历史对话并组装成字符串
*
* @param conversationId 会话ID
* @return 历史对话的字符串表示
*/
public
String
getHistoryMemoryAsString
(
String
conversationId
)
{
// 获取历史对话列表
List
<
Message
>
history
=
chatMemory
.
get
(
conversationId
);
System
.
out
.
println
(
history
.
size
());
// 将历史对话组装成字符串
return
history
.
stream
()
.
map
(
this
::
formatMessage
)
// 格式化每条消息
.
collect
(
Collectors
.
joining
(
"\n"
));
// 用换行符拼接
}
/**
* 格式化消息
*
* @param message 消息
* @return 格式化后的字符串
*/
private
String
formatMessage
(
Message
message
)
{
// 根据消息类型格式化
return
switch
(
message
.
getMessageType
())
{
case
USER
->
"用户: "
+
message
.
getText
();
case
ASSISTANT
->
"助手: "
+
extractAnswerContent
(
message
);
case
SYSTEM
->
"系统: "
+
message
.
getText
();
default
->
"未知: "
+
message
.
getText
();
};
}
public
String
getReference
(
List
<
Document
>
documents
)
{
// 使用 Stream 处理 documents,根据 filePath 去重
JSONArray
jsonArray
=
documents
.
stream
()
.
filter
(
document
->
{
document
.
getMetadata
();
return
true
;
})
// 过滤掉 metadata 为空的文档
.
collect
(
Collectors
.
toMap
(
document
->
document
.
getMetadata
().
get
(
"filePath"
),
// 以 filePath 作为键
document
->
{
JSONObject
jsonObject
=
new
JSONObject
();
jsonObject
.
set
(
"fileName"
,
document
.
getMetadata
().
get
(
"fileName"
));
jsonObject
.
set
(
"filePath"
,
document
.
getMetadata
().
get
(
"filePath"
));
return
jsonObject
;
},
(
existing
,
replacement
)
->
existing
// 如果 filePath 重复,保留已有的 JSONObject
))
.
values
()
// 获取去重后的 JSONObject
.
stream
()
.
collect
(
Collectors
.
toCollection
(
JSONArray:
:
new
));
// 转换为 JSONArray
// 构建最终结果
JSONObject
result
=
new
JSONObject
();
result
.
put
(
"data"
,
jsonArray
);
// 返回格式化后的引用信息
return
"\n<reference>"
+
"\n"
+
result
.
toJSONString
(
2
)
+
"\n"
+
"</reference>"
;
}
/**
* 从消息文本中提取 <answer> 标签内的内容
*
* @param message 消息对象
* @return <answer> 标签内的内容,如果没有找到则返回 null
*/
public
static
String
extractAnswerContent
(
Message
message
)
{
String
text
=
message
.
getText
();
Pattern
pattern
=
Pattern
.
compile
(
"<answer>(.*?)</answer>"
,
Pattern
.
DOTALL
);
Matcher
matcher
=
pattern
.
matcher
(
text
);
if
(
matcher
.
find
())
{
return
matcher
.
group
(
1
);
// 返回第一个捕获组,即 <answer> 标签内的内容
}
return
message
.
getText
();
// 如果没有找到,返回 null
}
}
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/service/impl/RagPromptService.java
0 → 100644
View file @
1c71562d
package
com
.
ask
.
service
.
impl
;
import
org.springframework.ai.chat.prompt.Prompt
;
import
org.springframework.ai.chat.prompt.PromptTemplate
;
import
org.springframework.stereotype.Service
;
import
java.util.Map
;
@Service
public
class
RagPromptService
{
/**
* 生成提示词
* @param query 用户问题
* @param context RAG 召回内容
* @param history 历史对话
* @return 生成的提示词
*/
public
String
createRagPrompt
(
String
query
,
String
context
,
String
history
)
{
// 定义系统消息和用户消息
String
systemMessage
=
"系统消息:请基于上下文和历史对话回答问题。"
;
String
userMessage
=
"用户问题:"
+
query
;
// 定义提示词模板
String
promptTemplate
=
"""
{systemMessage}
历史对话:
{history}
上下文:
{context}
问题:
{query}
"""
;
// 填充模板参数
return
promptTemplate
.
replace
(
"{systemMessage}"
,
systemMessage
)
.
replace
(
"{query}"
,
query
)
.
replace
(
"{history}"
,
history
)
.
replace
(
"{context}"
,
context
);
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment