Commit 60688825 authored by 林洋洋's avatar 林洋洋

添加模型管理接口,模型采用配置

parent 2c8dcb2a
package com.ask.api.entity;
import com.ask.common.base.BaseEntity;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
@EqualsAndHashCode(callSuper = true)
@Data
@TableName("ask_model")
@Schema(description = "智能模型实体")
public class AskModel extends BaseEntity {
/**
* 主键ID
*/
@TableId(type = IdType.AUTO)
@Schema(description = "主键ID")
private Long id;
@Schema(description = "模型名称")
private String name;
@Schema(description = "模型类型")
private String modelType;
@Schema(description = "具体模型名称")
private String modelName;
@Schema(description = "模型提供商")
private String provider;
@Schema(description = "用户ID")
private Long userId;
@Schema(description = "模型状态", example = "1")
private Integer status;
@Schema(description = "基础URL地址")
private String baseUrl;
@Schema(description = "API密钥")
private String key;
@Schema(description = "最大token数", example = "4096")
private Integer maxTokens;
@Schema(description = "温度参数", example = "0.7")
private Float temperature;
}
......@@ -61,7 +61,7 @@
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-ollama</artifactId>
<artifactId>spring-ai-ollama</artifactId>
</dependency>
<!-- undertow容器 -->
<dependency>
......@@ -100,13 +100,13 @@
</dependency>
<!-- Spring AI -->
<!-- <dependency>-->
<!-- <groupId>org.springframework.ai</groupId>-->
<!-- <artifactId>spring-ai-starter-model-deepseek</artifactId>-->
<!-- </dependency>-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-deepseek</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-openai</artifactId>
<artifactId>spring-ai-openai</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
......@@ -114,7 +114,7 @@
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-vector-store-pgvector</artifactId>
<artifactId>spring-ai-pgvector-store</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
......@@ -152,6 +152,15 @@
<artifactId>poi-tl</artifactId>
<version>1.12.1</version>
</dependency>
<dependency>
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
<version>2.9.0</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-deepseek</artifactId>
</dependency>
</dependencies>
<repositories>
<repository>
......
package com.ask.config;
import com.github.benmanes.caffeine.cache.Caffeine;
import org.springframework.cache.CacheManager;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.cache.caffeine.CaffeineCacheManager;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.concurrent.TimeUnit;
/**
* 本地缓存Caffeine配置类
*/
@EnableCaching
@Configuration
public class CaffeineCacheConfiguration {
@Bean("caffeineCacheManager")
public CacheManager caffeineCacheManager() {
CaffeineCacheManager cacheManager = new CaffeineCacheManager();
cacheManager.setCaffeine(Caffeine.newBuilder()
.expireAfterWrite(30, TimeUnit.MINUTES)
.maximumSize(10000));
return cacheManager;
}
}
......@@ -9,13 +9,14 @@ import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.client.reactive.ReactorClientHttpConnector;
......@@ -27,6 +28,8 @@ import reactor.netty.http.client.HttpClient;
import java.util.ArrayList;
import java.util.List;
import static org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType.HNSW;
@Configuration
public class CommonConfiguration {
......@@ -49,19 +52,19 @@ public class CommonConfiguration {
.build();
}
@Bean
public ChatClient openAiChatClient(OpenAiChatModel model) {
return ChatClient.builder(model)
.defaultAdvisors()
.build();
}
// @Bean
// public ChatClient openAiChatClient(OpenAiChatModel model) {
// return ChatClient.builder(model)
// .defaultAdvisors()
// .build();
// }
@Bean
public ChatClient deepseekChatClient(DeepSeekChatModel model) {
return ChatClient.builder(model)
.defaultAdvisors()
.build();
}
// @Bean
// public ChatClient deepseekChatClient(DeepSeekChatModel model) {
// return ChatClient.builder(model)
// .defaultAdvisors()
// .build();
// }
@Bean
public PromptChatMemoryAdvisor promptChatMemoryAdvisor(ChatMemory chatMemory){
return PromptChatMemoryAdvisor.builder(chatMemory).build();
......@@ -72,20 +75,28 @@ public class CommonConfiguration {
return MessageChatMemoryAdvisor.builder(chatMemory).build();
}
@Bean
public RetrievalAugmentationAdvisor retrievalAugmentationAdvisor(VectorStore vectorStore) {
// @Bean
// public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
// return PgVectorStore.builder(jdbcTemplate, embeddingModel)
// .vectorTableName("ask_vector_store") // Optional: defaults to "vector_store"
// .maxDocumentBatchSize(10000) // Optional: defaults to 10000
// .build();
// }
return RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder()
.similarityThreshold(0.70)
.topK(5)
.vectorStore(vectorStore)
.build())
.documentPostProcessors(new MyDocumentPostProcessor())
.queryAugmenter(ContextualQueryAugmenter.builder()
.allowEmptyContext(true)
.build())
.build();
}
// @Bean
// public RetrievalAugmentationAdvisor retrievalAugmentationAdvisor(VectorStore vectorStore) {
//
// return RetrievalAugmentationAdvisor.builder()
// .documentRetriever(VectorStoreDocumentRetriever.builder()
// .similarityThreshold(0.70)
// .topK(5)
// .vectorStore(vectorStore)
// .build())
// .documentPostProcessors(new MyDocumentPostProcessor())
// .queryAugmenter(ContextualQueryAugmenter.builder()
// .allowEmptyContext(true)
// .build())
// .build();
// }
}
package com.ask.controller;
import com.ask.api.entity.AskModel;
import com.ask.api.vo.KeyAndValueVO;
import com.ask.common.core.R;
import com.ask.enums.ModelProviderEnum;
import com.ask.enums.ModelTypeEnum;
import com.ask.service.AskModelService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor;
import org.apache.poi.util.StringUtil;
import org.springframework.web.bind.annotation.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* @author tarzan
* @date 2024-12-25 12:22:22
*/
@Tag(description = "智能模型管理", name = "智能模型管理")
@RestController
@RequestMapping("/ai/model")
@AllArgsConstructor
public class AskModelController {
private final AskModelService askModelService;
@Operation(summary = "创建模型")
@PostMapping("/model")
public R<Boolean> createModel(@Parameter(description = "模型信息") @RequestBody AskModel model){
return R.ok(askModelService.createModel(model));
}
@Operation(summary = "获取模型列表")
@GetMapping("/model")
public R<List<AskModel>> models(@Parameter(description = "模型名称") String name, @Parameter(description = "模型类型") String modelType, @Parameter(description = "提供商") String provider){
return R.ok(askModelService.models(name,modelType,provider));
}
@Operation(summary = "根据ID获取模型")
@GetMapping("/model/{id}")
public R<AskModel> get(@Parameter(description = "模型ID") @PathVariable Long id){
return R.ok(askModelService.getById(id));
}
@Operation(summary = "删除模型")
@DeleteMapping("/model/{id}")
public R<Boolean> delete(@Parameter(description = "模型ID") @PathVariable Long id){
return R.ok(askModelService.removeById(id));
}
@Operation(summary = "更新模型")
@PutMapping("/model/{id}")
public R<AskModel> update(@Parameter(description = "模型ID") @PathVariable Long id, @Parameter(description = "模型信息") @RequestBody AskModel model){
return R.ok(askModelService.updateModel(id,model));
}
@Operation(summary = "获取厂商列表")
@GetMapping("/providers")
public R<List<Map<String, String>>> getProviders(){
List<Map<String, String>> providers = Arrays.stream(ModelProviderEnum.values())
.map(provider -> {
Map<String, String> providerMap = new HashMap<>();
providerMap.put("provider", provider.getProvider());
providerMap.put("name", provider.getName());
return providerMap;
})
.collect(Collectors.toList());
return R.ok(providers);
}
@Operation(summary = "获取模型类型列表")
@GetMapping("/types")
public R<List<Map<String, String>>> getModelTypes(){
List<Map<String, String>> types = Arrays.stream(ModelTypeEnum.values())
.map(type -> {
Map<String, String> typeMap = new HashMap<>();
typeMap.put("code", type.getCode());
typeMap.put("name", type.getName());
typeMap.put("description", type.getDescription());
return typeMap;
})
.collect(Collectors.toList());
return R.ok(types);
}
}
......@@ -41,7 +41,6 @@ import java.util.UUID;
public class AskVectorStoreController {
private final AskVectorStoreService askVectorStoreService;
private final VectorStore vectorStore;
/**
* 分页查询向量存储
......@@ -144,7 +143,7 @@ public class AskVectorStoreController {
return R.failed("ID列表不能为空");
}
try {
vectorStore.delete(ids);
askVectorStoreService.removeBatchByIds(ids);
return R.ok(true);
} catch (Exception e) {
log.error("批量删除向量存储失败,IDs: {}, 错误: {}", ids, e.getMessage(), e);
......
......@@ -5,6 +5,7 @@ import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import com.ask.api.entity.ChatConversation;
import com.ask.common.core.R;
import com.ask.service.AskModelService;
import com.ask.service.ChatConversationService;
import com.ask.service.impl.ChatService;
import com.ask.service.impl.RagPromptService;
......@@ -23,9 +24,13 @@ import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolExecutionResult;
......@@ -52,32 +57,24 @@ import java.util.concurrent.atomic.AtomicBoolean;
@Tag(description = "ai", name = "AI对话模块")
public class ChatController {
private final ChatClient openAiChatClient;
private final ChatClient deepseekChatClient;
private final AskModelService askModelService;
private final ChatConversationService chatConversationService;
private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;
private final MessageChatMemoryAdvisor messageChatMemoryAdvisor;
private final PromptChatMemoryAdvisor promptChatMemoryAdvisor;
private final VectorStore vectorStore;
private final ChatService chatService;
private final RagPromptService ragPromptService;
private final ChatMemory chatMemory;
private final OpenAiChatModel openAiChatModel;
private final ExcelTools excelTools;
private final SqlTools sqlTools;
// private final ToolCallbackProvider toolCallbackProvider;
/**
* 获取会话ID
......@@ -109,13 +106,18 @@ public class ChatController {
@Operation(summary = "普通对话", description = "普通对话")
@GetMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> chat(@RequestParam String message,
@RequestParam String conversationId) {
@RequestParam String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
Long actualModelId = modelId.orElse(1L);
Message systemMessage = new SystemMessage("你是一个AI问答助手,请准确回答用户问题,回答要求:请使用markdown格式输出");
Message userMessage = new UserMessage(message);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
return FluxUtils.wrapDeepSeekStream(deepseekChatClient.prompt(prompt)
ChatClient chatClient = askModelService.getChatClientById(actualModelId);
if (Objects.isNull(chatClient)) {
return Flux.error(new Throwable("模型创建失败"));
}
return FluxUtils.wrapDeepSeekStream(chatClient.prompt(prompt)
.advisors(messageChatMemoryAdvisor)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.stream()
......@@ -132,7 +134,10 @@ public class ChatController {
*/
@Operation(summary = "知识库对话", description = "知识库对话")
@GetMapping(value = "/rag/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> ragChat(@RequestParam @Parameter(description = "对话内容") String message, @RequestParam @Parameter(description = "会话ID") String conversationId) {
public Flux<String> ragChat(@RequestParam @Parameter(description = "对话内容") String message,
@RequestParam @Parameter(description = "会话ID") String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
Long actualModelId = modelId.orElse(1L);
//获取对话历史
String historyMemory = chatService.getHistoryMemoryAsString(conversationId);
......@@ -151,7 +156,11 @@ public class ChatController {
String userPrompt = ragPromptService.createRagPrompt(message, context, historyMemory);
StringBuilder contentBuilder = new StringBuilder();
return FluxUtils.wrapDeepSeekStream(deepseekChatClient.prompt()
ChatClient chatClient = askModelService.getChatClientById(actualModelId);
if (Objects.isNull(chatClient)) {
return Flux.error(new Throwable("模型创建失败"));
}
return FluxUtils.wrapDeepSeekStream(chatClient.prompt()
.user(userPrompt)
.system("你是一个智能助手,基于以下上下文和历史对话回答问题,请用简洁的语言回答问题,并确保答案准确,要求" +
"1.以 Markdown 格式输出")
......@@ -169,13 +178,17 @@ public class ChatController {
@Operation(summary = "智能数据报表对话", description = "智能数据报表对话")
@GetMapping(value = "/chat/report", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> reportChat(@RequestParam String message,
@RequestParam String conversationId) {
@RequestParam String conversationId,
@RequestParam(required = false) Optional<Long> modelId) {
Long actualModelId = modelId.orElse(1L);
Message systemMessage = new SystemMessage("你是一个AI问答助手,请用回答用户问题,使用相关工具");
Message userMessage = new UserMessage(message);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
return FluxUtils.wrapDeepSeekStream(deepseekChatClient.prompt(prompt)
ChatClient chatClient = askModelService.getChatClientById(actualModelId);
if (Objects.isNull(chatClient)) {
return Flux.error(new Throwable("模型创建失败"));
}
return FluxUtils.wrapDeepSeekStream(chatClient.prompt(prompt)
.advisors(messageChatMemoryAdvisor)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.tools(excelTools)
......@@ -185,26 +198,16 @@ public class ChatController {
}
// @Operation(summary = "智能问数据对话", description = "智能问数据对话")
// @GetMapping(value = "/chat/data", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
// public Flux<String> dataChat(@RequestParam String message,
// @RequestParam Long knowledgeBaseId,
// @RequestParam String conversationId) {
//
// Message systemMessage = new SystemMessage("你是一个AI智能问数助手,数据库采用postgres 16,请使用相关工具回答用户问题");
// Message userMessage = new UserMessage(message);
// Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
//
// return FluxUtils.wrapDeepSeekStream(deepseekChatClient.prompt(prompt)
// .advisors(messageChatMemoryAdvisor)
// .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
//// .advisors(retrievalAugmentationAdvisor)
//// .advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, "knowledgeBaseId == "+knowledgeBaseId))
// .tools(sqlTools)
// .toolCallbacks(toolCallbackProvider)
// .advisors()
// .stream()
// .chatResponse());
//
// }
public void test() {
ChatModel chatModel = DeepSeekChatModel.builder()
.deepSeekApi(DeepSeekApi.builder().baseUrl("").apiKey("TEST").build())
.defaultOptions(DeepSeekChatOptions.builder().model("deepseek-r1").temperature(66.6).maxTokens(10000).build())
.build();
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors()
.build();
}
}
\ No newline at end of file
......@@ -2,6 +2,7 @@ package com.ask.controller;
import com.ask.api.entity.KnowledgeBase;
import com.ask.common.core.R;
import com.ask.service.AskModelService;
import com.ask.service.KnowledgeBaseService;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
......@@ -13,16 +14,20 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springdoc.core.annotations.ParameterObject;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.http.HttpHeaders;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
......@@ -39,10 +44,12 @@ import java.util.stream.Collectors;
public class KnowledgeBaseController {
private final KnowledgeBaseService knowledgeBaseService;
private final VectorStore vectorStore;
private final AskModelService askModelService;
/**
* 分页查询
*
* @param page 分页对象
* @param name 查询条件
* @return 分页数据
......@@ -53,12 +60,13 @@ public class KnowledgeBaseController {
@Parameter(description = "知识库名称") @RequestParam(required = false) String name) {
return R.ok(knowledgeBaseService.page(page,
Wrappers.lambdaQuery(KnowledgeBase.class)
.like(org.apache.commons.lang3.StringUtils.isNotBlank(name),KnowledgeBase::getName, name)
.like(org.apache.commons.lang3.StringUtils.isNotBlank(name), KnowledgeBase::getName, name)
.orderByAsc(KnowledgeBase::getId)));
}
/**
* 通过id查询知识库
*
* @param id id
* @return R
*/
......@@ -70,6 +78,7 @@ public class KnowledgeBaseController {
/**
* 新增知识库
*
* @param knowledgeBase 知识库
* @return R
*/
......@@ -91,6 +100,7 @@ public class KnowledgeBaseController {
/**
* 修改知识库
*
* @param knowledgeBase 知识库
* @return R
*/
......@@ -112,6 +122,7 @@ public class KnowledgeBaseController {
/**
* 通过id删除知识库
*
* @param id id
* @return R
*/
......@@ -123,6 +134,7 @@ public class KnowledgeBaseController {
/**
* 校验知识库名称是否重复
*
* @param name 知识库名称
* @param id 知识库ID(可选,修改时传入)
* @return R
......@@ -145,6 +157,7 @@ public class KnowledgeBaseController {
/**
* 向量搜索命中测试
*
* @param knowledgeBaseId 知识库ID
* @param content 搜索内容
* @param topK 返回最相似的K个结果(可选,默认5)
......@@ -192,8 +205,10 @@ public class KnowledgeBaseController {
new Filter.Value(1)
)
);
VectorStore vectorStore = askModelService.getVectorStoreById(2L);
if(Objects.isNull(vectorStore)){
return R.failed("向量模型获取失败");
}
// 执行向量搜索
List<Document> searchResults = vectorStore.similaritySearch(SearchRequest.builder().query(content).filterExpression(filterExpression).similarityThreshold(similarityThreshold).topK(topK).build());
......
......@@ -54,8 +54,6 @@ public class KnowledgeDocumentController {
private final KnowledgeDocumentService knowledgeDocumentService;
private final VectorStore vectorStore;
private final AskVectorStoreService askVectorStoreService;
......@@ -153,15 +151,14 @@ public class KnowledgeDocumentController {
// 删除知识库文档
knowledgeDocumentService.removeById(id);
// 构建基于documentId的过滤条件
Filter.Expression filterExpression = new Filter.Expression(
Filter.ExpressionType.EQ,
new Filter.Key("documentId"),
new Filter.Value(id)
);
// 删除向量存储中对应的文档切片
vectorStore.delete(filterExpression);
// // 构建基于documentId的过滤条件
// Filter.Expression filterExpression = new Filter.Expression(
// Filter.ExpressionType.EQ,
// new Filter.Key("documentId"),
// new Filter.Value(id)
// );
askVectorStoreService.remove(Wrappers.lambdaQuery(AskVectorStore.class)
.apply("metadata->>'documentId' = {0}", id)); // 删除向量存储中对应的文档切片
log.info("成功删除文档及其向量数据,文档ID: {}", id);
} catch (Exception e) {
......
package com.ask.enums;
import com.ask.model.IBaseModel;
import com.ask.model.IDeepSeekModel;
import com.ask.model.IOllamaModel;
import com.ask.model.IOpenAiModel;
import lombok.Getter;
import java.util.HashMap;
import java.util.Map;
@Getter
public enum ModelProviderEnum {
DeepSeek("DeepSeek", "DeepSeek",new IDeepSeekModel()),
Openai("OpenAI", "Openai",new IOpenAiModel()),
OLlama("OLlama", "model_ollama_provider",new IOllamaModel()),
;
private final String name;
private final String provider;
private final IBaseModel iBaseModel;
ModelProviderEnum(String name, String provider,IBaseModel iBaseModel) {
this.name = name;
this.provider = provider;
this.iBaseModel = iBaseModel;
}
public static IBaseModel get(String provider) {
if (provider == null || provider.isEmpty()) {
throw new IllegalArgumentException("Provider name cannot be null or empty.");
}
return getMap().getOrDefault(provider, null);
}
/**
* 获取模型提供者的映射。
*
* @return 包含所有模型提供者名称和其实例的映射
*/
public static Map<String, IBaseModel> getMap() {
Map<String, IBaseModel> map = new HashMap<>();
for (ModelProviderEnum providerEnum : values()) {
map.put(providerEnum.getProvider(), providerEnum.getIBaseModel());
}
return map;
}
}
package com.ask.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* AI模型类型枚举
*
* @author ai
* @date 2024/12/20
*/
@Getter
@AllArgsConstructor
public enum ModelTypeEnum {
/**
* 大语言模型
*/
LLM("LLM", "大语言模型", "用于文本理解、生成和对话的大型语言模型"),
/**
* 文本向量模型
*/
EMBEDDING("EMBEDDING", "文本向量模型", "将文本转换为向量表示的嵌入模型"),
/**
* 语音转文本模型
*/
STT("STT", "语音识别模型", "将语音转换为文本的语音识别模型"),
/**
* 文本转语音模型
*/
TTS("TTS", "语音生成模型", "将文本转换为语音的语音合成模型"),
/**
* AI视觉模型
*/
IMAGE("IMAGE", "AI视觉模型", "用于图像理解和分析的视觉模型"),
/**
* 文生图模型
*/
TTI("TTI", "文生图模型", "根据文本描述生成图像的模型"),
/**
* 重排序模型
*/
RERANKER("RERANKER", "重排序模型", "用于搜索结果重新排序的模型");
/**
* 类型编码
*/
private final String code;
/**
* 类型名称
*/
private final String name;
/**
* 类型描述
*/
private final String description;
/**
* 根据编码获取枚举
*/
public static ModelTypeEnum getByCode(String code) {
for (ModelTypeEnum type : values()) {
if (type.getCode().equals(code)) {
return type;
}
}
return null;
}
/**
* 根据名称获取枚举
*/
public static ModelTypeEnum getByName(String name) {
for (ModelTypeEnum type : values()) {
if (type.getName().equals(name)) {
return type;
}
}
return null;
}
/**
* 判断是否为有效的模型类型
*/
public static boolean isValidType(String code) {
return getByCode(code) != null;
}
}
\ No newline at end of file
package com.ask.mapper;
import com.ask.api.entity.AskModel;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
/**
* @author tarzan
* @date 2024-12-25 12:22:22
*/
@Mapper
public interface AskModelMapper extends BaseMapper<AskModel>{
}
package com.ask.model;
import com.ask.api.entity.AskModel;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
public interface IBaseModel {
public ChatModel buildChatModel(AskModel askModel);
public EmbeddingModel buildEmbedding(AskModel askModel);
}
package com.ask.model;
import com.ask.api.entity.AskModel;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.ai.embedding.EmbeddingModel;
public class IDeepSeekModel implements IBaseModel {
@Override
public ChatModel buildChatModel(AskModel askModel) {
return DeepSeekChatModel.builder()
.deepSeekApi(DeepSeekApi.builder().baseUrl(askModel.getBaseUrl()).apiKey(askModel.getKey()).build())
.defaultOptions(DeepSeekChatOptions.builder().model(askModel.getModelName()).temperature(Double.valueOf(askModel.getTemperature())).maxTokens(askModel.getMaxTokens()).build())
.build();
}
@Override
public EmbeddingModel buildEmbedding(AskModel askModel) {
return null;
}
}
package com.ask.model;
import com.ask.api.entity.AskModel;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.retry.RetryUtils;
public class IOllamaModel implements IBaseModel {
@Override
public ChatModel buildChatModel(AskModel askModel) {
OllamaApi ollamaApi = OllamaApi.builder()
.baseUrl(askModel.getBaseUrl()).build();
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(
OllamaOptions.builder()
.model(askModel.getModelName())
.temperature(Double.valueOf(askModel.getTemperature()))
.build())
.build();
}
@Override
public EmbeddingModel buildEmbedding(AskModel askModel) {
OllamaApi ollamaApi = OllamaApi.builder()
.baseUrl(askModel.getBaseUrl())
.build();
return OllamaEmbeddingModel.builder().defaultOptions(OllamaOptions.builder()
.model(askModel.getModelName())
.truncate(false)
.build()).ollamaApi(ollamaApi).build();
}
}
package com.ask.model;
import com.ask.api.entity.AskModel;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.retry.RetryUtils;
public class IOpenAiModel implements IBaseModel {
@Override
public ChatModel buildChatModel(AskModel askModel) {
return null;
}
@Override
public EmbeddingModel buildEmbedding(AskModel askModel) {
OpenAiApi openAiApi = OpenAiApi.builder()
.baseUrl(askModel.getBaseUrl())
.apiKey(askModel.getKey())
.build();
return new OpenAiEmbeddingModel(
openAiApi,
MetadataMode.EMBED,
OpenAiEmbeddingOptions.builder()
.model(askModel.getModelName())
.dimensions(askModel.getMaxTokens())
.build(),
RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
}
package com.ask.service;
import com.ask.api.entity.AskModel;
import com.baomidou.mybatisplus.extension.service.IService;
import org.apache.poi.ss.formula.functions.T;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import java.util.List;
public interface AskModelService extends IService<AskModel> {
public Boolean createModel(AskModel askModel);
AskModel updateModel(Long id, AskModel model);
List<AskModel> models(String name, String modelType, String provider);
public AskModel getModelInfoById(Long modelId);
public ChatClient getChatClientById(Long modelId);
public EmbeddingModel getEmbeddingModelById(Long modelId);
public VectorStore getVectorStoreById(Long modelId);
}
package com.ask.service.impl;
import com.ask.api.entity.AskModel;
import com.ask.enums.ModelProviderEnum;
import com.ask.mapper.AskModelMapper;
import com.ask.model.IBaseModel;
import com.ask.service.AskModelService;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Objects;
@Slf4j
@Service
@AllArgsConstructor
public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> implements AskModelService {
private final JdbcTemplate jdbcTemplate;
@Override
public Boolean createModel(AskModel AskModel) {
// Long userId = StpUtil.getLoginIdAsLong();
// long count = this.lambdaQuery().eq(AskModel::getName, AskModel.getName()).count();
// if (count > 0) {
// return false;
// }
//AskModel.setUserId(userId);
AskModel.setStatus(1);
return save(AskModel);
}
@Override
public AskModel getModelInfoById(Long modelId) {
return this.getById(modelId);
}
@Cacheable(cacheNames = "chatClient", key = "#modelId")
@Override
public ChatClient getChatClientById(Long modelId) {
AskModel askModel = this.getById(modelId);
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
if (Objects.isNull(baseModel)) {
return null;
}
ChatModel chatModel = baseModel.buildChatModel(askModel);
return ChatClient.builder(chatModel).build();
}
@Cacheable(cacheNames = "embeddingMode", key = "#modelId")
@Override
public EmbeddingModel getEmbeddingModelById(Long modelId) {
AskModel askModel = this.getById(modelId);
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
if (Objects.isNull(baseModel)) {
return null;
}
return baseModel.buildEmbedding(askModel);
}
@Override
@Cacheable(cacheNames = "vectorStore", key = "#modelId")
public VectorStore getVectorStoreById(Long modelId) {
EmbeddingModel embeddingModel = this.getEmbeddingModelById(modelId);
if (Objects.isNull(embeddingModel)) {
return null;
}
return PgVectorStore.builder(jdbcTemplate, embeddingModel)
.vectorTableName("ask_vector_store") // Optional: defaults to "vector_store"
.maxDocumentBatchSize(10000) // Optional: defaults to 10000
.build();
}
@Override
public AskModel updateModel(Long id, AskModel model) {
model.setId(id);
this.updateById(model);
return model;
}
@Override
public List<AskModel> models(String name, String modelType, String provider) {
LambdaQueryWrapper<AskModel> wrapper = Wrappers.<AskModel>lambdaQuery()
.like(StringUtils.isNotBlank(name), AskModel::getName, name)
.eq(StringUtils.isNotBlank(provider), AskModel::getProvider, provider)
.eq(StringUtils.isNotBlank(modelType), AskModel::getModelType, modelType)
.orderByAsc(AskModel::getId);
return baseMapper.selectList(wrapper);
}
}
......@@ -3,6 +3,7 @@ package com.ask.service.impl;
import com.ask.api.entity.AskVectorStore;
import com.ask.api.vo.KeyAndValueVO;
import com.ask.mapper.AskVectorStoreMapper;
import com.ask.service.AskModelService;
import com.ask.service.AskVectorStoreService;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
......@@ -41,13 +42,11 @@ import java.util.stream.Collectors;
@Service
public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper, AskVectorStore> implements AskVectorStoreService {
@Autowired
private EmbeddingModel embeddingModel;
@Autowired
private JdbcTemplate jdbcTemplate;
private AskModelService askModelService;
......@@ -74,7 +73,7 @@ public class AskVectorStoreServiceImpl extends ServiceImpl<AskVectorStoreMapper,
String result = (title == null || title.trim().isEmpty()) ?
(content == null ? "" : content) :
title.trim() + "\n" + (content == null ? "" : content);
return embeddingModel.embed(result);
return askModelService.getEmbeddingModelById(2L).embed(result);
})
.toList();
......
......@@ -2,6 +2,7 @@ package com.ask.service.impl;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import com.ask.service.AskModelService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
......@@ -33,11 +34,13 @@ import static org.springframework.ai.chat.messages.MessageType.ASSISTANT;
@Service
public class ChatService {
private final VectorStore vectorStore;
private final ChatMemory chatMemory;
@Value("${file.local.base-url:http://8.152.98.45/api}")
private String baseUrl;
private final AskModelService askModelService;
/**
* rag召回
*
......@@ -48,6 +51,9 @@ public class ChatService {
* @return 召回文档
*/
public List<Document> retrieveDocuments(String query, double threshold, int topK, Filter.Expression filter) {
VectorStore vectorStore = askModelService.getVectorStoreById(2L);
List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
.query(query)
.filterExpression(filter)
......@@ -57,7 +63,6 @@ public class ChatService {
if (CollectionUtils.isEmpty(documents)) {
return List.of();
}
// 根据ID去重
return documents.stream()
.filter(distinctByKey(Document::getId)) // 根据ID去重
......
package com.ask.tools;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class EchartsTools {
@Tool(description = "生成ECharts饼图配置,入参:标题、数据列表(name,value格式)")
public String generatePieChartConfig(
@ToolParam(description = "图表标题") String title,
@ToolParam(description = "数据列表,格式:name1,value1;name2,value2") String dataStr) {
log.info("生成饼图配置 - 标题: {}, 数据: {}", title, dataStr);
try {
JSONObject config = new JSONObject();
// 标题配置
JSONObject titleObj = new JSONObject();
titleObj.set("text", title);
titleObj.set("left", "center");
config.set("title", titleObj);
// 提示框配置
JSONObject tooltip = new JSONObject();
tooltip.set("trigger", "item");
tooltip.set("formatter", "{a} <br/>{b}: {c} ({d}%)");
config.set("tooltip", tooltip);
// 图例配置
JSONObject legend = new JSONObject();
legend.set("orient", "vertical");
legend.set("left", "left");
config.set("legend", legend);
// 系列配置
JSONArray series = new JSONArray();
JSONObject seriesItem = new JSONObject();
seriesItem.set("name", title);
seriesItem.set("type", "pie");
seriesItem.set("radius", "50%");
// 数据解析
JSONArray data = new JSONArray();
if (dataStr != null && !dataStr.trim().isEmpty()) {
String[] items = dataStr.split(";");
for (String item : items) {
String[] parts = item.split(",");
if (parts.length == 2) {
JSONObject dataItem = new JSONObject();
dataItem.set("name", parts[0].trim());
dataItem.set("value", Double.parseDouble(parts[1].trim()));
data.add(dataItem);
}
}
}
seriesItem.set("data", data);
JSONObject emphasis = new JSONObject();
JSONObject itemStyle = new JSONObject();
itemStyle.set("shadowBlur", 10);
itemStyle.set("shadowOffsetX", 0);
itemStyle.set("shadowColor", "rgba(0, 0, 0, 0.5)");
emphasis.set("itemStyle", itemStyle);
seriesItem.set("emphasis", emphasis);
series.add(seriesItem);
config.set("series", series);
return config.toString();
} catch (Exception e) {
log.error("生成饼图配置失败", e);
return "{\"error\": \"生成饼图配置失败: " + e.getMessage() + "\"}";
}
}
@Tool(description = "生成ECharts柱状图配置,入参:标题、X轴分类、Y轴数据")
public String generateBarChartConfig(
@ToolParam(description = "图表标题") String title,
@ToolParam(description = "X轴分类,用逗号分隔") String categories,
@ToolParam(description = "Y轴数据,用逗号分隔") String values) {
log.info("生成柱状图配置 - 标题: {}, 分类: {}, 数据: {}", title, categories, values);
try {
JSONObject config = new JSONObject();
// 标题配置
JSONObject titleObj = new JSONObject();
titleObj.set("text", title);
titleObj.set("left", "center");
config.set("title", titleObj);
// 提示框配置
JSONObject tooltip = new JSONObject();
tooltip.set("trigger", "axis");
tooltip.set("axisPointer", new JSONObject().set("type", "shadow"));
config.set("tooltip", tooltip);
// 网格配置
JSONObject grid = new JSONObject();
grid.set("left", "3%");
grid.set("right", "4%");
grid.set("bottom", "3%");
grid.set("containLabel", true);
config.set("grid", grid);
// X轴配置
JSONArray xAxis = new JSONArray();
JSONObject xAxisItem = new JSONObject();
xAxisItem.set("type", "category");
JSONArray xAxisData = new JSONArray();
if (categories != null && !categories.trim().isEmpty()) {
String[] cats = categories.split(",");
for (String cat : cats) {
xAxisData.add(cat.trim());
}
}
xAxisItem.set("data", xAxisData);
xAxis.add(xAxisItem);
config.set("xAxis", xAxis);
// Y轴配置
JSONArray yAxis = new JSONArray();
JSONObject yAxisItem = new JSONObject();
yAxisItem.set("type", "value");
yAxis.add(yAxisItem);
config.set("yAxis", yAxis);
// 系列配置
JSONArray series = new JSONArray();
JSONObject seriesItem = new JSONObject();
seriesItem.set("name", title);
seriesItem.set("type", "bar");
JSONArray data = new JSONArray();
if (values != null && !values.trim().isEmpty()) {
String[] vals = values.split(",");
for (String val : vals) {
data.add(Double.parseDouble(val.trim()));
}
}
seriesItem.set("data", data);
series.add(seriesItem);
config.set("series", series);
return config.toString();
} catch (Exception e) {
log.error("生成柱状图配置失败", e);
return "{\"error\": \"生成柱状图配置失败: " + e.getMessage() + "\"}";
}
}
@Tool(description = "生成ECharts折线图配置,入参:标题、X轴分类、Y轴数据")
public String generateLineChartConfig(
@ToolParam(description = "图表标题") String title,
@ToolParam(description = "X轴分类,用逗号分隔") String categories,
@ToolParam(description = "Y轴数据,用逗号分隔") String values) {
log.info("生成折线图配置 - 标题: {}, 分类: {}, 数据: {}", title, categories, values);
try {
JSONObject config = new JSONObject();
// 标题配置
JSONObject titleObj = new JSONObject();
titleObj.set("text", title);
titleObj.set("left", "center");
config.set("title", titleObj);
// 提示框配置
JSONObject tooltip = new JSONObject();
tooltip.set("trigger", "axis");
config.set("tooltip", tooltip);
// 图例配置
JSONObject legend = new JSONObject();
legend.set("data", new JSONArray().put(title));
config.set("legend", legend);
// 网格配置
JSONObject grid = new JSONObject();
grid.set("left", "3%");
grid.set("right", "4%");
grid.set("bottom", "3%");
grid.set("containLabel", true);
config.set("grid", grid);
// X轴配置
JSONArray xAxis = new JSONArray();
JSONObject xAxisItem = new JSONObject();
xAxisItem.set("type", "category");
xAxisItem.set("boundaryGap", false);
JSONArray xAxisData = new JSONArray();
if (categories != null && !categories.trim().isEmpty()) {
String[] cats = categories.split(",");
for (String cat : cats) {
xAxisData.add(cat.trim());
}
}
xAxisItem.set("data", xAxisData);
xAxis.add(xAxisItem);
config.set("xAxis", xAxis);
// Y轴配置
JSONArray yAxis = new JSONArray();
JSONObject yAxisItem = new JSONObject();
yAxisItem.set("type", "value");
yAxis.add(yAxisItem);
config.set("yAxis", yAxis);
// 系列配置
JSONArray series = new JSONArray();
JSONObject seriesItem = new JSONObject();
seriesItem.set("name", title);
seriesItem.set("type", "line");
seriesItem.set("stack", "Total");
JSONArray data = new JSONArray();
if (values != null && !values.trim().isEmpty()) {
String[] vals = values.split(",");
for (String val : vals) {
data.add(Double.parseDouble(val.trim()));
}
}
seriesItem.set("data", data);
series.add(seriesItem);
config.set("series", series);
return config.toString();
} catch (Exception e) {
log.error("生成折线图配置失败", e);
return "{\"error\": \"生成折线图配置失败: " + e.getMessage() + "\"}";
}
}
}
......@@ -17,53 +17,15 @@ spring:
username: postgres
password: e5d039e4ba5246068
driver-class-name: org.postgresql.Driver
cache:
type: caffeine
ai:
# mcp:
# client:
# sse:
# connections:
# charts:
# url: http://81.70.183.25:18000
model:
embedding: openai
vectorstore:
pgvector:
index-type: HNSW
distance-type: COSINE_DISTANCE
dimensions: 1024
max-document-batch-size: 10000 # Optional: Maximum number of documents per batch
schema-name: public
table-name: ask_vector_store
chat:
memory:
repository:
jdbc:
initialize-schema: never # 开发环境可以使用 always,方便调试
platform: postgresql
openai:
base-url: https://dashscope.aliyuncs.com/compatible-mode
api-key: sk-ae96ff281ff644c992843c64a711a950
chat:
options:
model: deepseek-r1-0528
embedding:
base-url: https://dashscope.aliyuncs.com/compatible-mode
options:
model: text-embedding-v4
deepseek:
base-url: https://dashscope.aliyuncs.com/compatible-mode/v1
api-key: sk-ae96ff281ff644c992843c64a711a950
chat:
enabled: true
options:
model: deepseek-r1-0528
ollama:
base-url: http://127.0.0.1:11434
embedding:
options:
model: nomic-embed-text
mybatis-plus:
mapper-locations: classpath*:/mapper/*Mapper.xml # mapper文件位置
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment