Commit d82e9721 authored by 林洋洋's avatar 林洋洋

AI对话优化

parent 8e6a9177
......@@ -2,11 +2,13 @@ package com.ask.config;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
......@@ -24,41 +26,48 @@ public class CommonConfiguration {
@Bean
public ChatMemory chatMemory (JdbcTemplate jdbcTemplate,PostgresChatMemoryDialect postgresChatMemoryDialect) {
ChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(postgresChatMemoryDialect)
.build();
public ChatMemory chatMemory(JdbcTemplate jdbcTemplate, PostgresChatMemoryDialect postgresChatMemoryDialect) {
ChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(postgresChatMemoryDialect)
.build();
return MessageWindowChatMemory.builder()
.chatMemoryRepository(chatMemoryRepository)
.maxMessages(10)
.build();
return MessageWindowChatMemory.builder()
.chatMemoryRepository(chatMemoryRepository)
.maxMessages(5)
.build();
}
@Bean
public ChatClient chatClient(OpenAiChatModel model, ChatMemory chatMemory) {
List<Advisor> advisors = new ArrayList<>();
Advisor messageChatMemoryAdvisor =MessageChatMemoryAdvisor.builder(chatMemory).build();
advisors.add(messageChatMemoryAdvisor);
return ChatClient.builder(model)
.defaultAdvisors(advisors)
.defaultAdvisors().build();
.defaultAdvisors()
.build();
}
@Bean
public PromptChatMemoryAdvisor promptChatMemoryAdvisor(ChatMemory chatMemory){
return PromptChatMemoryAdvisor.builder(chatMemory).build();
}
@Bean
public RetrievalAugmentationAdvisor retrievalAugmentationAdvisor(VectorStore vectorStore) {
@Bean
public MessageChatMemoryAdvisor messageChatMemoryAdvisor(ChatMemory chatMemory){
return MessageChatMemoryAdvisor.builder(chatMemory).build();
}
return RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder()
.similarityThreshold(0.60)
.topK(5)
.vectorStore(vectorStore)
.build())
.queryAugmenter(ContextualQueryAugmenter.builder()
.allowEmptyContext(true)
.build())
.build();
}
@Bean
public RetrievalAugmentationAdvisor retrievalAugmentationAdvisor(VectorStore vectorStore) {
return RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder()
.similarityThreshold(0.70)
.topK(5)
.vectorStore(vectorStore)
.build())
.documentPostProcessors(new MyDocumentPostProcessor())
.queryAugmenter(ContextualQueryAugmenter.builder()
.allowEmptyContext(true)
.build())
.build();
}
}
//package com.ask.config;
//
//import org.springframework.ai.chat.client.advisor.api.Advisor;
//import org.springframework.ai.chat.messages.UserMessage;
//import org.springframework.ai.chat.prompt.Prompt;
//
//public class CustomPromptAdvisor implements Advisor {
//
//
//// // 拼接自定义prompt
//// String finalPrompt = String.format(
//// "【系统提示】\n%s\n\n【知识库内容】\n%s\n\n【历史对话】\n%s\n\n【用户问题】\n%s\n\n请结合以上内容,专业、简明、准确地用中文回答用户问题。如果知识库内容无法回答,请如实说明。",
//// instructions, this.context, this.memory, question
//// );
//}
\ No newline at end of file
package com.ask.config;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
public class MyDocumentPostProcessor implements DocumentPostProcessor {
@Override
public List<Document> process(Query query, List<Document> documents) {
log.info("问题 {} 召回 {} 个文档",query,documents.size());
return documents.stream()
.collect(Collectors.collectingAndThen(
Collectors.toMap(d -> d.getMetadata().get("id"), d -> d, (d1, d2) -> d1),
m -> new ArrayList<>(m.values())
));
}
}
\ No newline at end of file
......@@ -10,6 +10,8 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
......@@ -40,6 +42,9 @@ public class ChatController {
private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;
private final MessageChatMemoryAdvisor messageChatMemoryAdvisor;
private final PromptChatMemoryAdvisor promptChatMemoryAdvisor;
/**
* 获取会话ID
* @return 新的会话ID
......@@ -87,32 +92,34 @@ public class ChatController {
// 创建提示,包含系统消息和用户消息
Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// 使用修改后的提示获取响应
return chatClient.prompt(prompt).advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)).stream().content();
return chatClient.prompt(prompt).advisors(messageChatMemoryAdvisor).advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)).stream().content();
}
/**
* 最基本的AI流式输出对话
* 知识库对话
*
* * @param message
* @return
*/
@Operation(summary = "知识库对话", description = "知识库对话")
@GetMapping(value = "/rag/chat", produces = "application/stream+json")
public Flux<ChatResponse> ragChat(String message, String conversationId) {
@GetMapping(value = "/rag/chat", produces = "text/html;charset=utf-8")
public Flux<String> ragChat(String message, String conversationId) {
// 创建系统消息,告诉大模型只返回工具名和参数
Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:");
Message userMessage = new UserMessage(message);
// 创建提示,包含系统消息和用户消息
Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:");
//
// Message userMessage = new UserMessage(message);
// // 创建提示,包含系统消息和用户消息
// Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// 使用修改后的提示获取响应
FilterExpressionBuilder builder = new FilterExpressionBuilder();
Filter.Expression filter = builder.eq("isEnabled",1).build();
// List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder().query(message).filterExpression(filter).similarityThreshold(0.75).topK(5).build());
return chatClient.prompt(prompt)
return chatClient.prompt(message)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.advisors(promptChatMemoryAdvisor) //会话记忆
.advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, filter))
.advisors(retrievalAugmentationAdvisor)
.stream().chatResponse();
.advisors(retrievalAugmentationAdvisor) //知识库召回
.stream().content();
}
}
\ No newline at end of file
......@@ -188,6 +188,8 @@ public class KnowledgeBaseController {
new Filter.Value(1)
)
);
// 执行向量搜索
List<Document> searchResults = vectorStore.similaritySearch(SearchRequest.builder().query(content).filterExpression(filterExpression).similarityThreshold(similarityThreshold).topK(topK).build());
......
server:
port: 9999
servlet:
context-path: /admin # 项目访问路径
spring:
application:
name: ask-data-ai
# 文件上传配置
servlet:
multipart:
enabled: true
max-file-size: 100MB # 单个文件最大大小
max-request-size: 500MB # 请求最大大小
file-size-threshold: 0 # 文件写入磁盘的阈值
datasource:
url: jdbc:postgresql://81.70.183.25:25432/ask_data_ai_db
username: postgres
password: postgres123
driver-class-name: org.postgresql.Driver
ai:
vectorstore:
pgvector:
index-type: HNSW
distance-type: COSINE_DISTANCE
dimensions: 1024
max-document-batch-size: 10000 # Optional: Maximum number of documents per batch
schema-name: public
table-name: ask_vector_store
chat:
memory:
repository:
jdbc:
initialize-schema: never # 开发环境可以使用 always,方便调试
platform: postgresql
openai:
base-url: https://dashscope.aliyuncs.com/compatible-mode
api-key: sk-ae96ff281ff644c992843c64a711a950
chat:
options:
model: qwen-plus
embedding:
base-url: https://dashscope.aliyuncs.com/compatible-mode
api-key: sk-ae96ff281ff644c992843c64a711a950
options:
model: text-embedding-v4
mybatis-plus:
mapper-locations: classpath*:/mapper/*Mapper.xml # mapper文件位置
global-config:
banner: false # 是否打印 mybatis-plus banner
db-config:
table-prefix: ask_ # 表名前缀
logic-delete-field: delFlag # 全局逻辑删除字段
logic-delete-value: 1 # 逻辑已删除值(默认为 1)
logic-not-delete-value: 0 # 逻辑未删除值(默认为 0)
insert-strategy: not_null # 插入策略
update-strategy: not_null # 更新策略
where-strategy: not_null # 查询策略
id-type: assign_uuid # 主键策略
configuration:
map-underscore-to-camel-case: true # 开启驼峰命名
cache-enabled: false # 是否开启二级缓存
call-setters-on-nulls: true # 是否在字段为 null 时调用 setter 方法
jdbc-type-for-null: 'null' # 指定当结果集中值为 null 的时候如何处理
log-impl: org.apache.ibatis.logging.stdout.StdOutImpl # 日志实现
# 启动时处理器配置
startup:
# 是否启用启动时向量化处理器
enabled: true
# 启动延迟时间(秒),避免启动时立即执行
delay_seconds: 30
# 每次提交任务间的延迟(毫秒)
task_interval_ms: 100
# 最大处理文档数量限制
max_documents: 100
# springdoc-openapi项目配置
springdoc:
swagger-ui:
enabled: true # 开启swagger-ui
path: /swagger-ui.html # 配置访问路径
api-docs:
enabled: true # 开启api-docs
path: /v3/api-docs # 配置访问路径
group-configs:
- group: 'default'
paths-to-match: '/**'
packages-to-scan: com.ask
default-produces-media-type: application/json
default-consumes-media-type: application/json
# knife4j的增强配置,不需要增强可以不配
knife4j:
enable: true
setting:
language: zh_cn
enable-swagger-models: true
enable-document-manage: true
swagger-model-name: 实体类列表
enable-version: false
enable-reload-cache-parameter: false
enable-after-script: false
enable-filter-multipart-api-method-type: POST
enable-filter-multipart-apis: false
enable-request-cache: true
enable-host: false
enable-host-text:
# swagger配置
swagger:
enabled: true
title: Ask Data AI接口文档
gateway: http://localhost:${server.port}/admin
token-url: ${swagger.gateway}/oauth2/token
scope: server
# 日志配置
logging:
level:
root: INFO
com.ask: DEBUG
org.springframework.ai: INFO
org.postgresql: WARN
pattern:
console: "%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n"
# 本地文件系统
file:
local:
enable: true
base-path: /app/upFiles
<?xml version="1.0" encoding="UTF-8"?>
<configuration debug="false">
<!--定义日志文件的存储地址 -->
<property name="LOG_HOME" value="logs" />
<!--<property name="COLOR_PATTERN" value="%black(%contextName-) %red(%d{yyyy-MM-dd HH:mm:ss}) %green([%thread]) %highlight(%-5level) %boldMagenta( %replace(%caller{1}){'\t|Caller.{1}0|\r\n', ''})- %gray(%msg%xEx%n)" />-->
<!-- 控制台输出 -->
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
<!--格式化输出:%d表示日期,%thread表示线程名,%-5level:级别从左显示5个字符宽度%msg:日志消息,%n是换行符
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{50}:%L - %msg%n</pattern>-->
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %highlight(%-5level) %cyan(%logger{50}:%L) - %msg%n</pattern>
</encoder>
</appender>
<!-- 按照每天生成日志文件 -->
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
<rollingPolicy class="ch.qos.logback.core.rolling.SizeAndTimeBasedRollingPolicy">
<!--日志文件输出的文件名 -->
<FileNamePattern>${LOG_HOME}/ask-%d{yyyy-MM-dd}.%i.log</FileNamePattern>
<!--日志文件保留天数 -->
<MaxHistory>30</MaxHistory>
<maxFileSize>10MB</maxFileSize>
</rollingPolicy>
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
<!--格式化输出:%d表示日期,%thread表示线程名,%-5level:级别从左显示5个字符宽度%msg:日志消息,%n是换行符 -->
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{50}:%L - %msg%n</pattern>
</encoder>
</appender>
<!-- 每天生成一个html格式的日志结束 -->
<!--myibatis log configure -->
<logger name="com.apache.ibatis" level="TRACE" />
<logger name="java.sql.Connection" level="DEBUG" />
<logger name="java.sql.Statement" level="DEBUG" />
<logger name="java.sql.PreparedStatement" level="DEBUG" />
<!-- 日志输出级别 -->
<root level="INFO">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
<logger name="okhttp3" level="ERROR"/>
</configuration>
\ No newline at end of file
# swagger 配置
swagger:
enabled: true
title: Pig Swagger API
gateway: http://${GATEWAY-HOST:127.0.0.1}:${GATEWAY-PORT:9999}/admin
token-url: ${swagger.gateway}/admin/oauth2/token
scope: server
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment