Commit 2ab2a099 authored by 林洋洋's avatar 林洋洋

添加普通AI对话

parent aec5a2d5
......@@ -39,7 +39,7 @@ public class GenDatasourceConf extends Model<GenDatasourceConf> {
/**
* 主键
*/
@TableId(type = IdType.ASSIGN_ID)
@TableId(type = IdType.AUTO)
private Long id;
/**
......
......@@ -45,7 +45,7 @@ public class GenTable extends Model<GenTable> {
/**
* id
*/
@TableId(type = IdType.ASSIGN_ID)
@TableId(type = IdType.AUTO)
@Schema(description = "id")
private Long id;
......
......@@ -37,7 +37,7 @@ public class GenTableColumnEntity extends Model<GenDatasourceConf> {
/**
* 主键
*/
@TableId(type = IdType.ASSIGN_ID)
@TableId(type = IdType.AUTO)
private Long id;
/**
......
......@@ -29,11 +29,22 @@
<description>pig AI智能问答管理模块</description>
<properties>
<spring-ai.version>1.1.0-SNAPSHOT</spring-ai.version>
<anyline.version>8.7.2-jdk17-20240808</anyline.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>${spring-ai.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<!-- ask api、model 模块-->
<!-- ask api、model 模块-->
<dependency>
<groupId>com.pig4cloud</groupId>
<artifactId>pig-ask-api</artifactId>
......@@ -127,7 +138,53 @@
<artifactId>anyline-data-jdbc-mysql</artifactId>
<version>${anyline.version}</version>
</dependency>
<!-- AI模块-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-openai</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-mcp-client-webflux</artifactId>
</dependency>
<!-- RAG-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-vector-store-pgvector</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-advisors-vector-store</artifactId>
</dependency>
<!-- 对话记忆-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-chat-memory-repository-jdbc</artifactId>
</dependency>
</dependencies>
<repositories>
<repository>
<id>spring-snapshots</id>
<name>Spring Snapshots</name>
<url>https://repo.spring.io/snapshot</url>
<releases>
<enabled>false</enabled>
</releases>
</repository>
<repository>
<name>Central Portal Snapshots</name>
<id>central-portal-snapshots</id>
<url>https://central.sonatype.com/repository/maven-snapshots/</url>
<releases>
<enabled>false</enabled>
</releases>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>
<profiles>
<profile>
......
package com.pig4cloud.pig.ask.config;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.jdbc.PostgresChatMemoryRepositoryDialect;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;
import java.util.ArrayList;
import java.util.List;
@Configuration
public class CommonConfiguration {
@Bean
public ChatMemory chatMemory (JdbcTemplate jdbcTemplate,PostgresChatMemoryDialect postgresChatMemoryDialect) {
ChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(postgresChatMemoryDialect)
.build();
return MessageWindowChatMemory.builder()
.chatMemoryRepository(chatMemoryRepository)
.maxMessages(10)
.build();
}
@Bean
public ChatClient chatClient(OpenAiChatModel model, ChatMemory chatMemory) {
List<Advisor> advisors = new ArrayList<>();
Advisor messageChatMemoryAdvisor =MessageChatMemoryAdvisor.builder(chatMemory).build();
advisors.add(messageChatMemoryAdvisor);
// Advisor questionAnswerAdvisor =QuestionAnswerAdvisor.builder(vectorStore).searchRequest(SearchRequest.builder().build()).build();
// advisors.add(questionAnswerAdvisor);
return ChatClient.builder(model)
.defaultAdvisors(advisors)
// .defaultToolCallbacks(toolCallbackProvider)
.defaultAdvisors().build();
}
}
package com.pig4cloud.pig.ask.config;
import com.pig4cloud.pig.common.security.util.SecurityUtils;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepositoryDialect;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@Component
public class PostgresChatMemoryDialect implements JdbcChatMemoryRepositoryDialect {
@Override
public String getSelectMessagesSql() {
return "SELECT content, type FROM ask_chat_conversation_detail WHERE conversation_id = ? ORDER BY \"timestamp\"";
}
@Override
public String getInsertMessageSql() {
return "INSERT INTO ask_chat_conversation_detail (conversation_id, content, type, \"timestamp\") VALUES ( ?, ?, ?, ?)";
}
@Override
public String getSelectConversationIdsSql() {
return "SELECT DISTINCT conversation_id FROM ask_chat_conversation_detail";
}
@Override
public String getDeleteMessagesSql() {
return "UPDATE ask_chat_conversation_detail set del_flag = '1' WHERE conversation_id = ? ";
}
}
package com.pig4cloud.pig.ask.controller;
import com.alibaba.nacos.common.utils.UuidUtils;
import com.pig4cloud.pig.ask.api.entity.ChatConversation;
import com.pig4cloud.pig.ask.service.ChatConversationService;
import com.pig4cloud.pig.common.core.util.R;
import io.swagger.v3.oas.annotations.security.SecurityRequirement;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.servlet.http.HttpServletResponse;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.http.HttpHeaders;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.Arrays;
import java.util.UUID;
@Slf4j
@RestController
@RequiredArgsConstructor
@RequestMapping("/chat/ai")
@Tag(description = "ai", name = "AI对话模块")
@SecurityRequirement(name = HttpHeaders.AUTHORIZATION)
public class ChatController {
private final ChatClient chatClient;
private final ChatConversationService chatConversationService;
/**
* 获取会话ID
* @return 新的会话ID
*/
@GetMapping("/conversation/id")
public R<ChatConversation> getConversationId(@RequestParam Integer agentId) {
ChatConversation chatConversation =new ChatConversation();
String conversationId=UuidUtils.generateUuid().replaceAll("-","");
chatConversation.setConversationId(conversationId);
chatConversation.setAgentId(agentId);
chatConversationService.save(chatConversation);
return R.ok(chatConversation);
}
/**
* 最基本的AI流式输出对话
*
* * @param message
* @return
*/
@GetMapping(value = "/chat", produces = "text/html;charset=utf-8")
public Flux<String> chat(String message,String conversationId) {
// 创建系统消息,告诉大模型只返回工具名和参数
Message systemMessage = new SystemMessage("你是一个AI客服助手,请按照用户提问的问题回答,回答内容务必使用markdown格式。");
// 用户消息
Message userMessage = new UserMessage(message);
// 创建提示,包含系统消息和用户消息
Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// 使用修改后的提示获取响应
return chatClient.prompt(prompt).advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)).stream().content();
}
}
\ No newline at end of file
......@@ -40,8 +40,10 @@ import com.pig4cloud.pig.ask.api.enums.CommonColumnFiledEnum;
import com.pig4cloud.pig.ask.mapper.GenTableMapper;
import com.pig4cloud.pig.ask.service.GenTableService;
import com.pig4cloud.pig.ask.utils.DataSourceQueryUtils;
import com.pig4cloud.pig.common.datasource.enums.DsJdbcUrlEnum;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.anyline.metadata.Column;
import org.anyline.metadata.Database;
import org.anyline.metadata.Schema;
......@@ -63,13 +65,11 @@ import java.util.*;
* @author lengleng
* @date 2025/05/31
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class GenTableServiceImpl extends ServiceImpl<GenTableMapper, GenTable> implements GenTableService {
private final JdbcTemplate jdbcTemplate;
// private final GenGroupService genGroupService;
private final DataSourceQueryUtils dataSourceQueryUtils;
/**
* 查询表ddl 语句
......@@ -113,13 +113,12 @@ public class GenTableServiceImpl extends ServiceImpl<GenTableMapper, GenTable> i
*/
@Override
public IPage<TableDto> queryTablePage(Page<TableDto> page, TableParam table) {
// 手动切换数据源
DynamicDataSourceContextHolder.push(table.getDsName());
CacheProxy.clear();
String sql = null;
Map<String, Object> params = new HashMap<>();
if (DsJdbcUrlEnum.MYSQL.getDbName().equals(table.getDbType())) {
sql = "SELECT table_name,table_comment,create_time FROM information_schema.tables WHERE table_schema = (SELECT database()) AND table_type = 'BASE TABLE' ORDER BY table_name;";
sql = "SELECT table_name,table_comment,create_time FROM information_schema.tables WHERE table_schema = (SELECT database()) AND table_type = 'BASE TABLE' ORDER BY table_name;";
params.put("table_name", table.getTableName());
} else if (DsJdbcUrlEnum.PG.getDbName().equals(table.getDbType())) {
sql = """
SELECT
......@@ -130,30 +129,18 @@ public class GenTableServiceImpl extends ServiceImpl<GenTableMapper, GenTable> i
WHERE
t.table_schema = current_schema()
AND t.table_type = 'BASE TABLE'
AND (:tableName IS NULL OR t.table_name ILIKE :tableName)
ORDER BY
t.table_name
""";
params.put("tableName",
StrUtil.isBlank(table.getTableName()) ? null : "%" + table.getTableName() + "%");
}
if (StringUtils.isBlank(sql)) {
return new Page<>(page.getCurrent(), page.getSize());
}
List<TableDto> tableList = jdbcTemplate.query(sql, (rs, rowNum) -> {
TableDto t = new TableDto();
t.setTableName(rs.getString("table_name"));
t.setTableComment(rs.getString("table_comment"));
return t;
});
tableList = tableList.stream().filter(t->{
if(StringUtils.isBlank(table.getTableName())){
return true;
}
return t.getTableName().equals(table.getTableName());
}).toList();
// 根据 page 进行分页
List<TableDto> records = CollUtil.page((int) page.getCurrent() - 1, (int) page.getSize(), tableList);
page.setTotal(tableList.size());
page.setRecords(records);
return page;
return dataSourceQueryUtils.executePageQuery(page,table.getDsName(),sql,TableDto.class);
}
......@@ -168,7 +155,6 @@ public class GenTableServiceImpl extends ServiceImpl<GenTableMapper, GenTable> i
// 手动切换数据源
DynamicDataSourceContextHolder.push(dsName);
CacheProxy.clear();
AnylineService.MetaDataService metadata = ServiceProxy.metadata();
return ServiceProxy.metadata().tables().values().stream().map(Table::getName).toList();
}
......
package com.pig4cloud.pig.ask.utils;
import com.baomidou.dynamic.datasource.toolkit.DynamicDataSourceContextHolder;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.BeanPropertyRowMapper;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
@Slf4j
@Component
@RequiredArgsConstructor
public class DataSourceQueryUtils {
private final JdbcTemplate jdbcTemplate;
/**
* 通用数据源切换查询方法
*
* @param dsName 数据源名称
* @param operation 具体的查询操作
* @return 查询结果
*/
public <T> T executeQuery(String dsName, Supplier<T> operation) {
String previousDs = DynamicDataSourceContextHolder.peek();
try {
DynamicDataSourceContextHolder.push(dsName);
return operation.get();
} finally {
DynamicDataSourceContextHolder.poll();
if (StringUtils.isNotBlank(previousDs)) {
DynamicDataSourceContextHolder.push(previousDs);
}
}
}
/**
* 分页查询通用方法
*/
public <T> IPage<T> executePageQuery(Page<T> page, String dsName, String sql, Class<T> clazz) {
return executeQuery(dsName, () -> {
List<T> list = jdbcTemplate.query(sql, new BeanPropertyRowMapper<>(clazz));
return handlePaging(page, list);
});
}
/**
* 带参数的分页查询
*/
public <T> IPage<T> executePageQuery(Page<T> page, String dsName, String sql,
Map<String, Object> params, Class<T> clazz) {
return executeQuery(dsName, () -> {
NamedParameterJdbcTemplate namedTemplate = new NamedParameterJdbcTemplate(jdbcTemplate);
List<T> list = namedTemplate.query(sql, params, new BeanPropertyRowMapper<>(clazz));
return handlePaging(page, list);
});
}
/**
* 处理分页
*/
private <T> IPage<T> handlePaging(Page<T> page, List<T> list) {
int fromIndex = (int) ((page.getCurrent() - 1) * page.getSize());
int toIndex = Math.min(fromIndex + (int) page.getSize(), list.size());
List<T> records = fromIndex < list.size()
? list.subList(fromIndex, toIndex)
: new ArrayList<>();
page.setTotal(list.size());
page.setRecords(records);
return page;
}
/**
* 执行更新操作
*/
public int executeUpdate(String dsName, String sql, Map<String, Object> params) {
return executeQuery(dsName, () -> {
NamedParameterJdbcTemplate namedTemplate = new NamedParameterJdbcTemplate(jdbcTemplate);
return namedTemplate.update(sql, params);
});
}
}
......@@ -18,7 +18,32 @@ spring:
username: postgres
password: postgres123
driver-class-name: org.postgresql.Driver
ai:
vectorstore:
pgvector:
index-type: HNSW
distance-type: COSINE_DISTANCE
dimensions: 1024
max-document-batch-size: 10000 # Optional: Maximum number of documents per batch
schema-name: public
table-name: vector_store
chat:
memory:
repository:
jdbc:
initialize-schema: never # 开发环境可以使用 always,方便调试
platform: postgresql
openai:
base-url: https://dashscope.aliyuncs.com/compatible-mode
api-key: sk-ae96ff281ff644c992843c64a711a950
chat:
options:
model: qwen-plus
embedding:
base-url: https://dashscope.aliyuncs.com/compatible-mode
api-key: sk-ae96ff281ff644c992843c64a711a950
options:
model: text-embedding-v4
# 本地文件系统
file:
local:
......
......@@ -3,5 +3,5 @@ swagger:
enabled: true
title: Pig Swagger API
gateway: http://${GATEWAY-HOST:127.0.0.1}:${GATEWAY-PORT:9999}
token-url: ${swagger.gateway}/auth/oauth2/token
token-url: ${swagger.gateway}/admin/oauth2/token
scope: server
......@@ -46,7 +46,7 @@ public class SysDept extends Model<SysDept> {
private static final long serialVersionUID = 1L;
@TableId(value = "dept_id", type = IdType.ASSIGN_ID)
@TableId(value = "dept_id", type = IdType.AUTO)
@Schema(description = "部门id")
private Long deptId;
......
......@@ -40,7 +40,7 @@ public class SysDict extends Model<SysDict> {
/**
* 编号
*/
@TableId(type = IdType.ASSIGN_ID)
@TableId(type = IdType.AUTO)
@Schema(description = "字典编号")
private Long id;
......
......@@ -41,7 +41,7 @@ public class SysDictItem extends Model<SysDictItem> {
/**
* 编号
*/
@TableId(type = IdType.ASSIGN_ID)
@TableId(type = IdType.AUTO)
@Schema(description = "字典项id")
private Long id;
......
......@@ -41,7 +41,7 @@ public class SysFile extends Model<SysFile> {
/**
* 编号
*/
@TableId(type = IdType.ASSIGN_ID)
@TableId(type = IdType.AUTO)
@Schema(description = "文件编号")
private Long id;
......
......@@ -46,7 +46,7 @@ public class SysLog implements Serializable {
/**
* 编号
*/
@TableId(type = IdType.ASSIGN_ID)
@TableId(type = IdType.AUTO)
@ExcelProperty("日志编号")
@Schema(description = "日志编号")
private Long id;
......
......@@ -49,7 +49,7 @@ public class SysMenu extends Model<SysMenu> {
/**
* 菜单ID
*/
@TableId(value = "menu_id", type = IdType.ASSIGN_ID)
@TableId(value = "menu_id", type = IdType.AUTO)
@Schema(description = "菜单id")
private Long menuId;
......
......@@ -43,7 +43,7 @@ public class SysOauthClientDetails extends Model<SysOauthClientDetails> {
private static final long serialVersionUID = 1L;
@TableId(value = "id", type = IdType.ASSIGN_ID)
@TableId(value = "id", type = IdType.AUTO)
@Schema(description = "id")
private Long id;
......
......@@ -44,7 +44,7 @@ public class SysPost extends Model<SysPost> {
/**
* 岗位ID
*/
@TableId(value = "post_id", type = IdType.ASSIGN_ID)
@TableId(value = "post_id", type = IdType.AUTO)
@Schema(description = "岗位ID")
private Long postId;
......
......@@ -41,7 +41,7 @@ public class SysPublicParam extends Model<SysPublicParam> {
/**
* 编号
*/
@TableId(type = IdType.ASSIGN_ID)
@TableId(type = IdType.AUTO)
@Schema(description = "公共参数编号")
private Long publicId;
......
......@@ -43,7 +43,7 @@ public class SysRole extends Model<SysRole> {
private static final long serialVersionUID = 1L;
@TableId(value = "role_id", type = IdType.ASSIGN_ID)
@TableId(value = "role_id", type = IdType.AUTO)
@Schema(description = "角色编号")
private Long roleId;
......
......@@ -44,7 +44,7 @@ public class SysUser implements Serializable {
/**
* 主键ID
*/
@TableId(value = "user_id", type = IdType.ASSIGN_ID)
@TableId(value = "user_id", type = IdType.AUTO)
@Schema(description = "主键id")
private Long userId;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment