Commit 18813de1 authored by 林洋洋's avatar 林洋洋

修改工具和对话返回

parent 51f84d66
package com.ask.api.dto;
import lombok.Data;
@Data
public class ChatResult {
private String message;
private String reasoningContent;
}
...@@ -52,7 +52,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i ...@@ -52,7 +52,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override @Override
public ChatClient getChatClientById(Long modelId) { public ChatClient getChatClientById(Long modelId) {
AskModel askModel = this.getById(modelId); AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){ if(askModel.getStatus()==0){
return null; return null;
} }
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider()); IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
...@@ -67,7 +67,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i ...@@ -67,7 +67,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override @Override
public EmbeddingModel getEmbeddingModelById(Long modelId) { public EmbeddingModel getEmbeddingModelById(Long modelId) {
AskModel askModel = this.getById(modelId); AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){ if(askModel.getStatus()==0){
return null; return null;
} }
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider()); IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
......
package com.ask.tools; package com.ask.tools;
import com.ask.common.core.FileTemplate;
import com.baomidou.mybatisplus.core.toolkit.StringUtils; import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.StopWatch;
import java.util.Collections; import java.util.*;
import java.util.List; import java.util.regex.Pattern;
import java.util.Map;
/**
* SQL查询工具类
* 提供安全的数据库查询功能,支持参数化查询和结果限制
*
* @author AI Assistant
*/
@Component @Component
@Slf4j @Slf4j
public class SqlTools { public class SqlTools {
@Autowired @Autowired
private JdbcTemplate jdbcTemplate; private JdbcTemplate jdbcTemplate;
@Autowired
private NamedParameterJdbcTemplate namedParameterJdbcTemplate;
@Value("${sql.tools.max-results:1000}")
private int maxResults;
@Value("${sql.tools.query-timeout:30}")
private int queryTimeoutSeconds;
// SQL注入防护:只允许SELECT语句,禁止危险关键字
private static final Pattern ALLOWED_SQL_PATTERN = Pattern.compile(
"^\\s*SELECT\\s+.*", Pattern.CASE_INSENSITIVE | Pattern.DOTALL
);
private static final Pattern DANGEROUS_KEYWORDS = Pattern.compile(
".*\\b(DELETE|UPDATE|INSERT|DROP|CREATE|ALTER|TRUNCATE|EXEC|EXECUTE|UNION|SCRIPT|DECLARE)\\b.*",
Pattern.CASE_INSENSITIVE
);
@Tool(description = "查询数据库 入参SQL (String)") // /**
public List<Map<String, Object>> selectBySql(@ToolParam(description = "SQL query string") String sql) { // * 执行SQL查询
// *
// * @param sql SQL查询语句,只支持SELECT语句
// * @return 查询结果列表,最多返回配置的最大行数
// */
// @Tool(description = "安全执行数据库查询,只支持SELECT语句,返回结果有数量限制")
// public List<Map<String, Object>> selectBySql(
// @ToolParam(description = "SQL查询语句,只支持SELECT语句") String sql) {
//
// return executeQuery(sql, Collections.emptyMap());
// }
/**
* 执行参数化SQL查询
*
* @param sql SQL查询语句,支持命名参数 :paramName
* @param params 参数Map
* @return 查询结果列表
*/
@Tool(description = "Postgres执行参数化SQL查询,支持命名参数,更安全")
public List<Map<String, Object>> selectBySqlWithParams(
@ToolParam(description = "SQL查询语句,支持命名参数如 :name") String sql,
@ToolParam(description = "参数Map,key为参数名,value为参数值") Map<String, Object> params) {
log.info("selectBySqlWithParams:{}", sql);
return executeQuery(sql, params != null ? params : Collections.emptyMap());
}
// /**
// * 执行分页查询
// *
// * @param sql SQL查询语句
// * @param offset 偏移量
// * @param limit 限制数量
// * @return 查询结果列表
// */
// @Tool(description = "执行分页查询")
// public List<Map<String, Object>> selectBySqlWithPaging(
// @ToolParam(description = "SQL查询语句") String sql,
// @ToolParam(description = "偏移量,从0开始") int offset,
// @ToolParam(description = "每页数量,最大1000") int limit) {
//
// // 限制分页参数
// offset = Math.max(0, offset);
// limit = Math.min(Math.max(1, limit), maxResults);
//
// String pagedSql = sql + " LIMIT " + limit + " OFFSET " + offset;
// return executeQuery(pagedSql, Collections.emptyMap());
// }
//
// /**
// * 获取查询结果总数
// *
// * @param sql 原始SQL查询语句
// * @return 总记录数
// */
// @Tool(description = "获取查询结果总数")
// public Long countBySql(@ToolParam(description = "SQL查询语句") String sql) {
//
// if (!isValidSql(sql)) {
// log.warn("无效的SQL语句: {}", sql);
// return 0L;
// }
//
// try {
// // 构建COUNT查询
// String countSql = "SELECT COUNT(*) FROM (" + sql + ") as count_query";
// StopWatch stopWatch = new StopWatch();
// stopWatch.start();
//
// Long count = jdbcTemplate.queryForObject(countSql, Long.class);
//
// stopWatch.stop();
// log.info("COUNT查询执行完成,耗时: {}ms", stopWatch.getTotalTimeMillis());
//
// return count != null ? count : 0L;
//
// } catch (Exception e) {
// log.error("COUNT查询异常: sql={}, error={}", sql, e.getMessage(), e);
// return 0L;
// }
// }
/**
* 执行查询的核心方法
*/
private List<Map<String, Object>> executeQuery(String sql, Map<String, Object> params) {
if (StringUtils.isEmpty(sql)) { if (StringUtils.isEmpty(sql)) {
log.warn("SQL语句为空");
return Collections.emptyList(); return Collections.emptyList();
} }
if (!isValidSql(sql)) {
log.warn("SQL安全检查失败: {}", sql);
return Collections.emptyList();
}
try { try {
return jdbcTemplate.queryForList(sql); StopWatch stopWatch = new StopWatch();
stopWatch.start();
// 设置查询超时
jdbcTemplate.setQueryTimeout(queryTimeoutSeconds);
List<Map<String, Object>> results;
if (params.isEmpty()) {
results = jdbcTemplate.queryForList(sql);
} else {
results = namedParameterJdbcTemplate.queryForList(sql, params);
}
stopWatch.stop();
// 限制结果数量
if (results.size() > maxResults) {
log.warn("查询结果超过最大限制 {}, 实际数量: {}, 已截取", maxResults, results.size());
results = results.subList(0, maxResults);
}
log.info("SQL查询执行完成 - 耗时: {}ms, 结果数量: {}, SQL: {}",
stopWatch.getTotalTimeMillis(), results.size(),
sql.length() > 100 ? sql.substring(0, 100) + "..." : sql);
return results;
} catch (DataAccessException e) {
log.error("数据库访问异常: sql={}, params={}, error={}", sql, params, e.getMessage());
return Collections.emptyList();
} catch (Exception e) { } catch (Exception e) {
log.error("数据库查询异常:sql{} 异常: {}",sql,e.getMessage()); log.error("SQL查询异常: sql={}, params={}, error={}", sql, params, e.getMessage(), e);
return Collections.emptyList(); return Collections.emptyList();
} }
} }
/**
* SQL安全性验证
*/
private boolean isValidSql(String sql) {
if (StringUtils.isEmpty(sql)) {
return false;
}
String trimmedSql = sql.trim();
// 检查是否为SELECT语句
if (!ALLOWED_SQL_PATTERN.matcher(trimmedSql).matches()) {
log.warn("只允许SELECT语句: {}", sql);
return false;
}
// 检查危险关键字
if (DANGEROUS_KEYWORDS.matcher(trimmedSql).matches()) {
log.warn("SQL包含危险关键字: {}", sql);
return false;
}
// 检查SQL长度
if (trimmedSql.length() > 10000) {
log.warn("SQL语句过长: {} characters", trimmedSql.length());
return false;
}
return true;
}
// /**
// * 获取数据库表信息
// */
// @Tool(description = "获取数据库中的表列表信息")
// public List<Map<String, Object>> getTableInfo(@ToolParam(description = "表名模式,支持%通配符") String tableNamePattern) {
//
// if (StringUtils.isEmpty(tableNamePattern)) {
// tableNamePattern = "%";
// }
//
// String sql = "SELECT table_name, table_comment, table_type " +
// "FROM information_schema.tables " +
// "WHERE table_schema = DATABASE() " +
// "AND table_name LIKE ? " +
// "ORDER BY table_name";
//
// try {
// return jdbcTemplate.queryForList(sql, tableNamePattern);
// } catch (Exception e) {
// log.error("获取表信息异常: pattern={}, error={}", tableNamePattern, e.getMessage());
// return Collections.emptyList();
// }
// }
//
// /**
// * 获取表的列信息
// */
// @Tool(description = "获取指定表的列信息")
// public List<Map<String, Object>> getTableColumns(@ToolParam(description = "表名") String tableName) {
//
// if (StringUtils.isEmpty(tableName)) {
// return Collections.emptyList();
// }
//
// String sql = "SELECT column_name, data_type, is_nullable, column_default, column_comment " +
// "FROM information_schema.columns " +
// "WHERE table_schema = DATABASE() " +
// "AND table_name = ? " +
// "ORDER BY ordinal_position";
//
// try {
// return jdbcTemplate.queryForList(sql, tableName);
// } catch (Exception e) {
// log.error("获取表列信息异常: table={}, error={}", tableName, e.getMessage());
// return Collections.emptyList();
// }
// }
} }
package com.ask.utils; //package com.ask.utils;
//
import com.baomidou.mybatisplus.core.toolkit.StringUtils; //import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import org.springframework.ai.chat.messages.AssistantMessage; //import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatResponse; //import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.deepseek.DeepSeekAssistantMessage; //import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
import reactor.core.publisher.Flux; //import reactor.core.publisher.Flux;
//
import java.lang.reflect.Field; //import java.lang.reflect.Field;
import java.util.List; //import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean; //import java.util.concurrent.atomic.AtomicBoolean;
//
public class FluxUtils { //public class FluxUtils {
//
/** // /**
* 将 DeepSeek 的 Flux<ChatResponse> 转换成带 <think>/<answer> 的 Flux<String> // * 将 DeepSeek 的 Flux<ChatResponse> 转换成带 <think>/<answer> 的 Flux<String>
* @param upstream 原始 SSE 流 // *
* @return 带标签的逐块流 // * @param upstream 原始 SSE 流
*/ // * @return 带标签的逐块流
public static Flux<String> wrapDeepSeekStream(Flux<ChatResponse> upstream) { // */
AtomicBoolean reasoningStarted = new AtomicBoolean(false); // public static Flux<String> wrapModelStream(Flux<ChatResponse> upstream) {
AtomicBoolean answerStarted = new AtomicBoolean(false); //
// return upstream
return upstream // .flatMapIterable(resp -> {
.flatMapIterable(resp -> { // AssistantMessage msg = resp.getResult().getOutput();
AssistantMessage msg = resp.getResult().getOutput(); //
// String reasoningContent = "";
String reasoningContent = ""; // String textContent = msg.getText(); // 普通回答
String textContent = msg.getText(); // 普通回答 //
// try {
try { // // 反射读取 DeepSeekAssistantMessage.reasoningContent
// 反射读取 DeepSeekAssistantMessage.reasoningContent // Field f = msg.getClass().getDeclaredField("reasoningContent");
Field f = msg.getClass().getDeclaredField("reasoningContent"); // f.setAccessible(true);
f.setAccessible(true); // reasoningContent = (String) f.get(msg);
reasoningContent = (String) f.get(msg); // } catch (Exception ignored) {
} catch (Exception ignored) { /* 不是 DeepSeekAssistantMessage 时留空 */ } // }
//
StringBuilder sb = new StringBuilder(); // return List.of(sb.toString());
// });
// 推理阶段 // }
if (!reasoningStarted.get()) { //
reasoningStarted.set(true); //
sb.append("<think>"); //}
}
if (StringUtils.isNotBlank(reasoningContent)) {
sb.append(reasoningContent);
}
// 回答阶段
if (StringUtils.isNotBlank(textContent)) {
if (!answerStarted.get()) {
answerStarted.set(true);
sb.append("</think><answer>");
}
sb.append(textContent);
}
return List.of(sb.toString());
})
.concatWith(Flux.just("</answer>"));
}
public static Flux<String> wrapDeepSeekStream(Flux<ChatResponse> upstream,StringBuilder
stringBuilder) {
AtomicBoolean reasoningStarted = new AtomicBoolean(false);
AtomicBoolean answerStarted = new AtomicBoolean(false);
return upstream
.flatMapIterable(resp -> {
AssistantMessage msg = resp.getResult().getOutput();
String reasoningContent = "";
String textContent = msg.getText(); // 普通回答
try {
// 反射读取 DeepSeekAssistantMessage.reasoningContent
Field f = msg.getClass().getDeclaredField("reasoningContent");
f.setAccessible(true);
reasoningContent = (String) f.get(msg);
} catch (Exception ignored) { /* 不是 DeepSeekAssistantMessage 时留空 */ }
StringBuilder sb = new StringBuilder();
// 推理阶段
if (!reasoningStarted.get()) {
reasoningStarted.set(true);
sb.append("<think>");
}
if (StringUtils.isNotBlank(reasoningContent)) {
sb.append(reasoningContent);
}
// 回答阶段:第一次出现答案时输出 </think><answer>
if (StringUtils.isNotBlank(textContent)) {
stringBuilder.append(textContent);
if (answerStarted.compareAndSet(false, true)) {
sb.append("</think><answer>");
}
sb.append(textContent);
}
return List.of(sb.toString());
})
.concatWith(Flux.just("</answer>"));
}
}
...@@ -109,6 +109,14 @@ logging: ...@@ -109,6 +109,14 @@ logging:
pattern: pattern:
console: "%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n" console: "%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n"
# SQL工具配置
sql:
tools:
# 最大查询结果数量限制
max-results: 1000
# 查询超时时间(秒)
query-timeout: 30
# 本地文件系统 # 本地文件系统
file: file:
local: local:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment