Commit 8e6a9177 authored by 林洋洋's avatar 林洋洋

知识库模块BUG修复

parent a72f886e
......@@ -27,5 +27,9 @@
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
</dependency>
</dependencies>
</project>
\ No newline at end of file
package com.ask.api.entity;
import com.ask.api.handle.JsonbTypeHandler;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import java.io.Serializable;
import java.util.Map;
......@@ -22,7 +21,7 @@ import java.util.Map;
*/
@Slf4j
@Data
@TableName("ask_vector_store")
@TableName(value = "ask_vector_store",autoResultMap = true)
@Schema(description = "向量存储")
public class AskVectorStore implements Serializable {
private static final long serialVersionUID = 1L;
......@@ -30,7 +29,7 @@ public class AskVectorStore implements Serializable {
/**
* 主键ID
*/
@TableId(type = IdType.ASSIGN_UUID)
@TableId(type = IdType.AUTO)
@Schema(description = "主键ID")
private String id;
......@@ -45,14 +44,17 @@ public class AskVectorStore implements Serializable {
*/
@JsonIgnore
@Schema(description = "文档元数据")
private String metadata;
@TableField(typeHandler = JsonbTypeHandler.class)
private Map<String,Object> metadata;
@TableField(exist = false)
@Schema(description = "文档ID")
private Long documentId;
@TableField(exist = false)
@Schema(description = "文件名称")
private String fileName;
@TableField(exist = false)
@Schema(description = "文件路径")
private String filePath;
......@@ -69,25 +71,17 @@ public class AskVectorStore implements Serializable {
@Schema(description = "启用状态")
private Integer isEnabled;
/**
* 向量化数据(float数组的JSON表示)
*/
@JsonIgnore
@Schema(description = "向量化数据")
private String embedding;
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
/**
* 解析metadata JSON字符串,填充对应的字段
*/
public void parseMetadata() {
if (!StringUtils.hasText(this.metadata)) {
if (this.metadata==null || this.metadata.isEmpty()) {
return;
}
try {
Map<String, Object> metadataMap = OBJECT_MAPPER.readValue(this.metadata, Map.class);
Map<String, Object> metadataMap = this.metadata;
// 解析 documentId
if (metadataMap.containsKey("documentId")) {
......
package com.ask.api.handle;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedTypes;
import org.postgresql.util.PGobject;
import java.lang.reflect.Field;
import java.sql.PreparedStatement;
import java.sql.SQLException;
@Slf4j
@MappedTypes({Object.class})
public class JsonbTypeHandler extends JacksonTypeHandler {
public JsonbTypeHandler(Class<?> type) {
super(type);
}
@Override
public void setNonNullParameter(PreparedStatement ps, int i, Object parameter, JdbcType jdbcType) throws SQLException {
if (ps != null) {
PGobject jsonObject = new PGobject();
jsonObject.setType("jsonb");
jsonObject.setValue(toJson(parameter));
ps.setObject(i, jsonObject);
}
}
}
\ No newline at end of file
......@@ -23,6 +23,7 @@ import javax.validation.Valid;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
/**
* 向量存储管理
......@@ -57,10 +58,10 @@ public class AskVectorStoreController {
) {
LambdaQueryWrapper<AskVectorStore> wrapper = Wrappers.lambdaQuery(AskVectorStore.class)
.like(org.apache.commons.lang3.StringUtils.isNoneBlank(content), AskVectorStore::getContent, content.trim())
// 使用metadata jsonB字段进行过滤
.like(StringUtils.hasText(content), AskVectorStore::getContent, content)
// 使用metadata json字段进行过滤
.apply("metadata::jsonb ->> 'documentId' = {0}", String.valueOf(documentId))
.apply(org.apache.commons.lang3.StringUtils.isNoneBlank(title), "metadata::jsonb ->> 'title' LIKE {0}", "%" + title + "%")
.apply(StringUtils.hasText(title), "metadata::jsonb ->> 'title' LIKE {0}", "%" + title + "%")
.orderByDesc(AskVectorStore::getId);
IPage<AskVectorStore> result = askVectorStoreService.page(page, wrapper);
result.getRecords().forEach(askVectorStore -> askVectorStore.parseMetadata());
......@@ -79,6 +80,7 @@ public class AskVectorStoreController {
if (!StringUtils.hasText(id)) {
return R.failed("ID不能为空");
}
AskVectorStore askVectorStore = askVectorStoreService.getById(id);
askVectorStore.parseMetadata();
return R.ok(askVectorStore);
......
......@@ -14,6 +14,7 @@ import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
......@@ -96,8 +97,8 @@ public class ChatController {
* @return
*/
@Operation(summary = "知识库对话", description = "知识库对话")
@GetMapping(value = "/rag/chat", produces = "text/html;charset=utf-8")
public Flux<String> ragChat(String message, String conversationId) {
@GetMapping(value = "/rag/chat", produces = "application/stream+json")
public Flux<ChatResponse> ragChat(String message, String conversationId) {
// 创建系统消息,告诉大模型只返回工具名和参数
Message systemMessage = new SystemMessage("你是一个AI客服助手。请严格按照以下格式回答每个问题:");
......@@ -106,12 +107,12 @@ public class ChatController {
Prompt prompt = new Prompt(Arrays.asList(systemMessage, userMessage));
// 使用修改后的提示获取响应
FilterExpressionBuilder builder = new FilterExpressionBuilder();
Filter.Expression filter = builder.eq("source","1").build();
Filter.Expression filter = builder.eq("isEnabled",1).build();
return chatClient.prompt(prompt)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, filter))
.advisors(retrievalAugmentationAdvisor)
.stream().content();
.stream().chatResponse();
}
}
\ No newline at end of file
......@@ -189,7 +189,7 @@ public class KnowledgeBaseController {
)
);
// 执行向量搜索
List<Document> searchResults = vectorStore.similaritySearch(SearchRequest.builder().filterExpression(filterExpression).similarityThreshold(similarityThreshold).topK(topK).build());
List<Document> searchResults = vectorStore.similaritySearch(SearchRequest.builder().query(content).filterExpression(filterExpression).similarityThreshold(similarityThreshold).topK(topK).build());
log.info("向量搜索测试完成 - 知识库ID: {}, 找到 {} 个相似结果", knowledgeBaseId, searchResults.size());
......
......@@ -71,9 +71,12 @@ public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper,
// 1. 批量生成所有embedding向量
List<float[]> embeddings = askVectorStores.stream()
.map(store -> {
String content = (store.getTitle() != null ? store.getTitle() : "") + "\n" +
(store.getContent() != null ? store.getContent() : "");
return embeddingModel.embed(content);
String title = store.getTitle();
String content = store.getContent();
String result = (title == null || title.trim().isEmpty()) ?
(content == null ? "" : content) :
title.trim() + "\n" + (content == null ? "" : content);
return embeddingModel.embed(result);
})
.toList();
......@@ -85,7 +88,7 @@ public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper,
public void setValues(PreparedStatement ps, int i) throws SQLException {
AskVectorStore store = askVectorStores.get(i);
float[] embedding = embeddings.get(i);
// Object id = UUID.fromString(store.getId());
// 使用PGvector处理向量数据,与PgVectorStore保持一致
PGvector pgVector = new PGvector(embedding);
......
......@@ -60,8 +60,8 @@ public class AsyncVectorizationServiceImpl implements AsyncVectorizationService
// 查询该文档下所有未向量化的数据
LambdaQueryWrapper<AskVectorStore> wrapper = new LambdaQueryWrapper<AskVectorStore>()
.apply("metadata::jsonb ->> 'documentId' = {0}", document.getId())
.isNull(AskVectorStore::getEmbedding); // 假设embedding为null表示未向量化
.apply("metadata::jsonb ->> 'documentId' = {0}", String.valueOf(document.getId()));
; // 假设embedding为null表示未向量化
List<AskVectorStore> vectorStores = askVectorStoreService.list(wrapper);
......
......@@ -19,6 +19,7 @@ import org.springframework.ai.document.Document;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.pdf.ParagraphPdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.core.io.InputStreamResource;
import org.springframework.stereotype.Service;
......@@ -45,7 +46,8 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
private final AskVectorStoreService askVectorStoreService;
private final AsyncVectorizationService asyncVectorizationService;
private final ObjectMapper objectMapper;
/**
* 从PDF文档中读取带目录结构的文档分段
*
......@@ -76,17 +78,12 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
* @param maxTokensPerSlice 每片最大token数(仅对CUSTOM策略有效)
* @return 文档片段列表,每个Document就是一片
*/
public List<Document> slicePdfDocument(InputStream inputStream, SliceStrategy sliceStrategy, Integer maxTokensPerSlice) {
InputStreamResource resource = new InputStreamResource(inputStream);
public List<Document> slicePdfDocument(String bucketName,String fileName , SliceStrategy sliceStrategy) {
InputStreamResource resource = new InputStreamResource(sysFileService.getFileStream(bucketName,fileName));
List<Document> documents = new ArrayList<>();
try {
switch (sliceStrategy) {
case PAGE:
// 按页面切片 - 每页一个Document
documents = sliceByPage(resource);
break;
case PARAGRAPH:
// 按段落切片 - 每个段落一个Document
documents = sliceByParagraph(resource);
......@@ -94,27 +91,17 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
case CUSTOM:
// 自定义切片 - 按token数量切片
documents = sliceByTokens(resource, maxTokensPerSlice != null ? maxTokensPerSlice : 500);
documents = sliceByTokens(resource);
break;
default:
log.warn("未知的切片策略: {}, 使用默认页面切片", sliceStrategy);
documents = sliceByPage(resource);
documents = sliceByTokens(resource);
}
//
// // 为每个文档片段添加元数据
// for (int i = 0; i < documents.size(); i++) {
// Document doc = documents.get(i);
// doc.getMetadata().put("slice_index", i + 1);
// doc.getMetadata().put("slice_strategy", sliceStrategy.name());
// doc.getMetadata().put("total_slices", documents.size());
// }
log.info("PDF切片完成,策略: {}, 切片数量: {}", sliceStrategy, documents.size());
} catch (Exception e) {
}catch (Exception e) {
log.error("PDF切片失败,策略: {}, 错误: {}", sliceStrategy, e.getMessage(), e);
throw new RuntimeException("PDF切片失败: " + e.getMessage(), e);
documents = sliceByTokens(new InputStreamResource(sysFileService.getFileStream(bucketName,fileName)));
}
return documents;
......@@ -159,8 +146,9 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
.build())
.withPagesPerDocument(0)
.build());
List<Document> pageDocuments = pdfReader.read();
List<Document> paragraphDocuments = new ArrayList<>();
for (Document pageDoc : pageDocuments) {
......@@ -313,37 +301,25 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
Document longDoc = new Document(longParagraph);
List<Document> subDocuments = textSplitter.apply(List.of(longDoc));
// 为每个子片段设置元数据
List<Document> result = new ArrayList<>();
for (int i = 0; i < subDocuments.size(); i++) {
Document subDoc = subDocuments.get(i);
for (Document subDoc : subDocuments) {
subDoc.getMetadata().put("size", Objects.requireNonNull(subDoc.getText()).length());
subDoc.getMetadata().put("title", title);
result.add(subDoc);
}
log.info("超长段落分片完成,原段落token数: {}, 分片数: {}",
estimateTokenCount(longParagraph), subDocuments.size());
return result;
return subDocuments;
}
/**
* 按token数量切片 - 每片指定token数量
*/
private List<Document> sliceByTokens(InputStreamResource resource, int maxTokensPerSlice) {
// 先读取整个文档内容
ParagraphPdfDocumentReader pdfReader = new ParagraphPdfDocumentReader(resource,
PdfDocumentReaderConfig.builder()
.withPageTopMargin(0)
.withPageExtractedTextFormatter(ExtractedTextFormatter.builder()
.withNumberOfTopTextLinesToDelete(1) // 去除页眉
.withNumberOfBottomTextLinesToDelete(1) // 去除页脚
.build())
.build());
private List<Document> sliceByTokens(InputStreamResource resource) {
List<Document> pageDocuments = pdfReader.read();
TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(resource);
List<Document> pageDocuments = tikaDocumentReader.read();
// 合并所有页面内容
StringBuilder fullContent = new StringBuilder();
......@@ -355,23 +331,18 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
// 使用TokenTextSplitter进行切片
TokenTextSplitter textSplitter = new TokenTextSplitter(
maxTokensPerSlice, // 每片最大token数
100, // 重叠token
50, // 最小片段大小
maxTokensPerSlice * 2, // 最大片段大小
4096, // 分片大小
50, // 最小分片字符
50, // 设置最小需要嵌入的长度
1000, // 最大片段大小
true // 保持分隔符
);
Document fullDocument = new Document(fullContent.toString().trim());
List<Document> tokenDocuments = textSplitter.apply(List.of(fullDocument));
// 为每个token片段添加元数据
for (int i = 0; i < tokenDocuments.size(); i++) {
Document doc = tokenDocuments.get(i);
doc.getMetadata().put("token_slice_index", i + 1);
doc.getMetadata().put("slice_type", "token");
doc.getMetadata().put("max_tokens_per_slice", maxTokensPerSlice);
doc.getMetadata().put("estimated_tokens", estimateTokenCount(doc.getText()));
for (Document subDoc : tokenDocuments) {
subDoc.getMetadata().put("size", Objects.requireNonNull(subDoc.getText()).length());
subDoc.getMetadata().put("title", "");
}
return tokenDocuments;
......@@ -381,7 +352,6 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
* PDF切片策略枚举
*/
public enum SliceStrategy {
PAGE, // 按页面切片
PARAGRAPH, // 按段落切片
CUSTOM // 自定义切片(按token数量)
}
......@@ -394,9 +364,8 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
// 读取文档内容 - 使用新的PDF切片函数
List<Document> segments = slicePdfDocument(
sysFileService.getFileStream(file.getBucketName(), file.getFileName()),
SliceStrategy.PARAGRAPH, // 可以改为 PARAGRAPH 或 CUSTOM
1024 // 仅在CUSTOM策略时生效
file.getBucketName(), file.getFileName(),
SliceStrategy.PARAGRAPH
);
if (segments.isEmpty()) {
log.warn("文档解析失败或内容为空: {}", file.getOriginal());
......@@ -466,11 +435,8 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
askVectorStore.setIsEnabled(1);
askVectorStore.setDocumentId(document.getId());
askVectorStore.setKnowledgeBaseId(document.getKnowledgeBaseId());
try {
askVectorStore.setMetadata(objectMapper.writeValueAsString(metadata));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
askVectorStore.setMetadata(metadata);
askVectorStores.add(askVectorStore);
});
askVectorStoreService.saveBatch(askVectorStores);
......@@ -557,5 +523,4 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
}
}
\ No newline at end of file
package com.ask;
import com.ask.config.VectorizationProperties;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
......@@ -14,6 +15,7 @@ import org.springframework.scheduling.annotation.EnableAsync;
@EnableAsync
@SpringBootApplication
@EnableConfigurationProperties({VectorizationProperties.class})
@MapperScan("com.ask.mapper")
public class AskDataAiApplication {
public static void main(String[] args) {
......
......@@ -13,7 +13,7 @@ spring:
max-request-size: 500MB # 请求最大大小
file-size-threshold: 0 # 文件写入磁盘的阈值
datasource:
url: jdbc:postgresql://81.70.183.25:25432/ask_data_ai_db
url: jdbc:postgresql://81.70.183.25:25432/ask_data_ai_db?stringtype=unspecified
username: postgres
password: postgres123
driver-class-name: org.postgresql.Driver
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment