Commit 18813de1 authored by 林洋洋's avatar 林洋洋

修改工具和对话返回

parent 51f84d66
package com.ask.api.dto;
import lombok.Data;
@Data
public class ChatResult {
private String message;
private String reasoningContent;
}
......@@ -52,7 +52,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override
public ChatClient getChatClientById(Long modelId) {
AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){
if(askModel.getStatus()==0){
return null;
}
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
......@@ -67,7 +67,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override
public EmbeddingModel getEmbeddingModelById(Long modelId) {
AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){
if(askModel.getStatus()==0){
return null;
}
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
......
package com.ask.tools;
import com.ask.common.core.FileTemplate;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.StopWatch;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.regex.Pattern;
/**
* SQL查询工具类
* 提供安全的数据库查询功能,支持参数化查询和结果限制
*
* @author AI Assistant
*/
@Component
@Slf4j
public class SqlTools {
@Autowired
private JdbcTemplate jdbcTemplate;
@Autowired
private NamedParameterJdbcTemplate namedParameterJdbcTemplate;
@Value("${sql.tools.max-results:1000}")
private int maxResults;
@Value("${sql.tools.query-timeout:30}")
private int queryTimeoutSeconds;
// SQL注入防护:只允许SELECT语句,禁止危险关键字
private static final Pattern ALLOWED_SQL_PATTERN = Pattern.compile(
"^\\s*SELECT\\s+.*", Pattern.CASE_INSENSITIVE | Pattern.DOTALL
);
private static final Pattern DANGEROUS_KEYWORDS = Pattern.compile(
".*\\b(DELETE|UPDATE|INSERT|DROP|CREATE|ALTER|TRUNCATE|EXEC|EXECUTE|UNION|SCRIPT|DECLARE)\\b.*",
Pattern.CASE_INSENSITIVE
);
@Tool(description = "查询数据库 入参SQL (String)")
public List<Map<String, Object>> selectBySql(@ToolParam(description = "SQL query string") String sql) {
// /**
// * 执行SQL查询
// *
// * @param sql SQL查询语句,只支持SELECT语句
// * @return 查询结果列表,最多返回配置的最大行数
// */
// @Tool(description = "安全执行数据库查询,只支持SELECT语句,返回结果有数量限制")
// public List<Map<String, Object>> selectBySql(
// @ToolParam(description = "SQL查询语句,只支持SELECT语句") String sql) {
//
// return executeQuery(sql, Collections.emptyMap());
// }
/**
* 执行参数化SQL查询
*
* @param sql SQL查询语句,支持命名参数 :paramName
* @param params 参数Map
* @return 查询结果列表
*/
@Tool(description = "Postgres执行参数化SQL查询,支持命名参数,更安全")
public List<Map<String, Object>> selectBySqlWithParams(
@ToolParam(description = "SQL查询语句,支持命名参数如 :name") String sql,
@ToolParam(description = "参数Map,key为参数名,value为参数值") Map<String, Object> params) {
log.info("selectBySqlWithParams:{}", sql);
return executeQuery(sql, params != null ? params : Collections.emptyMap());
}
// /**
// * 执行分页查询
// *
// * @param sql SQL查询语句
// * @param offset 偏移量
// * @param limit 限制数量
// * @return 查询结果列表
// */
// @Tool(description = "执行分页查询")
// public List<Map<String, Object>> selectBySqlWithPaging(
// @ToolParam(description = "SQL查询语句") String sql,
// @ToolParam(description = "偏移量,从0开始") int offset,
// @ToolParam(description = "每页数量,最大1000") int limit) {
//
// // 限制分页参数
// offset = Math.max(0, offset);
// limit = Math.min(Math.max(1, limit), maxResults);
//
// String pagedSql = sql + " LIMIT " + limit + " OFFSET " + offset;
// return executeQuery(pagedSql, Collections.emptyMap());
// }
//
// /**
// * 获取查询结果总数
// *
// * @param sql 原始SQL查询语句
// * @return 总记录数
// */
// @Tool(description = "获取查询结果总数")
// public Long countBySql(@ToolParam(description = "SQL查询语句") String sql) {
//
// if (!isValidSql(sql)) {
// log.warn("无效的SQL语句: {}", sql);
// return 0L;
// }
//
// try {
// // 构建COUNT查询
// String countSql = "SELECT COUNT(*) FROM (" + sql + ") as count_query";
// StopWatch stopWatch = new StopWatch();
// stopWatch.start();
//
// Long count = jdbcTemplate.queryForObject(countSql, Long.class);
//
// stopWatch.stop();
// log.info("COUNT查询执行完成,耗时: {}ms", stopWatch.getTotalTimeMillis());
//
// return count != null ? count : 0L;
//
// } catch (Exception e) {
// log.error("COUNT查询异常: sql={}, error={}", sql, e.getMessage(), e);
// return 0L;
// }
// }
/**
* 执行查询的核心方法
*/
private List<Map<String, Object>> executeQuery(String sql, Map<String, Object> params) {
if (StringUtils.isEmpty(sql)) {
log.warn("SQL语句为空");
return Collections.emptyList();
}
if (!isValidSql(sql)) {
log.warn("SQL安全检查失败: {}", sql);
return Collections.emptyList();
}
try {
return jdbcTemplate.queryForList(sql);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
// 设置查询超时
jdbcTemplate.setQueryTimeout(queryTimeoutSeconds);
List<Map<String, Object>> results;
if (params.isEmpty()) {
results = jdbcTemplate.queryForList(sql);
} else {
results = namedParameterJdbcTemplate.queryForList(sql, params);
}
stopWatch.stop();
// 限制结果数量
if (results.size() > maxResults) {
log.warn("查询结果超过最大限制 {}, 实际数量: {}, 已截取", maxResults, results.size());
results = results.subList(0, maxResults);
}
log.info("SQL查询执行完成 - 耗时: {}ms, 结果数量: {}, SQL: {}",
stopWatch.getTotalTimeMillis(), results.size(),
sql.length() > 100 ? sql.substring(0, 100) + "..." : sql);
return results;
} catch (DataAccessException e) {
log.error("数据库访问异常: sql={}, params={}, error={}", sql, params, e.getMessage());
return Collections.emptyList();
} catch (Exception e) {
log.error("数据库查询异常:sql{} 异常: {}",sql,e.getMessage());
log.error("SQL查询异常: sql={}, params={}, error={}", sql, params, e.getMessage(), e);
return Collections.emptyList();
}
}
/**
* SQL安全性验证
*/
private boolean isValidSql(String sql) {
if (StringUtils.isEmpty(sql)) {
return false;
}
String trimmedSql = sql.trim();
// 检查是否为SELECT语句
if (!ALLOWED_SQL_PATTERN.matcher(trimmedSql).matches()) {
log.warn("只允许SELECT语句: {}", sql);
return false;
}
// 检查危险关键字
if (DANGEROUS_KEYWORDS.matcher(trimmedSql).matches()) {
log.warn("SQL包含危险关键字: {}", sql);
return false;
}
// 检查SQL长度
if (trimmedSql.length() > 10000) {
log.warn("SQL语句过长: {} characters", trimmedSql.length());
return false;
}
return true;
}
// /**
// * 获取数据库表信息
// */
// @Tool(description = "获取数据库中的表列表信息")
// public List<Map<String, Object>> getTableInfo(@ToolParam(description = "表名模式,支持%通配符") String tableNamePattern) {
//
// if (StringUtils.isEmpty(tableNamePattern)) {
// tableNamePattern = "%";
// }
//
// String sql = "SELECT table_name, table_comment, table_type " +
// "FROM information_schema.tables " +
// "WHERE table_schema = DATABASE() " +
// "AND table_name LIKE ? " +
// "ORDER BY table_name";
//
// try {
// return jdbcTemplate.queryForList(sql, tableNamePattern);
// } catch (Exception e) {
// log.error("获取表信息异常: pattern={}, error={}", tableNamePattern, e.getMessage());
// return Collections.emptyList();
// }
// }
//
// /**
// * 获取表的列信息
// */
// @Tool(description = "获取指定表的列信息")
// public List<Map<String, Object>> getTableColumns(@ToolParam(description = "表名") String tableName) {
//
// if (StringUtils.isEmpty(tableName)) {
// return Collections.emptyList();
// }
//
// String sql = "SELECT column_name, data_type, is_nullable, column_default, column_comment " +
// "FROM information_schema.columns " +
// "WHERE table_schema = DATABASE() " +
// "AND table_name = ? " +
// "ORDER BY ordinal_position";
//
// try {
// return jdbcTemplate.queryForList(sql, tableName);
// } catch (Exception e) {
// log.error("获取表列信息异常: table={}, error={}", tableName, e.getMessage());
// return Collections.emptyList();
// }
// }
}
package com.ask.utils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
import reactor.core.publisher.Flux;
import java.lang.reflect.Field;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
public class FluxUtils {
/**
* 将 DeepSeek 的 Flux<ChatResponse> 转换成带 <think>/<answer> 的 Flux<String>
* @param upstream 原始 SSE 流
* @return 带标签的逐块流
*/
public static Flux<String> wrapDeepSeekStream(Flux<ChatResponse> upstream) {
AtomicBoolean reasoningStarted = new AtomicBoolean(false);
AtomicBoolean answerStarted = new AtomicBoolean(false);
return upstream
.flatMapIterable(resp -> {
AssistantMessage msg = resp.getResult().getOutput();
String reasoningContent = "";
String textContent = msg.getText(); // 普通回答
try {
// 反射读取 DeepSeekAssistantMessage.reasoningContent
Field f = msg.getClass().getDeclaredField("reasoningContent");
f.setAccessible(true);
reasoningContent = (String) f.get(msg);
} catch (Exception ignored) { /* 不是 DeepSeekAssistantMessage 时留空 */ }
StringBuilder sb = new StringBuilder();
// 推理阶段
if (!reasoningStarted.get()) {
reasoningStarted.set(true);
sb.append("<think>");
}
if (StringUtils.isNotBlank(reasoningContent)) {
sb.append(reasoningContent);
}
// 回答阶段
if (StringUtils.isNotBlank(textContent)) {
if (!answerStarted.get()) {
answerStarted.set(true);
sb.append("</think><answer>");
}
sb.append(textContent);
}
return List.of(sb.toString());
})
.concatWith(Flux.just("</answer>"));
}
public static Flux<String> wrapDeepSeekStream(Flux<ChatResponse> upstream,StringBuilder
stringBuilder) {
AtomicBoolean reasoningStarted = new AtomicBoolean(false);
AtomicBoolean answerStarted = new AtomicBoolean(false);
return upstream
.flatMapIterable(resp -> {
AssistantMessage msg = resp.getResult().getOutput();
String reasoningContent = "";
String textContent = msg.getText(); // 普通回答
try {
// 反射读取 DeepSeekAssistantMessage.reasoningContent
Field f = msg.getClass().getDeclaredField("reasoningContent");
f.setAccessible(true);
reasoningContent = (String) f.get(msg);
} catch (Exception ignored) { /* 不是 DeepSeekAssistantMessage 时留空 */ }
StringBuilder sb = new StringBuilder();
// 推理阶段
if (!reasoningStarted.get()) {
reasoningStarted.set(true);
sb.append("<think>");
}
if (StringUtils.isNotBlank(reasoningContent)) {
sb.append(reasoningContent);
}
// 回答阶段:第一次出现答案时输出 </think><answer>
if (StringUtils.isNotBlank(textContent)) {
stringBuilder.append(textContent);
if (answerStarted.compareAndSet(false, true)) {
sb.append("</think><answer>");
}
sb.append(textContent);
}
return List.of(sb.toString());
})
.concatWith(Flux.just("</answer>"));
}
}
//package com.ask.utils;
//
//import com.baomidou.mybatisplus.core.toolkit.StringUtils;
//import org.springframework.ai.chat.messages.AssistantMessage;
//import org.springframework.ai.chat.model.ChatResponse;
//import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
//import reactor.core.publisher.Flux;
//
//import java.lang.reflect.Field;
//import java.util.List;
//import java.util.concurrent.atomic.AtomicBoolean;
//
//public class FluxUtils {
//
// /**
// * 将 DeepSeek 的 Flux<ChatResponse> 转换成带 <think>/<answer> 的 Flux<String>
// *
// * @param upstream 原始 SSE 流
// * @return 带标签的逐块流
// */
// public static Flux<String> wrapModelStream(Flux<ChatResponse> upstream) {
//
// return upstream
// .flatMapIterable(resp -> {
// AssistantMessage msg = resp.getResult().getOutput();
//
// String reasoningContent = "";
// String textContent = msg.getText(); // 普通回答
//
// try {
// // 反射读取 DeepSeekAssistantMessage.reasoningContent
// Field f = msg.getClass().getDeclaredField("reasoningContent");
// f.setAccessible(true);
// reasoningContent = (String) f.get(msg);
// } catch (Exception ignored) {
// }
//
// return List.of(sb.toString());
// });
// }
//
//
//}
......@@ -109,6 +109,14 @@ logging:
pattern:
console: "%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n"
# SQL工具配置
sql:
tools:
# 最大查询结果数量限制
max-results: 1000
# 查询超时时间(秒)
query-timeout: 30
# 本地文件系统
file:
local:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment