Commit 1d749be5 authored by 林洋洋's avatar 林洋洋

切片相关代码提交

parent 92ff9ab0
......@@ -69,6 +69,13 @@ public class AskVectorStore implements Serializable {
@Schema(description = "启用状态")
private Integer isEnabled;
/**
* 向量化数据(float数组的JSON表示)
*/
@JsonIgnore
@Schema(description = "向量化数据")
private String embedding;
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
/**
......
......@@ -15,13 +15,13 @@ import java.util.List;
*/
public interface AskVectorStoreService extends IService<AskVectorStore> {
/**
* 向量化存储方法
* 获取内容和标题,向量化存储到向量字段上
* 批量更新向量化字段 embedding
*
* @param askVectorStore 向量存储实体
* @return 是否成功
* @param askVectorStores 需要更新的向量存储列表
* @return 更新成功的记录数
*/
boolean vectorizeAndStore(@Param("entity") AskVectorStore askVectorStore);
int batchUpdateVectorEmbedding(List<AskVectorStore> askVectorStores);
}
\ No newline at end of file
......@@ -5,15 +5,27 @@ import com.ask.mapper.AskVectorStoreMapper;
import com.ask.service.AskVectorStoreService;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import lombok.extern.slf4j.Slf4j;
import org.postgresql.util.PGobject;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
......@@ -32,84 +44,83 @@ import java.util.stream.Collectors;
public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper, AskVectorStore> implements AskVectorStoreService {
@Autowired
private VectorStore vectorStore;
private EmbeddingModel embeddingModel;
@Autowired
private ObjectMapper objectMapper;
private JdbcTemplate jdbcTemplate;
/**
* 批量更新向量化字段 embedding
*
* @param askVectorStores 需要更新的向量存储列表
* @return 更新成功的记录数
*/
public int batchUpdateVectorEmbedding(List<AskVectorStore> askVectorStores) {
if (askVectorStores == null || askVectorStores.isEmpty()) {
log.warn("批量更新向量化:输入列表为空");
return 0;
}
@Override
public boolean vectorizeAndStore(AskVectorStore askVectorStore) {
try {
// 校验必要字段
if (!StringUtils.hasText(askVectorStore.getContent())) {
log.warn("向量化存储失败:文档内容为空");
return false;
}
log.info("开始批量更新向量化,总数: {}", askVectorStores.size());
// 1. 批量生成所有embedding向量
List<float[]> embeddings = askVectorStores.stream()
.map(store -> {
String content = (store.getTitle() != null ? store.getTitle() : "") + "\n" +
(store.getContent() != null ? store.getContent() : "");
return embeddingModel.embed(content);
})
.toList();
// 构建文档内容,如果有标题则添加标题
StringBuilder contentBuilder = new StringBuilder();
if (StringUtils.hasText(askVectorStore.getTitle())) {
contentBuilder.append("标题: ").append(askVectorStore.getTitle()).append("\n\n");
}
contentBuilder.append(askVectorStore.getContent());
// 2. 批量更新数据库 - 参考PgVectorStore的实现方式
String sql = "UPDATE ask_vector_store SET embedding = ? WHERE id = ?";
String documentContent = contentBuilder.toString();
int[] updateCounts = jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
AskVectorStore store = askVectorStores.get(i);
float[] embedding = embeddings.get(i);
// 使用PGvector处理向量数据,与PgVectorStore保持一致
PGvector pgVector = new PGvector(embedding);
StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, pgVector);
StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, store.getId());
}
@Override
public int getBatchSize() {
return askVectorStores.size();
}
});
// 3. 统计更新成功的记录数
int successCount = 0;
int failureCount = 0;
// 构建元数据
Map<String, Object> metadata = new HashMap<>();
if (StringUtils.hasText(askVectorStore.getMetadata())) {
try {
// 解析已有的元数据
Map<String, Object> existingMetadata = objectMapper.readValue(
askVectorStore.getMetadata(), Map.class);
metadata.putAll(existingMetadata);
} catch (Exception e) {
log.warn("解析已有元数据失败:{}", e.getMessage());
for (int i = 0; i < updateCounts.length; i++) {
if (updateCounts[i] > 0) {
successCount++;
} else {
failureCount++;
log.warn("向量化更新失败,ID: {},可能记录不存在", askVectorStores.get(i).getId());
}
}
log.info("批量更新向量化完成:总数={}, 成功={}, 失败={}",
askVectorStores.size(), successCount, failureCount);
// 添加向量化相关的元数据
metadata.put("id", askVectorStore.getId());
metadata.put("title", askVectorStore.getTitle());
metadata.put("documentId", askVectorStore.getDocumentId());
metadata.put("knowledgeBaseId", askVectorStore.getKnowledgeBaseId());
metadata.put("isEnabled", askVectorStore.getIsEnabled());
metadata.put("vectorized", 1); // 标记为已向量化
metadata.put("vectorizeTime", LocalDateTime.now().toString());
// 创建 Spring AI Document 对象
Document document = new Document(askVectorStore.getId(), documentContent, metadata);
// 向量化存储到 VectorStore
vectorStore.add(List.of(document));
// 更新数据库中的向量化状态
updateVectorizedStatus(askVectorStore.getId(), metadata);
log.info("向量化存储成功:id={}, title={}", askVectorStore.getId(), askVectorStore.getTitle());
return true;
} catch (Exception e) {
log.error("向量化存储失败:id={}, error={}",
askVectorStore.getId(), e.getMessage(), e);
return false;
}
}
/**
* 更新向量化状态
*/
private void updateVectorizedStatus(String id, Map<String, Object> metadata) {
try {
String metadataJson = objectMapper.writeValueAsString(metadata);
LambdaUpdateWrapper<AskVectorStore> updateWrapper = new LambdaUpdateWrapper<>();
updateWrapper.eq(AskVectorStore::getId, id)
.set(AskVectorStore::getMetadata, metadataJson);
this.update(updateWrapper);
return successCount;
} catch (Exception e) {
log.warn("更新向量化状态失败:id={}, error={}", id, e.getMessage());
log.error("批量更新向量化异常:{}", e.getMessage(), e);
throw new RuntimeException("批量更新向量化失败: " + e.getMessage(), e);
}
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment