Commit 51f84d66 authored by 林洋洋's avatar 林洋洋

模型缓存调整

parent 60688825
...@@ -12,6 +12,7 @@ import io.swagger.v3.oas.annotations.Parameter; ...@@ -12,6 +12,7 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import org.apache.poi.util.StringUtil; import org.apache.poi.util.StringUtil;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.util.Arrays; import java.util.Arrays;
...@@ -32,66 +33,67 @@ import java.util.stream.Collectors; ...@@ -32,66 +33,67 @@ import java.util.stream.Collectors;
@AllArgsConstructor @AllArgsConstructor
public class AskModelController { public class AskModelController {
private final AskModelService askModelService; private final AskModelService askModelService;
@Operation(summary = "创建模型") @Operation(summary = "创建模型")
@PostMapping("/model") @PostMapping("/model")
public R<Boolean> createModel(@Parameter(description = "模型信息") @RequestBody AskModel model){ public R<Boolean> createModel(@Parameter(description = "模型信息") @RequestBody AskModel model) {
return R.ok(askModelService.createModel(model)); return R.ok(askModelService.createModel(model));
} }
@Operation(summary = "获取模型列表") @Operation(summary = "获取模型列表")
@GetMapping("/model") @GetMapping("/model")
public R<List<AskModel>> models(@Parameter(description = "模型名称") String name, @Parameter(description = "模型类型") String modelType, @Parameter(description = "提供商") String provider){ public R<List<AskModel>> models(@Parameter(description = "模型名称") String name, @Parameter(description = "模型类型") String modelType, @Parameter(description = "提供商") String provider) {
return R.ok(askModelService.models(name,modelType,provider)); return R.ok(askModelService.models(name, modelType, provider));
} }
@Operation(summary = "根据ID获取模型") @Operation(summary = "根据ID获取模型")
@GetMapping("/model/{id}") @GetMapping("/model/{id}")
public R<AskModel> get(@Parameter(description = "模型ID") @PathVariable Long id){ public R<AskModel> get(@Parameter(description = "模型ID") @PathVariable Long id) {
return R.ok(askModelService.getById(id)); return R.ok(askModelService.getById(id));
} }
@Operation(summary = "删除模型") @Operation(summary = "删除模型")
@DeleteMapping("/model/{id}") @DeleteMapping("/model/{id}")
public R<Boolean> delete(@Parameter(description = "模型ID") @PathVariable Long id){ @CacheEvict(cacheNames = {"chatClient", "embeddingMode", "vectorStore"}, key = "#model.id")
return R.ok(askModelService.removeById(id)); public R<Boolean> delete(@Parameter(description = "模型ID") @PathVariable Long id) {
} return R.ok(askModelService.removeById(id));
}
@Operation(summary = "更新模型") @Operation(summary = "更新模型")
@PutMapping("/model/{id}") @PutMapping("/model/{id}")
public R<AskModel> update(@Parameter(description = "模型ID") @PathVariable Long id, @Parameter(description = "模型信息") @RequestBody AskModel model){ public R<AskModel> update(@Parameter(description = "模型ID") @PathVariable Long id, @Parameter(description = "模型信息") @RequestBody AskModel model) {
return R.ok(askModelService.updateModel(id,model)); return R.ok(askModelService.updateModel(id, model));
} }
@Operation(summary = "获取厂商列表") @Operation(summary = "获取厂商列表")
@GetMapping("/providers") @GetMapping("/providers")
public R<List<Map<String, String>>> getProviders(){ public R<List<Map<String, String>>> getProviders() {
List<Map<String, String>> providers = Arrays.stream(ModelProviderEnum.values()) List<Map<String, String>> providers = Arrays.stream(ModelProviderEnum.values())
.map(provider -> { .map(provider -> {
Map<String, String> providerMap = new HashMap<>(); Map<String, String> providerMap = new HashMap<>();
providerMap.put("provider", provider.getProvider()); providerMap.put("provider", provider.getProvider());
providerMap.put("name", provider.getName()); providerMap.put("name", provider.getName());
return providerMap; return providerMap;
}) })
.collect(Collectors.toList()); .collect(Collectors.toList());
return R.ok(providers); return R.ok(providers);
} }
@Operation(summary = "获取模型类型列表") @Operation(summary = "获取模型类型列表")
@GetMapping("/types") @GetMapping("/types")
public R<List<Map<String, String>>> getModelTypes(){ public R<List<Map<String, String>>> getModelTypes() {
List<Map<String, String>> types = Arrays.stream(ModelTypeEnum.values()) List<Map<String, String>> types = Arrays.stream(ModelTypeEnum.values())
.map(type -> { .map(type -> {
Map<String, String> typeMap = new HashMap<>(); Map<String, String> typeMap = new HashMap<>();
typeMap.put("code", type.getCode()); typeMap.put("code", type.getCode());
typeMap.put("name", type.getName()); typeMap.put("name", type.getName());
typeMap.put("description", type.getDescription()); typeMap.put("description", type.getDescription());
return typeMap; return typeMap;
}) })
.collect(Collectors.toList()); .collect(Collectors.toList());
return R.ok(types); return R.ok(types);
} }
} }
...@@ -16,6 +16,7 @@ import org.springframework.ai.chat.model.ChatModel; ...@@ -16,6 +16,7 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore; import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable; import org.springframework.cache.annotation.Cacheable;
import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
...@@ -51,6 +52,9 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i ...@@ -51,6 +52,9 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override @Override
public ChatClient getChatClientById(Long modelId) { public ChatClient getChatClientById(Long modelId) {
AskModel askModel = this.getById(modelId); AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){
return null;
}
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider()); IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
if (Objects.isNull(baseModel)) { if (Objects.isNull(baseModel)) {
return null; return null;
...@@ -63,6 +67,9 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i ...@@ -63,6 +67,9 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
@Override @Override
public EmbeddingModel getEmbeddingModelById(Long modelId) { public EmbeddingModel getEmbeddingModelById(Long modelId) {
AskModel askModel = this.getById(modelId); AskModel askModel = this.getById(modelId);
if(askModel.getStatus()==1){
return null;
}
IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider()); IBaseModel baseModel = ModelProviderEnum.get(askModel.getProvider());
if (Objects.isNull(baseModel)) { if (Objects.isNull(baseModel)) {
return null; return null;
...@@ -86,6 +93,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i ...@@ -86,6 +93,7 @@ public class AskModelServiceImpl extends ServiceImpl<AskModelMapper, AskModel> i
} }
@Override @Override
@CacheEvict(cacheNames = {"chatClient", "embeddingMode", "vectorStore"}, key = "#model.id")
public AskModel updateModel(Long id, AskModel model) { public AskModel updateModel(Long id, AskModel model) {
model.setId(id); model.setId(id);
this.updateById(model); this.updateById(model);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment