Commit 2c8dcb2a authored by 林洋洋's avatar 林洋洋

添加docx 提取表格图片

parent 8c20bac4
package com.ask.api.entity;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
......@@ -12,6 +14,7 @@ import java.time.LocalDateTime;
public class AskImagesRecord {
// 路径
@TableId(type = IdType.AUTO)
private Long id;
// 值
private String imageName;
......
......@@ -137,6 +137,11 @@
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.poi</groupId>
<artifactId>poi-scratchpad</artifactId>
<version>5.2.2</version>
</dependency>
<dependency>
<groupId>org.apache.poi</groupId>
<artifactId>poi-ooxml</artifactId>
......
......@@ -9,6 +9,7 @@ import com.ask.service.ChatConversationService;
import com.ask.service.impl.ChatService;
import com.ask.service.impl.RagPromptService;
import com.ask.tools.ExcelTools;
import com.ask.tools.SqlTools;
import com.ask.utils.FluxUtils;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
......@@ -31,6 +32,7 @@ import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
......@@ -73,6 +75,10 @@ public class ChatController {
private final ExcelTools excelTools;
private final SqlTools sqlTools;
// private final ToolCallbackProvider toolCallbackProvider;
/**
* 获取会话ID
*
......@@ -169,13 +175,36 @@ public class ChatController {
Message userMessage = new UserMessage(message);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
return FluxUtils.wrapDeepSeekStream(openAiChatClient.prompt(prompt)
return FluxUtils.wrapDeepSeekStream(deepseekChatClient.prompt(prompt)
.advisors(messageChatMemoryAdvisor)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.tools(excelTools)
.advisors()
.advisors()
.stream()
.chatResponse());
}
// @Operation(summary = "智能问数据对话", description = "智能问数据对话")
// @GetMapping(value = "/chat/data", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
// public Flux<String> dataChat(@RequestParam String message,
// @RequestParam Long knowledgeBaseId,
// @RequestParam String conversationId) {
//
// Message systemMessage = new SystemMessage("你是一个AI智能问数助手,数据库采用postgres 16,请使用相关工具回答用户问题");
// Message userMessage = new UserMessage(message);
// Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
//
// return FluxUtils.wrapDeepSeekStream(deepseekChatClient.prompt(prompt)
// .advisors(messageChatMemoryAdvisor)
// .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
//// .advisors(retrievalAugmentationAdvisor)
//// .advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, "knowledgeBaseId == "+knowledgeBaseId))
// .tools(sqlTools)
// .toolCallbacks(toolCallbackProvider)
// .advisors()
// .stream()
// .chatResponse());
//
// }
}
\ No newline at end of file
......@@ -20,8 +20,10 @@ package com.ask.controller;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.StrUtil;
import com.ask.api.dto.FileUploadRequest;
import com.ask.api.entity.AskImagesRecord;
import com.ask.api.entity.SysFile;
import com.ask.common.core.R;
import com.ask.mapper.AskImagesRecordMapper;
import com.ask.service.SysFileService;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
......@@ -46,6 +48,9 @@ import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import jakarta.validation.Valid;
import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
......@@ -64,7 +69,7 @@ public class SysFileController {
private final SysFileService sysFileService;
private final AskImagesRecordMapper askImagesRecordMapper;
......@@ -122,5 +127,62 @@ public class SysFileController {
sysFileService.getFileByUUid(bucket, fileName, response,originalName);
}
@GetMapping("/images/{imageId}")
public void fileByImageId(
@Parameter(description = "图片Id", required = true, example = "1")
@PathVariable String imageId,
HttpServletResponse response) throws IOException {
AskImagesRecord record = askImagesRecordMapper.selectById(imageId);
if (record == null || record.getImageData() == null) {
response.sendError(HttpServletResponse.SC_NOT_FOUND, "Image not found");
return;
}
byte[] data = record.getImageData();
/* 1. 设置 MIME 类型(根据实际存的是 png/jpg/gif 等) */
// 如果数据库字段里保存了 MIME,可直接使用:
// response.setContentType(record.getMimeType());
// 否则简单判断:
String mime = guessMime(data);
response.setContentType(mime);
/* 2. 设置长度,避免浏览器一直转圈 */
response.setContentLength(data.length);
/* 3. 可选:缓存控制(秒为单位) */
/* 3. 设置CORS头,允许跨域访问 */
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Access-Control-Allow-Methods", "GET, POST, OPTIONS");
response.setHeader("Access-Control-Allow-Headers", "Content-Type, Authorization");
response.setHeader("Access-Control-Max-Age", "86400");
/* 4. 设置缓存控制(秒为单位) */
response.setHeader("Cache-Control", "max-age=3600, public");
response.setHeader("Pragma", "public");
/* 5. 设置安全头 */
response.setHeader("X-Content-Type-Options", "nosniff");
response.setHeader("X-Frame-Options", "SAMEORIGIN");
/* 4. 写出流 */
try (OutputStream os = response.getOutputStream()) {
os.write(data);
os.flush();
}
}
/* 工具:根据魔数猜测 MIME(简单实现) */
private String guessMime(byte[] data) {
if (data == null || data.length < 4) return "application/octet-stream";
int head = ((data[0] & 0xFF) << 24) | ((data[1] & 0xFF) << 16)
| ((data[2] & 0xFF) << 8) | (data[3] & 0xFF);
if (head == 0x89504E47) return "image/png";
if ((head & 0xFFD8FF00) == 0xFFD8FF00) return "image/jpeg";
if (head == 0x47494638) return "image/gif";
return "application/octet-stream";
}
}
......@@ -11,6 +11,7 @@ import com.ask.service.AskVectorStoreService;
import com.ask.service.AsyncVectorizationService;
import com.ask.service.KnowledgeDocumentService;
import com.ask.service.SysFileService;
import com.ask.utils.WordAllWithImages;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
......@@ -26,6 +27,7 @@ import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.core.io.InputStreamResource;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import java.io.InputStream;
......@@ -50,7 +52,7 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
private final AskVectorStoreService askVectorStoreService;
private final AsyncVectorizationService asyncVectorizationService;
private final DocumentParseService documentParseService;
private final WordAllWithImages wordAllWithImages;
/**
* 从PDF文档中读取带目录结构的文档分段
......@@ -77,8 +79,7 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
/**
* PDF文档切片函数 - 支持多种切片策略
*
* @param sliceStrategy 切片策略 (PAGE, PARAGRAPH, CUSTOM)
* @param sliceStrategy 切片策略 (PAGE, PARAGRAPH, CUSTOM)
* @return 文档片段列表,每个Document就是一片
*/
public List<Document> slicePdfDocument(String bucketName, String fileName, SliceStrategy sliceStrategy) {
......@@ -351,6 +352,7 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
return tokenDocuments;
}
private List<Document> sliceByTokens(String text) {
// 使用TokenTextSplitter进行切片
......@@ -386,18 +388,23 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
List<DocumentSegmentResult> results = new ArrayList<>();
for (SysFile file : request.getFiles()) {
List<Document> segments = new ArrayList<>();
if (Objects.equals(file.getType(), "docx")) {
String markdown = wordAllWithImages.readAll(sysFileService.getFileStream(file.getBucketName(), file.getFileName()));
List<Document> documents = sliceByTokens(markdown);
if (!CollectionUtils.isEmpty(documents)) {
segments.addAll(documents);
}
} else {
InputStreamResource resource = new InputStreamResource(sysFileService.getFileStream(file.getBucketName(), file.getFileName()));
List<Document> documents = sliceByTokens(resource);
// 读取文档内容 - 使用新的PDF切片函数
if (!CollectionUtils.isEmpty(documents)) {
segments.addAll(documents);
}
}
// String docText = documentParseService.extractText(sysFileService.getFileStream(file.getBucketName(), file.getFileName()));
// List<Document> segments = sliceByTokens(docText);
SliceStrategy sliceStrategy = SliceStrategy.CUSTOM;
// if ("pdf".equals(file.getType())) {
// sliceStrategy = SliceStrategy.PARAGRAPH;
// }
// 读取文档内容 - 使用新的PDF切片函数
List<Document> segments = slicePdfDocument(
file.getBucketName(), file.getFileName(),
sliceStrategy
);
if (segments.isEmpty()) {
log.warn("文档解析失败或内容为空: {}", file.getOriginal());
continue;
......@@ -432,16 +439,28 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
List<DocumentSegmentResult> results = new ArrayList<>();
for (SysFile file : request.getFiles()) {
TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(sysFileService.getFileStream(file.getBucketName(), file.getFileName())));
List<Document> pageDocuments = tikaDocumentReader.read();
// 合并所有页面内容
StringBuilder text = new StringBuilder();
for (Document pageDoc : pageDocuments) {
if (StringUtils.hasText(pageDoc.getText())) {
text.append(pageDoc.getText()).append("\n");
if (Objects.equals(file.getType(), "docx")) {
try {
String markdown = wordAllWithImages.readAll(sysFileService.getFileStream(file.getBucketName(), file.getFileName()));
text.append(markdown);
} catch (Exception e) {
log.warn("文档{}解析失败: {}", file.getOriginal(), e.getMessage());
continue;
}
} else {
TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(sysFileService.getFileStream(file.getBucketName(), file.getFileName())));
List<Document> pageDocuments = tikaDocumentReader.read();
for (Document pageDoc : pageDocuments) {
if (StringUtils.hasText(pageDoc.getText())) {
text.append(pageDoc.getText()).append("\n");
}
}
}
if (text.isEmpty()) {
log.warn("文档解析失败或内容为空: {}", file.getOriginal());
continue;
......@@ -451,13 +470,13 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
List<Document> segments = new ArrayList<>();
for (String s : texts) {
if(org.apache.commons.lang3.StringUtils.isBlank(s)){
if (org.apache.commons.lang3.StringUtils.isBlank(s)) {
continue;
}
if(s.length()>request.getMaxLength()){
s = s.substring(0,request.getMaxLength());
if (s.length() > request.getMaxLength()) {
s = s.substring(0, request.getMaxLength());
}
Document document=new Document(s.trim());
Document document = new Document(s.trim());
segments.add(document);
}
......@@ -516,7 +535,7 @@ public class KnowledgeDocumentServiceImpl extends ServiceImpl<KnowledgeDocumentM
// 构建metadata
Map<String, Object> metadata = new HashMap<>();
metadata.put("knowledgeBaseId", document.getKnowledgeBaseId());
metadata.put("title", org.apache.commons.lang3.StringUtils.isBlank(vo.getTitle())?"":vo.getTitle());
metadata.put("title", org.apache.commons.lang3.StringUtils.isBlank(vo.getTitle()) ? "" : vo.getTitle());
metadata.put("documentId", document.getId());
metadata.put("fileName", document.getFileName());
metadata.put("filePath", document.getFilePath());
......
package com.ask.tools;
import com.ask.common.core.FileTemplate;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Component
@Slf4j
public class SqlTools {
@Autowired
private JdbcTemplate jdbcTemplate;
@Tool(description = "查询数据库 入参SQL (String)")
public List<Map<String, Object>> selectBySql(@ToolParam(description = "SQL query string") String sql) {
if (StringUtils.isEmpty(sql)) {
return Collections.emptyList();
}
try {
return jdbcTemplate.queryForList(sql);
} catch (Exception e) {
log.error("数据库查询异常:sql{} 异常: {}",sql,e.getMessage());
return Collections.emptyList();
}
}
}
package com.ask.utils;
import com.ask.api.entity.AskImagesRecord;
import com.ask.mapper.AskImagesRecordMapper;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import org.apache.poi.xwpf.usermodel.*;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.io.*;
import java.util.Base64;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
@Component
@RequiredArgsConstructor
public final class WordAllWithImages {
private static final AtomicInteger IMG_COUNTER = new AtomicInteger(0);
private final AskImagesRecordMapper askImagesRecordMapper;
@Value("${file.local.base-url:http://8.152.98.45/api}")
private String baseUrl;
/**
* 读取 .docx 全部内容:正文 + 表格 + 图片(Base64 内嵌)
*/
public String readAll(InputStream inputStream) {
StringBuilder out = new StringBuilder();
try (XWPFDocument doc = new XWPFDocument(inputStream)) {
for (IBodyElement e : doc.getBodyElements()) {
switch (e.getElementType()) {
case PARAGRAPH:
XWPFParagraph p = (XWPFParagraph) e;
if (p.getPartType() != BodyType.TABLECELL) {
out.append(paragraphToMarkdown(p)).append("\n\n");
}
break;
case TABLE:
out.append(tableToMd((XWPFTable) e)).append("\n\n");
break;
default:
break;
}
}
}catch (Exception e){
return "";
}
return out.toString().trim();
}
/* -------------------- 段落(含图片)转 Markdown -------------------- */
private String paragraphToMarkdown(XWPFParagraph p) {
StringBuilder sb = new StringBuilder();
for (XWPFRun run : p.getRuns()) {
String text = run.text();
if (text != null && !text.isEmpty()) {
sb.append(text);
}
/* 处理内嵌图片 */
for (XWPFPicture pic : run.getEmbeddedPictures()) {
byte[] bytes = pic.getPictureData().getData();
AskImagesRecord askImagesRecord = new AskImagesRecord();
String imageName = UUID.randomUUID().toString().replace("-", "")+".png";
askImagesRecord.setImageName(imageName);
askImagesRecord.setImageData(bytes);
askImagesRecordMapper.insert(askImagesRecord);
sb.append("![image](")
.append(baseUrl)
.append("/admin/sys-file/images/")
.append(askImagesRecord.getId())
.append(")");
}
}
return sb.toString().trim();
}
/* -------------------- 表格转 Markdown -------------------- */
private static String tableToMd(XWPFTable tbl) {
StringBuilder md = new StringBuilder();
for (int r = 0; r < tbl.getNumberOfRows(); r++) {
XWPFTableRow row = tbl.getRow(r);
md.append("|");
for (XWPFTableCell cell : row.getTableCells()) {
md.append(" ").append(cell.getText().trim().replace("|", "\\|")).append(" |");
}
md.append("\n");
if (r == 0) {
md.append("|");
for (int i = 0; i < row.getTableCells().size(); i++) md.append(" --- |");
md.append("\n");
}
}
return md.toString().trim();
}
}
\ No newline at end of file
......@@ -18,6 +18,13 @@ spring:
password: e5d039e4ba5246068
driver-class-name: org.postgresql.Driver
ai:
# mcp:
# client:
# sse:
# connections:
# charts:
# url: http://81.70.183.25:18000
model:
embedding: openai
vectorstore:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment