Commit 18813de1 authored by 林洋洋's avatar 林洋洋

修改工具和对话返回

parent 51f84d66
package com.ask.api.dto;
import lombok.Data;
@Data
public class ChatResult {
private String message;
private String reasoningContent;
}
......@@ -3,15 +3,16 @@ package com.ask.controller;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import com.ask.api.dto.ChatResult;
import com.ask.api.entity.ChatConversation;
import com.ask.common.core.R;
import com.ask.service.AskModelService;
import com.ask.service.ChatConversationService;
import com.ask.service.impl.ChatService;
import com.ask.service.impl.RagPromptService;
import com.ask.tools.EchartsTools;
import com.ask.tools.ExcelTools;
import com.ask.tools.SqlTools;
import com.ask.utils.FluxUtils;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
......@@ -28,6 +29,7 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.deepseek.api.DeepSeekApi;
......@@ -65,15 +67,14 @@ public class ChatController {
private final MessageChatMemoryAdvisor messageChatMemoryAdvisor;
private final ChatService chatService;
private final RagPromptService ragPromptService;
private final ExcelTools excelTools;
private final SqlTools sqlTools;
private final EchartsTools echartsTools;
/**
......@@ -105,9 +106,9 @@ public class ChatController {
*/
@Operation(summary = "普通对话", description = "普通对话")
@GetMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> chat(@RequestParam String message,
@RequestParam String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
public Flux<ChatResult> chat(@RequestParam String message,
@RequestParam String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
Long actualModelId = modelId.orElse(1L);
Message systemMessage = new SystemMessage("你是一个AI问答助手,请准确回答用户问题,回答要求:请使用markdown格式输出");
......@@ -117,11 +118,20 @@ public class ChatController {
if (Objects.isNull(chatClient)) {
return Flux.error(new Throwable("模型创建失败"));
}
return FluxUtils.wrapDeepSeekStream(chatClient.prompt(prompt)
return chatClient.prompt(prompt)
.advisors(messageChatMemoryAdvisor)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.stream()
.chatResponse());
.chatResponse()
.map(response -> {
AssistantMessage assistantMessage = response.getResult().getOutput();
ChatResult result = new ChatResult();
if (assistantMessage instanceof DeepSeekAssistantMessage) {
result.setReasoningContent(((DeepSeekAssistantMessage) assistantMessage).getReasoningContent());
}
result.setMessage(assistantMessage.getText());
return result;
});
}
......@@ -134,9 +144,9 @@ public class ChatController {
*/
@Operation(summary = "知识库对话", description = "知识库对话")
@GetMapping(value = "/rag/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> ragChat(@RequestParam @Parameter(description = "对话内容") String message,
@RequestParam @Parameter(description = "会话ID") String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
public Flux<ChatResult> ragChat(@RequestParam @Parameter(description = "对话内容") String message,
@RequestParam @Parameter(description = "会话ID") String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
Long actualModelId = modelId.orElse(1L);
//获取对话历史
......@@ -160,26 +170,34 @@ public class ChatController {
if (Objects.isNull(chatClient)) {
return Flux.error(new Throwable("模型创建失败"));
}
return FluxUtils.wrapDeepSeekStream(chatClient.prompt()
.user(userPrompt)
.system("你是一个智能助手,基于以下上下文和历史对话回答问题,请用简洁的语言回答问题,并确保答案准确,要求" +
"1.以 Markdown 格式输出")
.stream()
.chatResponse(), contentBuilder)
.concatWith(Flux.just(reference))
.doOnComplete(() -> {
// 流结束时获取完整内容
return chatClient.prompt()
.user(userPrompt)
.system("你是一个智能助手,基于以下上下文和历史对话回答问题,请用简洁的语言回答问题,并确保答案准确,要求" +
"1.以 Markdown 格式输出")
.stream()
.chatResponse()
.map(response -> {
AssistantMessage assistantMessage = response.getResult().getOutput();
ChatResult result = new ChatResult();
if (assistantMessage instanceof DeepSeekAssistantMessage) {
result.setReasoningContent(((DeepSeekAssistantMessage) assistantMessage).getReasoningContent());
}
result.setMessage(assistantMessage.getText());
if (StringUtils.isNotBlank(assistantMessage.getText())) {
contentBuilder.append(assistantMessage.getText());
}
return result;
}).doOnComplete(() -> {
String fullResponse = contentBuilder.toString();
// 异步保存到数据库(添加错误处理)
chatService.saveHistoryMemory(conversationId, new AssistantMessage(fullResponse));
});
}
@Operation(summary = "智能数据报表对话", description = "智能数据报表对话")
@GetMapping(value = "/chat/report", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> reportChat(@RequestParam String message,
@RequestParam String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
public Flux<ChatResult> reportChat(@RequestParam String message,
@RequestParam String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
Long actualModelId = modelId.orElse(1L);
Message systemMessage = new SystemMessage("你是一个AI问答助手,请用回答用户问题,使用相关工具");
Message userMessage = new UserMessage(message);
......@@ -188,25 +206,74 @@ public class ChatController {
if (Objects.isNull(chatClient)) {
return Flux.error(new Throwable("模型创建失败"));
}
return FluxUtils.wrapDeepSeekStream(chatClient.prompt(prompt)
return chatClient.prompt(prompt)
.advisors(messageChatMemoryAdvisor)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.tools(excelTools)
.advisors()
.stream()
.chatResponse());
.chatResponse()
.map(response -> {
AssistantMessage assistantMessage = response.getResult().getOutput();
ChatResult result = new ChatResult();
if (assistantMessage instanceof DeepSeekAssistantMessage) {
result.setReasoningContent(((DeepSeekAssistantMessage) assistantMessage).getReasoningContent());
}
result.setMessage(assistantMessage.getText());
return result;
});
}
public void test() {
ChatModel chatModel = DeepSeekChatModel.builder()
.deepSeekApi(DeepSeekApi.builder().baseUrl("").apiKey("TEST").build())
.defaultOptions(DeepSeekChatOptions.builder().model("deepseek-r1").temperature(66.6).maxTokens(10000).build())
.build();
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors()
.build();
@Operation(summary = "智能问数对话", description = "智能问数对话")
@GetMapping(value = "/chat/data", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ChatResult> dataChat(@RequestParam String message,
@RequestParam String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
Long actualModelId = modelId.orElse(1L);
Message systemMessage = new SystemMessage("" +
"【生产记录表 - ask_production_records】\n" +
"用途:存储各公司及其下属项目的月度生产量数据。\n" +
"\n" +
"字段说明:\n" +
"1. company_name (varchar100) → 公司名称(如:龙源环保) \n" +
"2. subsidiary_name(varchar100) → 项目(如:脱硝催化剂) \n" +
"3. unit (varchar50) → 单位(如:立方米、件、平方米) \n" +
"4. value (numeric10,2) → 月生产量数值(保留两位小数) \n" +
"5. year (int4) → 年份(如:2024) \n" +
"6. month (int4) → 月份(1-12) \n" +
"7. id (int4 PK) → 主键,自增\n" +
"\n" +
"【查询规则】\n" +
"- 每次回答必须先通过 SQL 工具查询此表,禁止口算或推测。 \n" +
"- 返回结果需附带“单位”字段。 \n" +
"- 若数据不存在,直接回复“暂无记录”,禁止编造。" +
"【回答要求】\n" +
"- 可以采用表格+图表的形式展示数据");
Message userMessage = new UserMessage(message);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatClient chatClient = askModelService.getChatClientById(actualModelId);
if (Objects.isNull(chatClient)) {
return Flux.error(new Throwable("模型创建失败"));
}
return chatClient.prompt(prompt)
.advisors(messageChatMemoryAdvisor)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.tools(sqlTools)
.tools(echartsTools)
.stream()
.chatResponse()
.map(response -> {
AssistantMessage assistantMessage = response.getResult().getOutput();
ChatResult result = new ChatResult();
if (assistantMessage instanceof DeepSeekAssistantMessage) {
result.setReasoningContent(((DeepSeekAssistantMessage) assistantMessage).getReasoningContent());
}
result.setMessage(assistantMessage.getText());
return result;
});
}
......
......@@ -52,7 +52,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override
public ChatClient getChatClientById(Long modelId) {
AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){
if(askModel.getStatus()==0){
return null;
}
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
......@@ -67,7 +67,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override
public EmbeddingModel getEmbeddingModelById(Long modelId) {
AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){
if(askModel.getStatus()==0){
return null;
}
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
......
package com.ask.tools;
import com.ask.common.core.FileTemplate;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.StopWatch;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.regex.Pattern;
/**
* SQL查询工具类
* 提供安全的数据库查询功能,支持参数化查询和结果限制
*
* @author AI Assistant
*/
@Component
@Slf4j
public class SqlTools {
@Autowired
private JdbcTemplate jdbcTemplate;
@Autowired
private NamedParameterJdbcTemplate namedParameterJdbcTemplate;
@Value("${sql.tools.max-results:1000}")
private int maxResults;
@Value("${sql.tools.query-timeout:30}")
private int queryTimeoutSeconds;
// SQL注入防护:只允许SELECT语句,禁止危险关键字
private static final Pattern ALLOWED_SQL_PATTERN = Pattern.compile(
"^\\s*SELECT\\s+.*", Pattern.CASE_INSENSITIVE | Pattern.DOTALL
);
private static final Pattern DANGEROUS_KEYWORDS = Pattern.compile(
".*\\b(DELETE|UPDATE|INSERT|DROP|CREATE|ALTER|TRUNCATE|EXEC|EXECUTE|UNION|SCRIPT|DECLARE)\\b.*",
Pattern.CASE_INSENSITIVE
);
@Tool(description = "查询数据库 入参SQL (String)")
public List<Map<String, Object>> selectBySql(@ToolParam(description = "SQL query string") String sql) {
// /**
// * 执行SQL查询
// *
// * @param sql SQL查询语句,只支持SELECT语句
// * @return 查询结果列表,最多返回配置的最大行数
// */
// @Tool(description = "安全执行数据库查询,只支持SELECT语句,返回结果有数量限制")
// public List<Map<String, Object>> selectBySql(
// @ToolParam(description = "SQL查询语句,只支持SELECT语句") String sql) {
//
// return executeQuery(sql, Collections.emptyMap());
// }
/**
* 执行参数化SQL查询
*
* @param sql SQL查询语句,支持命名参数 :paramName
* @param params 参数Map
* @return 查询结果列表
*/
@Tool(description = "Postgres执行参数化SQL查询,支持命名参数,更安全")
public List<Map<String, Object>> selectBySqlWithParams(
@ToolParam(description = "SQL查询语句,支持命名参数如 :name") String sql,
@ToolParam(description = "参数Map,key为参数名,value为参数值") Map<String, Object> params) {
log.info("selectBySqlWithParams:{}", sql);
return executeQuery(sql, params != null ? params : Collections.emptyMap());
}
// /**
// * 执行分页查询
// *
// * @param sql SQL查询语句
// * @param offset 偏移量
// * @param limit 限制数量
// * @return 查询结果列表
// */
// @Tool(description = "执行分页查询")
// public List<Map<String, Object>> selectBySqlWithPaging(
// @ToolParam(description = "SQL查询语句") String sql,
// @ToolParam(description = "偏移量,从0开始") int offset,
// @ToolParam(description = "每页数量,最大1000") int limit) {
//
// // 限制分页参数
// offset = Math.max(0, offset);
// limit = Math.min(Math.max(1, limit), maxResults);
//
// String pagedSql = sql + " LIMIT " + limit + " OFFSET " + offset;
// return executeQuery(pagedSql, Collections.emptyMap());
// }
//
// /**
// * 获取查询结果总数
// *
// * @param sql 原始SQL查询语句
// * @return 总记录数
// */
// @Tool(description = "获取查询结果总数")
// public Long countBySql(@ToolParam(description = "SQL查询语句") String sql) {
//
// if (!isValidSql(sql)) {
// log.warn("无效的SQL语句: {}", sql);
// return 0L;
// }
//
// try {
// // 构建COUNT查询
// String countSql = "SELECT COUNT(*) FROM (" + sql + ") as count_query";
// StopWatch stopWatch = new StopWatch();
// stopWatch.start();
//
// Long count = jdbcTemplate.queryForObject(countSql, Long.class);
//
// stopWatch.stop();
// log.info("COUNT查询执行完成,耗时: {}ms", stopWatch.getTotalTimeMillis());
//
// return count != null ? count : 0L;
//
// } catch (Exception e) {
// log.error("COUNT查询异常: sql={}, error={}", sql, e.getMessage(), e);
// return 0L;
// }
// }
/**
* 执行查询的核心方法
*/
private List<Map<String, Object>> executeQuery(String sql, Map<String, Object> params) {
if (StringUtils.isEmpty(sql)) {
log.warn("SQL语句为空");
return Collections.emptyList();
}
if (!isValidSql(sql)) {
log.warn("SQL安全检查失败: {}", sql);
return Collections.emptyList();
}
try {
return jdbcTemplate.queryForList(sql);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
// 设置查询超时
jdbcTemplate.setQueryTimeout(queryTimeoutSeconds);
List<Map<String, Object>> results;
if (params.isEmpty()) {
results = jdbcTemplate.queryForList(sql);
} else {
results = namedParameterJdbcTemplate.queryForList(sql, params);
}
stopWatch.stop();
// 限制结果数量
if (results.size() > maxResults) {
log.warn("查询结果超过最大限制 {}, 实际数量: {}, 已截取", maxResults, results.size());
results = results.subList(0, maxResults);
}
log.info("SQL查询执行完成 - 耗时: {}ms, 结果数量: {}, SQL: {}",
stopWatch.getTotalTimeMillis(), results.size(),
sql.length() > 100 ? sql.substring(0, 100) + "..." : sql);
return results;
} catch (DataAccessException e) {
log.error("数据库访问异常: sql={}, params={}, error={}", sql, params, e.getMessage());
return Collections.emptyList();
} catch (Exception e) {
log.error("数据库查询异常:sql{} 异常: {}",sql,e.getMessage());
log.error("SQL查询异常: sql={}, params={}, error={}", sql, params, e.getMessage(), e);
return Collections.emptyList();
}
}
/**
* SQL安全性验证
*/
private boolean isValidSql(String sql) {
if (StringUtils.isEmpty(sql)) {
return false;
}
String trimmedSql = sql.trim();
// 检查是否为SELECT语句
if (!ALLOWED_SQL_PATTERN.matcher(trimmedSql).matches()) {
log.warn("只允许SELECT语句: {}", sql);
return false;
}
// 检查危险关键字
if (DANGEROUS_KEYWORDS.matcher(trimmedSql).matches()) {
log.warn("SQL包含危险关键字: {}", sql);
return false;
}
// 检查SQL长度
if (trimmedSql.length() > 10000) {
log.warn("SQL语句过长: {} characters", trimmedSql.length());
return false;
}
return true;
}
// /**
// * 获取数据库表信息
// */
// @Tool(description = "获取数据库中的表列表信息")
// public List<Map<String, Object>> getTableInfo(@ToolParam(description = "表名模式,支持%通配符") String tableNamePattern) {
//
// if (StringUtils.isEmpty(tableNamePattern)) {
// tableNamePattern = "%";
// }
//
// String sql = "SELECT table_name, table_comment, table_type " +
// "FROM information_schema.tables " +
// "WHERE table_schema = DATABASE() " +
// "AND table_name LIKE ? " +
// "ORDER BY table_name";
//
// try {
// return jdbcTemplate.queryForList(sql, tableNamePattern);
// } catch (Exception e) {
// log.error("获取表信息异常: pattern={}, error={}", tableNamePattern, e.getMessage());
// return Collections.emptyList();
// }
// }
//
// /**
// * 获取表的列信息
// */
// @Tool(description = "获取指定表的列信息")
// public List<Map<String, Object>> getTableColumns(@ToolParam(description = "表名") String tableName) {
//
// if (StringUtils.isEmpty(tableName)) {
// return Collections.emptyList();
// }
//
// String sql = "SELECT column_name, data_type, is_nullable, column_default, column_comment " +
// "FROM information_schema.columns " +
// "WHERE table_schema = DATABASE() " +
// "AND table_name = ? " +
// "ORDER BY ordinal_position";
//
// try {
// return jdbcTemplate.queryForList(sql, tableName);
// } catch (Exception e) {
// log.error("获取表列信息异常: table={}, error={}", tableName, e.getMessage());
// return Collections.emptyList();
// }
// }
}
package com.ask.utils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
import reactor.core.publisher.Flux;
import java.lang.reflect.Field;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
public class FluxUtils {
/**
* 将 DeepSeek 的 Flux<ChatResponse> 转换成带 <think>/<answer> 的 Flux<String>
* @param upstream 原始 SSE 流
* @return 带标签的逐块流
*/
public static Flux<String> wrapDeepSeekStream(Flux<ChatResponse> upstream) {
AtomicBoolean reasoningStarted = new AtomicBoolean(false);
AtomicBoolean answerStarted = new AtomicBoolean(false);
return upstream
.flatMapIterable(resp -> {
AssistantMessage msg = resp.getResult().getOutput();
String reasoningContent = "";
String textContent = msg.getText(); // 普通回答
try {
// 反射读取 DeepSeekAssistantMessage.reasoningContent
Field f = msg.getClass().getDeclaredField("reasoningContent");
f.setAccessible(true);
reasoningContent = (String) f.get(msg);
} catch (Exception ignored) { /* 不是 DeepSeekAssistantMessage 时留空 */ }
StringBuilder sb = new StringBuilder();
// 推理阶段
if (!reasoningStarted.get()) {
reasoningStarted.set(true);
sb.append("<think>");
}
if (StringUtils.isNotBlank(reasoningContent)) {
sb.append(reasoningContent);
}
// 回答阶段
if (StringUtils.isNotBlank(textContent)) {
if (!answerStarted.get()) {
answerStarted.set(true);
sb.append("</think><answer>");
}
sb.append(textContent);
}
return List.of(sb.toString());
})
.concatWith(Flux.just("</answer>"));
}
public static Flux<String> wrapDeepSeekStream(Flux<ChatResponse> upstream,StringBuilder
stringBuilder) {
AtomicBoolean reasoningStarted = new AtomicBoolean(false);
AtomicBoolean answerStarted = new AtomicBoolean(false);
return upstream
.flatMapIterable(resp -> {
AssistantMessage msg = resp.getResult().getOutput();
String reasoningContent = "";
String textContent = msg.getText(); // 普通回答
try {
// 反射读取 DeepSeekAssistantMessage.reasoningContent
Field f = msg.getClass().getDeclaredField("reasoningContent");
f.setAccessible(true);
reasoningContent = (String) f.get(msg);
} catch (Exception ignored) { /* 不是 DeepSeekAssistantMessage 时留空 */ }
StringBuilder sb = new StringBuilder();
// 推理阶段
if (!reasoningStarted.get()) {
reasoningStarted.set(true);
sb.append("<think>");
}
if (StringUtils.isNotBlank(reasoningContent)) {
sb.append(reasoningContent);
}
// 回答阶段:第一次出现答案时输出 </think><answer>
if (StringUtils.isNotBlank(textContent)) {
stringBuilder.append(textContent);
if (answerStarted.compareAndSet(false, true)) {
sb.append("</think><answer>");
}
sb.append(textContent);
}
return List.of(sb.toString());
})
.concatWith(Flux.just("</answer>"));
}
}
//package com.ask.utils;
//
//import com.baomidou.mybatisplus.core.toolkit.StringUtils;
//import org.springframework.ai.chat.messages.AssistantMessage;
//import org.springframework.ai.chat.model.ChatResponse;
//import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
//import reactor.core.publisher.Flux;
//
//import java.lang.reflect.Field;
//import java.util.List;
//import java.util.concurrent.atomic.AtomicBoolean;
//
//public class FluxUtils {
//
// /**
// * 将 DeepSeek 的 Flux<ChatResponse> 转换成带 <think>/<answer> 的 Flux<String>
// *
// * @param upstream 原始 SSE 流
// * @return 带标签的逐块流
// */
// public static Flux<String> wrapModelStream(Flux<ChatResponse> upstream) {
//
// return upstream
// .flatMapIterable(resp -> {
// AssistantMessage msg = resp.getResult().getOutput();
//
// String reasoningContent = "";
// String textContent = msg.getText(); // 普通回答
//
// try {
// // 反射读取 DeepSeekAssistantMessage.reasoningContent
// Field f = msg.getClass().getDeclaredField("reasoningContent");
// f.setAccessible(true);
// reasoningContent = (String) f.get(msg);
// } catch (Exception ignored) {
// }
//
// return List.of(sb.toString());
// });
// }
//
//
//}
......@@ -109,6 +109,14 @@ logging:
pattern:
console: "%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n"
# SQL工具配置
sql:
tools:
# 最大查询结果数量限制
max-results: 1000
# 查询超时时间(秒)
query-timeout: 30
# 本地文件系统
file:
local:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment