Commit aba26463 authored by 林洋洋's avatar 林洋洋

模型切换deepseek 使用deepseek的思考

parent 1c71562d
...@@ -79,6 +79,10 @@ ...@@ -79,6 +79,10 @@
</dependency> </dependency>
<!-- Spring AI --> <!-- Spring AI -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-deepseek</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-openai</artifactId> <artifactId>spring-ai-starter-model-openai</artifactId>
......
...@@ -9,6 +9,7 @@ import org.springframework.ai.chat.memory.ChatMemoryRepository; ...@@ -9,6 +9,7 @@ import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory; import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository; import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor; import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter; import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
...@@ -38,7 +39,14 @@ public class CommonConfiguration { ...@@ -38,7 +39,14 @@ public class CommonConfiguration {
} }
@Bean @Bean
public ChatClient chatClient(OpenAiChatModel model) { public ChatClient openAiChatClient(OpenAiChatModel model) {
return ChatClient.builder(model)
.defaultAdvisors()
.build();
}
@Bean
public ChatClient deepseekChatClient(DeepSeekChatModel model) {
return ChatClient.builder(model) return ChatClient.builder(model)
.defaultAdvisors() .defaultAdvisors()
.build(); .build();
......
...@@ -8,29 +8,37 @@ import com.ask.common.core.R; ...@@ -8,29 +8,37 @@ import com.ask.common.core.R;
import com.ask.service.ChatConversationService; import com.ask.service.ChatConversationService;
import com.ask.service.impl.ChatService; import com.ask.service.impl.ChatService;
import com.ask.service.impl.RagPromptService; import com.ask.service.impl.RagPromptService;
import com.ask.utils.FluxUtils;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor; import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
@Slf4j @Slf4j
@RestController @RestController
...@@ -39,7 +47,9 @@ import java.util.*; ...@@ -39,7 +47,9 @@ import java.util.*;
@Tag(description = "ai", name = "AI对话模块") @Tag(description = "ai", name = "AI对话模块")
public class ChatController { public class ChatController {
private final ChatClient chatClient; private final ChatClient openAiChatClient;
private final ChatClient deepseekChatClient;
private final ChatConversationService chatConversationService; private final ChatConversationService chatConversationService;
...@@ -56,6 +66,7 @@ public class ChatController { ...@@ -56,6 +66,7 @@ public class ChatController {
private final RagPromptService ragPromptService; private final RagPromptService ragPromptService;
private final ChatMemory chatMemory; private final ChatMemory chatMemory;
private final OpenAiChatModel openAiChatModel;
/** /**
* 获取会话ID * 获取会话ID
...@@ -84,21 +95,24 @@ public class ChatController { ...@@ -84,21 +95,24 @@ public class ChatController {
* *
* @return * @return
*/ */
@Operation(summary = "普通对话", description = "普通对话") @GetMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@GetMapping(value = "/chat", produces = "text/html;charset=utf-8") public Flux<String> chat(@RequestParam String message,
public Flux<String> chat(@Parameter(description = "对话内容") @RequestParam String message, @Parameter(description = "会话ID") @RequestParam String conversationId) { @RequestParam String conversationId) {
// 创建系统消息,告诉大模型只返回工具名和参数
Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:"); Message systemMessage = new SystemMessage("你是一个AI问答助手,请用回答用户问题");
// 用户消息 Message userMessage = new UserMessage("问题:" + message + "\n回答要求:请使用markdown格式输出");
String question = "请严格按以下格式回答:\n" + "<think>\n" + "[你的逐步推理过程]\n" + "</think>\n" + "<answer>\n" + "[最终答案]\n" + "</answer>\n" + "推理过程不要设计`<think>` 和 `<answer>` \n" + "问题:" + message + "\n"; Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
Message userMessage = new UserMessage(question);
// 创建提示,包含系统消息和用户消息
Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// 使用修改后的提示获取响应
return chatClient.prompt(prompt).advisors(messageChatMemoryAdvisor).advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)).stream().content();
}
AtomicBoolean reasoningStarted = new AtomicBoolean(false);
AtomicBoolean answerStarted = new AtomicBoolean(false);
return FluxUtils.wrapDeepSeekStream(deepseekChatClient.prompt(prompt)
.advisors(messageChatMemoryAdvisor)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.stream()
.chatResponse());
}
/** /**
* 知识库对话 * 知识库对话
* <p> * <p>
...@@ -107,7 +121,7 @@ public class ChatController { ...@@ -107,7 +121,7 @@ public class ChatController {
* @return * @return
*/ */
@Operation(summary = "知识库对话", description = "知识库对话") @Operation(summary = "知识库对话", description = "知识库对话")
@GetMapping(value = "/rag/chat", produces = "text/html;charset=utf-8") @GetMapping(value = "/rag/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> ragChat(@RequestParam @Parameter(description = "对话内容") String message, @RequestParam @Parameter(description = "会话ID") String conversationId) { public Flux<String> ragChat(@RequestParam @Parameter(description = "对话内容") String message, @RequestParam @Parameter(description = "会话ID") String conversationId) {
//获取对话历史 //获取对话历史
...@@ -125,21 +139,15 @@ public class ChatController { ...@@ -125,21 +139,15 @@ public class ChatController {
String context = chatService.convertDocumentsToString(documents); String context = chatService.convertDocumentsToString(documents);
//创建提示词 //创建提示词
String userPrompt = ragPromptService.createRagPrompt(message, context, historyMemory); String userPrompt = ragPromptService.createRagPrompt(message, context, historyMemory);
StringBuilder contentBuilder = new StringBuilder();
return chatClient.prompt() StringBuilder contentBuilder = new StringBuilder();
return FluxUtils.wrapDeepSeekStream(deepseekChatClient.prompt()
.user(userPrompt) .user(userPrompt)
.system("你是一个智能助手,基于以下上下文和历史对话回答问题,请用简洁的语言回答问题,并确保答案准确,要求" + .system("你是一个智能助手,基于以下上下文和历史对话回答问题,请用简洁的语言回答问题,并确保答案准确,要求" +
"1.以 Markdown 格式输出" + "1.以 Markdown 格式输出" )
"2.请务必将你的思考过程放在 <think></think> 标签内" +
"3.请务必将生成最终答案放在 <answer></answer> 标签内")
.stream() .stream()
.content() .chatResponse(),contentBuilder)
.concatWith(Mono.just(reference)) .concatWith(Flux.just(reference))
.doOnNext(chunk -> {
// 实时收集每个流片段
contentBuilder.append(chunk);
})
.doOnComplete(() -> { .doOnComplete(() -> {
// 流结束时获取完整内容 // 流结束时获取完整内容
String fullResponse = contentBuilder.toString(); String fullResponse = contentBuilder.toString();
......
package com.ask.utils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
public class FluxUtils {
/**
* 将 DeepSeek 的 Flux<ChatResponse> 转换成带 <think>/<answer> 的 Flux<String>
* @param upstream 原始 SSE 流
* @return 带标签的逐块流
*/
public static Flux<String> wrapDeepSeekStream(Flux<ChatResponse> upstream) {
AtomicBoolean reasoningStarted = new AtomicBoolean(false);
AtomicBoolean answerStarted = new AtomicBoolean(false);
return upstream
.flatMapIterable(resp -> {
DeepSeekAssistantMessage msg =
(DeepSeekAssistantMessage) resp.getResult().getOutput();
StringBuilder sb = new StringBuilder();
// 推理阶段:第一次出现推理内容时输出 <think>
if (StringUtils.isNotBlank(msg.getReasoningContent())) {
if (reasoningStarted.compareAndSet(false, true)) {
sb.append("<think>");
}
sb.append(msg.getReasoningContent());
}
// 回答阶段:第一次出现答案时输出 </think><answer>
if (StringUtils.isNotBlank(msg.getText())) {
if (answerStarted.compareAndSet(false, true)) {
sb.append("</think><answer>");
}
sb.append(msg.getText());
}
return List.of(sb.toString());
})
.concatWith(Flux.just("</answer>")); // 末尾补一次关闭标签
}
public static Flux<String> wrapDeepSeekStream(Flux<ChatResponse> upstream,StringBuilder
stringBuilder) {
AtomicBoolean reasoningStarted = new AtomicBoolean(false);
AtomicBoolean answerStarted = new AtomicBoolean(false);
return upstream
.flatMapIterable(resp -> {
DeepSeekAssistantMessage msg =
(DeepSeekAssistantMessage) resp.getResult().getOutput();
StringBuilder sb = new StringBuilder();
// 推理阶段:第一次出现推理内容时输出 <think>
if (StringUtils.isNotBlank(msg.getReasoningContent())) {
if (reasoningStarted.compareAndSet(false, true)) {
sb.append("<think>");
}
sb.append(msg.getReasoningContent());
}
// 回答阶段:第一次出现答案时输出 </think><answer>
if (StringUtils.isNotBlank(msg.getText())) {
stringBuilder.append(msg.getText());
if (answerStarted.compareAndSet(false, true)) {
sb.append("</think><answer>");
}
sb.append(msg.getText());
}
return List.of(sb.toString());
})
.concatWith(Flux.just("</answer>")); // 末尾补一次关闭标签
}
}
...@@ -37,12 +37,19 @@ spring: ...@@ -37,12 +37,19 @@ spring:
api-key: sk-ae96ff281ff644c992843c64a711a950 api-key: sk-ae96ff281ff644c992843c64a711a950
chat: chat:
options: options:
model: qwen-plus model: deepseek-r1
embedding: embedding:
base-url: https://dashscope.aliyuncs.com/compatible-mode base-url: https://dashscope.aliyuncs.com/compatible-mode
api-key: sk-ae96ff281ff644c992843c64a711a950 api-key: sk-ae96ff281ff644c992843c64a711a950
options: options:
model: text-embedding-v4 model: text-embedding-v4
deepseek:
base-url: https://dashscope.aliyuncs.com/compatible-mode/v1
api-key: sk-ae96ff281ff644c992843c64a711a950
chat:
enabled: true
options:
model: deepseek-r1
mybatis-plus: mybatis-plus:
mapper-locations: classpath*:/mapper/*Mapper.xml # mapper文件位置 mapper-locations: classpath*:/mapper/*Mapper.xml # mapper文件位置
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment