Commit d8561934 authored by luwei's avatar luwei

修改

parent d7a073ab
...@@ -91,6 +91,26 @@ def download_template(): ...@@ -91,6 +91,26 @@ def download_template():
) )
class FileUpdateRequest(BaseModel):
filename: str | None = Field(default=None, min_length=1, max_length=255)
remark: str | None = Field(default=None, max_length=500)
category_id: str | None = Field(default=None)
@router.put('/files/{file_id}')
def update_file(file_id: str, request: FileUpdateRequest):
try:
result = service.update_file(
file_id=file_id,
filename=request.filename,
remark=request.remark,
category_id=request.category_id,
)
return success_response(data=result, message='文件更新成功')
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
@router.delete('/files/{file_id}') @router.delete('/files/{file_id}')
def delete_file(file_id: str): def delete_file(file_id: str):
try: try:
...@@ -115,9 +135,9 @@ def get_file_quality(file_id: str): ...@@ -115,9 +135,9 @@ def get_file_quality(file_id: str):
@router.get('/files/{file_id}/records') @router.get('/files/{file_id}/records')
def get_file_records(file_id: str, limit: int = Query(default=500, ge=1, le=5000)): def get_file_records(file_id: str):
try: try:
result = service.get_file_records(file_id=file_id, limit=limit) result = service.get_file_records(file_id=file_id, limit=None)
return success_response(data=result) return success_response(data=result)
except ValueError as error: except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error raise HTTPException(status_code=400, detail=str(error)) from error
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from app.services.eval_service import EvalService from app.services.eval_service import EvalService
...@@ -14,8 +14,11 @@ class EvalRequest(BaseModel): ...@@ -14,8 +14,11 @@ class EvalRequest(BaseModel):
@router.get('/packages') @router.get('/packages')
def list_packages(): def list_packages(
return success_response(data=service.list_packages()) category_id: str = Query(default=''),
name: str = Query(default=''),
):
return success_response(data=service.list_packages(category_id=category_id, name=name.strip()))
@router.get('/models') @router.get('/models')
......
...@@ -28,7 +28,8 @@ class MPCParamsSchema(BaseModel): ...@@ -28,7 +28,8 @@ class MPCParamsSchema(BaseModel):
class CreateExperimentRequest(BaseModel): class CreateExperimentRequest(BaseModel):
name: str = Field(min_length=1, max_length=255) name: str = Field(min_length=1, max_length=255)
model_id: int model_id: int
package_id: int input_csv_path: str = Field(min_length=1, description='读取传感器数据的CSV文件路径')
output_csv_path: str = Field(min_length=1, description='输出曲线数据的CSV文件路径')
target_temp: float = Field(description='目标温度(°C)') target_temp: float = Field(description='目标温度(°C)')
sampling_interval: float = Field(default=1.0, gt=0, le=3600, description='采样周期(秒)') sampling_interval: float = Field(default=1.0, gt=0, le=3600, description='采样周期(秒)')
mpc_params: MPCParamsSchema = Field(default_factory=MPCParamsSchema) mpc_params: MPCParamsSchema = Field(default_factory=MPCParamsSchema)
...@@ -59,7 +60,8 @@ def create_experiment(req: CreateExperimentRequest): ...@@ -59,7 +60,8 @@ def create_experiment(req: CreateExperimentRequest):
exp = service.create_experiment( exp = service.create_experiment(
name=req.name.strip(), name=req.name.strip(),
model_id=req.model_id, model_id=req.model_id,
package_id=req.package_id, input_csv_path=req.input_csv_path.strip(),
output_csv_path=req.output_csv_path.strip(),
target_temp=req.target_temp, target_temp=req.target_temp,
sampling_interval=req.sampling_interval, sampling_interval=req.sampling_interval,
mpc_params=req.mpc_params.model_dump(), mpc_params=req.mpc_params.model_dump(),
...@@ -113,25 +115,5 @@ def get_data_points(exp_id: int, from_step: int = Query(default=0, ge=0)): ...@@ -113,25 +115,5 @@ def get_data_points(exp_id: int, from_step: int = Query(default=0, ge=0)):
return success_response(data=service.get_data_points(exp_id, from_step)) return success_response(data=service.get_data_points(exp_id, from_step))
@router.get('/experiments/{exp_id}/report')
def get_report(exp_id: int):
try:
return success_response(data=service.get_report(exp_id))
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
@router.post('/experiments/{exp_id}/export')
def export_to_history(exp_id: int):
try:
result = service.export_to_history(exp_id)
return success_response(data=result, message='已导出到历史数据')
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
# ── 历史数据 ──────────────────────────────────────────────────────────────────
@router.get('/history')
def list_history():
return success_response(data=service.list_history_experiments())
...@@ -20,6 +20,7 @@ class CategoryUpdateRequest(BaseModel): ...@@ -20,6 +20,7 @@ class CategoryUpdateRequest(BaseModel):
class CleanRules(BaseModel): class CleanRules(BaseModel):
enabled: bool = False enabled: bool = False
newton_interp: bool = False
current_min: float | None = None current_min: float | None = None
current_max: float | None = None current_max: float | None = None
voltage_min: float | None = None voltage_min: float | None = None
...@@ -28,6 +29,11 @@ class CleanRules(BaseModel): ...@@ -28,6 +29,11 @@ class CleanRules(BaseModel):
temperature_max: float | None = None temperature_max: float | None = None
class SmoothConfig(BaseModel):
enabled: bool = False
window: int = Field(default=5, ge=2, le=500)
@router.get('/categories') @router.get('/categories')
def get_categories(): def get_categories():
return success_response(data=service.get_category_tree()) return success_response(data=service.get_category_tree())
...@@ -63,8 +69,22 @@ def delete_category(category_id: str): ...@@ -63,8 +69,22 @@ def delete_category(category_id: str):
# Must be declared before /{package_id} routes to avoid path conflict # Must be declared before /{package_id} routes to avoid path conflict
@router.get('/data-files') @router.get('/data-files')
def list_all_data_files(): def list_all_data_files(
return success_response(data=service.list_all_data_files()) category_id: str = Query(default=''),
filename: str = Query(default=''),
remark: str = Query(default=''),
):
return success_response(data=service.list_all_data_files(
category_id=category_id,
filename=filename.strip(),
remark=remark.strip(),
))
class PackageUpdateRequest(BaseModel):
name: str = Field(min_length=1, max_length=255)
category_id: str | int | None = Field(default=None)
remark: str | None = Field(default=None)
class PackageCreateRequest(BaseModel): class PackageCreateRequest(BaseModel):
...@@ -73,6 +93,8 @@ class PackageCreateRequest(BaseModel): ...@@ -73,6 +93,8 @@ class PackageCreateRequest(BaseModel):
remark: str | None = Field(default=None) remark: str | None = Field(default=None)
file_ids: list[int] = Field(default_factory=list) file_ids: list[int] = Field(default_factory=list)
clean_rules: CleanRules | None = Field(default=None) clean_rules: CleanRules | None = Field(default=None)
smooth: SmoothConfig | None = Field(default=None)
auto_split: bool = Field(default=False)
row_start: int | None = Field(default=None, ge=1) row_start: int | None = Field(default=None, ge=1)
row_end: int | None = Field(default=None, ge=1) row_end: int | None = Field(default=None, ge=1)
...@@ -80,20 +102,23 @@ class PackageCreateRequest(BaseModel): ...@@ -80,20 +102,23 @@ class PackageCreateRequest(BaseModel):
class PreviewRequest(BaseModel): class PreviewRequest(BaseModel):
file_ids: list[int] = Field(default_factory=list) file_ids: list[int] = Field(default_factory=list)
clean_rules: CleanRules | None = Field(default=None) clean_rules: CleanRules | None = Field(default=None)
smooth: SmoothConfig | None = Field(default=None)
row_start: int | None = Field(default=None, ge=1) row_start: int | None = Field(default=None, ge=1)
row_end: int | None = Field(default=None, ge=1) row_end: int | None = Field(default=None, ge=1)
@router.post('/preview') @router.post('/preview')
def preview_package(request: PreviewRequest, limit: int = Query(default=300, ge=1, le=2000)): def preview_package(request: PreviewRequest):
try: try:
if request.row_start and request.row_end and request.row_start > request.row_end: if request.row_start and request.row_end and request.row_start > request.row_end:
raise ValueError('起始行不能大于结束行') raise ValueError('起始行不能大于结束行')
clean_rules = request.clean_rules.model_dump() if request.clean_rules else None clean_rules = request.clean_rules.model_dump() if request.clean_rules else None
smooth = request.smooth.model_dump() if request.smooth else None
result = service.preview_records( result = service.preview_records(
file_ids=request.file_ids, file_ids=request.file_ids,
limit=limit, limit=None,
clean_rules=clean_rules, clean_rules=clean_rules,
smooth=smooth,
row_start=request.row_start, row_start=request.row_start,
row_end=request.row_end, row_end=request.row_end,
) )
...@@ -117,12 +142,27 @@ def create_package(request: PackageCreateRequest): ...@@ -117,12 +142,27 @@ def create_package(request: PackageCreateRequest):
raise ValueError('起始行不能大于结束行') raise ValueError('起始行不能大于结束行')
category_id = None if request.category_id in (None, '', 'all') else str(request.category_id) category_id = None if request.category_id in (None, '', 'all') else str(request.category_id)
clean_rules = request.clean_rules.model_dump() if request.clean_rules else None clean_rules = request.clean_rules.model_dump() if request.clean_rules else None
smooth = request.smooth.model_dump() if request.smooth else None
base_name = request.name.strip()
if request.auto_split:
pkgs = service.create_package_split(
name=base_name,
category_id=category_id,
remark=request.remark,
file_ids=request.file_ids,
clean_rules=clean_rules,
smooth=smooth,
row_start=request.row_start,
row_end=request.row_end,
)
return success_response(data=pkgs, message='数据包创建成功,已自动划分训练集/验证集/测试集')
pkg = service.create_package( pkg = service.create_package(
name=request.name.strip(), name=base_name,
category_id=category_id, category_id=category_id,
remark=request.remark, remark=request.remark,
file_ids=request.file_ids, file_ids=request.file_ids,
clean_rules=clean_rules, clean_rules=clean_rules,
smooth=smooth,
row_start=request.row_start, row_start=request.row_start,
row_end=request.row_end, row_end=request.row_end,
) )
...@@ -132,14 +172,29 @@ def create_package(request: PackageCreateRequest): ...@@ -132,14 +172,29 @@ def create_package(request: PackageCreateRequest):
@router.get('/{package_id}/records') @router.get('/{package_id}/records')
def get_package_records(package_id: str, limit: int = Query(default=500, ge=1, le=5000)): def get_package_records(package_id: str):
try: try:
result = service.get_package_records(package_id=package_id, limit=limit) result = service.get_package_records(package_id=package_id, limit=None)
return success_response(data=result) return success_response(data=result)
except ValueError as error: except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error raise HTTPException(status_code=400, detail=str(error)) from error
@router.put('/{package_id}')
def update_package(package_id: str, request: PackageUpdateRequest):
try:
category_id = None if request.category_id in (None, '', 'all') else str(request.category_id)
result = service.update_package(
package_id=package_id,
name=request.name.strip(),
category_id=category_id,
remark=request.remark,
)
return success_response(data=result, message='数据包更新成功')
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
@router.delete('/{package_id}') @router.delete('/{package_id}')
def delete_package(package_id: str): def delete_package(package_id: str):
try: try:
......
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.services.train_service import TrainService from app.services.train_service import TrainService
...@@ -21,15 +21,29 @@ class LSTMParams(BaseModel): ...@@ -21,15 +21,29 @@ class LSTMParams(BaseModel):
class CreateTaskRequest(BaseModel): class CreateTaskRequest(BaseModel):
model_name: str = Field(min_length=1, max_length=255) model_name: str = Field(min_length=1, max_length=255)
package_id: int train_package_id: int
val_package_id: int
params: LSTMParams = Field(default_factory=LSTMParams) params: LSTMParams = Field(default_factory=LSTMParams)
class SaveModelRequest(BaseModel):
model_name: str | None = Field(default=None, max_length=255)
description: str | None = None
class UpdateModelRequest(BaseModel):
model_name: str | None = Field(default=None, max_length=255)
description: str | None = None
# ── packages ────────────────────────────────────────────────────────────────── # ── packages ──────────────────────────────────────────────────────────────────
@router.get('/packages') @router.get('/packages')
def list_packages(): def list_packages(
return success_response(data=service.list_packages()) category_id: str = Query(default=''),
name: str = Query(default=''),
):
return success_response(data=service.list_packages(category_id=category_id, name=name.strip()))
# ── tasks ───────────────────────────────────────────────────────────────────── # ── tasks ─────────────────────────────────────────────────────────────────────
...@@ -52,7 +66,8 @@ def create_task(request: CreateTaskRequest): ...@@ -52,7 +66,8 @@ def create_task(request: CreateTaskRequest):
try: try:
task = service.create_task( task = service.create_task(
model_name=request.model_name.strip(), model_name=request.model_name.strip(),
package_id=request.package_id, train_package_id=request.train_package_id,
val_package_id=request.val_package_id,
params=request.params.model_dump(), params=request.params.model_dump(),
) )
return success_response(data=task, message='训练任务已启动') return success_response(data=task, message='训练任务已启动')
...@@ -88,9 +103,9 @@ def delete_task(task_id: int): ...@@ -88,9 +103,9 @@ def delete_task(task_id: int):
@router.post('/tasks/{task_id}/save') @router.post('/tasks/{task_id}/save')
def save_model(task_id: int): def save_model(task_id: int, request: SaveModelRequest):
try: try:
model = service.save_model(task_id) model = service.save_model(task_id, model_name=request.model_name, description=request.description)
return success_response(data=model, message='模型已保存') return success_response(data=model, message='模型已保存')
except ValueError as error: except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error raise HTTPException(status_code=400, detail=str(error)) from error
...@@ -110,3 +125,15 @@ def delete_model(model_id: int): ...@@ -110,3 +125,15 @@ def delete_model(model_id: int):
return success_response(data=True, message='模型已删除') return success_response(data=True, message='模型已删除')
except ValueError as error: except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error raise HTTPException(status_code=400, detail=str(error)) from error
@router.patch('/models/{model_id}')
def update_model(model_id: int, request: UpdateModelRequest):
try:
model = service.update_saved_model(
model_id,
model_name=request.model_name,
description=request.description,
)
return success_response(data=model, message='已更新')
except ValueError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
\ No newline at end of file
...@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException ...@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import text
from app.api.data_management import router as data_management_router from app.api.data_management import router as data_management_router
from app.api.eval_management import router as eval_management_router from app.api.eval_management import router as eval_management_router
...@@ -16,10 +17,29 @@ from app.models.train_management import SavedModel, TrainTask # noqa: F401 ...@@ -16,10 +17,29 @@ from app.models.train_management import SavedModel, TrainTask # noqa: F401
from app.utils.response import error_response, success_response from app.utils.response import error_response, success_response
def _run_migrations() -> None:
"""Add new columns to existing tables without dropping data."""
migrations = [
"ALTER TABLE train_tasks ADD COLUMN val_package_id BIGINT NULL COMMENT '\u9a8c\u8bc1\u96c6\u6570\u636e\u5305ID' AFTER package_name",
"ALTER TABLE train_tasks ADD COLUMN val_package_name VARCHAR(255) NULL COMMENT '\u9a8c\u8bc1\u96c6\u6570\u636e\u5305\u540d\u79f0' AFTER val_package_id",
"ALTER TABLE train_tasks ADD COLUMN epoch_logs JSON NULL COMMENT '\u6bcf\u8f6e\u8bad\u7ec3\u65e5\u5fd7' AFTER val_package_name",
]
with engine.connect() as conn:
for stmt in migrations:
try:
conn.execute(text(stmt))
conn.commit()
except Exception:
pass # Column already exists
def create_app() -> FastAPI: def create_app() -> FastAPI:
# Auto-create any missing tables (safe: uses CREATE TABLE IF NOT EXISTS internally) # Auto-create any missing tables (safe: uses CREATE TABLE IF NOT EXISTS internally)
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
# Safe column migrations for existing tables
_run_migrations()
app = FastAPI( app = FastAPI(
title='Thermal Control System API', title='Thermal Control System API',
version='0.1.0', version='0.1.0',
......
...@@ -96,7 +96,22 @@ def predict_lstm( ...@@ -96,7 +96,22 @@ def predict_lstm(
# ── metrics ─────────────────────────────────────────────────────────────── # ── metrics ───────────────────────────────────────────────────────────────
errors = preds_real - actuals_real errors = preds_real - actuals_real
mae = float(np.mean(np.abs(errors))) mae = float(np.mean(np.abs(errors)))
rmse = float(math.sqrt(float(np.mean(errors ** 2)))) mse = float(np.mean(errors ** 2))
rmse = float(math.sqrt(mse))
# MAPE — skip points where actual == 0 to avoid division by zero
nonzero_mask = actuals_real != 0
if nonzero_mask.any():
mape: float | None = float(
np.mean(np.abs(errors[nonzero_mask] / actuals_real[nonzero_mask])) * 100
)
else:
mape = None
# R² (coefficient of determination)
ss_res = float(np.sum(errors ** 2))
ss_tot = float(np.sum((actuals_real - float(np.mean(actuals_real))) ** 2))
r2: float | None = (1.0 - ss_res / ss_tot) if ss_tot != 0 else None
# ── build result points ─────────────────────────────────────────────────── # ── build result points ───────────────────────────────────────────────────
times = [str(r.get('time', '')) for r in records] times = [str(r.get('time', '')) for r in records]
...@@ -115,6 +130,8 @@ def predict_lstm( ...@@ -115,6 +130,8 @@ def predict_lstm(
'time': times[seq_len + i] if (seq_len + i) < len(times) else str(seq_len + i), 'time': times[seq_len + i] if (seq_len + i) < len(times) else str(seq_len + i),
'actual': round(float(actuals_real[i]), 4), 'actual': round(float(actuals_real[i]), 4),
'predicted': round(float(preds_real[i]), 4), 'predicted': round(float(preds_real[i]), 4),
'current': round(float(raw[seq_len + i, 0]), 4),
'voltage': round(float(raw[seq_len + i, 1]), 4),
} }
for i in sample_idx for i in sample_idx
] ]
...@@ -122,6 +139,9 @@ def predict_lstm( ...@@ -122,6 +139,9 @@ def predict_lstm(
return { return {
'total_count': total, 'total_count': total,
'mae': round(mae, 6), 'mae': round(mae, 6),
'mse': round(mse, 6),
'rmse': round(rmse, 6), 'rmse': round(rmse, 6),
'mape': round(mape, 4) if mape is not None else None,
'r2': round(r2, 6) if r2 is not None else None,
'chart_data': chart_data, 'chart_data': chart_data,
} }
...@@ -44,11 +44,6 @@ FEATURE_COLS = ['current', 'voltage', 'set_temperature', 'actual_temperature'] ...@@ -44,11 +44,6 @@ FEATURE_COLS = ['current', 'voltage', 'set_temperature', 'actual_temperature']
TARGET_COL = 'actual_temperature' TARGET_COL = 'actual_temperature'
TARGET_IDX = FEATURE_COLS.index(TARGET_COL) TARGET_IDX = FEATURE_COLS.index(TARGET_COL)
# Fixed dataset split ratios (train / val / test)
_TRAIN_RATIO = 0.70
_VAL_RATIO = 0.15
# test = 1 - _TRAIN_RATIO - _VAL_RATIO (≈ 0.15)
def _check_torch() -> None: def _check_torch() -> None:
if not _TORCH_AVAILABLE: if not _TORCH_AVAILABLE:
...@@ -88,12 +83,137 @@ def _make_sequences(data: np.ndarray, seq_len: int) -> tuple[np.ndarray, np.ndar ...@@ -88,12 +83,137 @@ def _make_sequences(data: np.ndarray, seq_len: int) -> tuple[np.ndarray, np.ndar
# ── public training entry point ─────────────────────────────────────────────── # ── public training entry point ───────────────────────────────────────────────
def train_lstm( def train_lstm(
records: list[dict], train_records: list[dict],
val_records: list[dict],
params: dict, params: dict,
save_path: Path, save_path: Path,
on_progress: Callable[[int, float, float | None], None], on_progress: Callable[[int, int, float, float | None], None],
cancel_event: threading.Event, cancel_event: threading.Event,
) -> dict[str, float | None]: ) -> dict[str, float | None]:
"""
Train an LSTM model on *train_records* validated against *val_records*.
Args:
train_records: list of dicts for the training set.
val_records: list of dicts for the validation set.
params: hyper-parameter dict (seq_len, hidden_size, num_layers,
epochs, batch_size, learning_rate).
save_path: destination .pt file.
on_progress: callback(pct, epoch, train_loss, val_loss) called after each epoch.
cancel_event: when set, training stops with InterruptedError.
Returns:
{'train_loss': float, 'val_loss': float|None}
"""
_check_torch()
seq_len = max(1, int(params.get('seq_len', 20)))
hidden_size = max(1, int(params.get('hidden_size', 64)))
num_layers = max(1, int(params.get('num_layers', 2)))
epochs = max(1, int(params.get('epochs', 50)))
batch_size = max(1, int(params.get('batch_size', 32)))
lr = float(params.get('learning_rate', 0.001))
# ── data preparation ────────────────────────────────────────────────────
train_data = _extract_features(train_records)
val_data = _extract_features(val_records)
min_required = seq_len + 10
if len(train_data) < min_required:
raise ValueError(
f'\u8bad\u7ec3\u96c6\u6709\u6548\u6570\u636e\u91cf\u4e0d\u8db3\uff1a\u9700\u81f3\u5c11 {min_required} \u6761\uff0c\u5f53\u524d\u4ec5 {len(train_data)} \u6761\u3002'
'\u8bf7\u68c0\u67e5\u6570\u636e\u5305\u5185\u5bb9\u6216\u51cf\u5c0f\u5e8f\u5217\u957f\u5ea6\u3002'
)
# min-max normalisation fitted on train set only
data_min = train_data.min(axis=0)
data_max = train_data.max(axis=0)
data_range = data_max - data_min
data_range[data_range == 0] = 1.0
train_norm = (train_data - data_min) / data_range
X_train, y_train = _make_sequences(train_norm, seq_len)
has_val = len(val_data) >= seq_len + 1
X_val_t = y_val_t = None
if has_val:
val_norm = (val_data - data_min) / data_range
X_val, y_val = _make_sequences(val_norm, seq_len)
if len(X_val) == 0:
has_val = False
else:
X_val_t = torch.tensor(X_val)
y_val_t = torch.tensor(y_val)
device = torch.device('cpu')
X_train_t = torch.tensor(X_train).to(device)
y_train_t = torch.tensor(y_train).to(device)
train_loader = DataLoader(
TensorDataset(X_train_t, y_train_t),
batch_size=batch_size,
shuffle=True,
)
# ── model ────────────────────────────────────────────────────────────────
input_size = len(FEATURE_COLS)
model = _LSTMModel(input_size, hidden_size, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
train_loss = 0.0
val_loss: float | None = None
for epoch in range(epochs):
if cancel_event.is_set():
raise InterruptedError('\u8bad\u7ec3\u5df2\u53d6\u6d88')
# ── train step ───────────────────────────────────────────────────────
model.train()
epoch_loss = 0.0
for xb, yb in train_loader:
optimizer.zero_grad()
pred = model(xb)
loss = criterion(pred, yb)
loss.backward()
optimizer.step()
epoch_loss += loss.item() * len(xb)
train_loss = epoch_loss / len(X_train)
# ── val step ─────────────────────────────────────────────────────────
if has_val:
model.eval()
with torch.no_grad():
val_pred = model(X_val_t)
val_loss = criterion(val_pred, y_val_t).item()
pct = int((epoch + 1) / epochs * 100)
on_progress(pct, epoch + 1, train_loss, val_loss)
# ── persist ──────────────────────────────────────────────────────────────
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
'model_state': model.state_dict(),
'params': params,
'data_min': data_min.tolist(),
'data_max': data_max.tolist(),
'feature_cols': FEATURE_COLS,
'target_col': TARGET_COL,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'seq_len': seq_len,
},
save_path,
)
return {
'train_loss': round(float(train_loss), 6),
'val_loss': round(float(val_loss), 6) if val_loss is not None else None,
}
""" """
Train an LSTM model on *records* and persist it to *save_path*. Train an LSTM model on *records* and persist it to *save_path*.
......
...@@ -19,7 +19,10 @@ class EvalRecord(Base): ...@@ -19,7 +19,10 @@ class EvalRecord(Base):
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称') package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称')
total_count: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text('0'), comment='评估数据点总数') total_count: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text('0'), comment='评估数据点总数')
mae: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='平均绝对误差') mae: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='平均绝对误差')
mse: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='均方误差')
rmse: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='均方根误差') rmse: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='均方根误差')
mape: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='平均绝对百分比误差(%)')
r2: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='决定系数')
# Store up to ~2000 sampled points for chart rendering # Store up to ~2000 sampled points for chart rendering
chart_data: Mapped[list | None] = mapped_column(JSON, nullable=True, comment='图表数据(采样)') chart_data: Mapped[list | None] = mapped_column(JSON, nullable=True, comment='图表数据(采样)')
created_at: Mapped[str] = mapped_column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) created_at: Mapped[str] = mapped_column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'))
...@@ -15,8 +15,10 @@ class MonitorExperiment(Base): ...@@ -15,8 +15,10 @@ class MonitorExperiment(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False, comment='试验名称') name: Mapped[str] = mapped_column(String(255), nullable=False, comment='试验名称')
model_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='模型ID') model_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='模型ID')
model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称') model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称')
package_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='初始数据包ID') package_id: Mapped[int | None] = mapped_column(BIGINT, nullable=True, comment='初始数据包ID(旧版兼容)')
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称') package_name: Mapped[str | None] = mapped_column(String(255), nullable=True, comment='数据包名称(旧版兼容)')
input_csv_path: Mapped[str | None] = mapped_column(Text, nullable=True, comment='输入CSV路径(传感器数据源)')
output_csv_path: Mapped[str | None] = mapped_column(Text, nullable=True, comment='输出CSV路径(生成曲线写入)')
target_temp: Mapped[float] = mapped_column(FLOAT, nullable=False, comment='目标温度(°C)') target_temp: Mapped[float] = mapped_column(FLOAT, nullable=False, comment='目标温度(°C)')
mpc_params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='MPC参数') mpc_params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='MPC参数')
status: Mapped[str] = mapped_column( status: Mapped[str] = mapped_column(
......
...@@ -13,8 +13,11 @@ class TrainTask(Base): ...@@ -13,8 +13,11 @@ class TrainTask(Base):
id: Mapped[int] = mapped_column(BIGINT, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(BIGINT, primary_key=True, autoincrement=True)
model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称') model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称')
package_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='数据包ID') package_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='训练集数据包ID')
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称') package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='训练集数据包名称')
val_package_id: Mapped[int | None] = mapped_column(BIGINT, nullable=True, comment='验证集数据包ID')
val_package_name: Mapped[str | None] = mapped_column(String(255), nullable=True, comment='验证集数据包名称')
epoch_logs: Mapped[list | None] = mapped_column(JSON, nullable=True, comment='每轮训练日志')
params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数') params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数')
status: Mapped[str] = mapped_column( status: Mapped[str] = mapped_column(
Enum('pending', 'running', 'completed', 'failed', 'cancelled', name='train_status_enum'), Enum('pending', 'running', 'completed', 'failed', 'cancelled', name='train_status_enum'),
...@@ -50,6 +53,7 @@ class SavedModel(Base): ...@@ -50,6 +53,7 @@ class SavedModel(Base):
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称') package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称')
params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数') params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数')
file_path: Mapped[str] = mapped_column(String(500), nullable=False, comment='模型文件路径') file_path: Mapped[str] = mapped_column(String(500), nullable=False, comment='模型文件路径')
description: Mapped[str | None] = mapped_column(Text, nullable=True, comment='模型说明')
train_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True) train_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True)
val_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True) val_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True)
test_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True) test_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True)
......
...@@ -152,6 +152,7 @@ class DataManagementService: ...@@ -152,6 +152,7 @@ class DataManagementService:
'category_id': item.category_id, 'category_id': item.category_id,
'uploaded_at': item.uploaded_at.strftime('%Y-%m-%d %H:%M:%S') if item.uploaded_at else '', 'uploaded_at': item.uploaded_at.strftime('%Y-%m-%d %H:%M:%S') if item.uploaded_at else '',
'data_count': item.data_count, 'data_count': item.data_count,
'remark': item.remark or '',
} }
for item in rows for item in rows
] ]
...@@ -212,6 +213,46 @@ class DataManagementService: ...@@ -212,6 +213,46 @@ class DataManagementService:
target_path.unlink() target_path.unlink()
raise raise
def update_file(self, file_id: str, filename: str | None = None, remark: str | None = None, category_id: str | None = None) -> dict[str, Any]:
file_db_id = self._parse_int_id(file_id, '文件ID')
if category_id is not None:
cat_db_id = self._parse_int_id(category_id, '分类ID')
with db_session() as session:
matched = session.query(DataFile).filter(DataFile.id == file_db_id).first()
if not matched:
raise ValueError('文件不存在')
if filename is not None:
stripped = filename.strip()
if not stripped:
raise ValueError('文件名不能为空')
matched.filename = stripped
if remark is not None:
matched.remark = remark.strip() or None
if category_id is not None:
category = (
session.query(Category)
.filter(Category.id == cat_db_id, Category.type == 'data_file')
.first()
)
if not category:
raise ValueError('分类不存在')
matched.category_id = cat_db_id
session.commit()
session.refresh(matched)
return {
'id': matched.id,
'filename': matched.filename,
'remark': matched.remark or '',
'category_id': matched.category_id,
}
def delete_file(self, file_id: str) -> None: def delete_file(self, file_id: str) -> None:
file_db_id = self._parse_int_id(file_id, '文件ID') file_db_id = self._parse_int_id(file_id, '文件ID')
...@@ -227,7 +268,7 @@ class DataManagementService: ...@@ -227,7 +268,7 @@ class DataManagementService:
session.delete(matched) session.delete(matched)
session.commit() session.commit()
def get_file_records(self, file_id: str, limit: int = 500) -> dict[str, Any]: def get_file_records(self, file_id: str, limit: int | None = None) -> dict[str, Any]:
file_db_id = self._parse_int_id(file_id, '文件ID') file_db_id = self._parse_int_id(file_id, '文件ID')
with db_session() as session: with db_session() as session:
......
...@@ -5,7 +5,7 @@ from typing import Any ...@@ -5,7 +5,7 @@ from typing import Any
from app.database import db_session from app.database import db_session
from app.ml.lstm_predictor import predict_lstm from app.ml.lstm_predictor import predict_lstm
from app.models import DataFile, DataPackage, DataPackageFile from app.models import DataPackage
from app.models.eval_management import EvalRecord from app.models.eval_management import EvalRecord
from app.models.train_management import SavedModel from app.models.train_management import SavedModel
from app.services.data_management_service import DataManagementService from app.services.data_management_service import DataManagementService
...@@ -24,10 +24,19 @@ class EvalService: ...@@ -24,10 +24,19 @@ class EvalService:
# ── dropdown data ───────────────────────────────────────────────────────── # ── dropdown data ─────────────────────────────────────────────────────────
def list_packages(self) -> list[dict[str, Any]]: def list_packages(self, category_id: str = '', name: str = '') -> list[dict[str, Any]]:
with db_session() as session: with db_session() as session:
rows = session.query(DataPackage).order_by(DataPackage.created_at.desc()).all() query = session.query(DataPackage)
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count} for p in rows] if category_id not in ('', 'all', None):
try:
db_id = int(category_id)
query = query.filter(DataPackage.category_id == db_id)
except (ValueError, TypeError):
pass
if name:
query = query.filter(DataPackage.name.like(f'%{name}%'))
rows = query.order_by(DataPackage.created_at.desc()).all()
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count, 'category_id': p.category_id} for p in rows]
def list_saved_models(self) -> list[dict[str, Any]]: def list_saved_models(self) -> list[dict[str, Any]]:
with db_session() as session: with db_session() as session:
...@@ -67,7 +76,10 @@ class EvalService: ...@@ -67,7 +76,10 @@ class EvalService:
package_name=package_name, package_name=package_name,
total_count=result['total_count'], total_count=result['total_count'],
mae=result['mae'], mae=result['mae'],
mse=result['mse'],
rmse=result['rmse'], rmse=result['rmse'],
mape=result['mape'],
r2=result['r2'],
chart_data=result['chart_data'], chart_data=result['chart_data'],
) )
session.add(record) session.add(record)
...@@ -106,32 +118,13 @@ class EvalService: ...@@ -106,32 +118,13 @@ class EvalService:
def _load_package_records(self, package_id: int) -> list[dict[str, Any]]: def _load_package_records(self, package_id: int) -> list[dict[str, Any]]:
with db_session() as session: with db_session() as session:
pkg = session.query(DataPackage).filter(DataPackage.id == package_id).first() pkg = session.query(DataPackage).filter(DataPackage.id == package_id).first()
clean_rules = pkg.clean_rules if pkg else None if not pkg or not pkg.stored_name:
pf_rows = (
session.query(DataPackageFile)
.filter(DataPackageFile.package_id == package_id)
.order_by(DataPackageFile.sort_order.asc())
.all()
)
file_ids = [pf.file_id for pf in pf_rows]
if not file_ids:
return [] return []
files = session.query(DataFile).filter(DataFile.id.in_(file_ids)).all() stored_name = pkg.stored_name
file_map = {f.id: f for f in files}
all_records: list[dict[str, Any]] = []
for fid in file_ids:
if fid not in file_map:
continue
fmeta = file_map[fid]
path = self._dm._resolve_local_file_path(fmeta.file_path, fmeta.stored_name)
recs, _ = self._dm._read_records(path, limit=None)
all_records.extend(recs)
if clean_rules and clean_rules.get('enabled'):
all_records = self._apply_clean_rules(all_records, clean_rules)
return all_records pkg_path = self._base_dir / 'uploads' / 'packages' / stored_name
recs, _ = self._dm._read_records(pkg_path, limit=None)
return recs
@staticmethod @staticmethod
def _apply_clean_rules(records: list[dict[str, Any]], clean_rules: dict) -> list[dict[str, Any]]: def _apply_clean_rules(records: list[dict[str, Any]], clean_rules: dict) -> list[dict[str, Any]]:
...@@ -177,7 +170,10 @@ class EvalService: ...@@ -177,7 +170,10 @@ class EvalService:
'package_name': row.package_name, 'package_name': row.package_name,
'total_count': row.total_count, 'total_count': row.total_count,
'mae': row.mae, 'mae': row.mae,
'mse': row.mse,
'rmse': row.rmse, 'rmse': row.rmse,
'mape': row.mape,
'r2': row.r2,
'created_at': row.created_at.strftime('%Y-%m-%d %H:%M:%S') if row.created_at else '', 'created_at': row.created_at.strftime('%Y-%m-%d %H:%M:%S') if row.created_at else '',
} }
if include_chart: if include_chart:
......
This diff is collapsed.
...@@ -29,14 +29,19 @@ class TrainService: ...@@ -29,14 +29,19 @@ class TrainService:
# ── packages (for dropdown) ────────────────────────────────────────────── # ── packages (for dropdown) ──────────────────────────────────────────────
def list_packages(self) -> list[dict[str, Any]]: def list_packages(self, category_id: str = '', name: str = '') -> list[dict[str, Any]]:
with db_session() as session: with db_session() as session:
rows = ( query = session.query(DataPackage)
session.query(DataPackage) if category_id not in ('', 'all', None):
.order_by(DataPackage.created_at.desc()) try:
.all() db_id = int(category_id)
) query = query.filter(DataPackage.category_id == db_id)
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count} for p in rows] except (ValueError, TypeError):
pass
if name:
query = query.filter(DataPackage.name.like(f'%{name}%'))
rows = query.order_by(DataPackage.created_at.desc()).all()
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count, 'category_id': p.category_id} for p in rows]
# ── tasks ──────────────────────────────────────────────────────────────── # ── tasks ────────────────────────────────────────────────────────────────
...@@ -52,16 +57,21 @@ class TrainService: ...@@ -52,16 +57,21 @@ class TrainService:
raise ValueError('训练任务不存在') raise ValueError('训练任务不存在')
return self._task_to_dict(task) return self._task_to_dict(task)
def create_task(self, model_name: str, package_id: int, params: dict) -> dict[str, Any]: def create_task(self, model_name: str, train_package_id: int, val_package_id: int, params: dict) -> dict[str, Any]:
with db_session() as session: with db_session() as session:
pkg = session.query(DataPackage).filter(DataPackage.id == package_id).first() train_pkg = session.query(DataPackage).filter(DataPackage.id == train_package_id).first()
if not pkg: if not train_pkg:
raise ValueError('数据包不存在') raise ValueError('训练集数据包不存在')
val_pkg = session.query(DataPackage).filter(DataPackage.id == val_package_id).first()
if not val_pkg:
raise ValueError('验证集数据包不存在')
task = TrainTask( task = TrainTask(
model_name=model_name, model_name=model_name,
package_id=package_id, package_id=train_package_id,
package_name=pkg.name, package_name=train_pkg.name,
val_package_id=val_package_id,
val_package_name=val_pkg.name,
params=params, params=params,
status='pending', status='pending',
progress=0, progress=0,
...@@ -71,7 +81,7 @@ class TrainService: ...@@ -71,7 +81,7 @@ class TrainService:
session.refresh(task) session.refresh(task)
task_dict = self._task_to_dict(task) task_dict = self._task_to_dict(task)
self._launch_thread(task_dict['id'], package_id, params) self._launch_thread(task_dict['id'], train_package_id, val_package_id, params)
return task_dict return task_dict
def cancel_task(self, task_id: int) -> None: def cancel_task(self, task_id: int) -> None:
...@@ -100,6 +110,8 @@ class TrainService: ...@@ -100,6 +110,8 @@ class TrainService:
model_name=task.model_name, model_name=task.model_name,
package_id=task.package_id, package_id=task.package_id,
package_name=task.package_name, package_name=task.package_name,
val_package_id=task.val_package_id,
val_package_name=task.val_package_name,
params=task.params, params=task.params,
status='pending', status='pending',
progress=0, progress=0,
...@@ -109,7 +121,7 @@ class TrainService: ...@@ -109,7 +121,7 @@ class TrainService:
session.refresh(new_task) session.refresh(new_task)
task_dict = self._task_to_dict(new_task) task_dict = self._task_to_dict(new_task)
self._launch_thread(task_dict['id'], task_dict['package_id'], task_dict['params']) self._launch_thread(task_dict['id'], task_dict['package_id'], task_dict['val_package_id'], task_dict['params'])
return task_dict return task_dict
def delete_task(self, task_id: int) -> None: def delete_task(self, task_id: int) -> None:
...@@ -128,7 +140,7 @@ class TrainService: ...@@ -128,7 +140,7 @@ class TrainService:
session.delete(task) session.delete(task)
session.commit() session.commit()
def save_model(self, task_id: int) -> dict[str, Any]: def save_model(self, task_id: int, model_name: str | None = None, description: str | None = None) -> dict[str, Any]:
with db_session() as session: with db_session() as session:
task = session.query(TrainTask).filter(TrainTask.id == task_id).first() task = session.query(TrainTask).filter(TrainTask.id == task_id).first()
if not task: if not task:
...@@ -144,7 +156,8 @@ class TrainService: ...@@ -144,7 +156,8 @@ class TrainService:
saved = SavedModel( saved = SavedModel(
task_id=task_id, task_id=task_id,
model_name=task.model_name, model_name=(model_name.strip() if model_name and model_name.strip() else task.model_name),
description=description,
package_id=task.package_id, package_id=task.package_id,
package_name=task.package_name, package_name=task.package_name,
params=task.params, params=task.params,
...@@ -158,7 +171,18 @@ class TrainService: ...@@ -158,7 +171,18 @@ class TrainService:
session.commit() session.commit()
session.refresh(saved) session.refresh(saved)
return self._saved_to_dict(saved) return self._saved_to_dict(saved)
def update_saved_model(self, model_id: int, model_name: str | None = None, description: str | None = None) -> dict[str, Any]:
with db_session() as session:
model = session.query(SavedModel).filter(SavedModel.id == model_id).first()
if not model:
raise ValueError('模型不存在')
if model_name is not None and model_name.strip():
model.model_name = model_name.strip()
if description is not None:
model.description = description
session.commit()
session.refresh(model)
return self._saved_to_dict(model)
# ── saved models ───────────────────────────────────────────────────────── # ── saved models ─────────────────────────────────────────────────────────
def list_saved_models(self) -> list[dict[str, Any]]: def list_saved_models(self) -> list[dict[str, Any]]:
...@@ -193,14 +217,15 @@ class TrainService: ...@@ -193,14 +217,15 @@ class TrainService:
if p.is_absolute(): if p.is_absolute():
return p return p
return self._base_dir / p return self._base_dir / p
def _launch_thread(self, task_id: int, package_id: int, params: dict) -> None:
def _launch_thread(self, task_id: int, train_package_id: int, val_package_id: int | None, params: dict) -> None:
cancel_event = threading.Event() cancel_event = threading.Event()
with _registry_lock: with _registry_lock:
_cancel_events[task_id] = cancel_event _cancel_events[task_id] = cancel_event
thread = threading.Thread( thread = threading.Thread(
target=self._training_worker, target=self._training_worker,
args=(task_id, package_id, params, cancel_event), args=(task_id, train_package_id, val_package_id, params, cancel_event),
daemon=True, daemon=True,
name=f'train-task-{task_id}', name=f'train-task-{task_id}',
) )
...@@ -209,35 +234,44 @@ class TrainService: ...@@ -209,35 +234,44 @@ class TrainService:
def _training_worker( def _training_worker(
self, self,
task_id: int, task_id: int,
package_id: int, train_package_id: int,
val_package_id: int | None,
params: dict, params: dict,
cancel_event: threading.Event, cancel_event: threading.Event,
) -> None: ) -> None:
try: try:
self._update_task(task_id, status='running', progress=0) self._update_task(task_id, status='running', progress=0)
records = self._load_package_records(package_id) train_records = self._load_package_records(train_package_id)
if not records: if not train_records:
raise ValueError('数据包没有有效数据,请检查关联文件') raise ValueError('训练集数据包没有有效数据,请检查关联文件')
val_records: list[dict[str, Any]] = []
if val_package_id:
val_records = self._load_package_records(val_package_id)
save_path = self._models_dir / f'task_{task_id}.pt' save_path = self._models_dir / f'task_{task_id}.pt'
last_pct = [0] epoch_log_buffer: list[dict[str, Any]] = []
def on_progress(pct: int, train_loss: float, val_loss: float | None) -> None: def on_progress(pct: int, epoch: int, train_loss: float, val_loss: float | None) -> None:
if cancel_event.is_set(): if cancel_event.is_set():
return return
# Throttle: persist at most every 2 % to reduce DB writes epoch_log_buffer.append({
if pct - last_pct[0] >= 2 or pct == 100: 'epoch': epoch,
last_pct[0] = pct 'train_loss': round(float(train_loss), 6),
self._update_task( 'val_loss': round(float(val_loss), 6) if val_loss is not None else None,
task_id, })
progress=pct, self._update_task(
train_loss=round(float(train_loss), 6), task_id,
val_loss=round(float(val_loss), 6) if val_loss is not None else None, progress=pct,
) train_loss=round(float(train_loss), 6),
val_loss=round(float(val_loss), 6) if val_loss is not None else None,
epoch_logs=list(epoch_log_buffer),
)
result = train_lstm( result = train_lstm(
records=records, train_records=train_records,
val_records=val_records,
params=params, params=params,
save_path=save_path, save_path=save_path,
on_progress=on_progress, on_progress=on_progress,
...@@ -304,6 +338,9 @@ class TrainService: ...@@ -304,6 +338,9 @@ class TrainService:
'model_name': task.model_name, 'model_name': task.model_name,
'package_id': task.package_id, 'package_id': task.package_id,
'package_name': task.package_name, 'package_name': task.package_name,
'val_package_id': task.val_package_id,
'val_package_name': task.val_package_name,
'epoch_logs': task.epoch_logs or [],
'params': task.params, 'params': task.params,
'status': task.status, 'status': task.status,
'progress': task.progress, 'progress': task.progress,
...@@ -321,6 +358,7 @@ class TrainService: ...@@ -321,6 +358,7 @@ class TrainService:
'id': model.id, 'id': model.id,
'task_id': model.task_id, 'task_id': model.task_id,
'model_name': model.model_name, 'model_name': model.model_name,
'description': model.description or '',
'package_id': model.package_id, 'package_id': model.package_id,
'package_name': model.package_name, 'package_name': model.package_name,
'params': model.params, 'params': model.params,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
step,actual_temp,current_output,reference_temp,target_temp
0,20.534,0.0,23.4272,35.0
1,20.534,0.1796,23.4272,35.0
2,20.534,0.1796,23.4272,35.0
3,20.534,0.1949,23.4272,35.0
4,20.534,0.1949,23.4272,35.0
5,20.534,0.1949,23.4272,35.0
6,20.534,0.1949,23.4272,35.0
7,20.534,0.1949,23.4272,35.0
8,20.534,0.1354,23.4272,35.0
9,20.534,0.1354,23.4272,35.0
10,20.534,0.1354,23.4272,35.0
11,20.534,0.1354,23.4272,35.0
12,20.534,0.0,23.4272,35.0
13,20.534,0.1796,23.4272,35.0
14,20.534,0.1796,23.4272,35.0
15,20.534,0.1949,23.4272,35.0
16,20.534,0.1949,23.4272,35.0
17,20.534,0.1949,23.4272,35.0
18,20.534,0.1949,23.4272,35.0
19,20.534,0.1949,23.4272,35.0
20,20.534,0.1354,23.4272,35.0
21,20.534,0.1354,23.4272,35.0
22,20.534,0.1354,23.4272,35.0
23,20.534,0.1354,23.4272,35.0
24,20.534,0.0,23.4272,35.0
25,20.534,0.1796,23.4272,35.0
26,20.534,0.1796,23.4272,35.0
27,20.534,0.1949,23.4272,35.0
...@@ -11,10 +11,8 @@ const tabs = [ ...@@ -11,10 +11,8 @@ const tabs = [
{ label: '模型训练', path: '/model-training' }, { label: '模型训练', path: '/model-training' },
{ label: '模型库', path: '/model-list' }, { label: '模型库', path: '/model-list' },
{ label: '模型评估', path: '/model-evaluation' }, { label: '模型评估', path: '/model-evaluation' },
{ label: '实时监控', path: '/realtime-monitor' }, { label: '试验管理', path: '/realtime-monitor' },
{ label: '历史数据', path: '/history-data' }, { label: '实时监控', path: '/live-monitor' },
] ]
const activeTab = computed(() => { const activeTab = computed(() => {
......
...@@ -32,6 +32,10 @@ export function deleteDataFile(fileId) { ...@@ -32,6 +32,10 @@ export function deleteDataFile(fileId) {
return request.delete(`/data/files/${fileId}`) return request.delete(`/data/files/${fileId}`)
} }
export function updateDataFile(fileId, payload) {
return request.put(`/data/files/${fileId}`, payload)
}
export function getFileRecords(fileId, params) { export function getFileRecords(fileId, params) {
return request.get(`/data/files/${fileId}/records`, { params }) return request.get(`/data/files/${fileId}/records`, { params })
} }
......
import request from '@/utils/request' import request from '@/utils/request'
export function getEvalPackages() { export function getEvalPackages(params = {}) {
return request.get('/eval/packages') return request.get('/eval/packages', { params })
} }
export function getEvalModels() { export function getEvalModels() {
......
import request from '@/utils/request'
export function getHistoryList() {
return request.get('/monitor/history')
}
export function getHistoryDetail(expId) {
return request.get(`/monitor/experiments/${expId}/report`)
}
...@@ -20,8 +20,8 @@ export function deletePkgCategory(categoryId) { ...@@ -20,8 +20,8 @@ export function deletePkgCategory(categoryId) {
// ── data files (for selection) ─────────────────────────────────────────────── // ── data files (for selection) ───────────────────────────────────────────────
export function getAllDataFiles() { export function getAllDataFiles(params) {
return request.get('/packages/data-files') return request.get('/packages/data-files', { params })
} }
// ── packages ───────────────────────────────────────────────────────────────── // ── packages ─────────────────────────────────────────────────────────────────
...@@ -34,6 +34,10 @@ export function createPackage(payload) { ...@@ -34,6 +34,10 @@ export function createPackage(payload) {
return request.post('/packages', payload) return request.post('/packages', payload)
} }
export function updatePackage(packageId, payload) {
return request.put(`/packages/${packageId}`, payload)
}
export function deletePackage(packageId) { export function deletePackage(packageId) {
return request.delete(`/packages/${packageId}`) return request.delete(`/packages/${packageId}`)
} }
......
...@@ -40,10 +40,4 @@ export function getDataPoints(expId, fromStep = 0) { ...@@ -40,10 +40,4 @@ export function getDataPoints(expId, fromStep = 0) {
return request.get(`/monitor/experiments/${expId}/data`, { params: { from_step: fromStep } }) return request.get(`/monitor/experiments/${expId}/data`, { params: { from_step: fromStep } })
} }
export function getReport(expId) {
return request.get(`/monitor/experiments/${expId}/report`)
}
export function exportToHistory(expId) {
return request.post(`/monitor/experiments/${expId}/export`)
}
...@@ -28,8 +28,8 @@ export function deleteTrainTask(taskId) { ...@@ -28,8 +28,8 @@ export function deleteTrainTask(taskId) {
return request.delete(`/train/tasks/${taskId}`) return request.delete(`/train/tasks/${taskId}`)
} }
export function saveTrainModel(taskId) { export function saveTrainModel(taskId, payload = {}) {
return request.post(`/train/tasks/${taskId}/save`) return request.post(`/train/tasks/${taskId}/save`, payload)
} }
export function getSavedModels() { export function getSavedModels() {
...@@ -39,3 +39,7 @@ export function getSavedModels() { ...@@ -39,3 +39,7 @@ export function getSavedModels() {
export function deleteSavedModel(modelId) { export function deleteSavedModel(modelId) {
return request.delete(`/train/models/${modelId}`) return request.delete(`/train/models/${modelId}`)
} }
export function updateSavedModel(modelId, payload) {
return request.patch(`/train/models/${modelId}`, payload)
}
...@@ -23,9 +23,9 @@ const router = createRouter({ ...@@ -23,9 +23,9 @@ const router = createRouter({
component: () => import('@/views/RealtimeMonitor/components/ExperimentDetail.vue'), component: () => import('@/views/RealtimeMonitor/components/ExperimentDetail.vue'),
}, },
{ {
path: '/history-data', path: '/live-monitor',
name: 'history-data', name: 'live-monitor',
component: () => import('@/views/HistoryData/index.vue'), component: () => import('@/views/LiveMonitor/index.vue'),
}, },
{ {
path: '/model-training', path: '/model-training',
......
...@@ -7,6 +7,10 @@ const props = defineProps({ ...@@ -7,6 +7,10 @@ const props = defineProps({
type: Array, type: Array,
default: () => [], default: () => [],
}, },
totalCount: {
type: Number,
default: null,
},
}) })
const chartRef = ref(null) const chartRef = ref(null)
...@@ -225,7 +229,7 @@ onBeforeUnmount(() => { ...@@ -225,7 +229,7 @@ onBeforeUnmount(() => {
<div class="stats-row"> <div class="stats-row">
<div class="stat-card"> <div class="stat-card">
<div class="stat-label">总条数</div> <div class="stat-label">总条数</div>
<div class="stat-value">{{ stats.count }}</div> <div class="stat-value">{{ props.totalCount ?? stats.count }}</div>
</div> </div>
<div class="stat-card"> <div class="stat-card">
<div class="stat-label">最高温度</div> <div class="stat-label">最高温度</div>
......
<script setup> <script setup>
import { Delete, Document, Edit, Folder, FolderOpened, Upload } from '@element-plus/icons-vue' import { DataAnalysis, Delete, Document, Edit, Folder, FolderOpened, Upload } from '@element-plus/icons-vue'
import { computed, onBeforeUnmount, onMounted, reactive, ref } from 'vue' import { computed, onBeforeUnmount, onMounted, reactive, ref } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElAutoResizer, ElMessage, ElMessageBox, ElTableV2 } from 'element-plus'
import { import {
createCategory, createCategory,
deleteCategory, deleteCategory,
...@@ -12,6 +12,7 @@ import { ...@@ -12,6 +12,7 @@ import {
getFileQuality, getFileQuality,
getFileRecords, getFileRecords,
updateCategory, updateCategory,
updateDataFile,
uploadDataFile, uploadDataFile,
} from '@/api/dataManagement' } from '@/api/dataManagement'
import DataCurve from './components/DataCurve.vue' import DataCurve from './components/DataCurve.vue'
...@@ -31,6 +32,21 @@ const recordList = ref([]) ...@@ -31,6 +32,21 @@ const recordList = ref([])
const contentLoading = ref(false) const contentLoading = ref(false)
const contentMode = ref('table') const contentMode = ref('table')
const recordColumns = [
{ key: 'time', dataKey: 'time', title: '时间', width: 170 },
{ key: 'current', dataKey: 'current', title: '电流', width: 110 },
{ key: 'voltage', dataKey: 'voltage', title: '电压', width: 110 },
{ key: 'set_temperature', dataKey: 'set_temperature', title: '设定温度', width: 120 },
{ key: 'actual_temperature', dataKey: 'actual_temperature', title: '实际温度', width: 120 },
]
const curveRecords = computed(() => {
const r = recordList.value
if (r.length <= 2000) return r
const step = Math.ceil(r.length / 2000)
return r.filter((_, i) => i % step === 0)
})
const categoryDialogVisible = ref(false) const categoryDialogVisible = ref(false)
const categoryDialogMode = ref('create') const categoryDialogMode = ref('create')
const editingCategoryId = ref('') const editingCategoryId = ref('')
...@@ -48,6 +64,66 @@ const qualityLoading = ref(false) ...@@ -48,6 +64,66 @@ const qualityLoading = ref(false)
const qualityFile = ref(null) const qualityFile = ref(null)
const qualityResult = ref(null) const qualityResult = ref(null)
const fileEditDialogVisible = ref(false)
const fileEditSubmitting = ref(false)
const editingFileId = ref(null)
const fileEditForm = reactive({
filename: '',
remark: '',
categoryId: '',
})
const openFileEditDialog = (row) => {
editingFileId.value = row.id
fileEditForm.filename = row.filename
fileEditForm.remark = row.remark || ''
fileEditForm.categoryId = row.category_id ? String(row.category_id) : ''
fileEditDialogVisible.value = true
}
const submitFileEdit = async () => {
const name = fileEditForm.filename.trim()
if (!name) {
ElMessage.warning('文件名不能为空')
return
}
if (!fileEditForm.categoryId) {
ElMessage.warning('请选择分类')
return
}
fileEditSubmitting.value = true
const prevCategoryId = fileList.value.find((f) => f.id === editingFileId.value)?.category_id
try {
const result = await updateDataFile(editingFileId.value, {
filename: name,
remark: fileEditForm.remark.trim() || null,
category_id: fileEditForm.categoryId,
})
ElMessage.success('保存成功')
fileEditDialogVisible.value = false
const categoryChanged = String(result.category_id) !== String(prevCategoryId)
if (categoryChanged) {
await loadFileList()
if (currentFile.value?.id === editingFileId.value) {
currentFile.value = null
recordList.value = []
}
} else {
const idx = fileList.value.findIndex((f) => f.id === editingFileId.value)
if (idx !== -1) {
fileList.value[idx].filename = result.filename
fileList.value[idx].remark = result.remark || ''
}
if (currentFile.value?.id === editingFileId.value) {
currentFile.value = { ...currentFile.value, filename: result.filename, remark: result.remark || '' }
}
}
} finally {
fileEditSubmitting.value = false
}
}
const qualityColorClass = (percent) => { const qualityColorClass = (percent) => {
if (percent >= 90) return 'quality-good' if (percent >= 90) return 'quality-good'
if (percent >= 70) return 'quality-warn' if (percent >= 70) return 'quality-warn'
...@@ -407,7 +483,7 @@ const handleViewFile = async (row) => { ...@@ -407,7 +483,7 @@ const handleViewFile = async (row) => {
currentFile.value = row currentFile.value = row
contentLoading.value = true contentLoading.value = true
try { try {
const result = await getFileRecords(row.id, { limit: 500 }) const result = await getFileRecords(row.id)
recordList.value = result.records recordList.value = result.records
} finally { } finally {
contentLoading.value = false contentLoading.value = false
...@@ -506,15 +582,16 @@ onBeforeUnmount(() => { ...@@ -506,15 +582,16 @@ onBeforeUnmount(() => {
</el-form-item> </el-form-item>
</el-form> </el-form>
<el-table :data="fileList" border stripe v-loading="loadingFiles" height="calc(100vh - 250px)"> <el-table :data="fileList" border stripe v-loading="loadingFiles" height="calc(100vh - 250px)" row-class-name="file-row" @row-click="handleViewFile">
<el-table-column prop="filename" label="文件名" min-width="160" /> <el-table-column prop="filename" label="文件名" min-width="160" />
<el-table-column prop="remark" label="备注" min-width="120" show-overflow-tooltip />
<el-table-column prop="uploaded_at" label="上传时间" min-width="170" /> <el-table-column prop="uploaded_at" label="上传时间" min-width="170" />
<el-table-column prop="data_count" label="数据量" width="90" /> <el-table-column prop="data_count" label="数据量" width="90" />
<el-table-column label="操作" width="200" fixed="right"> <el-table-column label="操作" width="220" fixed="right">
<template #default="{ row }"> <template #default="{ row }">
<el-button link type="primary" @click="handleViewFile(row)">查看</el-button> <el-button link type="primary" :icon="Edit" @click.stop="openFileEditDialog(row)" />
<el-button link type="success" @click="handleQualityCheck(row)">质量判定</el-button> <el-button link type="success" :icon="DataAnalysis" @click.stop="handleQualityCheck(row)" title="质量判定" />
<el-button link type="danger" @click="handleDeleteFile(row)">删除</el-button> <el-button link type="danger" :icon="Delete" @click.stop="handleDeleteFile(row)" />
</template> </template>
</el-table-column> </el-table-column>
</el-table> </el-table>
...@@ -528,6 +605,7 @@ onBeforeUnmount(() => { ...@@ -528,6 +605,7 @@ onBeforeUnmount(() => {
<span> <span>
文件内容 文件内容
<span v-if="currentFile" class="file-title">- {{ currentFile.filename }}</span> <span v-if="currentFile" class="file-title">- {{ currentFile.filename }}</span>
<span v-if="currentFile && currentFile.remark" class="file-remark">{{ currentFile.remark }}</span>
</span> </span>
<el-radio-group v-model="contentMode" size="small"> <el-radio-group v-model="contentMode" size="small">
<el-radio-button value="table">表格</el-radio-button> <el-radio-button value="table">表格</el-radio-button>
...@@ -539,21 +617,21 @@ onBeforeUnmount(() => { ...@@ -539,21 +617,21 @@ onBeforeUnmount(() => {
<div class="content-wrap" v-loading="contentLoading"> <div class="content-wrap" v-loading="contentLoading">
<el-empty v-if="!currentFile" description="请选择并查看一个文件" /> <el-empty v-if="!currentFile" description="请选择并查看一个文件" />
<el-table <el-auto-resizer v-else-if="contentMode === 'table'">
v-else-if="contentMode === 'table'" <template #default="{ height, width }">
:data="recordList" <el-table-v2
border :columns="recordColumns"
stripe :data="recordList"
height="calc(100vh - 250px)" :width="width"
> :height="height"
<el-table-column prop="time" label="时间" min-width="140" /> :row-height="36"
<el-table-column prop="current" label="电流" min-width="100" /> :header-height="40"
<el-table-column prop="voltage" label="电压" min-width="100" /> fixed
<el-table-column prop="set_temperature" label="设定温度" min-width="100" /> />
<el-table-column prop="actual_temperature" label="实际温度" min-width="100" /> </template>
</el-table> </el-auto-resizer>
<DataCurve v-else :records="recordList" /> <DataCurve v-else :records="curveRecords" :total-count="recordList.length" />
</div> </div>
</el-card> </el-card>
</div> </div>
...@@ -699,6 +777,39 @@ onBeforeUnmount(() => { ...@@ -699,6 +777,39 @@ onBeforeUnmount(() => {
<el-button @click="qualityDialogVisible = false">关闭</el-button> <el-button @click="qualityDialogVisible = false">关闭</el-button>
</template> </template>
</el-dialog> </el-dialog>
<!-- 编辑文件对话框 -->
<el-dialog v-model="fileEditDialogVisible" title="编辑文件信息" width="460px" destroy-on-close>
<el-form label-position="top">
<el-form-item label="文件名" required>
<el-input v-model="fileEditForm.filename" maxlength="255" show-word-limit clearable />
</el-form-item>
<el-form-item label="所属分类" required>
<el-select v-model="fileEditForm.categoryId" placeholder="请选择分类" style="width: 100%">
<el-option
v-for="item in flatCategoryOptions"
:key="item.value"
:label="item.label"
:value="item.value"
/>
</el-select>
</el-form-item>
<el-form-item label="备注">
<el-input
v-model="fileEditForm.remark"
type="textarea"
:rows="3"
maxlength="500"
show-word-limit
placeholder="可选备注"
/>
</el-form-item>
</el-form>
<template #footer>
<el-button @click="fileEditDialogVisible = false">取消</el-button>
<el-button type="primary" :loading="fileEditSubmitting" @click="submitFileEdit">保存</el-button>
</template>
</el-dialog>
</section> </section>
</template> </template>
...@@ -794,6 +905,13 @@ onBeforeUnmount(() => { ...@@ -794,6 +905,13 @@ onBeforeUnmount(() => {
margin-left: 4px; margin-left: 4px;
} }
.file-remark {
font-size: 12px;
color: var(--text-tertiary);
font-weight: 400;
margin-left: 2px;
}
.tree-node { .tree-node {
width: 100%; width: 100%;
display: flex; display: flex;
...@@ -948,4 +1066,8 @@ onBeforeUnmount(() => { ...@@ -948,4 +1066,8 @@ onBeforeUnmount(() => {
.quality-no-config { .quality-no-config {
color: #e6a23c; color: #e6a23c;
} }
:deep(.file-row) {
cursor: pointer;
}
</style> </style>
This diff is collapsed.
<script setup>
import * as echarts from 'echarts'
import { Refresh } from '@element-plus/icons-vue'
import { ref, onMounted, onBeforeUnmount, nextTick } from 'vue'
import { useRouter } from 'vue-router'
import { getExperiments, getDataPoints } from '@/api/realtimeMonitor'
const router = useRouter()
const runningExps = ref([])
const loadingList = ref(false)
// echarts instances & data state, keyed by experiment id
const chartsMap = {} // id -> echarts instance
const dataMap = {} // id -> data points array
const fromStepMap = {} // id -> next from_step
const MAX_DISPLAY = 300
let listTimer = null
let dataTimer = null
// ── 数据加载 ──────────────────────────────────────────────────────────────────
const loadList = async () => {
loadingList.value = true
try {
const all = await getExperiments()
const running = all.filter((e) => e.status === 'running')
// 清理已停止试验的图表实例
const runningIds = new Set(running.map((e) => e.id))
for (const id of Object.keys(chartsMap)) {
if (!runningIds.has(Number(id))) {
chartsMap[id]?.dispose()
delete chartsMap[id]
delete dataMap[id]
delete fromStepMap[id]
}
}
runningExps.value = running
} finally {
loadingList.value = false
}
}
const fetchAllData = async () => {
for (const exp of runningExps.value) {
try {
const from = fromStepMap[exp.id] ?? 0
const newPts = await getDataPoints(exp.id, from)
if (newPts.length) {
dataMap[exp.id] = [...(dataMap[exp.id] ?? []), ...newPts]
fromStepMap[exp.id] = dataMap[exp.id][dataMap[exp.id].length - 1].step_idx + 1
await nextTick()
updateChart(exp.id, exp.target_temp)
}
} catch {
// ignore per-experiment errors
}
}
}
// ── 图表 ──────────────────────────────────────────────────────────────────────
const getOrInitChart = (expId) => {
if (!chartsMap[expId]) {
const el = document.getElementById(`live-chart-${expId}`)
if (el) chartsMap[expId] = echarts.init(el)
}
return chartsMap[expId]
}
const updateChart = (expId, targetTemp) => {
const c = getOrInitChart(expId)
if (!c) return
const pts = dataMap[expId] ?? []
if (!pts.length) return
const step = pts.length > MAX_DISPLAY ? Math.ceil(pts.length / MAX_DISPLAY) : 1
const display = pts.filter((_, i) => i % step === 0)
const xData = display.map((d) => `${d.step_idx}`)
const actuals = display.map((d) => d.actual_temp)
const currents = display.map((d) => d.current_output)
const targetLine = display.map(() => targetTemp ?? null)
c.setOption(
{
animation: false,
color: ['#409EFF', '#F56C6C', '#E6A23C'],
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)',
borderColor: '#e2e8f0',
borderWidth: 1,
formatter(params) {
if (!params?.length) return ''
const lines = [`<b>步 ${params[0].axisValue}</b>`]
params.forEach((p) => {
const unit = p.seriesName === '电流输出' ? ' A' : ' °C'
const v = p.data != null ? Number(p.data).toFixed(2) + unit : '--'
lines.push(`${p.marker}${p.seriesName}: <b>${v}</b>`)
})
return lines.join('<br/>')
},
},
legend: { bottom: 2, itemWidth: 14, itemHeight: 7, textStyle: { fontSize: 11, color: '#475569' } },
grid: { top: 12, left: 8, right: 50, bottom: 40, containLabel: true },
xAxis: {
type: 'category',
boundaryGap: false,
data: xData,
axisLabel: { color: '#94a3b8', fontSize: 10, interval: Math.max(0, Math.floor(xData.length / 6) - 1) },
axisLine: { lineStyle: { color: '#e2e8f0' } },
},
yAxis: [
{
type: 'value',
name: '°C',
nameTextStyle: { color: '#64748b', fontSize: 10 },
axisLabel: { color: '#64748b', fontSize: 10 },
splitLine: { lineStyle: { color: '#e2e8f0' } },
},
{
type: 'value',
name: 'A',
position: 'right',
nameTextStyle: { color: '#E6A23C', fontSize: 10 },
axisLabel: { color: '#E6A23C', fontSize: 10 },
splitLine: { show: false },
min: 0,
},
],
series: [
{
name: '实际温度',
type: 'line',
yAxisIndex: 0,
data: actuals,
smooth: true,
symbol: 'none',
lineStyle: { width: 1.5 },
},
{
name: '目标温度',
type: 'line',
yAxisIndex: 0,
data: targetLine,
symbol: 'none',
lineStyle: { width: 1, type: 'dotted', color: '#F56C6C' },
},
{
name: '电流输出',
type: 'line',
yAxisIndex: 1,
data: currents,
smooth: false,
symbol: 'none',
lineStyle: { width: 1.5, color: '#E6A23C' },
areaStyle: { color: '#E6A23C', opacity: 0.05 },
},
],
},
true,
)
}
const onResize = () => {
Object.values(chartsMap).forEach((c) => c?.resize())
}
// ── 生命周期 ──────────────────────────────────────────────────────────────────
onMounted(async () => {
await loadList()
await nextTick()
// 加载各试验初始数据
await fetchAllData()
listTimer = setInterval(loadList, 5000)
dataTimer = setInterval(fetchAllData, 2000)
window.addEventListener('resize', onResize)
})
onBeforeUnmount(() => {
clearInterval(listTimer)
clearInterval(dataTimer)
window.removeEventListener('resize', onResize)
Object.values(chartsMap).forEach((c) => c?.dispose())
})
</script>
<template>
<div class="live-page">
<div class="live-header">
<span class="live-title">实时监控</span>
<el-tag type="success" size="small" style="margin-left:10px">
{{ runningExps.length }} 个试验运行中
</el-tag>
<el-button
:icon="Refresh"
size="small"
plain
:loading="loadingList"
style="margin-left:auto"
@click="loadList"
>
刷新
</el-button>
<el-button size="small" plain @click="router.push('/realtime-monitor')">
前往试验管理
</el-button>
</div>
<div v-if="!runningExps.length && !loadingList" class="empty-state">
<el-empty description="当前没有正在运行的试验" :image-size="120">
<el-button type="primary" @click="router.push('/realtime-monitor')">
前往试验管理
</el-button>
</el-empty>
</div>
<div v-else class="exp-grid">
<el-card
v-for="exp in runningExps"
:key="exp.id"
shadow="hover"
class="exp-card"
>
<template #header>
<div class="exp-card-header">
<span class="exp-name">{{ exp.name }}</span>
<el-tag type="success" size="small" effect="plain">运行中</el-tag>
</div>
<div class="exp-meta">
<span>目标 {{ exp.target_temp?.toFixed(1) }} °C</span>
<span>步数 {{ exp.total_steps }}</span>
<span>{{ exp.model_name }}</span>
</div>
</template>
<div
:id="`live-chart-${exp.id}`"
class="mini-chart"
>
<div v-if="!(dataMap[exp.id]?.length)" class="chart-placeholder">
<el-icon class="is-loading" style="font-size:20px;color:#409EFF"><Refresh /></el-icon>
<span>等待数据…</span>
</div>
</div>
</el-card>
</div>
</div>
</template>
<style lang="scss" scoped>
.live-page {
padding: 20px;
display: flex;
flex-direction: column;
gap: 16px;
height: 100%;
}
.live-header {
display: flex;
align-items: center;
gap: 8px;
flex-shrink: 0;
}
.live-title {
font-size: 16px;
font-weight: 700;
color: var(--text-primary);
}
.empty-state {
flex: 1;
display: flex;
align-items: center;
justify-content: center;
}
.exp-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(480px, 1fr));
gap: 16px;
align-content: start;
}
.exp-card {
:deep(.el-card__header) {
padding: 10px 16px 6px;
}
:deep(.el-card__body) {
padding: 0;
}
}
.exp-card-header {
display: flex;
align-items: center;
gap: 8px;
}
.exp-name {
font-size: 14px;
font-weight: 600;
color: var(--text-primary);
flex: 1;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.exp-meta {
display: flex;
gap: 12px;
font-size: 12px;
color: var(--text-tertiary);
margin-top: 4px;
}
.mini-chart {
height: 240px;
width: 100%;
position: relative;
}
.chart-placeholder {
position: absolute;
inset: 0;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 8px;
color: #94a3b8;
font-size: 13px;
}
</style>
...@@ -23,6 +23,12 @@ const validData = computed(() => ...@@ -23,6 +23,12 @@ const validData = computed(() =>
const xLabels = computed(() => validData.value.map((d) => d.time || String(d.index ?? ''))) const xLabels = computed(() => validData.value.map((d) => d.time || String(d.index ?? '')))
const actualSeries = computed(() => validData.value.map((d) => d.actual ?? null)) const actualSeries = computed(() => validData.value.map((d) => d.actual ?? null))
const predictedSeries = computed(() => validData.value.map((d) => d.predicted ?? null)) const predictedSeries = computed(() => validData.value.map((d) => d.predicted ?? null))
const currentSeries = computed(() => validData.value.map((d) => d.current ?? null))
const voltageSeries = computed(() => validData.value.map((d) => d.voltage ?? null))
const hasCurrentVoltage = computed(() =>
validData.value.some((d) => d.current != null || d.voltage != null),
)
const renderChart = () => { const renderChart = () => {
if (!chartRef.value) return if (!chartRef.value) return
...@@ -30,10 +36,82 @@ const renderChart = () => { ...@@ -30,10 +36,82 @@ const renderChart = () => {
const hasData = validData.value.length > 0 const hasData = validData.value.length > 0
const yAxes = [
{
type: 'value',
name: '温度 (℃)',
position: 'left',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
axisLine: { show: true, lineStyle: { color: '#409EFF' } },
splitLine: { lineStyle: { type: 'dashed', color: 'rgba(148,163,184,0.4)' } },
},
]
if (hasCurrentVoltage.value) {
yAxes.push({
type: 'value',
name: '电流 (A) / 电压 (V)',
position: 'right',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
axisLine: { show: true, lineStyle: { color: '#E6A23C' } },
splitLine: { show: false },
})
}
const series = [
{
name: '实际温度 (℃)',
type: 'line',
yAxisIndex: 0,
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'solid' },
data: actualSeries.value,
},
{
name: '预测温度 (℃)',
type: 'line',
yAxisIndex: 0,
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'dashed' },
data: predictedSeries.value,
},
]
if (hasCurrentVoltage.value) {
series.push(
{
name: '电流 (A)',
type: 'line',
yAxisIndex: 1,
smooth: false,
symbol: 'none',
lineStyle: { width: 1.5, type: 'solid' },
data: currentSeries.value,
},
{
name: '电压 (V)',
type: 'line',
yAxisIndex: 1,
smooth: false,
symbol: 'none',
lineStyle: { width: 1.5, type: 'dotted' },
data: voltageSeries.value,
},
)
}
const legendData = hasCurrentVoltage.value
? ['实际温度 (℃)', '预测温度 (℃)', '电流 (A)', '电压 (V)']
: ['实际温度 (℃)', '预测温度 (℃)']
chartInstance.setOption( chartInstance.setOption(
{ {
animation: false, animation: false,
color: ['#409EFF', '#F56C6C'], color: ['#409EFF', '#F56C6C', '#E6A23C', '#67C23A'],
tooltip: { tooltip: {
trigger: 'axis', trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)', backgroundColor: 'rgba(255,255,255,0.96)',
...@@ -48,10 +126,13 @@ const renderChart = () => { ...@@ -48,10 +126,13 @@ const renderChart = () => {
] ]
params.forEach((item) => { params.forEach((item) => {
const val = item.data != null ? Number(item.data).toFixed(4) : '--' const val = item.data != null ? Number(item.data).toFixed(4) : '--'
const unit = item.seriesName.includes('℃') ? ' ℃' :
item.seriesName.includes('(A)') ? ' A' :
item.seriesName.includes('(V)') ? ' V' : ''
lines.push( lines.push(
`<div style="display:flex;align-items:center;gap:8px;min-width:180px;justify-content:space-between;"> `<div style="display:flex;align-items:center;gap:8px;min-width:200px;justify-content:space-between;">
<span>${item.marker}${item.seriesName}</span> <span>${item.marker}${item.seriesName}</span>
<strong>${val}</strong> <strong>${val}${unit}</strong>
</div>`, </div>`,
) )
}) })
...@@ -63,9 +144,9 @@ const renderChart = () => { ...@@ -63,9 +144,9 @@ const renderChart = () => {
itemWidth: 20, itemWidth: 20,
itemHeight: 10, itemHeight: 10,
textStyle: { color: '#475569', fontSize: 12 }, textStyle: { color: '#475569', fontSize: 12 },
data: ['实际温度 (℃)', '预测温度 (℃)'], data: legendData,
}, },
grid: { top: 16, left: 16, right: 20, bottom: 52, containLabel: true }, grid: { top: 16, left: 16, right: hasCurrentVoltage.value ? 80 : 20, bottom: 52, containLabel: true },
xAxis: { xAxis: {
type: 'category', type: 'category',
boundaryGap: false, boundaryGap: false,
...@@ -79,31 +160,8 @@ const renderChart = () => { ...@@ -79,31 +160,8 @@ const renderChart = () => {
}, },
axisLine: { lineStyle: { color: '#cbd5e1' } }, axisLine: { lineStyle: { color: '#cbd5e1' } },
}, },
yAxis: { yAxis: yAxes,
type: 'value', series,
name: '温度 (℃)',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
splitLine: { lineStyle: { type: 'dashed', color: 'rgba(148,163,184,0.4)' } },
},
series: [
{
name: '实际温度 (℃)',
type: 'line',
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'solid' },
data: actualSeries.value,
},
{
name: '预测温度 (℃)',
type: 'line',
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'dashed' },
data: predictedSeries.value,
},
],
graphic: hasData graphic: hasData
? [] ? []
: [ : [
...@@ -145,3 +203,4 @@ onBeforeUnmount(() => { ...@@ -145,3 +203,4 @@ onBeforeUnmount(() => {
<template> <template>
<div ref="chartRef" :style="{ width: '100%', height: props.height }" /> <div ref="chartRef" :style="{ width: '100%', height: props.height }" />
</template> </template>
<script setup> <script setup>
import { Refresh } from '@element-plus/icons-vue' import { Refresh } from '@element-plus/icons-vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import { onMounted, ref } from 'vue' import { onMounted, reactive, ref } from 'vue'
import { deleteSavedModel, getSavedModels } from '@/api/trainManagement' import { useRouter } from 'vue-router'
import { deleteSavedModel, getSavedModels, updateSavedModel } from '@/api/trainManagement'
const router = useRouter()
const models = ref([]) const models = ref([])
const loading = ref(false) const loading = ref(false)
...@@ -16,6 +19,43 @@ const loadModels = async () => { ...@@ -16,6 +19,43 @@ const loadModels = async () => {
} }
} }
// ── edit dialog ───────────────────────────────────────────────────────────────
const editVisible = ref(false)
const editRow = ref(null)
const editForm = reactive({ model_name: '', description: '' })
const editSaving = ref(false)
const handleEdit = (row) => {
editRow.value = row
editForm.model_name = row.model_name
editForm.description = row.description || ''
editVisible.value = true
}
const confirmEdit = async () => {
if (!editForm.model_name.trim()) return ElMessage.warning('模型名称不能为空')
editSaving.value = true
try {
await updateSavedModel(editRow.value.id, {
model_name: editForm.model_name.trim(),
description: editForm.description.trim(),
})
ElMessage.success('已更新')
editVisible.value = false
await loadModels()
} catch (e) {
ElMessage.error(e?.message || '更新失败')
} finally {
editSaving.value = false
}
}
// ── evaluate shortcut ─────────────────────────────────────────────────────────
const handleEvaluate = (row) => {
router.push({ name: 'model-evaluation', query: { model_id: row.id } })
}
// ── delete ────────────────────────────────────────────────────────────────────
const handleDelete = async (row) => { const handleDelete = async (row) => {
try { try {
await ElMessageBox.confirm(`确定删除模型"${row.model_name}"吗?删除后无法恢复。`, '提示', { await ElMessageBox.confirm(`确定删除模型"${row.model_name}"吗?删除后无法恢复。`, '提示', {
...@@ -29,14 +69,16 @@ const handleDelete = async (row) => { ...@@ -29,14 +69,16 @@ const handleDelete = async (row) => {
} }
} }
// ── display helpers ───────────────────────────────────────────────────────────
const formatParams = (params) => { const formatParams = (params) => {
if (!params) return '-' if (!params) return '-'
return [ return [
`seq=${params.seq_len}`, `序列长度 ${params.seq_len}`,
`hidden=${params.hidden_size}`, `隐藏层 ${params.hidden_size}`,
`layers=${params.num_layers}`, `层数 ${params.num_layers}`,
`epochs=${params.epochs}`, `轮数 ${params.epochs}`,
`lr=${params.learning_rate}`, `批次 ${params.batch_size}`,
`学习率 ${params.learning_rate}`,
].join(' / ') ].join(' / ')
} }
...@@ -65,10 +107,15 @@ onMounted(loadModels) ...@@ -65,10 +107,15 @@ onMounted(loadModels)
<el-table :data="models" v-loading="loading" border stripe height="calc(100vh - 160px)"> <el-table :data="models" v-loading="loading" border stripe height="calc(100vh - 160px)">
<el-table-column type="index" width="55" label="#" align="center" /> <el-table-column type="index" width="55" label="#" align="center" />
<el-table-column prop="model_name" label="模型名称" min-width="150" show-overflow-tooltip /> <el-table-column prop="model_name" label="模型名称" min-width="140" show-overflow-tooltip />
<el-table-column prop="description" label="说明" min-width="160" show-overflow-tooltip>
<template #default="{ row }">
<span class="desc-text">{{ row.description || '—' }}</span>
</template>
</el-table-column>
<el-table-column prop="package_name" label="训练数据包" min-width="140" show-overflow-tooltip /> <el-table-column prop="package_name" label="训练数据包" min-width="140" show-overflow-tooltip />
<el-table-column label="LSTM 参数" min-width="280" show-overflow-tooltip> <el-table-column label="LSTM 参数" min-width="310" show-overflow-tooltip>
<template #default="{ row }"> <template #default="{ row }">
<el-tooltip :content="formatParams(row.params)" placement="top"> <el-tooltip :content="formatParams(row.params)" placement="top">
<span class="params-text">{{ formatParams(row.params) }}</span> <span class="params-text">{{ formatParams(row.params) }}</span>
...@@ -84,8 +131,10 @@ onMounted(loadModels) ...@@ -84,8 +131,10 @@ onMounted(loadModels)
<el-table-column prop="created_at" label="保存时间" width="170" /> <el-table-column prop="created_at" label="保存时间" width="170" />
<el-table-column label="操作" width="90" fixed="right" align="center"> <el-table-column label="操作" width="160" fixed="right" align="center">
<template #default="{ row }"> <template #default="{ row }">
<el-button link type="primary" @click="handleEdit(row)">编辑</el-button>
<el-button link type="success" @click="handleEvaluate(row)">评估</el-button>
<el-button link type="danger" @click="handleDelete(row)">删除</el-button> <el-button link type="danger" @click="handleDelete(row)">删除</el-button>
</template> </template>
</el-table-column> </el-table-column>
...@@ -98,6 +147,29 @@ onMounted(loadModels) ...@@ -98,6 +147,29 @@ onMounted(loadModels)
/> />
</el-card> </el-card>
</div> </div>
<!-- ── edit dialog ──────────────────────────────────────────────── -->
<el-dialog v-model="editVisible" title="编辑模型信息" width="480px" :close-on-click-modal="false">
<el-form :model="editForm" label-position="top">
<el-form-item label="模型名称" required>
<el-input v-model="editForm.model_name" maxlength="100" show-word-limit />
</el-form-item>
<el-form-item label="说明">
<el-input
v-model="editForm.description"
type="textarea"
:rows="3"
placeholder="模型用途、训练说明等"
maxlength="500"
show-word-limit
/>
</el-form-item>
</el-form>
<template #footer>
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" :loading="editSaving" @click="confirmEdit">保存</el-button>
</template>
</el-dialog>
</template> </template>
<style lang="scss" scoped> <style lang="scss" scoped>
...@@ -144,7 +216,8 @@ onMounted(loadModels) ...@@ -144,7 +216,8 @@ onMounted(loadModels)
} }
.params-text, .params-text,
.loss-text { .loss-text,
.desc-text {
font-size: 12px; font-size: 12px;
color: var(--text-secondary); color: var(--text-secondary);
white-space: nowrap; white-space: nowrap;
...@@ -158,3 +231,5 @@ onMounted(loadModels) ...@@ -158,3 +231,5 @@ onMounted(loadModels)
font-weight: 500; font-weight: 500;
} }
</style> </style>
This diff is collapsed.
<script setup> <script setup>
import { ref, watch } from 'vue' import { computed, ref, watch } from 'vue'
import { ElAutoResizer, ElTableV2 } from 'element-plus'
import { getPackageRecords } from '@/api/packageManagement' import { getPackageRecords } from '@/api/packageManagement'
import DataCurve from '@/views/DataManagement/components/DataCurve.vue' import DataCurve from '@/views/DataManagement/components/DataCurve.vue'
...@@ -18,6 +19,21 @@ const loading = ref(false) ...@@ -18,6 +19,21 @@ const loading = ref(false)
const records = ref([]) const records = ref([])
const contentMode = ref('table') const contentMode = ref('table')
const detailColumns = [
{ key: 'time', dataKey: 'time', title: '时间', width: 170 },
{ key: 'current', dataKey: 'current', title: '电流', width: 110 },
{ key: 'voltage', dataKey: 'voltage', title: '电压', width: 110 },
{ key: 'set_temperature', dataKey: 'set_temperature', title: '设定温度', width: 120 },
{ key: 'actual_temperature', dataKey: 'actual_temperature', title: '实际温度', width: 120 },
]
const curveRecords = computed(() => {
const r = records.value
if (r.length <= 2000) return r
const step = Math.ceil(r.length / 2000)
return r.filter((_, i) => i % step === 0)
})
const loadRecords = async (pkg) => { const loadRecords = async (pkg) => {
if (!pkg) { if (!pkg) {
records.value = [] records.value = []
...@@ -25,7 +41,7 @@ const loadRecords = async (pkg) => { ...@@ -25,7 +41,7 @@ const loadRecords = async (pkg) => {
} }
loading.value = true loading.value = true
try { try {
const result = await getPackageRecords(pkg.id, { limit: 500 }) const result = await getPackageRecords(pkg.id)
records.value = result.records records.value = result.records
} finally { } finally {
loading.value = false loading.value = false
...@@ -56,21 +72,21 @@ watch(() => props.package, (pkg) => { ...@@ -56,21 +72,21 @@ watch(() => props.package, (pkg) => {
<div class="content-wrap" v-loading="loading"> <div class="content-wrap" v-loading="loading">
<el-empty v-if="!props.package" description="请选择一个数据包查看" /> <el-empty v-if="!props.package" description="请选择一个数据包查看" />
<el-table <el-auto-resizer v-else-if="contentMode === 'table'">
v-else-if="contentMode === 'table'" <template #default="{ height, width }">
:data="records" <el-table-v2
border :columns="detailColumns"
stripe :data="records"
height="calc(100vh - 250px)" :width="width"
> :height="height"
<el-table-column prop="time" label="时间" min-width="140" /> :row-height="36"
<el-table-column prop="current" label="电流" min-width="90" /> :header-height="40"
<el-table-column prop="voltage" label="电压" min-width="90" /> fixed
<el-table-column prop="set_temperature" label="设定温度" min-width="100" /> />
<el-table-column prop="actual_temperature" label="实际温度" min-width="100" /> </template>
</el-table> </el-auto-resizer>
<DataCurve v-else :records="records" /> <DataCurve v-else :records="curveRecords" :total-count="records.length" />
</div> </div>
</el-card> </el-card>
</template> </template>
......
<script setup> <script setup>
import { Plus, Search } from '@element-plus/icons-vue' import { Delete, Edit, Plus, Search } from '@element-plus/icons-vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import { onMounted, reactive, ref, watch } from 'vue' import { onMounted, reactive, ref, watch } from 'vue'
import { deletePackage, getPackages } from '@/api/packageManagement' import { deletePackage, getPackages, updatePackage } from '@/api/packageManagement'
import { getPkgCategoryTree } from '@/api/packageManagement'
const props = defineProps({ const props = defineProps({
categoryId: { categoryId: {
...@@ -52,6 +53,61 @@ const handleView = (row) => { ...@@ -52,6 +53,61 @@ const handleView = (row) => {
emit('view', row) emit('view', row)
} }
// ── edit ──────────────────────────────────────────────────────────────────
const editVisible = ref(false)
const editLoading = ref(false)
const editForm = reactive({ id: null, name: '', category_id: null, remark: '' })
const categoryOptions = ref([])
const loadCategoryOptions = async () => {
if (categoryOptions.value.length) return
try {
const tree = await getPkgCategoryTree()
const flatten = (nodes, result = []) => {
for (const n of nodes) {
result.push({ value: n.id, label: n.name })
if (n.children?.length) flatten(n.children, result)
}
return result
}
categoryOptions.value = flatten(tree)
} catch {}
}
const handleEdit = async (row) => {
await loadCategoryOptions()
editForm.id = row.id
editForm.name = row.name
editForm.category_id = row.category_id ?? null
editForm.remark = row.remark ?? ''
editVisible.value = true
}
const submitEdit = async () => {
if (!editForm.name.trim()) {
ElMessage.warning('请输入数据包名称')
return
}
editLoading.value = true
try {
const updated = await updatePackage(editForm.id, {
name: editForm.name.trim(),
category_id: editForm.category_id || null,
remark: editForm.remark.trim() || null,
})
ElMessage.success('修改成功')
editVisible.value = false
// 更新列表中对应行
const idx = packageList.value.findIndex(p => p.id === editForm.id)
if (idx !== -1) Object.assign(packageList.value[idx], updated)
if (currentPackage.value?.id === editForm.id) Object.assign(currentPackage.value, updated)
} catch (e) {
ElMessage.error(e?.response?.data?.detail || '修改失败')
} finally {
editLoading.value = false
}
}
const handleDelete = async (row) => { const handleDelete = async (row) => {
try { try {
await ElMessageBox.confirm(`确定删除数据包"${row.name}"吗?`, '提示', { type: 'warning' }) await ElMessageBox.confirm(`确定删除数据包"${row.name}"吗?`, '提示', { type: 'warning' })
...@@ -105,18 +161,48 @@ onMounted(loadPackages) ...@@ -105,18 +161,48 @@ onMounted(loadPackages)
highlight-current-row highlight-current-row
height="calc(100vh - 265px)" height="calc(100vh - 265px)"
:row-class-name="({ row }) => (currentPackage?.id === row.id ? 'current-row' : '')" :row-class-name="({ row }) => (currentPackage?.id === row.id ? 'current-row' : '')"
style="cursor: pointer"
@row-click="handleView"
> >
<el-table-column prop="name" label="数据包名称" min-width="150" show-overflow-tooltip /> <el-table-column prop="name" label="数据包名称" min-width="100" show-overflow-tooltip />
<el-table-column prop="created_at" label="创建时间" min-width="160" /> <el-table-column prop="created_at" label="创建时间" min-width="100" />
<el-table-column prop="data_count" label="数据量" width="80" align="center" /> <el-table-column prop="data_count" label="数据量" width="80" align="center" />
<el-table-column label="操作" width="130" fixed="right"> <el-table-column label="操作" width="150" fixed="right" align="center">
<template #default="{ row }"> <template #default="{ row }">
<el-button link type="primary" @click="handleView(row)">查看</el-button> <div>
<el-button link type="danger" @click="handleDelete(row)">删除</el-button> <el-button link type="primary" :icon="Edit" @click.stop="handleEdit(row)" />
<el-button link type="danger" :icon="Delete" @click.stop="handleDelete(row)" />
</div>
</template> </template>
</el-table-column> </el-table-column>
</el-table> </el-table>
</el-card> </el-card>
<!-- 编辑数据包 -->
<el-dialog v-model="editVisible" title="编辑数据包" width="480px" :close-on-click-modal="false">
<el-form label-width="90px">
<el-form-item label="名称" required>
<el-input v-model="editForm.name" placeholder="数据包名称" clearable />
</el-form-item>
<el-form-item label="分类">
<el-select v-model="editForm.category_id" placeholder="选择分类(可选)" clearable style="width:100%">
<el-option
v-for="opt in categoryOptions"
:key="opt.value"
:label="opt.label"
:value="opt.value"
/>
</el-select>
</el-form-item>
<el-form-item label="备注">
<el-input v-model="editForm.remark" type="textarea" :rows="3" placeholder="备注(可选)" />
</el-form-item>
</el-form>
<template #footer>
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" :loading="editLoading" @click="submitEdit">保存</el-button>
</template>
</el-dialog>
</template> </template>
<style lang="scss" scoped> <style lang="scss" scoped>
......
...@@ -8,7 +8,6 @@ import { ...@@ -8,7 +8,6 @@ import {
createExperiment, createExperiment,
deleteExperiment, deleteExperiment,
getMonitorModels, getMonitorModels,
getMonitorPackages,
} from '@/api/realtimeMonitor' } from '@/api/realtimeMonitor'
const router = useRouter() const router = useRouter()
...@@ -30,7 +29,6 @@ const loadExperiments = async () => { ...@@ -30,7 +29,6 @@ const loadExperiments = async () => {
const dialogVisible = ref(false) const dialogVisible = ref(false)
const submitting = ref(false) const submitting = ref(false)
const models = ref([]) const models = ref([])
const packages = ref([])
const defaultMpcParams = () => ({ const defaultMpcParams = () => ({
P: 20, P: 20,
...@@ -50,7 +48,8 @@ const defaultMpcParams = () => ({ ...@@ -50,7 +48,8 @@ const defaultMpcParams = () => ({
const form = reactive({ const form = reactive({
name: '', name: '',
model_id: '', model_id: '',
package_id: '', input_csv_path: '',
output_csv_path: '',
target_temp: 35.0, target_temp: 35.0,
sampling_interval: 1.0, sampling_interval: 1.0,
mpc_params: defaultMpcParams(), mpc_params: defaultMpcParams(),
...@@ -62,7 +61,8 @@ const openDialog = async () => { ...@@ -62,7 +61,8 @@ const openDialog = async () => {
Object.assign(form, { Object.assign(form, {
name: '', name: '',
model_id: '', model_id: '',
package_id: '', input_csv_path: '',
output_csv_path: '',
target_temp: 35.0, target_temp: 35.0,
sampling_interval: 1.0, sampling_interval: 1.0,
mpc_params: defaultMpcParams(), mpc_params: defaultMpcParams(),
...@@ -70,23 +70,23 @@ const openDialog = async () => { ...@@ -70,23 +70,23 @@ const openDialog = async () => {
showAdvanced.value = false showAdvanced.value = false
dialogVisible.value = true dialogVisible.value = true
if (!models.value.length) { if (!models.value.length) {
const [m, p] = await Promise.all([getMonitorModels(), getMonitorPackages()]) models.value = await getMonitorModels()
models.value = m
packages.value = p
} }
} }
const handleCreate = async () => { const handleCreate = async () => {
if (!form.name.trim()) { ElMessage.warning('请输入试验名称'); return } if (!form.name.trim()) { ElMessage.warning('请输入试验名称'); return }
if (!form.model_id) { ElMessage.warning('请选择预测模型'); return } if (!form.model_id) { ElMessage.warning('请选择预测模型'); return }
if (!form.package_id) { ElMessage.warning('请选择初始数据包'); return } if (!form.input_csv_path.trim()) { ElMessage.warning('请输入输入CSV文件路径'); return }
if (!form.output_csv_path.trim()) { ElMessage.warning('请输入输出CSV文件路径'); return }
submitting.value = true submitting.value = true
try { try {
await createExperiment({ await createExperiment({
name: form.name.trim(), name: form.name.trim(),
model_id: form.model_id, model_id: form.model_id,
package_id: form.package_id, input_csv_path: form.input_csv_path.trim(),
output_csv_path: form.output_csv_path.trim(),
target_temp: form.target_temp, target_temp: form.target_temp,
sampling_interval: form.sampling_interval, sampling_interval: form.sampling_interval,
mpc_params: { ...form.mpc_params }, mpc_params: { ...form.mpc_params },
...@@ -133,7 +133,7 @@ onMounted(loadExperiments) ...@@ -133,7 +133,7 @@ onMounted(loadExperiments)
<el-card shadow="hover" class="page-card"> <el-card shadow="hover" class="page-card">
<template #header> <template #header>
<div class="card-header-row"> <div class="card-header-row">
<span class="card-title">实时监控试验</span> <span class="card-title">试验管理</span>
<div class="header-actions"> <div class="header-actions">
<el-button :icon="Refresh" size="small" plain :loading="loading" @click="loadExperiments"> <el-button :icon="Refresh" size="small" plain :loading="loading" @click="loadExperiments">
刷新 刷新
...@@ -148,7 +148,7 @@ onMounted(loadExperiments) ...@@ -148,7 +148,7 @@ onMounted(loadExperiments)
<el-table :data="experiments" v-loading="loading" stripe> <el-table :data="experiments" v-loading="loading" stripe>
<el-table-column prop="name" label="试验名称" min-width="160" /> <el-table-column prop="name" label="试验名称" min-width="160" />
<el-table-column prop="model_name" label="预测模型" min-width="140" /> <el-table-column prop="model_name" label="预测模型" min-width="140" />
<el-table-column prop="package_name" label="数据包" min-width="140" /> <el-table-column prop="input_csv_path" label="输入CSV" min-width="180" show-overflow-tooltip />
<el-table-column prop="target_temp" label="目标温度(°C)" width="120" align="right"> <el-table-column prop="target_temp" label="目标温度(°C)" width="120" align="right">
<template #default="{ row }">{{ row.target_temp?.toFixed(1) }}</template> <template #default="{ row }">{{ row.target_temp?.toFixed(1) }}</template>
</el-table-column> </el-table-column>
...@@ -180,8 +180,8 @@ onMounted(loadExperiments) ...@@ -180,8 +180,8 @@ onMounted(loadExperiments)
</el-card> </el-card>
<!-- 创建试验对话框 --> <!-- 创建试验对话框 -->
<el-dialog v-model="dialogVisible" title="创建试验" width="600px" :close-on-click-modal="false"> <el-dialog v-model="dialogVisible" title="创建试验" width="640px" :close-on-click-modal="false">
<el-form label-width="130px" @submit.prevent> <el-form label-width="140px" @submit.prevent>
<el-form-item label="试验名称" required> <el-form-item label="试验名称" required>
<el-input v-model="form.name" placeholder="请输入试验名称" /> <el-input v-model="form.name" placeholder="请输入试验名称" />
</el-form-item> </el-form-item>
...@@ -193,10 +193,19 @@ onMounted(loadExperiments) ...@@ -193,10 +193,19 @@ onMounted(loadExperiments)
</el-option> </el-option>
</el-select> </el-select>
</el-form-item> </el-form-item>
<el-form-item label="初始数据包" required> <el-form-item label="输入CSV路径" required>
<el-select v-model="form.package_id" placeholder="选择数据包" style="width:100%"> <el-input
<el-option v-for="p in packages" :key="p.id" :label="p.name" :value="p.id" /> v-model="form.input_csv_path"
</el-select> placeholder="传感器数据源 CSV 文件的绝对路径"
/>
<div class="field-hint">每步控制将从该文件末尾读取最新传感器数据</div>
</el-form-item>
<el-form-item label="输出CSV路径" required>
<el-input
v-model="form.output_csv_path"
placeholder="生成曲线写入的 CSV 文件绝对路径"
/>
<div class="field-hint">MPC 每步控制结果(温度、电流)将追加写入该文件</div>
</el-form-item> </el-form-item>
<el-form-item label="目标温度(°C)" required> <el-form-item label="目标温度(°C)" required>
<el-input-number v-model="form.target_temp" :step="0.5" :precision="1" style="width:160px" /> <el-input-number v-model="form.target_temp" :step="0.5" :precision="1" style="width:160px" />
...@@ -210,7 +219,7 @@ onMounted(loadExperiments) ...@@ -210,7 +219,7 @@ onMounted(loadExperiments)
:precision="1" :precision="1"
style="width:160px" style="width:160px"
/> />
<span style="margin-left:8px;color:#94a3b8;font-size:12px">每步等待时长,需与传感器采集频率一致</span> <span style="margin-left:8px;color:#94a3b8;font-size:12px">需与传感器采集频率一致</span>
</el-form-item> </el-form-item>
<!-- MPC 参数(高级) --> <!-- MPC 参数(高级) -->
...@@ -290,9 +299,17 @@ onMounted(loadExperiments) ...@@ -290,9 +299,17 @@ onMounted(loadExperiments)
gap: 8px; gap: 8px;
} }
.field-hint {
font-size: 12px;
color: #94a3b8;
margin-top: 4px;
line-height: 1.4;
}
.params-grid { .params-grid {
display: grid; display: grid;
grid-template-columns: 1fr 1fr; grid-template-columns: 1fr 1fr;
gap: 0 16px; gap: 0 16px;
} }
</style> </style>
...@@ -12,7 +12,7 @@ export default defineConfig({ ...@@ -12,7 +12,7 @@ export default defineConfig({
server: { server: {
proxy: { proxy: {
'/api': { '/api': {
target: 'http://127.0.0.1:8000', target: 'http://127.0.0.1:8002',
changeOrigin: true, changeOrigin: true,
}, },
}, },
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment