Commit 8e6a9177 authored by 林洋洋's avatar 林洋洋

知识库模块BUG修复

parent a72f886e
...@@ -27,5 +27,9 @@ ...@@ -27,5 +27,9 @@
<artifactId>lombok</artifactId> <artifactId>lombok</artifactId>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
</dependency>
</dependencies> </dependencies>
</project> </project>
\ No newline at end of file
package com.ask.api.entity; package com.ask.api.entity;
import com.ask.api.handle.JsonbTypeHandler;
import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import java.io.Serializable; import java.io.Serializable;
import java.util.Map; import java.util.Map;
...@@ -22,7 +21,7 @@ import java.util.Map; ...@@ -22,7 +21,7 @@ import java.util.Map;
*/ */
@Slf4j @Slf4j
@Data @Data
@TableName("ask_vector_store") @TableName(value = "ask_vector_store",autoResultMap = true)
@Schema(description = "向量存储") @Schema(description = "向量存储")
public class AskVectorStore implements Serializable { public class AskVectorStore implements Serializable {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
...@@ -30,7 +29,7 @@ public class AskVectorStore implements Serializable { ...@@ -30,7 +29,7 @@ public class AskVectorStore implements Serializable {
/** /**
* 主键ID * 主键ID
*/ */
@TableId(type = IdType.ASSIGN_UUID) @TableId(type = IdType.AUTO)
@Schema(description = "主键ID") @Schema(description = "主键ID")
private String id; private String id;
...@@ -45,14 +44,17 @@ public class AskVectorStore implements Serializable { ...@@ -45,14 +44,17 @@ public class AskVectorStore implements Serializable {
*/ */
@JsonIgnore @JsonIgnore
@Schema(description = "文档元数据") @Schema(description = "文档元数据")
private String metadata; @TableField(typeHandler = JsonbTypeHandler.class)
private Map<String,Object> metadata;
@TableField(exist = false) @TableField(exist = false)
@Schema(description = "文档ID") @Schema(description = "文档ID")
private Long documentId; private Long documentId;
@TableField(exist = false)
@Schema(description = "文件名称") @Schema(description = "文件名称")
private String fileName; private String fileName;
@TableField(exist = false)
@Schema(description = "文件路径") @Schema(description = "文件路径")
private String filePath; private String filePath;
...@@ -69,25 +71,17 @@ public class AskVectorStore implements Serializable { ...@@ -69,25 +71,17 @@ public class AskVectorStore implements Serializable {
@Schema(description = "启用状态") @Schema(description = "启用状态")
private Integer isEnabled; private Integer isEnabled;
/**
* 向量化数据(float数组的JSON表示)
*/
@JsonIgnore
@Schema(description = "向量化数据")
private String embedding;
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
/** /**
* 解析metadata JSON字符串,填充对应的字段 * 解析metadata JSON字符串,填充对应的字段
*/ */
public void parseMetadata() { public void parseMetadata() {
if (!StringUtils.hasText(this.metadata)) { if (this.metadata==null || this.metadata.isEmpty()) {
return; return;
} }
try { try {
Map<String, Object> metadataMap = OBJECT_MAPPER.readValue(this.metadata, Map.class); Map<String, Object> metadataMap = this.metadata;
// 解析 documentId // 解析 documentId
if (metadataMap.containsKey("documentId")) { if (metadataMap.containsKey("documentId")) {
......
package com.ask.api.handle;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedTypes;
import org.postgresql.util.PGobject;
import java.lang.reflect.Field;
import java.sql.PreparedStatement;
import java.sql.SQLException;
@Slf4j
@MappedTypes({Object.class})
public class JsonbTypeHandler extends JacksonTypeHandler {
public JsonbTypeHandler(Class<?> type) {
super(type);
}
@Override
public void setNonNullParameter(PreparedStatement ps, int i, Object parameter, JdbcType jdbcType) throws SQLException {
if (ps != null) {
PGobject jsonObject = new PGobject();
jsonObject.setType("jsonb");
jsonObject.setValue(toJson(parameter));
ps.setObject(i, jsonObject);
}
}
}
\ No newline at end of file
...@@ -23,6 +23,7 @@ import javax.validation.Valid; ...@@ -23,6 +23,7 @@ import javax.validation.Valid;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID;
/** /**
* 向量存储管理 * 向量存储管理
...@@ -57,10 +58,10 @@ public class AskVectorStoreController { ...@@ -57,10 +58,10 @@ public class AskVectorStoreController {
) { ) {
LambdaQueryWrapper<AskVectorStore> wrapper = Wrappers.lambdaQuery(AskVectorStore.class) LambdaQueryWrapper<AskVectorStore> wrapper = Wrappers.lambdaQuery(AskVectorStore.class)
.like(org.apache.commons.lang3.StringUtils.isNoneBlank(content), AskVectorStore::getContent, content.trim()) .like(StringUtils.hasText(content), AskVectorStore::getContent, content)
// 使用metadata jsonB字段进行过滤 // 使用metadata json字段进行过滤
.apply("metadata::jsonb ->> 'documentId' = {0}", String.valueOf(documentId)) .apply("metadata::jsonb ->> 'documentId' = {0}", String.valueOf(documentId))
.apply(org.apache.commons.lang3.StringUtils.isNoneBlank(title), "metadata::jsonb ->> 'title' LIKE {0}", "%" + title + "%") .apply(StringUtils.hasText(title), "metadata::jsonb ->> 'title' LIKE {0}", "%" + title + "%")
.orderByDesc(AskVectorStore::getId); .orderByDesc(AskVectorStore::getId);
IPage<AskVectorStore> result = askVectorStoreService.page(page, wrapper); IPage<AskVectorStore> result = askVectorStoreService.page(page, wrapper);
result.getRecords().forEach(askVectorStore -> askVectorStore.parseMetadata()); result.getRecords().forEach(askVectorStore -> askVectorStore.parseMetadata());
...@@ -79,7 +80,8 @@ public class AskVectorStoreController { ...@@ -79,7 +80,8 @@ public class AskVectorStoreController {
if (!StringUtils.hasText(id)) { if (!StringUtils.hasText(id)) {
return R.failed("ID不能为空"); return R.failed("ID不能为空");
} }
AskVectorStore askVectorStore = askVectorStoreService.getById(id);
AskVectorStore askVectorStore = askVectorStoreService.getById(id);
askVectorStore.parseMetadata(); askVectorStore.parseMetadata();
return R.ok(askVectorStore); return R.ok(askVectorStore);
} }
......
...@@ -14,6 +14,7 @@ import org.springframework.ai.chat.memory.ChatMemory; ...@@ -14,6 +14,7 @@ import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor; import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
...@@ -96,8 +97,8 @@ public class ChatController { ...@@ -96,8 +97,8 @@ public class ChatController {
* @return * @return
*/ */
@Operation(summary = "知识库对话", description = "知识库对话") @Operation(summary = "知识库对话", description = "知识库对话")
@GetMapping(value = "/rag/chat", produces = "text/html;charset=utf-8") @GetMapping(value = "/rag/chat", produces = "application/stream+json")
public Flux<String> ragChat(String message, String conversationId) { public Flux<ChatResponse> ragChat(String message, String conversationId) {
// 创建系统消息,告诉大模型只返回工具名和参数 // 创建系统消息,告诉大模型只返回工具名和参数
Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:"); Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:");
...@@ -106,12 +107,12 @@ public class ChatController { ...@@ -106,12 +107,12 @@ public class ChatController {
Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage)); Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// 使用修改后的提示获取响应 // 使用修改后的提示获取响应
FilterExpressionBuilder builder = new FilterExpressionBuilder(); FilterExpressionBuilder builder = new FilterExpressionBuilder();
Filter.Expression filter = builder.eq("source","1").build(); Filter.Expression filter = builder.eq("isEnabled",1).build();
return chatClient.prompt(prompt) return chatClient.prompt(prompt)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)) .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, filter)) .advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, filter))
.advisors(retrievalAugmentationAdvisor) .advisors(retrievalAugmentationAdvisor)
.stream().content(); .stream().chatResponse();
} }
} }
\ No newline at end of file
...@@ -189,7 +189,7 @@ public class KnowledgeBaseController { ...@@ -189,7 +189,7 @@ public class KnowledgeBaseController {
) )
); );
// 执行向量搜索 // 执行向量搜索
List<Document> searchResults = vectorStore.similaritySearch(SearchRequest.builder().filterExpression(filterExpression).similarityThreshold(similarityThreshold).topK(topK).build()); List<Document> searchResults = vectorStore.similaritySearch(SearchRequest.builder().query(content).filterExpression(filterExpression).similarityThreshold(similarityThreshold).topK(topK).build());
log.info("向量搜索测试完成 - 知识库ID: {}, 找到 {} 个相似结果", knowledgeBaseId, searchResults.size()); log.info("向量搜索测试完成 - 知识库ID: {}, 找到 {} 个相似结果", knowledgeBaseId, searchResults.size());
......
...@@ -71,9 +71,12 @@ public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper, ...@@ -71,9 +71,12 @@ public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper,
// 1. 批量生成所有embedding向量 // 1. 批量生成所有embedding向量
List<float[]> embeddings = askVectorStores.stream() List<float[]> embeddings = askVectorStores.stream()
.map(store -> { .map(store -> {
String content = (store.getTitle() != null ? store.getTitle() : "") + "\n" + String title = store.getTitle();
(store.getContent() != null ? store.getContent() : ""); String content = store.getContent();
return embeddingModel.embed(content); String result = (title == null || title.trim().isEmpty()) ?
(content == null ? "" : content) :
title.trim() + "\n" + (content == null ? "" : content);
return embeddingModel.embed(result);
}) })
.toList(); .toList();
...@@ -85,7 +88,7 @@ public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper, ...@@ -85,7 +88,7 @@ public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper,
public void setValues(PreparedStatement ps, int i) throws SQLException { public void setValues(PreparedStatement ps, int i) throws SQLException {
AskVectorStore store = askVectorStores.get(i); AskVectorStore store = askVectorStores.get(i);
float[] embedding = embeddings.get(i); float[] embedding = embeddings.get(i);
// Object id = UUID.fromString(store.getId());
// 使用PGvector处理向量数据,与PgVectorStore保持一致 // 使用PGvector处理向量数据,与PgVectorStore保持一致
PGvector pgVector = new PGvector(embedding); PGvector pgVector = new PGvector(embedding);
......
...@@ -60,8 +60,8 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService ...@@ -60,8 +60,8 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
// 查询该文档下所有未向量化的数据 // 查询该文档下所有未向量化的数据
LambdaQueryWrapper<AskVectorStore> wrapper = new LambdaQueryWrapper<AskVectorStore>() LambdaQueryWrapper<AskVectorStore> wrapper = new LambdaQueryWrapper<AskVectorStore>()
.apply("metadata::jsonb ->> 'documentId' = {0}", document.getId()) .apply("metadata::jsonb ->> 'documentId' = {0}", String.valueOf(document.getId()));
.isNull(AskVectorStore::getEmbedding); // 假设embedding为null表示未向量化 ; // 假设embedding为null表示未向量化
List<AskVectorStore> vectorStores = askVectorStoreService.list(wrapper); List<AskVectorStore> vectorStores = askVectorStoreService.list(wrapper);
......
package com.ask; package com.ask;
import com.ask.config.VectorizationProperties; import com.ask.config.VectorizationProperties;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
...@@ -14,6 +15,7 @@ import org.springframework.scheduling.annotation.EnableAsync; ...@@ -14,6 +15,7 @@ import org.springframework.scheduling.annotation.EnableAsync;
@EnableAsync @EnableAsync
@SpringBootApplication @SpringBootApplication
@EnableConfigurationProperties({VectorizationProperties.class}) @EnableConfigurationProperties({VectorizationProperties.class})
@MapperScan("com.ask.mapper")
public class AskDataAiApplication { public class AskDataAiApplication {
public static void main(String[] args) { public static void main(String[] args) {
......
...@@ -13,7 +13,7 @@ spring: ...@@ -13,7 +13,7 @@ spring:
max-request-size: 500MB # 请求最大大小 max-request-size: 500MB # 请求最大大小
file-size-threshold: 0 # 文件写入磁盘的阈值 file-size-threshold: 0 # 文件写入磁盘的阈值
datasource: datasource:
url: jdbc:postgresql://81.70.183.25:25432/ask_data_ai_db url: jdbc:postgresql://81.70.183.25:25432/ask_data_ai_db?stringtype=unspecified
username: postgres username: postgres
password: postgres123 password: postgres123
driver-class-name: org.postgresql.Driver driver-class-name: org.postgresql.Driver
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment