Commit 1c71562d authored by 林洋洋's avatar 林洋洋

代码优化

parent 717ac795
package com.ask.config;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.scheduling.annotation.AsyncConfigurer;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.servlet.config.annotation.AsyncSupportConfigurer;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;
......@@ -16,10 +23,42 @@ import java.util.concurrent.ThreadPoolExecutor;
* @date 2024/12/20
*/
@Slf4j
@Configuration
@EnableAsync
public class AsyncConfig {
@AutoConfiguration
public class AsyncConfig implements AsyncConfigurer {
/**
* 获取当前机器的核数, 不一定准确 请根据实际场景 CPU密集 || IO 密集
*/
public static final int cpuNum = Runtime.getRuntime().availableProcessors();
@Value("${thread.pool.corePoolSize:}")
private Optional<Integer> corePoolSize;
@Value("${thread.pool.maxPoolSize:}")
private Optional<Integer> maxPoolSize;
@Value("${thread.pool.queueCapacity:}")
private Optional<Integer> queueCapacity;
@Value("${thread.pool.awaitTerminationSeconds:}")
private Optional<Integer> awaitTerminationSeconds;
@Override
@Bean
public Executor getAsyncExecutor() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
// 核心线程大小 默认区 CPU 数量
taskExecutor.setCorePoolSize(corePoolSize.orElse(cpuNum));
// 最大线程大小 默认区 CPU * 2 数量
taskExecutor.setMaxPoolSize(maxPoolSize.orElse(cpuNum * 2));
// 队列最大容量
taskExecutor.setQueueCapacity(queueCapacity.orElse(500));
taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
taskExecutor.setWaitForTasksToCompleteOnShutdown(true);
taskExecutor.setAwaitTerminationSeconds(awaitTerminationSeconds.orElse(60));
taskExecutor.setThreadNamePrefix("ASK-Thread-");
taskExecutor.initialize();
return taskExecutor;
}
/**
* 向量化专用线程池
*
......@@ -61,31 +100,4 @@ public class AsyncConfig {
return executor;
}
// /**
// * 通用异步线程池
// *
// * @return 通用任务执行器
// */
// @Bean("taskExecutor")
// public Executor taskExecutor() {
// ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
//
// // 异步线程配置
// executor.setCorePoolSize(5);
// executor.setMaxPoolSize(5);
// executor.setQueueCapacity(99999);
// executor.setThreadNamePrefix("async-service-");
//
// executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
// executor.setKeepAliveSeconds(60);
// executor.setWaitForTasksToCompleteOnShutdown(true);
// executor.setAwaitTerminationSeconds(60);
//
// executor.initialize();
//
// log.info("通用线程池初始化完成:核心线程={}, 最大线程={}, 队列容量={}",
// executor.getCorePoolSize(), executor.getMaxPoolSize(), executor.getQueueCapacity());
//
// return executor;
// }
}
\ No newline at end of file
......@@ -31,7 +31,6 @@ public class CommonConfiguration {
.jdbcTemplate(jdbcTemplate)
.dialect(postgresChatMemoryDialect)
.build();
return MessageWindowChatMemory.builder()
.chatMemoryRepository(chatMemoryRepository)
.maxMessages(5)
......@@ -39,7 +38,7 @@ public class CommonConfiguration {
}
@Bean
public ChatClient chatClient(OpenAiChatModel model, ChatMemory chatMemory) {
public ChatClient chatClient(OpenAiChatModel model) {
return ChatClient.builder(model)
.defaultAdvisors()
.build();
......
......@@ -7,7 +7,7 @@ public class PostgresChatMemoryDialect implements JdbcChatMemoryRepositoryDialec
@Override
public String getSelectMessagesSql() {
return "SELECT content, type FROM ask_chat_conversation_detail WHERE conversation_id = ? ORDER BY \"timestamp\"";
return "SELECT content, type FROM ask_chat_conversation_detail WHERE conversation_id = ? and del_flag = '0' ORDER BY \"timestamp\"";
}
@Override
public String getInsertMessageSql() {
......
package com.ask.controller;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import com.ask.api.entity.ChatConversation;
import com.ask.common.core.R;
import com.ask.service.ChatConversationService;
import com.ask.service.impl.ChatService;
import com.ask.service.impl.RagPromptService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
......@@ -13,21 +17,20 @@ import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.Arrays;
import java.util.Objects;
import java.util.UUID;
import java.util.*;
@Slf4j
@RestController
......@@ -38,55 +41,56 @@ public class ChatController {
private final ChatClient chatClient;
private final ChatConversationService chatConversationService;
private final ChatConversationService chatConversationService;
private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;
private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;
private final MessageChatMemoryAdvisor messageChatMemoryAdvisor;
private final MessageChatMemoryAdvisor messageChatMemoryAdvisor;
private final PromptChatMemoryAdvisor promptChatMemoryAdvisor;
private final VectorStore vectorStore;
private final ChatService chatService;
private final RagPromptService ragPromptService;
private final ChatMemory chatMemory;
private final PromptChatMemoryAdvisor promptChatMemoryAdvisor;
/**
* 获取会话ID
*
* @return 新的会话ID
*/
@Operation(summary = "创建对话", description = "创建对话")
@Operation(summary = "创建对话", description = "创建对话")
@GetMapping("/create/client")
public R<ChatConversation> getConversationId(@Parameter(description = "智能体ID") @RequestParam Integer agentId,
@Parameter(description = "用户ID") @RequestParam Long userId) {
if(Objects.isNull(agentId)) {
throw new RuntimeException("userID不能为NULL!");
}
ChatConversation chatConversation =new ChatConversation();
String conversationId= UUID.randomUUID().toString().replaceAll("-","");
chatConversation.setConversationId(conversationId);
chatConversation.setAgentId(agentId);
chatConversation.setUserId(userId);
chatConversationService.save(chatConversation);
public R<ChatConversation> getConversationId(@Parameter(description = "智能体ID") @RequestParam Integer agentId, @Parameter(description = "用户ID") @RequestParam Long userId) {
if (Objects.isNull(agentId)) {
throw new RuntimeException("userID不能为NULL!");
}
ChatConversation chatConversation = new ChatConversation();
String conversationId = UUID.randomUUID().toString().replaceAll("-", "");
chatConversation.setConversationId(conversationId);
chatConversation.setAgentId(agentId);
chatConversation.setUserId(userId);
chatConversationService.save(chatConversation);
return R.ok(chatConversation);
}
/**
* 最基本的AI流式输出对话
* <p>
* * @param message
*
* * @param message
* @return
*/
@Operation(summary = "普通对话", description = "普通对话")
@Operation(summary = "普通对话", description = "普通对话")
@GetMapping(value = "/chat", produces = "text/html;charset=utf-8")
public Flux<String> chat(@Parameter(description = "对话内容") @RequestParam String message,
@Parameter(description = "会话ID") @RequestParam String conversationId) {
public Flux<String> chat(@Parameter(description = "对话内容") @RequestParam String message, @Parameter(description = "会话ID") @RequestParam String conversationId) {
// 创建系统消息,告诉大模型只返回工具名和参数
Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:");
Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:");
// 用户消息
String question = "请严格按以下格式回答:\n" +
"<think>\n" +
"[你的逐步推理过程]\n" +
"</think>\n" +
"<answer>\n" +
"[最终答案]\n" +
"</answer>\n" +
"推理过程不要设计`<think>` 和 `<answer>` \n" +
"问题:"+message+"\n" ;
String question = "请严格按以下格式回答:\n" + "<think>\n" + "[你的逐步推理过程]\n" + "</think>\n" + "<answer>\n" + "[最终答案]\n" + "</answer>\n" + "推理过程不要设计`<think>` 和 `<answer>` \n" + "问题:" + message + "\n";
Message userMessage = new UserMessage(question);
// 创建提示,包含系统消息和用户消息
......@@ -95,31 +99,52 @@ public class ChatController {
return chatClient.prompt(prompt).advisors(messageChatMemoryAdvisor).advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)).stream().content();
}
/**
* 知识库对话
*
* * @param message
* @return
*/
@Operation(summary = "知识库对话", description = "知识库对话")
@GetMapping(value = "/rag/chat", produces = "text/html;charset=utf-8")
public Flux<String> ragChat(String message, String conversationId) {
// 创建系统消息,告诉大模型只返回工具名和参数
// Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:");
//
// Message userMessage = new UserMessage(message);
// // 创建提示,包含系统消息和用户消息
// Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// 使用修改后的提示获取响应
FilterExpressionBuilder builder = new FilterExpressionBuilder();
Filter.Expression filter = builder.eq("isEnabled",1).build();
// List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder().query(message).filterExpression(filter).similarityThreshold(0.75).topK(5).build());
return chatClient.prompt(message)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.advisors(promptChatMemoryAdvisor) //会话记忆
.advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, filter))
.advisors(retrievalAugmentationAdvisor) //知识库召回
.stream().content();
}
/**
* 知识库对话
* <p>
* * @param message
*
* @return
*/
@Operation(summary = "知识库对话", description = "知识库对话")
@GetMapping(value = "/rag/chat", produces = "text/html;charset=utf-8")
public Flux<String> ragChat(@RequestParam @Parameter(description = "对话内容") String message, @RequestParam @Parameter(description = "会话ID") String conversationId) {
//获取对话历史
String historyMemory = chatService.getHistoryMemoryAsString(conversationId);
//新增问题到对话记录
UserMessage userMessage = new UserMessage(message);
chatService.saveHistoryMemory(conversationId, userMessage);
//向量数据召回
FilterExpressionBuilder builder = new FilterExpressionBuilder();
Filter.Expression filter = builder.eq("isEnabled", 1).build();
List<Document> documents = chatService.retrieveDocuments(message, 0.75, 5, filter);
//获取文件引用
String reference = chatService.getReference(documents);
//拼装知识库上下文内容
String context = chatService.convertDocumentsToString(documents);
//创建提示词
String userPrompt = ragPromptService.createRagPrompt(message, context, historyMemory);
StringBuilder contentBuilder = new StringBuilder();
return chatClient.prompt()
.user(userPrompt)
.system("你是一个智能助手,基于以下上下文和历史对话回答问题,请用简洁的语言回答问题,并确保答案准确,要求" +
"1.以 Markdown 格式输出" +
"2.请务必将你的思考过程放在 <think></think> 标签内" +
"3.请务必将生成最终答案放在 <answer></answer> 标签内")
.stream()
.content()
.concatWith(Mono.just(reference))
.doOnNext(chunk -> {
// 实时收集每个流片段
contentBuilder.append(chunk);
})
.doOnComplete(() -> {
// 流结束时获取完整内容
String fullResponse = contentBuilder.toString();
// 异步保存到数据库(添加错误处理)
chatService.saveHistoryMemory(conversationId, new AssistantMessage(fullResponse));
});
}
}
\ No newline at end of file
package com.ask.service;
import com.ask.api.entity.AskVectorStore;
import com.ask.api.entity.KnowledgeDocument;
import java.util.List;
import java.util.concurrent.CompletableFuture;
/**
* 异步向量化服务接口
*
......@@ -16,9 +12,8 @@ public interface AsyncVectorizationService {
/**
* 根据文档异步向量化处理
*
*
* @param document 文档对象
* @return 异步任务结果
*/
CompletableFuture<Integer> vectorizeByDocumentIdAsync(KnowledgeDocument document);
void vectorizeByDocumentIdAsync(KnowledgeDocument document);
}
\ No newline at end of file
......@@ -5,17 +5,16 @@ import com.ask.api.entity.KnowledgeDocument;
import com.ask.mapper.KnowledgeDocumentMapper;
import com.ask.service.AskVectorStoreService;
import com.ask.service.AsyncVectorizationService;
import com.ask.service.KnowledgeDocumentService;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 异步向量化服务实现类
......@@ -25,26 +24,47 @@ import java.util.concurrent.CompletableFuture;
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class AsyncVectorizationServiceImpl implements AsyncVectorizationService {
@Autowired
private AskVectorStoreService askVectorStoreService;
private final AskVectorStoreService askVectorStoreService;
@Autowired
private KnowledgeDocumentMapper knowledgeDocumentMapper;
private final KnowledgeDocumentMapper knowledgeDocumentMapper;
private final ThreadPoolExecutor processThreadPool = new ThreadPoolExecutor(
3, // 核心线程数
3, // 最大线程数
60L, // 空闲线程存活时间
TimeUnit.SECONDS, // 时间单位
new LinkedBlockingQueue<>(100), // 任务队列
new ThreadFactory() {
private final AtomicInteger threadNumber = new AtomicInteger(1);
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r, "data-process-thread-" + threadNumber.getAndIncrement());
t.setDaemon(true);
return t;
}
},
new ThreadPoolExecutor.CallerRunsPolicy() // 拒绝策略:由调用线程执行
);
/**
* 根据文档异步向量化处理
*
* @param document 文档对象
* @return 异步任务结果
*/
@Override
@Async("vectorizationExecutor")
public CompletableFuture<Integer> vectorizeByDocumentIdAsync(KnowledgeDocument document) {
public void vectorizeByDocumentIdAsync(KnowledgeDocument document) {
processThreadPool.submit(() -> processData(document));
}
public void processData(KnowledgeDocument document) {
if (document == null) {
log.warn("异步向量化:文档对象为空");
return CompletableFuture.completedFuture(0);
return;
}
try {
......@@ -52,16 +72,16 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
Thread.currentThread().getName(), document.getId());
long startTime = System.currentTimeMillis();
// 更新文档状态为处理中
knowledgeDocumentMapper.update(Wrappers.<KnowledgeDocument>lambdaUpdate()
.eq(KnowledgeDocument::getId, document.getId())
.set(KnowledgeDocument::getStatus, 1));
// 查询该文档下所有未向量化的数据
LambdaQueryWrapper<AskVectorStore> wrapper = new LambdaQueryWrapper<AskVectorStore>()
.apply("metadata::jsonb ->> 'documentId' = {0}", String.valueOf(document.getId()));
; // 假设embedding为null表示未向量化
; // 假设embedding为null表示未向量化
List<AskVectorStore> vectorStores = askVectorStoreService.list(wrapper);
......@@ -71,7 +91,7 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
knowledgeDocumentMapper.update(Wrappers.<KnowledgeDocument>lambdaUpdate()
.eq(KnowledgeDocument::getId, document.getId())
.set(KnowledgeDocument::getStatus, 3));
return CompletableFuture.completedFuture(0);
return;
}
// 批量处理向量化
......@@ -81,13 +101,11 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
log.info("根据文档异步向量化完成,线程:{},文档ID:{},耗时:{}ms,成功:{}/{}",
Thread.currentThread().getName(), document.getId(), (endTime - startTime),
successCount, vectorStores.size());
// 更新文档状态为处理完成
knowledgeDocumentMapper.update(Wrappers.<KnowledgeDocument>lambdaUpdate()
.eq(KnowledgeDocument::getId, document.getId())
.set(KnowledgeDocument::getStatus, 2));
return CompletableFuture.completedFuture(successCount);
} catch (Exception e) {
// 更新文档状态为处理失败
......@@ -96,7 +114,7 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
.set(KnowledgeDocument::getStatus, 3));
log.error("根据文档异步向量化失败,线程:{},文档ID:{},错误:{}",
Thread.currentThread().getName(), document.getId(), e.getMessage(), e);
return CompletableFuture.completedFuture(0);
}
}
}
\ No newline at end of file
package com.ask.service.impl;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static org.springframework.ai.chat.messages.MessageType.ASSISTANT;
@Slf4j
@RequiredArgsConstructor
@Service
public class ChatService {
private final VectorStore vectorStore;
private final ChatMemory chatMemory;
/**
* rag召回
*
* @param query 问题
* @param threshold 相似度
* @param topK 召回数量
* @param filter 过滤
* @return 召回文档
*/
public List<Document> retrieveDocuments(String query, double threshold, int topK, Filter.Expression filter) {
List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
.query(query)
.filterExpression(filter)
.similarityThreshold(threshold)
.topK(topK)
.build());
if (CollectionUtils.isEmpty(documents)) {
return List.of();
}
// 根据ID去重
return documents.stream()
.filter(distinctByKey(Document::getId)) // 根据ID去重
.collect(Collectors.toList());
}
/**
* 根据指定字段去重
*
* @param keyExtractor 提取字段的函数
* @param <T> 元素类型
* @return 去重的Predicate
*/
private static <T> Predicate<T> distinctByKey(Function<? super T, ?> keyExtractor) {
Set<Object> seen = ConcurrentHashMap.newKeySet();
return t -> seen.add(keyExtractor.apply(t));
}
/**
* 将召回的数据转换为String
*
* @param documents 召回的数据
* @return 拼接后的字符串
*/
public String convertDocumentsToString(List<Document> documents) {
return documents.stream()
.map(Document::getText) // 提取每个Document的内容
.collect(Collectors.joining("\n")); // 用换行符拼接
}
/**
* 保存会话记忆
*
* @param conversationId
* @param message
*/
public void saveHistoryMemory(String conversationId, Message message) {
chatMemory.add(conversationId, message);
}
/**
* 获取历史对话并组装成字符串
*
* @param conversationId 会话ID
* @return 历史对话的字符串表示
*/
public String getHistoryMemoryAsString(String conversationId) {
// 获取历史对话列表
List<Message> history = chatMemory.get(conversationId);
System.out.println(history.size());
// 将历史对话组装成字符串
return history.stream()
.map(this::formatMessage) // 格式化每条消息
.collect(Collectors.joining("\n")); // 用换行符拼接
}
/**
* 格式化消息
*
* @param message 消息
* @return 格式化后的字符串
*/
private String formatMessage(Message message) {
// 根据消息类型格式化
return switch (message.getMessageType()) {
case USER -> "用户: " + message.getText();
case ASSISTANT -> "助手: " + extractAnswerContent(message);
case SYSTEM -> "系统: " + message.getText();
default -> "未知: " + message.getText();
};
}
public String getReference(List<Document> documents) {
// 使用 Stream 处理 documents,根据 filePath 去重
JSONArray jsonArray = documents.stream()
.filter(document -> {
document.getMetadata();
return true;
}) // 过滤掉 metadata 为空的文档
.collect(Collectors.toMap(
document -> document.getMetadata().get("filePath"), // 以 filePath 作为键
document -> {
JSONObject jsonObject = new JSONObject();
jsonObject.set("fileName", document.getMetadata().get("fileName"));
jsonObject.set("filePath", document.getMetadata().get("filePath"));
return jsonObject;
},
(existing, replacement) -> existing // 如果 filePath 重复,保留已有的 JSONObject
))
.values() // 获取去重后的 JSONObject
.stream()
.collect(Collectors.toCollection(JSONArray::new)); // 转换为 JSONArray
// 构建最终结果
JSONObject result = new JSONObject();
result.put("data", jsonArray);
// 返回格式化后的引用信息
return "\n<reference>" + "\n" + result.toJSONString(2) + "\n" + "</reference>";
}
/**
* 从消息文本中提取 <answer> 标签内的内容
*
* @param message 消息对象
* @return <answer> 标签内的内容,如果没有找到则返回 null
*/
public static String extractAnswerContent(Message message) {
String text = message.getText();
Pattern pattern = Pattern.compile("<answer>(.*?)</answer>", Pattern.DOTALL);
Matcher matcher = pattern.matcher(text);
if (matcher.find()) {
return matcher.group(1); // 返回第一个捕获组,即 <answer> 标签内的内容
}
return message.getText(); // 如果没有找到,返回 null
}
}
package com.ask.service.impl;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.stereotype.Service;
import java.util.Map;
@Service
public class RagPromptService {
/**
* 生成提示词
* @param query 用户问题
* @param context RAG 召回内容
* @param history 历史对话
* @return 生成的提示词
*/
public String createRagPrompt(String query, String context, String history) {
// 定义系统消息和用户消息
String systemMessage = "系统消息:请基于上下文和历史对话回答问题。";
String userMessage = "用户问题:" + query;
// 定义提示词模板
String promptTemplate = """
{systemMessage}
历史对话:
{history}
上下文:
{context}
问题:
{query}
""";
// 填充模板参数
return promptTemplate
.replace("{systemMessage}", systemMessage)
.replace("{query}", query)
.replace("{history}", history)
.replace("{context}", context);
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment