Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
A
ask_data_ai_admin
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
linyangyang
ask_data_ai_admin
Commits
18813de1
Commit
18813de1
authored
Jul 30, 2025
by
林洋洋
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
修改工具和对话返回
parent
51f84d66
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
402 additions
and
153 deletions
+402
-153
ChatResult.java
...data-ai-api/src/main/java/com/ask/api/dto/ChatResult.java
+9
-0
ChatController.java
...-biz/src/main/java/com/ask/controller/ChatController.java
+102
-35
AskModelServiceImpl.java
...c/main/java/com/ask/service/impl/AskModelServiceImpl.java
+2
-2
SqlTools.java
...ask-data-ai-biz/src/main/java/com/ask/tools/SqlTools.java
+238
-10
FluxUtils.java
...sk-data-ai-biz/src/main/java/com/ask/utils/FluxUtils.java
+43
-106
application.yml
...ta-ai/ask-data-ai-boot/src/main/resources/application.yml
+8
-0
No files found.
ask-data-ai/ask-data-ai-api/src/main/java/com/ask/api/dto/ChatResult.java
0 → 100644
View file @
18813de1
package
com
.
ask
.
api
.
dto
;
import
lombok.Data
;
@Data
public
class
ChatResult
{
private
String
message
;
private
String
reasoningContent
;
}
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/controller/ChatController.java
View file @
18813de1
...
...
@@ -3,15 +3,16 @@ package com.ask.controller;
import
cn.hutool.json.JSONArray
;
import
cn.hutool.json.JSONObject
;
import
com.ask.api.dto.ChatResult
;
import
com.ask.api.entity.ChatConversation
;
import
com.ask.common.core.R
;
import
com.ask.service.AskModelService
;
import
com.ask.service.ChatConversationService
;
import
com.ask.service.impl.ChatService
;
import
com.ask.service.impl.RagPromptService
;
import
com.ask.tools.EchartsTools
;
import
com.ask.tools.ExcelTools
;
import
com.ask.tools.SqlTools
;
import
com.ask.utils.FluxUtils
;
import
io.swagger.v3.oas.annotations.Operation
;
import
io.swagger.v3.oas.annotations.Parameter
;
import
io.swagger.v3.oas.annotations.tags.Tag
;
...
...
@@ -28,6 +29,7 @@ import org.springframework.ai.chat.model.ChatModel;
import
org.springframework.ai.chat.model.ChatResponse
;
import
org.springframework.ai.chat.prompt.Prompt
;
import
org.springframework.ai.deepseek.DeepSeekAssistantMessage
;
import
org.springframework.ai.deepseek.DeepSeekChatModel
;
import
org.springframework.ai.deepseek.DeepSeekChatOptions
;
import
org.springframework.ai.deepseek.api.DeepSeekApi
;
...
...
@@ -65,15 +67,14 @@ public class ChatController {
private
final
MessageChatMemoryAdvisor
messageChatMemoryAdvisor
;
private
final
ChatService
chatService
;
private
final
RagPromptService
ragPromptService
;
private
final
ExcelTools
excelTools
;
private
final
SqlTools
sqlTools
;
private
final
EchartsTools
echartsTools
;
/**
...
...
@@ -105,7 +106,7 @@ public class ChatController {
*/
@Operation
(
summary
=
"普通对话"
,
description
=
"普通对话"
)
@GetMapping
(
value
=
"/chat"
,
produces
=
MediaType
.
TEXT_EVENT_STREAM_VALUE
)
public
Flux
<
String
>
chat
(
@RequestParam
String
message
,
public
Flux
<
ChatResult
>
chat
(
@RequestParam
String
message
,
@RequestParam
String
conversationId
,
@RequestParam
(
required
=
false
)
Optional
<
Long
>
modelId
)
{
Long
actualModelId
=
modelId
.
orElse
(
1L
);
...
...
@@ -117,11 +118,20 @@ public class ChatController {
if
(
Objects
.
isNull
(
chatClient
))
{
return
Flux
.
error
(
new
Throwable
(
"模型创建失败"
));
}
return
FluxUtils
.
wrapDeepSeekStream
(
chatClient
.
prompt
(
prompt
)
return
chatClient
.
prompt
(
prompt
)
.
advisors
(
messageChatMemoryAdvisor
)
.
advisors
(
a
->
a
.
param
(
ChatMemory
.
CONVERSATION_ID
,
conversationId
))
.
stream
()
.
chatResponse
());
.
chatResponse
()
.
map
(
response
->
{
AssistantMessage
assistantMessage
=
response
.
getResult
().
getOutput
();
ChatResult
result
=
new
ChatResult
();
if
(
assistantMessage
instanceof
DeepSeekAssistantMessage
)
{
result
.
setReasoningContent
(((
DeepSeekAssistantMessage
)
assistantMessage
).
getReasoningContent
());
}
result
.
setMessage
(
assistantMessage
.
getText
());
return
result
;
});
}
...
...
@@ -134,7 +144,7 @@ public class ChatController {
*/
@Operation
(
summary
=
"知识库对话"
,
description
=
"知识库对话"
)
@GetMapping
(
value
=
"/rag/chat"
,
produces
=
MediaType
.
TEXT_EVENT_STREAM_VALUE
)
public
Flux
<
String
>
ragChat
(
@RequestParam
@Parameter
(
description
=
"对话内容"
)
String
message
,
public
Flux
<
ChatResult
>
ragChat
(
@RequestParam
@Parameter
(
description
=
"对话内容"
)
String
message
,
@RequestParam
@Parameter
(
description
=
"会话ID"
)
String
conversationId
,
@RequestParam
(
required
=
false
)
Optional
<
Long
>
modelId
)
{
Long
actualModelId
=
modelId
.
orElse
(
1L
);
...
...
@@ -160,24 +170,32 @@ public class ChatController {
if
(
Objects
.
isNull
(
chatClient
))
{
return
Flux
.
error
(
new
Throwable
(
"模型创建失败"
));
}
return
FluxUtils
.
wrapDeepSeekStream
(
chatClient
.
prompt
()
return
chatClient
.
prompt
()
.
user
(
userPrompt
)
.
system
(
"你是一个智能助手,基于以下上下文和历史对话回答问题,请用简洁的语言回答问题,并确保答案准确,要求"
+
"1.以 Markdown 格式输出"
)
.
stream
()
.
chatResponse
(),
contentBuilder
)
.
concatWith
(
Flux
.
just
(
reference
))
.
doOnComplete
(()
->
{
// 流结束时获取完整内容
.
chatResponse
()
.
map
(
response
->
{
AssistantMessage
assistantMessage
=
response
.
getResult
().
getOutput
();
ChatResult
result
=
new
ChatResult
();
if
(
assistantMessage
instanceof
DeepSeekAssistantMessage
)
{
result
.
setReasoningContent
(((
DeepSeekAssistantMessage
)
assistantMessage
).
getReasoningContent
());
}
result
.
setMessage
(
assistantMessage
.
getText
());
if
(
StringUtils
.
isNotBlank
(
assistantMessage
.
getText
()))
{
contentBuilder
.
append
(
assistantMessage
.
getText
());
}
return
result
;
}).
doOnComplete
(()
->
{
String
fullResponse
=
contentBuilder
.
toString
();
// 异步保存到数据库(添加错误处理)
chatService
.
saveHistoryMemory
(
conversationId
,
new
AssistantMessage
(
fullResponse
));
});
}
@Operation
(
summary
=
"智能数据报表对话"
,
description
=
"智能数据报表对话"
)
@GetMapping
(
value
=
"/chat/report"
,
produces
=
MediaType
.
TEXT_EVENT_STREAM_VALUE
)
public
Flux
<
String
>
reportChat
(
@RequestParam
String
message
,
public
Flux
<
ChatResult
>
reportChat
(
@RequestParam
String
message
,
@RequestParam
String
conversationId
,
@RequestParam
(
required
=
false
)
Optional
<
Long
>
modelId
)
{
Long
actualModelId
=
modelId
.
orElse
(
1L
);
...
...
@@ -188,25 +206,74 @@ public class ChatController {
if
(
Objects
.
isNull
(
chatClient
))
{
return
Flux
.
error
(
new
Throwable
(
"模型创建失败"
));
}
return
FluxUtils
.
wrapDeepSeekStream
(
chatClient
.
prompt
(
prompt
)
return
chatClient
.
prompt
(
prompt
)
.
advisors
(
messageChatMemoryAdvisor
)
.
advisors
(
a
->
a
.
param
(
ChatMemory
.
CONVERSATION_ID
,
conversationId
))
.
tools
(
excelTools
)
.
advisors
()
.
stream
()
.
chatResponse
());
.
chatResponse
()
.
map
(
response
->
{
AssistantMessage
assistantMessage
=
response
.
getResult
().
getOutput
();
ChatResult
result
=
new
ChatResult
();
if
(
assistantMessage
instanceof
DeepSeekAssistantMessage
)
{
result
.
setReasoningContent
(((
DeepSeekAssistantMessage
)
assistantMessage
).
getReasoningContent
());
}
result
.
setMessage
(
assistantMessage
.
getText
());
return
result
;
});
}
public
void
test
()
{
ChatModel
chatModel
=
DeepSeekChatModel
.
builder
()
.
deepSeekApi
(
DeepSeekApi
.
builder
().
baseUrl
(
""
).
apiKey
(
"TEST"
).
build
())
.
defaultOptions
(
DeepSeekChatOptions
.
builder
().
model
(
"deepseek-r1"
).
temperature
(
66.6
).
maxTokens
(
10000
).
build
())
.
build
();
ChatClient
chatClient
=
ChatClient
.
builder
(
chatModel
)
.
defaultAdvisors
()
.
build
();
@Operation
(
summary
=
"智能问数对话"
,
description
=
"智能问数对话"
)
@GetMapping
(
value
=
"/chat/data"
,
produces
=
MediaType
.
TEXT_EVENT_STREAM_VALUE
)
public
Flux
<
ChatResult
>
dataChat
(
@RequestParam
String
message
,
@RequestParam
String
conversationId
,
@RequestParam
(
required
=
false
)
Optional
<
Long
>
modelId
)
{
Long
actualModelId
=
modelId
.
orElse
(
1L
);
Message
systemMessage
=
new
SystemMessage
(
""
+
"【生产记录表 - ask_production_records】\n"
+
"用途:存储各公司及其下属项目的月度生产量数据。\n"
+
"\n"
+
"字段说明:\n"
+
"1. company_name (varchar100) → 公司名称(如:龙源环保) \n"
+
"2. subsidiary_name(varchar100) → 项目(如:脱硝催化剂) \n"
+
"3. unit (varchar50) → 单位(如:立方米、件、平方米) \n"
+
"4. value (numeric10,2) → 月生产量数值(保留两位小数) \n"
+
"5. year (int4) → 年份(如:2024) \n"
+
"6. month (int4) → 月份(1-12) \n"
+
"7. id (int4 PK) → 主键,自增\n"
+
"\n"
+
"【查询规则】\n"
+
"- 每次回答必须先通过 SQL 工具查询此表,禁止口算或推测。 \n"
+
"- 返回结果需附带“单位”字段。 \n"
+
"- 若数据不存在,直接回复“暂无记录”,禁止编造。"
+
"【回答要求】\n"
+
"- 可以采用表格+图表的形式展示数据"
);
Message
userMessage
=
new
UserMessage
(
message
);
Prompt
prompt
=
new
Prompt
(
List
.
of
(
systemMessage
,
userMessage
));
ChatClient
chatClient
=
askModelService
.
getChatClientById
(
actualModelId
);
if
(
Objects
.
isNull
(
chatClient
))
{
return
Flux
.
error
(
new
Throwable
(
"模型创建失败"
));
}
return
chatClient
.
prompt
(
prompt
)
.
advisors
(
messageChatMemoryAdvisor
)
.
advisors
(
a
->
a
.
param
(
ChatMemory
.
CONVERSATION_ID
,
conversationId
))
.
tools
(
sqlTools
)
.
tools
(
echartsTools
)
.
stream
()
.
chatResponse
()
.
map
(
response
->
{
AssistantMessage
assistantMessage
=
response
.
getResult
().
getOutput
();
ChatResult
result
=
new
ChatResult
();
if
(
assistantMessage
instanceof
DeepSeekAssistantMessage
)
{
result
.
setReasoningContent
(((
DeepSeekAssistantMessage
)
assistantMessage
).
getReasoningContent
());
}
result
.
setMessage
(
assistantMessage
.
getText
());
return
result
;
});
}
...
...
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/service/impl/AskModelServiceImpl.java
View file @
18813de1
...
...
@@ -52,7 +52,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override
public
ChatClient
getChatClientById
(
Long
modelId
)
{
AskModel
askModel
=
this
.
getById
(
modelId
);
if
(
askModel
.
getStatus
()==
1
){
if
(
askModel
.
getStatus
()==
0
){
return
null
;
}
IBaseModel
baseModel
=
ModelProviderEnum
.
get
(
askModel
.
getProvider
());
...
...
@@ -67,7 +67,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override
public
EmbeddingModel
getEmbeddingModelById
(
Long
modelId
)
{
AskModel
askModel
=
this
.
getById
(
modelId
);
if
(
askModel
.
getStatus
()==
1
){
if
(
askModel
.
getStatus
()==
0
){
return
null
;
}
IBaseModel
baseModel
=
ModelProviderEnum
.
get
(
askModel
.
getProvider
());
...
...
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/tools/SqlTools.java
View file @
18813de1
package
com
.
ask
.
tools
;
import
com.ask.common.core.FileTemplate
;
import
com.baomidou.mybatisplus.core.toolkit.StringUtils
;
import
lombok.extern.slf4j.Slf4j
;
import
org.springframework.ai.tool.annotation.Tool
;
import
org.springframework.ai.tool.annotation.ToolParam
;
import
org.springframework.beans.factory.annotation.Autowired
;
import
org.springframework.beans.factory.annotation.Value
;
import
org.springframework.dao.DataAccessException
;
import
org.springframework.jdbc.core.JdbcTemplate
;
import
org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate
;
import
org.springframework.stereotype.Component
;
import
org.springframework.util.StopWatch
;
import
java.util.Collections
;
import
java.util.List
;
import
java.util.Map
;
import
java.util.*
;
import
java.util.regex.Pattern
;
/**
* SQL查询工具类
* 提供安全的数据库查询功能,支持参数化查询和结果限制
*
* @author AI Assistant
*/
@Component
@Slf4j
public
class
SqlTools
{
@Autowired
private
JdbcTemplate
jdbcTemplate
;
@Tool
(
description
=
"查询数据库 入参SQL (String)"
)
public
List
<
Map
<
String
,
Object
>>
selectBySql
(
@ToolParam
(
description
=
"SQL query string"
)
String
sql
)
{
@Autowired
private
NamedParameterJdbcTemplate
namedParameterJdbcTemplate
;
@Value
(
"${sql.tools.max-results:1000}"
)
private
int
maxResults
;
@Value
(
"${sql.tools.query-timeout:30}"
)
private
int
queryTimeoutSeconds
;
// SQL注入防护:只允许SELECT语句,禁止危险关键字
private
static
final
Pattern
ALLOWED_SQL_PATTERN
=
Pattern
.
compile
(
"^\\s*SELECT\\s+.*"
,
Pattern
.
CASE_INSENSITIVE
|
Pattern
.
DOTALL
);
private
static
final
Pattern
DANGEROUS_KEYWORDS
=
Pattern
.
compile
(
".*\\b(DELETE|UPDATE|INSERT|DROP|CREATE|ALTER|TRUNCATE|EXEC|EXECUTE|UNION|SCRIPT|DECLARE)\\b.*"
,
Pattern
.
CASE_INSENSITIVE
);
// /**
// * 执行SQL查询
// *
// * @param sql SQL查询语句,只支持SELECT语句
// * @return 查询结果列表,最多返回配置的最大行数
// */
// @Tool(description = "安全执行数据库查询,只支持SELECT语句,返回结果有数量限制")
// public List<Map<String, Object>> selectBySql(
// @ToolParam(description = "SQL查询语句,只支持SELECT语句") String sql) {
//
// return executeQuery(sql, Collections.emptyMap());
// }
/**
* 执行参数化SQL查询
*
* @param sql SQL查询语句,支持命名参数 :paramName
* @param params 参数Map
* @return 查询结果列表
*/
@Tool
(
description
=
"Postgres执行参数化SQL查询,支持命名参数,更安全"
)
public
List
<
Map
<
String
,
Object
>>
selectBySqlWithParams
(
@ToolParam
(
description
=
"SQL查询语句,支持命名参数如 :name"
)
String
sql
,
@ToolParam
(
description
=
"参数Map,key为参数名,value为参数值"
)
Map
<
String
,
Object
>
params
)
{
log
.
info
(
"selectBySqlWithParams:{}"
,
sql
);
return
executeQuery
(
sql
,
params
!=
null
?
params
:
Collections
.
emptyMap
());
}
// /**
// * 执行分页查询
// *
// * @param sql SQL查询语句
// * @param offset 偏移量
// * @param limit 限制数量
// * @return 查询结果列表
// */
// @Tool(description = "执行分页查询")
// public List<Map<String, Object>> selectBySqlWithPaging(
// @ToolParam(description = "SQL查询语句") String sql,
// @ToolParam(description = "偏移量,从0开始") int offset,
// @ToolParam(description = "每页数量,最大1000") int limit) {
//
// // 限制分页参数
// offset = Math.max(0, offset);
// limit = Math.min(Math.max(1, limit), maxResults);
//
// String pagedSql = sql + " LIMIT " + limit + " OFFSET " + offset;
// return executeQuery(pagedSql, Collections.emptyMap());
// }
//
// /**
// * 获取查询结果总数
// *
// * @param sql 原始SQL查询语句
// * @return 总记录数
// */
// @Tool(description = "获取查询结果总数")
// public Long countBySql(@ToolParam(description = "SQL查询语句") String sql) {
//
// if (!isValidSql(sql)) {
// log.warn("无效的SQL语句: {}", sql);
// return 0L;
// }
//
// try {
// // 构建COUNT查询
// String countSql = "SELECT COUNT(*) FROM (" + sql + ") as count_query";
// StopWatch stopWatch = new StopWatch();
// stopWatch.start();
//
// Long count = jdbcTemplate.queryForObject(countSql, Long.class);
//
// stopWatch.stop();
// log.info("COUNT查询执行完成,耗时: {}ms", stopWatch.getTotalTimeMillis());
//
// return count != null ? count : 0L;
//
// } catch (Exception e) {
// log.error("COUNT查询异常: sql={}, error={}", sql, e.getMessage(), e);
// return 0L;
// }
// }
/**
* 执行查询的核心方法
*/
private
List
<
Map
<
String
,
Object
>>
executeQuery
(
String
sql
,
Map
<
String
,
Object
>
params
)
{
if
(
StringUtils
.
isEmpty
(
sql
))
{
log
.
warn
(
"SQL语句为空"
);
return
Collections
.
emptyList
();
}
if
(!
isValidSql
(
sql
))
{
log
.
warn
(
"SQL安全检查失败: {}"
,
sql
);
return
Collections
.
emptyList
();
}
try
{
return
jdbcTemplate
.
queryForList
(
sql
);
StopWatch
stopWatch
=
new
StopWatch
();
stopWatch
.
start
();
// 设置查询超时
jdbcTemplate
.
setQueryTimeout
(
queryTimeoutSeconds
);
List
<
Map
<
String
,
Object
>>
results
;
if
(
params
.
isEmpty
())
{
results
=
jdbcTemplate
.
queryForList
(
sql
);
}
else
{
results
=
namedParameterJdbcTemplate
.
queryForList
(
sql
,
params
);
}
stopWatch
.
stop
();
// 限制结果数量
if
(
results
.
size
()
>
maxResults
)
{
log
.
warn
(
"查询结果超过最大限制 {}, 实际数量: {}, 已截取"
,
maxResults
,
results
.
size
());
results
=
results
.
subList
(
0
,
maxResults
);
}
log
.
info
(
"SQL查询执行完成 - 耗时: {}ms, 结果数量: {}, SQL: {}"
,
stopWatch
.
getTotalTimeMillis
(),
results
.
size
(),
sql
.
length
()
>
100
?
sql
.
substring
(
0
,
100
)
+
"..."
:
sql
);
return
results
;
}
catch
(
DataAccessException
e
)
{
log
.
error
(
"数据库访问异常: sql={}, params={}, error={}"
,
sql
,
params
,
e
.
getMessage
());
return
Collections
.
emptyList
();
}
catch
(
Exception
e
)
{
log
.
error
(
"
数据库查询异常:sql{} 异常: {}"
,
sql
,
e
.
getMessage
()
);
log
.
error
(
"
SQL查询异常: sql={}, params={}, error={}"
,
sql
,
params
,
e
.
getMessage
(),
e
);
return
Collections
.
emptyList
();
}
}
/**
* SQL安全性验证
*/
private
boolean
isValidSql
(
String
sql
)
{
if
(
StringUtils
.
isEmpty
(
sql
))
{
return
false
;
}
String
trimmedSql
=
sql
.
trim
();
// 检查是否为SELECT语句
if
(!
ALLOWED_SQL_PATTERN
.
matcher
(
trimmedSql
).
matches
())
{
log
.
warn
(
"只允许SELECT语句: {}"
,
sql
);
return
false
;
}
// 检查危险关键字
if
(
DANGEROUS_KEYWORDS
.
matcher
(
trimmedSql
).
matches
())
{
log
.
warn
(
"SQL包含危险关键字: {}"
,
sql
);
return
false
;
}
// 检查SQL长度
if
(
trimmedSql
.
length
()
>
10000
)
{
log
.
warn
(
"SQL语句过长: {} characters"
,
trimmedSql
.
length
());
return
false
;
}
return
true
;
}
// /**
// * 获取数据库表信息
// */
// @Tool(description = "获取数据库中的表列表信息")
// public List<Map<String, Object>> getTableInfo(@ToolParam(description = "表名模式,支持%通配符") String tableNamePattern) {
//
// if (StringUtils.isEmpty(tableNamePattern)) {
// tableNamePattern = "%";
// }
//
// String sql = "SELECT table_name, table_comment, table_type " +
// "FROM information_schema.tables " +
// "WHERE table_schema = DATABASE() " +
// "AND table_name LIKE ? " +
// "ORDER BY table_name";
//
// try {
// return jdbcTemplate.queryForList(sql, tableNamePattern);
// } catch (Exception e) {
// log.error("获取表信息异常: pattern={}, error={}", tableNamePattern, e.getMessage());
// return Collections.emptyList();
// }
// }
//
// /**
// * 获取表的列信息
// */
// @Tool(description = "获取指定表的列信息")
// public List<Map<String, Object>> getTableColumns(@ToolParam(description = "表名") String tableName) {
//
// if (StringUtils.isEmpty(tableName)) {
// return Collections.emptyList();
// }
//
// String sql = "SELECT column_name, data_type, is_nullable, column_default, column_comment " +
// "FROM information_schema.columns " +
// "WHERE table_schema = DATABASE() " +
// "AND table_name = ? " +
// "ORDER BY ordinal_position";
//
// try {
// return jdbcTemplate.queryForList(sql, tableName);
// } catch (Exception e) {
// log.error("获取表列信息异常: table={}, error={}", tableName, e.getMessage());
// return Collections.emptyList();
// }
// }
}
ask-data-ai/ask-data-ai-biz/src/main/java/com/ask/utils/FluxUtils.java
View file @
18813de1
package
com
.
ask
.
utils
;
import
com.baomidou.mybatisplus.core.toolkit.StringUtils
;
import
org.springframework.ai.chat.messages.AssistantMessage
;
import
org.springframework.ai.chat.model.ChatResponse
;
import
org.springframework.ai.deepseek.DeepSeekAssistantMessage
;
import
reactor.core.publisher.Flux
;
import
java.lang.reflect.Field
;
import
java.util.List
;
import
java.util.concurrent.atomic.AtomicBoolean
;
public
class
FluxUtils
{
/**
* 将 DeepSeek 的 Flux<ChatResponse> 转换成带 <think>/<answer> 的 Flux<String>
* @param upstream 原始 SSE 流
* @return 带标签的逐块流
*/
public
static
Flux
<
String
>
wrapDeepSeekStream
(
Flux
<
ChatResponse
>
upstream
)
{
AtomicBoolean
reasoningStarted
=
new
AtomicBoolean
(
false
);
AtomicBoolean
answerStarted
=
new
AtomicBoolean
(
false
);
return
upstream
.
flatMapIterable
(
resp
->
{
AssistantMessage
msg
=
resp
.
getResult
().
getOutput
();
String
reasoningContent
=
""
;
String
textContent
=
msg
.
getText
();
// 普通回答
try
{
// 反射读取 DeepSeekAssistantMessage.reasoningContent
Field
f
=
msg
.
getClass
().
getDeclaredField
(
"reasoningContent"
);
f
.
setAccessible
(
true
);
reasoningContent
=
(
String
)
f
.
get
(
msg
);
}
catch
(
Exception
ignored
)
{
/* 不是 DeepSeekAssistantMessage 时留空 */
}
StringBuilder
sb
=
new
StringBuilder
();
// 推理阶段
if
(!
reasoningStarted
.
get
())
{
reasoningStarted
.
set
(
true
);
sb
.
append
(
"<think>"
);
}
if
(
StringUtils
.
isNotBlank
(
reasoningContent
))
{
sb
.
append
(
reasoningContent
);
}
// 回答阶段
if
(
StringUtils
.
isNotBlank
(
textContent
))
{
if
(!
answerStarted
.
get
())
{
answerStarted
.
set
(
true
);
sb
.
append
(
"</think><answer>"
);
}
sb
.
append
(
textContent
);
}
return
List
.
of
(
sb
.
toString
());
})
.
concatWith
(
Flux
.
just
(
"</answer>"
));
}
public
static
Flux
<
String
>
wrapDeepSeekStream
(
Flux
<
ChatResponse
>
upstream
,
StringBuilder
stringBuilder
)
{
AtomicBoolean
reasoningStarted
=
new
AtomicBoolean
(
false
);
AtomicBoolean
answerStarted
=
new
AtomicBoolean
(
false
);
return
upstream
.
flatMapIterable
(
resp
->
{
AssistantMessage
msg
=
resp
.
getResult
().
getOutput
();
String
reasoningContent
=
""
;
String
textContent
=
msg
.
getText
();
// 普通回答
try
{
// 反射读取 DeepSeekAssistantMessage.reasoningContent
Field
f
=
msg
.
getClass
().
getDeclaredField
(
"reasoningContent"
);
f
.
setAccessible
(
true
);
reasoningContent
=
(
String
)
f
.
get
(
msg
);
}
catch
(
Exception
ignored
)
{
/* 不是 DeepSeekAssistantMessage 时留空 */
}
StringBuilder
sb
=
new
StringBuilder
();
// 推理阶段
if
(!
reasoningStarted
.
get
())
{
reasoningStarted
.
set
(
true
);
sb
.
append
(
"<think>"
);
}
if
(
StringUtils
.
isNotBlank
(
reasoningContent
))
{
sb
.
append
(
reasoningContent
);
}
// 回答阶段:第一次出现答案时输出 </think><answer>
if
(
StringUtils
.
isNotBlank
(
textContent
))
{
stringBuilder
.
append
(
textContent
);
if
(
answerStarted
.
compareAndSet
(
false
,
true
))
{
sb
.
append
(
"</think><answer>"
);
}
sb
.
append
(
textContent
);
}
return
List
.
of
(
sb
.
toString
());
})
.
concatWith
(
Flux
.
just
(
"</answer>"
));
}
}
//package com.ask.utils;
//
//import com.baomidou.mybatisplus.core.toolkit.StringUtils;
//import org.springframework.ai.chat.messages.AssistantMessage;
//import org.springframework.ai.chat.model.ChatResponse;
//import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
//import reactor.core.publisher.Flux;
//
//import java.lang.reflect.Field;
//import java.util.List;
//import java.util.concurrent.atomic.AtomicBoolean;
//
//public class FluxUtils {
//
// /**
// * 将 DeepSeek 的 Flux<ChatResponse> 转换成带 <think>/<answer> 的 Flux<String>
// *
// * @param upstream 原始 SSE 流
// * @return 带标签的逐块流
// */
// public static Flux<String> wrapModelStream(Flux<ChatResponse> upstream) {
//
// return upstream
// .flatMapIterable(resp -> {
// AssistantMessage msg = resp.getResult().getOutput();
//
// String reasoningContent = "";
// String textContent = msg.getText(); // 普通回答
//
// try {
// // 反射读取 DeepSeekAssistantMessage.reasoningContent
// Field f = msg.getClass().getDeclaredField("reasoningContent");
// f.setAccessible(true);
// reasoningContent = (String) f.get(msg);
// } catch (Exception ignored) {
// }
//
// return List.of(sb.toString());
// });
// }
//
//
//}
ask-data-ai/ask-data-ai-boot/src/main/resources/application.yml
View file @
18813de1
...
...
@@ -109,6 +109,14 @@ logging:
pattern
:
console
:
"
%d{HH:mm:ss.SSS}
[%thread]
%-5level
%logger{36}
-
%msg%n"
# SQL工具配置
sql
:
tools
:
# 最大查询结果数量限制
max-results
:
1000
# 查询超时时间(秒)
query-timeout
:
30
# 本地文件系统
file
:
local
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment