Commit 8e6a9177 authored by 林洋洋's avatar 林洋洋

知识库模块BUG修复

parent a72f886e
......@@ -27,5 +27,9 @@
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
</dependency>
</dependencies>
</project>
\ No newline at end of file
package com.ask.api.entity;
import com.ask.api.handle.JsonbTypeHandler;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import java.io.Serializable;
import java.util.Map;
......@@ -22,7 +21,7 @@ import java.util.Map;
*/
@Slf4j
@Data
@TableName("ask_vector_store")
@TableName(value = "ask_vector_store",autoResultMap = true)
@Schema(description = "向量存储")
public class AskVectorStore implements Serializable {
private static final long serialVersionUID = 1L;
......@@ -30,7 +29,7 @@ public class AskVectorStore implements Serializable {
/**
* 主键ID
*/
@TableId(type = IdType.ASSIGN_UUID)
@TableId(type = IdType.AUTO)
@Schema(description = "主键ID")
private String id;
......@@ -45,14 +44,17 @@ public class AskVectorStore implements Serializable {
*/
@JsonIgnore
@Schema(description = "文档元数据")
private String metadata;
@TableField(typeHandler = JsonbTypeHandler.class)
private Map<String,Object> metadata;
@TableField(exist = false)
@Schema(description = "文档ID")
private Long documentId;
@TableField(exist = false)
@Schema(description = "文件名称")
private String fileName;
@TableField(exist = false)
@Schema(description = "文件路径")
private String filePath;
......@@ -69,25 +71,17 @@ public class AskVectorStore implements Serializable {
@Schema(description = "启用状态")
private Integer isEnabled;
/**
* 向量化数据(float数组的JSON表示)
*/
@JsonIgnore
@Schema(description = "向量化数据")
private String embedding;
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
/**
* 解析metadata JSON字符串,填充对应的字段
*/
public void parseMetadata() {
if (!StringUtils.hasText(this.metadata)) {
if (this.metadata==null || this.metadata.isEmpty()) {
return;
}
try {
Map<String, Object> metadataMap = OBJECT_MAPPER.readValue(this.metadata, Map.class);
Map<String, Object> metadataMap = this.metadata;
// 解析 documentId
if (metadataMap.containsKey("documentId")) {
......
package com.ask.api.handle;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedTypes;
import org.postgresql.util.PGobject;
import java.lang.reflect.Field;
import java.sql.PreparedStatement;
import java.sql.SQLException;
@Slf4j
@MappedTypes({Object.class})
public class JsonbTypeHandler extends JacksonTypeHandler {
public JsonbTypeHandler(Class<?> type) {
super(type);
}
@Override
public void setNonNullParameter(PreparedStatement ps, int i, Object parameter, JdbcType jdbcType) throws SQLException {
if (ps != null) {
PGobject jsonObject = new PGobject();
jsonObject.setType("jsonb");
jsonObject.setValue(toJson(parameter));
ps.setObject(i, jsonObject);
}
}
}
\ No newline at end of file
......@@ -23,6 +23,7 @@ import javax.validation.Valid;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
/**
* 向量存储管理
......@@ -57,10 +58,10 @@ public class AskVectorStoreController {
) {
LambdaQueryWrapper<AskVectorStore> wrapper = Wrappers.lambdaQuery(AskVectorStore.class)
.like(org.apache.commons.lang3.StringUtils.isNoneBlank(content), AskVectorStore::getContent, content.trim())
// 使用metadata jsonB字段进行过滤
.like(StringUtils.hasText(content), AskVectorStore::getContent, content)
// 使用metadata json字段进行过滤
.apply("metadata::jsonb ->> 'documentId' = {0}", String.valueOf(documentId))
.apply(org.apache.commons.lang3.StringUtils.isNoneBlank(title), "metadata::jsonb ->> 'title' LIKE {0}", "%" + title + "%")
.apply(StringUtils.hasText(title), "metadata::jsonb ->> 'title' LIKE {0}", "%" + title + "%")
.orderByDesc(AskVectorStore::getId);
IPage<AskVectorStore> result = askVectorStoreService.page(page, wrapper);
result.getRecords().forEach(askVectorStore -> askVectorStore.parseMetadata());
......@@ -79,7 +80,8 @@ public class AskVectorStoreController {
if (!StringUtils.hasText(id)) {
return R.failed("ID不能为空");
}
AskVectorStore askVectorStore = askVectorStoreService.getById(id);
AskVectorStore askVectorStore = askVectorStoreService.getById(id);
askVectorStore.parseMetadata();
return R.ok(askVectorStore);
}
......
......@@ -14,6 +14,7 @@ import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
......@@ -96,8 +97,8 @@ public class ChatController {
* @return
*/
@Operation(summary = "知识库对话", description = "知识库对话")
@GetMapping(value = "/rag/chat", produces = "text/html;charset=utf-8")
public Flux<String> ragChat(String message, String conversationId) {
@GetMapping(value = "/rag/chat", produces = "application/stream+json")
public Flux<ChatResponse> ragChat(String message, String conversationId) {
// 创建系统消息,告诉大模型只返回工具名和参数
Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:");
......@@ -106,12 +107,12 @@ public class ChatController {
Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// 使用修改后的提示获取响应
FilterExpressionBuilder builder = new FilterExpressionBuilder();
Filter.Expression filter = builder.eq("source","1").build();
Filter.Expression filter = builder.eq("isEnabled",1).build();
return chatClient.prompt(prompt)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, filter))
.advisors(retrievalAugmentationAdvisor)
.stream().content();
.stream().chatResponse();
}
}
\ No newline at end of file
......@@ -189,7 +189,7 @@ public class KnowledgeBaseController {
)
);
// 执行向量搜索
List<Document> searchResults = vectorStore.similaritySearch(SearchRequest.builder().filterExpression(filterExpression).similarityThreshold(similarityThreshold).topK(topK).build());
List<Document> searchResults = vectorStore.similaritySearch(SearchRequest.builder().query(content).filterExpression(filterExpression).similarityThreshold(similarityThreshold).topK(topK).build());
log.info("向量搜索测试完成 - 知识库ID: {}, 找到 {} 个相似结果", knowledgeBaseId, searchResults.size());
......
......@@ -71,9 +71,12 @@ public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper,
// 1. 批量生成所有embedding向量
List<float[]> embeddings = askVectorStores.stream()
.map(store -> {
String content = (store.getTitle() != null ? store.getTitle() : "") + "\n" +
(store.getContent() != null ? store.getContent() : "");
return embeddingModel.embed(content);
String title = store.getTitle();
String content = store.getContent();
String result = (title == null || title.trim().isEmpty()) ?
(content == null ? "" : content) :
title.trim() + "\n" + (content == null ? "" : content);
return embeddingModel.embed(result);
})
.toList();
......@@ -85,7 +88,7 @@ public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper,
public void setValues(PreparedStatement ps, int i) throws SQLException {
AskVectorStore store = askVectorStores.get(i);
float[] embedding = embeddings.get(i);
// Object id = UUID.fromString(store.getId());
// 使用PGvector处理向量数据,与PgVectorStore保持一致
PGvector pgVector = new PGvector(embedding);
......
......@@ -60,8 +60,8 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
// 查询该文档下所有未向量化的数据
LambdaQueryWrapper<AskVectorStore> wrapper = new LambdaQueryWrapper<AskVectorStore>()
.apply("metadata::jsonb ->> 'documentId' = {0}", document.getId())
.isNull(AskVectorStore::getEmbedding); // 假设embedding为null表示未向量化
.apply("metadata::jsonb ->> 'documentId' = {0}", String.valueOf(document.getId()));
; // 假设embedding为null表示未向量化
List<AskVectorStore> vectorStores = askVectorStoreService.list(wrapper);
......
package com.ask;
import com.ask.config.VectorizationProperties;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
......@@ -14,6 +15,7 @@ import org.springframework.scheduling.annotation.EnableAsync;
@EnableAsync
@SpringBootApplication
@EnableConfigurationProperties({VectorizationProperties.class})
@MapperScan("com.ask.mapper")
public class AskDataAiApplication {
public static void main(String[] args) {
......
......@@ -13,7 +13,7 @@ spring:
max-request-size: 500MB # 请求最大大小
file-size-threshold: 0 # 文件写入磁盘的阈值
datasource:
url: jdbc:postgresql://81.70.183.25:25432/ask_data_ai_db
url: jdbc:postgresql://81.70.183.25:25432/ask_data_ai_db?stringtype=unspecified
username: postgres
password: postgres123
driver-class-name: org.postgresql.Driver
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment