Commit 51f84d66 authored by 林洋洋's avatar 林洋洋

模型缓存调整

parent 60688825
...@@ -12,6 +12,7 @@ import io.swagger.v3.oas.annotations.Parameter; ...@@ -12,6 +12,7 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import org.apache.poi.util.StringUtil; import org.apache.poi.util.StringUtil;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.util.Arrays; import java.util.Arrays;
...@@ -36,37 +37,38 @@ public class AskModelController { ...@@ -36,37 +37,38 @@ public class AskModelController {
@Operation(summary = "创建模型") @Operation(summary = "创建模型")
@PostMapping("/model") @PostMapping("/model")
public R<Boolean> createModel(@Parameter(description = "模型信息") @RequestBody AskModel model){ public R<Boolean> createModel(@Parameter(description = "模型信息") @RequestBody AskModel model) {
return R.ok(askModelService.createModel(model)); return R.ok(askModelService.createModel(model));
} }
@Operation(summary = "获取模型列表") @Operation(summary = "获取模型列表")
@GetMapping("/model") @GetMapping("/model")
public R<List<AskModel>> models(@Parameter(description = "模型名称") String name, @Parameter(description = "模型类型") String modelType, @Parameter(description = "提供商") String provider){ public R<List<AskModel>> models(@Parameter(description = "模型名称") String name, @Parameter(description = "模型类型") String modelType, @Parameter(description = "提供商") String provider) {
return R.ok(askModelService.models(name,modelType,provider)); return R.ok(askModelService.models(name, modelType, provider));
} }
@Operation(summary = "根据ID获取模型") @Operation(summary = "根据ID获取模型")
@GetMapping("/model/{id}") @GetMapping("/model/{id}")
public R<AskModel> get(@Parameter(description = "模型ID") @PathVariable Long id){ public R<AskModel> get(@Parameter(description = "模型ID") @PathVariable Long id) {
return R.ok(askModelService.getById(id)); return R.ok(askModelService.getById(id));
} }
@Operation(summary = "删除模型") @Operation(summary = "删除模型")
@DeleteMapping("/model/{id}") @DeleteMapping("/model/{id}")
public R<Boolean> delete(@Parameter(description = "模型ID") @PathVariable Long id){ @CacheEvict(cacheNames = {"chatClient", "embeddingMode", "vectorStore"}, key = "#model.id")
public R<Boolean> delete(@Parameter(description = "模型ID") @PathVariable Long id) {
return R.ok(askModelService.removeById(id)); return R.ok(askModelService.removeById(id));
} }
@Operation(summary = "更新模型") @Operation(summary = "更新模型")
@PutMapping("/model/{id}") @PutMapping("/model/{id}")
public R<AskModel> update(@Parameter(description = "模型ID") @PathVariable Long id, @Parameter(description = "模型信息") @RequestBody AskModel model){ public R<AskModel> update(@Parameter(description = "模型ID") @PathVariable Long id, @Parameter(description = "模型信息") @RequestBody AskModel model) {
return R.ok(askModelService.updateModel(id,model)); return R.ok(askModelService.updateModel(id, model));
} }
@Operation(summary = "获取厂商列表") @Operation(summary = "获取厂商列表")
@GetMapping("/providers") @GetMapping("/providers")
public R<List<Map<String, String>>> getProviders(){ public R<List<Map<String, String>>> getProviders() {
List<Map<String, String>> providers = Arrays.stream(ModelProviderEnum.values()) List<Map<String, String>> providers = Arrays.stream(ModelProviderEnum.values())
.map(provider -> { .map(provider -> {
Map<String, String> providerMap = new HashMap<>(); Map<String, String> providerMap = new HashMap<>();
...@@ -80,7 +82,7 @@ public class AskModelController { ...@@ -80,7 +82,7 @@ public class AskModelController {
@Operation(summary = "获取模型类型列表") @Operation(summary = "获取模型类型列表")
@GetMapping("/types") @GetMapping("/types")
public R<List<Map<String, String>>> getModelTypes(){ public R<List<Map<String, String>>> getModelTypes() {
List<Map<String, String>> types = Arrays.stream(ModelTypeEnum.values()) List<Map<String, String>> types = Arrays.stream(ModelTypeEnum.values())
.map(type -> { .map(type -> {
Map<String, String> typeMap = new HashMap<>(); Map<String, String> typeMap = new HashMap<>();
......
...@@ -16,6 +16,7 @@ import org.springframework.ai.chat.model.ChatModel; ...@@ -16,6 +16,7 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore; import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable; import org.springframework.cache.annotation.Cacheable;
import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
...@@ -51,6 +52,9 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i ...@@ -51,6 +52,9 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override @Override
public ChatClient getChatClientById(Long modelId) { public ChatClient getChatClientById(Long modelId) {
AskModel askModel = this.getById(modelId); AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){
return null;
}
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider()); IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
if (Objects.isNull(baseModel)) { if (Objects.isNull(baseModel)) {
return null; return null;
...@@ -63,6 +67,9 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i ...@@ -63,6 +67,9 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override @Override
public EmbeddingModel getEmbeddingModelById(Long modelId) { public EmbeddingModel getEmbeddingModelById(Long modelId) {
AskModel askModel = this.getById(modelId); AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){
return null;
}
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider()); IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
if (Objects.isNull(baseModel)) { if (Objects.isNull(baseModel)) {
return null; return null;
...@@ -86,6 +93,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i ...@@ -86,6 +93,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
} }
@Override @Override
@CacheEvict(cacheNames = {"chatClient", "embeddingMode", "vectorStore"}, key = "#model.id")
public AskModel updateModel(Long id, AskModel model) { public AskModel updateModel(Long id, AskModel model) {
model.setId(id); model.setId(id);
this.updateById(model); this.updateById(model);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment