Commit d8561934 authored by luwei's avatar luwei

修改

parent d7a073ab
...@@ -91,6 +91,26 @@ def download_template(): ...@@ -91,6 +91,26 @@ def download_template():
) )
class FileUpdateRequest(BaseModel):
filename: str | None = Field(default=None, min_length=1, max_length=255)
remark: str | None = Field(default=None, max_length=500)
category_id: str | None = Field(default=None)
@router.put('/files/{file_id}')
def update_file(file_id: str, request: FileUpdateRequest):
try:
result = service.update_file(
file_id=file_id,
filename=request.filename,
remark=request.remark,
category_id=request.category_id,
)
return success_response(data=result, message='文件更新成功')
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
@router.delete('/files/{file_id}') @router.delete('/files/{file_id}')
def delete_file(file_id: str): def delete_file(file_id: str):
try: try:
...@@ -115,9 +135,9 @@ def get_file_quality(file_id: str): ...@@ -115,9 +135,9 @@ def get_file_quality(file_id: str):
@router.get('/files/{file_id}/records') @router.get('/files/{file_id}/records')
def get_file_records(file_id: str, limit: int = Query(default=500, ge=1, le=5000)): def get_file_records(file_id: str):
try: try:
result = service.get_file_records(file_id=file_id, limit=limit) result = service.get_file_records(file_id=file_id, limit=None)
return success_response(data=result) return success_response(data=result)
except ValueError as error: except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error raise HTTPException(status_code=400, detail=str(error)) from error
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from app.services.eval_service import EvalService from app.services.eval_service import EvalService
...@@ -14,8 +14,11 @@ class EvalRequest(BaseModel): ...@@ -14,8 +14,11 @@ class EvalRequest(BaseModel):
@router.get('/packages') @router.get('/packages')
def list_packages(): def list_packages(
return success_response(data=service.list_packages()) category_id: str = Query(default=''),
name: str = Query(default=''),
):
return success_response(data=service.list_packages(category_id=category_id, name=name.strip()))
@router.get('/models') @router.get('/models')
......
...@@ -28,7 +28,8 @@ class MPCParamsSchema(BaseModel): ...@@ -28,7 +28,8 @@ class MPCParamsSchema(BaseModel):
class CreateExperimentRequest(BaseModel): class CreateExperimentRequest(BaseModel):
name: str = Field(min_length=1, max_length=255) name: str = Field(min_length=1, max_length=255)
model_id: int model_id: int
package_id: int input_csv_path: str = Field(min_length=1, description='读取传感器数据的CSV文件路径')
output_csv_path: str = Field(min_length=1, description='输出曲线数据的CSV文件路径')
target_temp: float = Field(description='目标温度(°C)') target_temp: float = Field(description='目标温度(°C)')
sampling_interval: float = Field(default=1.0, gt=0, le=3600, description='采样周期(秒)') sampling_interval: float = Field(default=1.0, gt=0, le=3600, description='采样周期(秒)')
mpc_params: MPCParamsSchema = Field(default_factory=MPCParamsSchema) mpc_params: MPCParamsSchema = Field(default_factory=MPCParamsSchema)
...@@ -59,7 +60,8 @@ def create_experiment(req: CreateExperimentRequest): ...@@ -59,7 +60,8 @@ def create_experiment(req: CreateExperimentRequest):
exp = service.create_experiment( exp = service.create_experiment(
name=req.name.strip(), name=req.name.strip(),
model_id=req.model_id, model_id=req.model_id,
package_id=req.package_id, input_csv_path=req.input_csv_path.strip(),
output_csv_path=req.output_csv_path.strip(),
target_temp=req.target_temp, target_temp=req.target_temp,
sampling_interval=req.sampling_interval, sampling_interval=req.sampling_interval,
mpc_params=req.mpc_params.model_dump(), mpc_params=req.mpc_params.model_dump(),
...@@ -113,25 +115,5 @@ def get_data_points(exp_id: int, from_step: int = Query(default=0, ge=0)): ...@@ -113,25 +115,5 @@ def get_data_points(exp_id: int, from_step: int = Query(default=0, ge=0)):
return success_response(data=service.get_data_points(exp_id, from_step)) return success_response(data=service.get_data_points(exp_id, from_step))
@router.get('/experiments/{exp_id}/report')
def get_report(exp_id: int):
try:
return success_response(data=service.get_report(exp_id))
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
@router.post('/experiments/{exp_id}/export')
def export_to_history(exp_id: int):
try:
result = service.export_to_history(exp_id)
return success_response(data=result, message='已导出到历史数据')
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
# ── 历史数据 ──────────────────────────────────────────────────────────────────
@router.get('/history')
def list_history():
return success_response(data=service.list_history_experiments())
...@@ -20,6 +20,7 @@ class CategoryUpdateRequest(BaseModel): ...@@ -20,6 +20,7 @@ class CategoryUpdateRequest(BaseModel):
class CleanRules(BaseModel): class CleanRules(BaseModel):
enabled: bool = False enabled: bool = False
newton_interp: bool = False
current_min: float | None = None current_min: float | None = None
current_max: float | None = None current_max: float | None = None
voltage_min: float | None = None voltage_min: float | None = None
...@@ -28,6 +29,11 @@ class CleanRules(BaseModel): ...@@ -28,6 +29,11 @@ class CleanRules(BaseModel):
temperature_max: float | None = None temperature_max: float | None = None
class SmoothConfig(BaseModel):
enabled: bool = False
window: int = Field(default=5, ge=2, le=500)
@router.get('/categories') @router.get('/categories')
def get_categories(): def get_categories():
return success_response(data=service.get_category_tree()) return success_response(data=service.get_category_tree())
...@@ -63,8 +69,22 @@ def delete_category(category_id: str): ...@@ -63,8 +69,22 @@ def delete_category(category_id: str):
# Must be declared before /{package_id} routes to avoid path conflict # Must be declared before /{package_id} routes to avoid path conflict
@router.get('/data-files') @router.get('/data-files')
def list_all_data_files(): def list_all_data_files(
return success_response(data=service.list_all_data_files()) category_id: str = Query(default=''),
filename: str = Query(default=''),
remark: str = Query(default=''),
):
return success_response(data=service.list_all_data_files(
category_id=category_id,
filename=filename.strip(),
remark=remark.strip(),
))
class PackageUpdateRequest(BaseModel):
name: str = Field(min_length=1, max_length=255)
category_id: str | int | None = Field(default=None)
remark: str | None = Field(default=None)
class PackageCreateRequest(BaseModel): class PackageCreateRequest(BaseModel):
...@@ -73,6 +93,8 @@ class PackageCreateRequest(BaseModel): ...@@ -73,6 +93,8 @@ class PackageCreateRequest(BaseModel):
remark: str | None = Field(default=None) remark: str | None = Field(default=None)
file_ids: list[int] = Field(default_factory=list) file_ids: list[int] = Field(default_factory=list)
clean_rules: CleanRules | None = Field(default=None) clean_rules: CleanRules | None = Field(default=None)
smooth: SmoothConfig | None = Field(default=None)
auto_split: bool = Field(default=False)
row_start: int | None = Field(default=None, ge=1) row_start: int | None = Field(default=None, ge=1)
row_end: int | None = Field(default=None, ge=1) row_end: int | None = Field(default=None, ge=1)
...@@ -80,20 +102,23 @@ class PackageCreateRequest(BaseModel): ...@@ -80,20 +102,23 @@ class PackageCreateRequest(BaseModel):
class PreviewRequest(BaseModel): class PreviewRequest(BaseModel):
file_ids: list[int] = Field(default_factory=list) file_ids: list[int] = Field(default_factory=list)
clean_rules: CleanRules | None = Field(default=None) clean_rules: CleanRules | None = Field(default=None)
smooth: SmoothConfig | None = Field(default=None)
row_start: int | None = Field(default=None, ge=1) row_start: int | None = Field(default=None, ge=1)
row_end: int | None = Field(default=None, ge=1) row_end: int | None = Field(default=None, ge=1)
@router.post('/preview') @router.post('/preview')
def preview_package(request: PreviewRequest, limit: int = Query(default=300, ge=1, le=2000)): def preview_package(request: PreviewRequest):
try: try:
if request.row_start and request.row_end and request.row_start > request.row_end: if request.row_start and request.row_end and request.row_start > request.row_end:
raise ValueError('起始行不能大于结束行') raise ValueError('起始行不能大于结束行')
clean_rules = request.clean_rules.model_dump() if request.clean_rules else None clean_rules = request.clean_rules.model_dump() if request.clean_rules else None
smooth = request.smooth.model_dump() if request.smooth else None
result = service.preview_records( result = service.preview_records(
file_ids=request.file_ids, file_ids=request.file_ids,
limit=limit, limit=None,
clean_rules=clean_rules, clean_rules=clean_rules,
smooth=smooth,
row_start=request.row_start, row_start=request.row_start,
row_end=request.row_end, row_end=request.row_end,
) )
...@@ -117,12 +142,27 @@ def create_package(request: PackageCreateRequest): ...@@ -117,12 +142,27 @@ def create_package(request: PackageCreateRequest):
raise ValueError('起始行不能大于结束行') raise ValueError('起始行不能大于结束行')
category_id = None if request.category_id in (None, '', 'all') else str(request.category_id) category_id = None if request.category_id in (None, '', 'all') else str(request.category_id)
clean_rules = request.clean_rules.model_dump() if request.clean_rules else None clean_rules = request.clean_rules.model_dump() if request.clean_rules else None
smooth = request.smooth.model_dump() if request.smooth else None
base_name = request.name.strip()
if request.auto_split:
pkgs = service.create_package_split(
name=base_name,
category_id=category_id,
remark=request.remark,
file_ids=request.file_ids,
clean_rules=clean_rules,
smooth=smooth,
row_start=request.row_start,
row_end=request.row_end,
)
return success_response(data=pkgs, message='数据包创建成功,已自动划分训练集/验证集/测试集')
pkg = service.create_package( pkg = service.create_package(
name=request.name.strip(), name=base_name,
category_id=category_id, category_id=category_id,
remark=request.remark, remark=request.remark,
file_ids=request.file_ids, file_ids=request.file_ids,
clean_rules=clean_rules, clean_rules=clean_rules,
smooth=smooth,
row_start=request.row_start, row_start=request.row_start,
row_end=request.row_end, row_end=request.row_end,
) )
...@@ -132,14 +172,29 @@ def create_package(request: PackageCreateRequest): ...@@ -132,14 +172,29 @@ def create_package(request: PackageCreateRequest):
@router.get('/{package_id}/records') @router.get('/{package_id}/records')
def get_package_records(package_id: str, limit: int = Query(default=500, ge=1, le=5000)): def get_package_records(package_id: str):
try: try:
result = service.get_package_records(package_id=package_id, limit=limit) result = service.get_package_records(package_id=package_id, limit=None)
return success_response(data=result) return success_response(data=result)
except ValueError as error: except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error raise HTTPException(status_code=400, detail=str(error)) from error
@router.put('/{package_id}')
def update_package(package_id: str, request: PackageUpdateRequest):
try:
category_id = None if request.category_id in (None, '', 'all') else str(request.category_id)
result = service.update_package(
package_id=package_id,
name=request.name.strip(),
category_id=category_id,
remark=request.remark,
)
return success_response(data=result, message='数据包更新成功')
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
@router.delete('/{package_id}') @router.delete('/{package_id}')
def delete_package(package_id: str): def delete_package(package_id: str):
try: try:
......
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.services.train_service import TrainService from app.services.train_service import TrainService
...@@ -21,15 +21,29 @@ class LSTMParams(BaseModel): ...@@ -21,15 +21,29 @@ class LSTMParams(BaseModel):
class CreateTaskRequest(BaseModel): class CreateTaskRequest(BaseModel):
model_name: str = Field(min_length=1, max_length=255) model_name: str = Field(min_length=1, max_length=255)
package_id: int train_package_id: int
val_package_id: int
params: LSTMParams = Field(default_factory=LSTMParams) params: LSTMParams = Field(default_factory=LSTMParams)
class SaveModelRequest(BaseModel):
model_name: str | None = Field(default=None, max_length=255)
description: str | None = None
class UpdateModelRequest(BaseModel):
model_name: str | None = Field(default=None, max_length=255)
description: str | None = None
# ── packages ────────────────────────────────────────────────────────────────── # ── packages ──────────────────────────────────────────────────────────────────
@router.get('/packages') @router.get('/packages')
def list_packages(): def list_packages(
return success_response(data=service.list_packages()) category_id: str = Query(default=''),
name: str = Query(default=''),
):
return success_response(data=service.list_packages(category_id=category_id, name=name.strip()))
# ── tasks ───────────────────────────────────────────────────────────────────── # ── tasks ─────────────────────────────────────────────────────────────────────
...@@ -52,7 +66,8 @@ def create_task(request: CreateTaskRequest): ...@@ -52,7 +66,8 @@ def create_task(request: CreateTaskRequest):
try: try:
task = service.create_task( task = service.create_task(
model_name=request.model_name.strip(), model_name=request.model_name.strip(),
package_id=request.package_id, train_package_id=request.train_package_id,
val_package_id=request.val_package_id,
params=request.params.model_dump(), params=request.params.model_dump(),
) )
return success_response(data=task, message='训练任务已启动') return success_response(data=task, message='训练任务已启动')
...@@ -88,9 +103,9 @@ def delete_task(task_id: int): ...@@ -88,9 +103,9 @@ def delete_task(task_id: int):
@router.post('/tasks/{task_id}/save') @router.post('/tasks/{task_id}/save')
def save_model(task_id: int): def save_model(task_id: int, request: SaveModelRequest):
try: try:
model = service.save_model(task_id) model = service.save_model(task_id, model_name=request.model_name, description=request.description)
return success_response(data=model, message='模型已保存') return success_response(data=model, message='模型已保存')
except ValueError as error: except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error raise HTTPException(status_code=400, detail=str(error)) from error
...@@ -110,3 +125,15 @@ def delete_model(model_id: int): ...@@ -110,3 +125,15 @@ def delete_model(model_id: int):
return success_response(data=True, message='模型已删除') return success_response(data=True, message='模型已删除')
except ValueError as error: except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error raise HTTPException(status_code=400, detail=str(error)) from error
@router.patch('/models/{model_id}')
def update_model(model_id: int, request: UpdateModelRequest):
try:
model = service.update_saved_model(
model_id,
model_name=request.model_name,
description=request.description,
)
return success_response(data=model, message='已更新')
except ValueError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
\ No newline at end of file
...@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException ...@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import text
from app.api.data_management import router as data_management_router from app.api.data_management import router as data_management_router
from app.api.eval_management import router as eval_management_router from app.api.eval_management import router as eval_management_router
...@@ -16,10 +17,29 @@ from app.models.train_management import SavedModel, TrainTask # noqa: F401 ...@@ -16,10 +17,29 @@ from app.models.train_management import SavedModel, TrainTask # noqa: F401
from app.utils.response import error_response, success_response from app.utils.response import error_response, success_response
def _run_migrations() -> None:
"""Add new columns to existing tables without dropping data."""
migrations = [
"ALTER TABLE train_tasks ADD COLUMN val_package_id BIGINT NULL COMMENT '\u9a8c\u8bc1\u96c6\u6570\u636e\u5305ID' AFTER package_name",
"ALTER TABLE train_tasks ADD COLUMN val_package_name VARCHAR(255) NULL COMMENT '\u9a8c\u8bc1\u96c6\u6570\u636e\u5305\u540d\u79f0' AFTER val_package_id",
"ALTER TABLE train_tasks ADD COLUMN epoch_logs JSON NULL COMMENT '\u6bcf\u8f6e\u8bad\u7ec3\u65e5\u5fd7' AFTER val_package_name",
]
with engine.connect() as conn:
for stmt in migrations:
try:
conn.execute(text(stmt))
conn.commit()
except Exception:
pass # Column already exists
def create_app() -> FastAPI: def create_app() -> FastAPI:
# Auto-create any missing tables (safe: uses CREATE TABLE IF NOT EXISTS internally) # Auto-create any missing tables (safe: uses CREATE TABLE IF NOT EXISTS internally)
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
# Safe column migrations for existing tables
_run_migrations()
app = FastAPI( app = FastAPI(
title='Thermal Control System API', title='Thermal Control System API',
version='0.1.0', version='0.1.0',
......
...@@ -96,7 +96,22 @@ def predict_lstm( ...@@ -96,7 +96,22 @@ def predict_lstm(
# ── metrics ─────────────────────────────────────────────────────────────── # ── metrics ───────────────────────────────────────────────────────────────
errors = preds_real - actuals_real errors = preds_real - actuals_real
mae = float(np.mean(np.abs(errors))) mae = float(np.mean(np.abs(errors)))
rmse = float(math.sqrt(float(np.mean(errors ** 2)))) mse = float(np.mean(errors ** 2))
rmse = float(math.sqrt(mse))
# MAPE — skip points where actual == 0 to avoid division by zero
nonzero_mask = actuals_real != 0
if nonzero_mask.any():
mape: float | None = float(
np.mean(np.abs(errors[nonzero_mask] / actuals_real[nonzero_mask])) * 100
)
else:
mape = None
# R² (coefficient of determination)
ss_res = float(np.sum(errors ** 2))
ss_tot = float(np.sum((actuals_real - float(np.mean(actuals_real))) ** 2))
r2: float | None = (1.0 - ss_res / ss_tot) if ss_tot != 0 else None
# ── build result points ─────────────────────────────────────────────────── # ── build result points ───────────────────────────────────────────────────
times = [str(r.get('time', '')) for r in records] times = [str(r.get('time', '')) for r in records]
...@@ -115,6 +130,8 @@ def predict_lstm( ...@@ -115,6 +130,8 @@ def predict_lstm(
'time': times[seq_len + i] if (seq_len + i) < len(times) else str(seq_len + i), 'time': times[seq_len + i] if (seq_len + i) < len(times) else str(seq_len + i),
'actual': round(float(actuals_real[i]), 4), 'actual': round(float(actuals_real[i]), 4),
'predicted': round(float(preds_real[i]), 4), 'predicted': round(float(preds_real[i]), 4),
'current': round(float(raw[seq_len + i, 0]), 4),
'voltage': round(float(raw[seq_len + i, 1]), 4),
} }
for i in sample_idx for i in sample_idx
] ]
...@@ -122,6 +139,9 @@ def predict_lstm( ...@@ -122,6 +139,9 @@ def predict_lstm(
return { return {
'total_count': total, 'total_count': total,
'mae': round(mae, 6), 'mae': round(mae, 6),
'mse': round(mse, 6),
'rmse': round(rmse, 6), 'rmse': round(rmse, 6),
'mape': round(mape, 4) if mape is not None else None,
'r2': round(r2, 6) if r2 is not None else None,
'chart_data': chart_data, 'chart_data': chart_data,
} }
...@@ -44,11 +44,6 @@ FEATURE_COLS = ['current', 'voltage', 'set_temperature', 'actual_temperature'] ...@@ -44,11 +44,6 @@ FEATURE_COLS = ['current', 'voltage', 'set_temperature', 'actual_temperature']
TARGET_COL = 'actual_temperature' TARGET_COL = 'actual_temperature'
TARGET_IDX = FEATURE_COLS.index(TARGET_COL) TARGET_IDX = FEATURE_COLS.index(TARGET_COL)
# Fixed dataset split ratios (train / val / test)
_TRAIN_RATIO = 0.70
_VAL_RATIO = 0.15
# test = 1 - _TRAIN_RATIO - _VAL_RATIO (≈ 0.15)
def _check_torch() -> None: def _check_torch() -> None:
if not _TORCH_AVAILABLE: if not _TORCH_AVAILABLE:
...@@ -88,12 +83,137 @@ def _make_sequences(data: np.ndarray, seq_len: int) -> tuple[np.ndarray, np.ndar ...@@ -88,12 +83,137 @@ def _make_sequences(data: np.ndarray, seq_len: int) -> tuple[np.ndarray, np.ndar
# ── public training entry point ─────────────────────────────────────────────── # ── public training entry point ───────────────────────────────────────────────
def train_lstm( def train_lstm(
records: list[dict], train_records: list[dict],
val_records: list[dict],
params: dict, params: dict,
save_path: Path, save_path: Path,
on_progress: Callable[[int, float, float | None], None], on_progress: Callable[[int, int, float, float | None], None],
cancel_event: threading.Event, cancel_event: threading.Event,
) -> dict[str, float | None]: ) -> dict[str, float | None]:
"""
Train an LSTM model on *train_records* validated against *val_records*.
Args:
train_records: list of dicts for the training set.
val_records: list of dicts for the validation set.
params: hyper-parameter dict (seq_len, hidden_size, num_layers,
epochs, batch_size, learning_rate).
save_path: destination .pt file.
on_progress: callback(pct, epoch, train_loss, val_loss) called after each epoch.
cancel_event: when set, training stops with InterruptedError.
Returns:
{'train_loss': float, 'val_loss': float|None}
"""
_check_torch()
seq_len = max(1, int(params.get('seq_len', 20)))
hidden_size = max(1, int(params.get('hidden_size', 64)))
num_layers = max(1, int(params.get('num_layers', 2)))
epochs = max(1, int(params.get('epochs', 50)))
batch_size = max(1, int(params.get('batch_size', 32)))
lr = float(params.get('learning_rate', 0.001))
# ── data preparation ────────────────────────────────────────────────────
train_data = _extract_features(train_records)
val_data = _extract_features(val_records)
min_required = seq_len + 10
if len(train_data) < min_required:
raise ValueError(
f'\u8bad\u7ec3\u96c6\u6709\u6548\u6570\u636e\u91cf\u4e0d\u8db3\uff1a\u9700\u81f3\u5c11 {min_required} \u6761\uff0c\u5f53\u524d\u4ec5 {len(train_data)} \u6761\u3002'
'\u8bf7\u68c0\u67e5\u6570\u636e\u5305\u5185\u5bb9\u6216\u51cf\u5c0f\u5e8f\u5217\u957f\u5ea6\u3002'
)
# min-max normalisation fitted on train set only
data_min = train_data.min(axis=0)
data_max = train_data.max(axis=0)
data_range = data_max - data_min
data_range[data_range == 0] = 1.0
train_norm = (train_data - data_min) / data_range
X_train, y_train = _make_sequences(train_norm, seq_len)
has_val = len(val_data) >= seq_len + 1
X_val_t = y_val_t = None
if has_val:
val_norm = (val_data - data_min) / data_range
X_val, y_val = _make_sequences(val_norm, seq_len)
if len(X_val) == 0:
has_val = False
else:
X_val_t = torch.tensor(X_val)
y_val_t = torch.tensor(y_val)
device = torch.device('cpu')
X_train_t = torch.tensor(X_train).to(device)
y_train_t = torch.tensor(y_train).to(device)
train_loader = DataLoader(
TensorDataset(X_train_t, y_train_t),
batch_size=batch_size,
shuffle=True,
)
# ── model ────────────────────────────────────────────────────────────────
input_size = len(FEATURE_COLS)
model = _LSTMModel(input_size, hidden_size, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
train_loss = 0.0
val_loss: float | None = None
for epoch in range(epochs):
if cancel_event.is_set():
raise InterruptedError('\u8bad\u7ec3\u5df2\u53d6\u6d88')
# ── train step ───────────────────────────────────────────────────────
model.train()
epoch_loss = 0.0
for xb, yb in train_loader:
optimizer.zero_grad()
pred = model(xb)
loss = criterion(pred, yb)
loss.backward()
optimizer.step()
epoch_loss += loss.item() * len(xb)
train_loss = epoch_loss / len(X_train)
# ── val step ─────────────────────────────────────────────────────────
if has_val:
model.eval()
with torch.no_grad():
val_pred = model(X_val_t)
val_loss = criterion(val_pred, y_val_t).item()
pct = int((epoch + 1) / epochs * 100)
on_progress(pct, epoch + 1, train_loss, val_loss)
# ── persist ──────────────────────────────────────────────────────────────
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
'model_state': model.state_dict(),
'params': params,
'data_min': data_min.tolist(),
'data_max': data_max.tolist(),
'feature_cols': FEATURE_COLS,
'target_col': TARGET_COL,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'seq_len': seq_len,
},
save_path,
)
return {
'train_loss': round(float(train_loss), 6),
'val_loss': round(float(val_loss), 6) if val_loss is not None else None,
}
""" """
Train an LSTM model on *records* and persist it to *save_path*. Train an LSTM model on *records* and persist it to *save_path*.
......
...@@ -19,7 +19,10 @@ class EvalRecord(Base): ...@@ -19,7 +19,10 @@ class EvalRecord(Base):
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称') package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称')
total_count: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text('0'), comment='评估数据点总数') total_count: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text('0'), comment='评估数据点总数')
mae: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='平均绝对误差') mae: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='平均绝对误差')
mse: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='均方误差')
rmse: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='均方根误差') rmse: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='均方根误差')
mape: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='平均绝对百分比误差(%)')
r2: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='决定系数')
# Store up to ~2000 sampled points for chart rendering # Store up to ~2000 sampled points for chart rendering
chart_data: Mapped[list | None] = mapped_column(JSON, nullable=True, comment='图表数据(采样)') chart_data: Mapped[list | None] = mapped_column(JSON, nullable=True, comment='图表数据(采样)')
created_at: Mapped[str] = mapped_column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) created_at: Mapped[str] = mapped_column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'))
...@@ -15,8 +15,10 @@ class MonitorExperiment(Base): ...@@ -15,8 +15,10 @@ class MonitorExperiment(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False, comment='试验名称') name: Mapped[str] = mapped_column(String(255), nullable=False, comment='试验名称')
model_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='模型ID') model_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='模型ID')
model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称') model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称')
package_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='初始数据包ID') package_id: Mapped[int | None] = mapped_column(BIGINT, nullable=True, comment='初始数据包ID(旧版兼容)')
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称') package_name: Mapped[str | None] = mapped_column(String(255), nullable=True, comment='数据包名称(旧版兼容)')
input_csv_path: Mapped[str | None] = mapped_column(Text, nullable=True, comment='输入CSV路径(传感器数据源)')
output_csv_path: Mapped[str | None] = mapped_column(Text, nullable=True, comment='输出CSV路径(生成曲线写入)')
target_temp: Mapped[float] = mapped_column(FLOAT, nullable=False, comment='目标温度(°C)') target_temp: Mapped[float] = mapped_column(FLOAT, nullable=False, comment='目标温度(°C)')
mpc_params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='MPC参数') mpc_params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='MPC参数')
status: Mapped[str] = mapped_column( status: Mapped[str] = mapped_column(
......
...@@ -13,8 +13,11 @@ class TrainTask(Base): ...@@ -13,8 +13,11 @@ class TrainTask(Base):
id: Mapped[int] = mapped_column(BIGINT, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(BIGINT, primary_key=True, autoincrement=True)
model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称') model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称')
package_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='数据包ID') package_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='训练集数据包ID')
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称') package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='训练集数据包名称')
val_package_id: Mapped[int | None] = mapped_column(BIGINT, nullable=True, comment='验证集数据包ID')
val_package_name: Mapped[str | None] = mapped_column(String(255), nullable=True, comment='验证集数据包名称')
epoch_logs: Mapped[list | None] = mapped_column(JSON, nullable=True, comment='每轮训练日志')
params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数') params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数')
status: Mapped[str] = mapped_column( status: Mapped[str] = mapped_column(
Enum('pending', 'running', 'completed', 'failed', 'cancelled', name='train_status_enum'), Enum('pending', 'running', 'completed', 'failed', 'cancelled', name='train_status_enum'),
...@@ -50,6 +53,7 @@ class SavedModel(Base): ...@@ -50,6 +53,7 @@ class SavedModel(Base):
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称') package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称')
params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数') params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数')
file_path: Mapped[str] = mapped_column(String(500), nullable=False, comment='模型文件路径') file_path: Mapped[str] = mapped_column(String(500), nullable=False, comment='模型文件路径')
description: Mapped[str | None] = mapped_column(Text, nullable=True, comment='模型说明')
train_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True) train_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True)
val_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True) val_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True)
test_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True) test_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True)
......
...@@ -152,6 +152,7 @@ class DataManagementService: ...@@ -152,6 +152,7 @@ class DataManagementService:
'category_id': item.category_id, 'category_id': item.category_id,
'uploaded_at': item.uploaded_at.strftime('%Y-%m-%d %H:%M:%S') if item.uploaded_at else '', 'uploaded_at': item.uploaded_at.strftime('%Y-%m-%d %H:%M:%S') if item.uploaded_at else '',
'data_count': item.data_count, 'data_count': item.data_count,
'remark': item.remark or '',
} }
for item in rows for item in rows
] ]
...@@ -212,6 +213,46 @@ class DataManagementService: ...@@ -212,6 +213,46 @@ class DataManagementService:
target_path.unlink() target_path.unlink()
raise raise
def update_file(self, file_id: str, filename: str | None = None, remark: str | None = None, category_id: str | None = None) -> dict[str, Any]:
file_db_id = self._parse_int_id(file_id, '文件ID')
if category_id is not None:
cat_db_id = self._parse_int_id(category_id, '分类ID')
with db_session() as session:
matched = session.query(DataFile).filter(DataFile.id == file_db_id).first()
if not matched:
raise ValueError('文件不存在')
if filename is not None:
stripped = filename.strip()
if not stripped:
raise ValueError('文件名不能为空')
matched.filename = stripped
if remark is not None:
matched.remark = remark.strip() or None
if category_id is not None:
category = (
session.query(Category)
.filter(Category.id == cat_db_id, Category.type == 'data_file')
.first()
)
if not category:
raise ValueError('分类不存在')
matched.category_id = cat_db_id
session.commit()
session.refresh(matched)
return {
'id': matched.id,
'filename': matched.filename,
'remark': matched.remark or '',
'category_id': matched.category_id,
}
def delete_file(self, file_id: str) -> None: def delete_file(self, file_id: str) -> None:
file_db_id = self._parse_int_id(file_id, '文件ID') file_db_id = self._parse_int_id(file_id, '文件ID')
...@@ -227,7 +268,7 @@ class DataManagementService: ...@@ -227,7 +268,7 @@ class DataManagementService:
session.delete(matched) session.delete(matched)
session.commit() session.commit()
def get_file_records(self, file_id: str, limit: int = 500) -> dict[str, Any]: def get_file_records(self, file_id: str, limit: int | None = None) -> dict[str, Any]:
file_db_id = self._parse_int_id(file_id, '文件ID') file_db_id = self._parse_int_id(file_id, '文件ID')
with db_session() as session: with db_session() as session:
......
...@@ -5,7 +5,7 @@ from typing import Any ...@@ -5,7 +5,7 @@ from typing import Any
from app.database import db_session from app.database import db_session
from app.ml.lstm_predictor import predict_lstm from app.ml.lstm_predictor import predict_lstm
from app.models import DataFile, DataPackage, DataPackageFile from app.models import DataPackage
from app.models.eval_management import EvalRecord from app.models.eval_management import EvalRecord
from app.models.train_management import SavedModel from app.models.train_management import SavedModel
from app.services.data_management_service import DataManagementService from app.services.data_management_service import DataManagementService
...@@ -24,10 +24,19 @@ class EvalService: ...@@ -24,10 +24,19 @@ class EvalService:
# ── dropdown data ───────────────────────────────────────────────────────── # ── dropdown data ─────────────────────────────────────────────────────────
def list_packages(self) -> list[dict[str, Any]]: def list_packages(self, category_id: str = '', name: str = '') -> list[dict[str, Any]]:
with db_session() as session: with db_session() as session:
rows = session.query(DataPackage).order_by(DataPackage.created_at.desc()).all() query = session.query(DataPackage)
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count} for p in rows] if category_id not in ('', 'all', None):
try:
db_id = int(category_id)
query = query.filter(DataPackage.category_id == db_id)
except (ValueError, TypeError):
pass
if name:
query = query.filter(DataPackage.name.like(f'%{name}%'))
rows = query.order_by(DataPackage.created_at.desc()).all()
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count, 'category_id': p.category_id} for p in rows]
def list_saved_models(self) -> list[dict[str, Any]]: def list_saved_models(self) -> list[dict[str, Any]]:
with db_session() as session: with db_session() as session:
...@@ -67,7 +76,10 @@ class EvalService: ...@@ -67,7 +76,10 @@ class EvalService:
package_name=package_name, package_name=package_name,
total_count=result['total_count'], total_count=result['total_count'],
mae=result['mae'], mae=result['mae'],
mse=result['mse'],
rmse=result['rmse'], rmse=result['rmse'],
mape=result['mape'],
r2=result['r2'],
chart_data=result['chart_data'], chart_data=result['chart_data'],
) )
session.add(record) session.add(record)
...@@ -106,32 +118,13 @@ class EvalService: ...@@ -106,32 +118,13 @@ class EvalService:
def _load_package_records(self, package_id: int) -> list[dict[str, Any]]: def _load_package_records(self, package_id: int) -> list[dict[str, Any]]:
with db_session() as session: with db_session() as session:
pkg = session.query(DataPackage).filter(DataPackage.id == package_id).first() pkg = session.query(DataPackage).filter(DataPackage.id == package_id).first()
clean_rules = pkg.clean_rules if pkg else None if not pkg or not pkg.stored_name:
pf_rows = (
session.query(DataPackageFile)
.filter(DataPackageFile.package_id == package_id)
.order_by(DataPackageFile.sort_order.asc())
.all()
)
file_ids = [pf.file_id for pf in pf_rows]
if not file_ids:
return [] return []
files = session.query(DataFile).filter(DataFile.id.in_(file_ids)).all() stored_name = pkg.stored_name
file_map = {f.id: f for f in files}
all_records: list[dict[str, Any]] = []
for fid in file_ids:
if fid not in file_map:
continue
fmeta = file_map[fid]
path = self._dm._resolve_local_file_path(fmeta.file_path, fmeta.stored_name)
recs, _ = self._dm._read_records(path, limit=None)
all_records.extend(recs)
if clean_rules and clean_rules.get('enabled'):
all_records = self._apply_clean_rules(all_records, clean_rules)
return all_records pkg_path = self._base_dir / 'uploads' / 'packages' / stored_name
recs, _ = self._dm._read_records(pkg_path, limit=None)
return recs
@staticmethod @staticmethod
def _apply_clean_rules(records: list[dict[str, Any]], clean_rules: dict) -> list[dict[str, Any]]: def _apply_clean_rules(records: list[dict[str, Any]], clean_rules: dict) -> list[dict[str, Any]]:
...@@ -177,7 +170,10 @@ class EvalService: ...@@ -177,7 +170,10 @@ class EvalService:
'package_name': row.package_name, 'package_name': row.package_name,
'total_count': row.total_count, 'total_count': row.total_count,
'mae': row.mae, 'mae': row.mae,
'mse': row.mse,
'rmse': row.rmse, 'rmse': row.rmse,
'mape': row.mape,
'r2': row.r2,
'created_at': row.created_at.strftime('%Y-%m-%d %H:%M:%S') if row.created_at else '', 'created_at': row.created_at.strftime('%Y-%m-%d %H:%M:%S') if row.created_at else '',
} }
if include_chart: if include_chart:
......
...@@ -82,7 +82,8 @@ class MonitorService: ...@@ -82,7 +82,8 @@ class MonitorService:
self, self,
name: str, name: str,
model_id: int, model_id: int,
package_id: int, input_csv_path: str,
output_csv_path: str,
target_temp: float, target_temp: float,
sampling_interval: float, sampling_interval: float,
mpc_params: dict, mpc_params: dict,
...@@ -91,16 +92,19 @@ class MonitorService: ...@@ -91,16 +92,19 @@ class MonitorService:
model = session.query(SavedModel).filter(SavedModel.id == model_id).first() model = session.query(SavedModel).filter(SavedModel.id == model_id).first()
if not model: if not model:
raise ValueError('模型不存在') raise ValueError('模型不存在')
pkg = session.query(DataPackage).filter(DataPackage.id == package_id).first()
if not pkg: input_path = Path(input_csv_path)
raise ValueError('数据包不存在') if not input_path.exists():
raise ValueError(f'输入CSV文件不存在: {input_csv_path}')
exp = MonitorExperiment( exp = MonitorExperiment(
name=name, name=name,
model_id=model_id, model_id=model_id,
model_name=model.model_name, model_name=model.model_name,
package_id=package_id, package_id=None,
package_name=pkg.name, package_name=None,
input_csv_path=input_csv_path,
output_csv_path=output_csv_path,
target_temp=target_temp, target_temp=target_temp,
sampling_interval=sampling_interval, sampling_interval=sampling_interval,
mpc_params=mpc_params, mpc_params=mpc_params,
...@@ -183,97 +187,6 @@ class MonitorService: ...@@ -183,97 +187,6 @@ class MonitorService:
) )
return [self._point_to_dict(p) for p in rows] return [self._point_to_dict(p) for p in rows]
# ── 报告 ──────────────────────────────────────────────────────────────────
def get_report(self, exp_id: int) -> dict[str, Any]:
with db_session() as session:
exp = session.query(MonitorExperiment).filter(MonitorExperiment.id == exp_id).first()
if not exp:
raise ValueError('试验不存在')
rows = (
session.query(MonitorDataPoint)
.filter(MonitorDataPoint.experiment_id == exp_id)
.order_by(MonitorDataPoint.step_idx)
.all()
)
if not rows:
return {'experiment': self._exp_to_dict(exp), 'summary': None, 'points': []}
actuals = np.array([p.actual_temp for p in rows], dtype=np.float64)
references = np.array([p.reference_temp for p in rows], dtype=np.float64)
currents = np.array([p.current_output for p in rows], dtype=np.float64)
target = float(exp.target_temp)
errors = actuals - target
mae = float(np.mean(np.abs(errors)))
rmse = float(np.sqrt(np.mean(errors ** 2)))
# 超调量(仅升温场景)
if target > actuals[0]:
overshoot = max(0.0, float(np.max(actuals)) - target)
else:
overshoot = max(0.0, target - float(np.min(actuals)))
# 调节时间:首次进入 ±2°C 稳定带并持续 10 步
settling_step = None
band = 2.0
for i in range(len(actuals)):
if abs(actuals[i] - target) <= band:
end = min(i + 10, len(actuals))
if all(abs(actuals[j] - target) <= band * 1.5 for j in range(i, end)):
settling_step = int(i)
break
# 图表数据(最多 600 点)
n = len(rows)
step = max(1, n // 600)
chart = [self._point_to_dict(rows[i]) for i in range(0, n, step)]
return {
'experiment': self._exp_to_dict(exp),
'summary': {
'total_steps': n,
'duration_s': round(n * float(exp.sampling_interval or 1.0), 1),
'target_temp': target,
'initial_temp': round(float(actuals[0]), 3),
'final_temp': round(float(actuals[-1]), 3),
'mae': round(mae, 4),
'rmse': round(rmse, 4),
'overshoot': round(overshoot, 4),
'settling_step': settling_step,
'avg_current': round(float(np.mean(currents)), 4),
'max_current': round(float(np.max(currents)), 4),
},
'points': chart,
}
# ── 历史数据列表 ───────────────────────────────────────────────────────────
def list_history_experiments(self) -> list[dict[str, Any]]:
with db_session() as session:
rows = (
session.query(MonitorExperiment)
.filter(MonitorExperiment.exported == 1)
.order_by(MonitorExperiment.stop_time.desc())
.all()
)
return [self._exp_to_dict(e) for e in rows]
# ── 导出到历史数据 ─────────────────────────────────────────────────────────
def export_to_history(self, exp_id: int) -> dict[str, Any]:
with db_session() as session:
exp = session.query(MonitorExperiment).filter(MonitorExperiment.id == exp_id).first()
if not exp:
raise ValueError('试验不存在')
if exp.status == 'running':
raise ValueError('请先停止试验再导出')
exp.exported = 1
session.commit()
return {'exported': True}
# ── 控制线程 ─────────────────────────────────────────────────────────────── # ── 控制线程 ───────────────────────────────────────────────────────────────
def _simulation_worker(self, exp_id: int, cancel_event: threading.Event) -> None: def _simulation_worker(self, exp_id: int, cancel_event: threading.Event) -> None:
...@@ -293,7 +206,8 @@ class MonitorService: ...@@ -293,7 +206,8 @@ class MonitorService:
if not exp: if not exp:
return return
model_id = exp.model_id model_id = exp.model_id
package_id = exp.package_id input_csv_path = exp.input_csv_path
output_csv_path = exp.output_csv_path
target_temp = float(exp.target_temp) target_temp = float(exp.target_temp)
sampling_interval = float(exp.sampling_interval or 1.0) sampling_interval = float(exp.sampling_interval or 1.0)
mpc_params_dict = exp.mpc_params or {} mpc_params_dict = exp.mpc_params or {}
...@@ -334,12 +248,15 @@ class MonitorService: ...@@ -334,12 +248,15 @@ class MonitorService:
ctrl = MPCController.from_checkpoint(model_path, params) ctrl = MPCController.from_checkpoint(model_path, params)
seq_len = ctrl.predictor.seq_len seq_len = ctrl.predictor.seq_len
# ── 初始热身:从数据包读取最新记录填满缓冲区 ───────────────────── # ── 初始热身:从输入CSV读取最新记录填满缓冲区 ─────────────────────
init_records = self._load_package_records_tail(package_id, seq_len + 10) if not input_csv_path:
self._set_error(exp_id, '试验未配置输入CSV路径')
return
init_records = self._load_csv_records_tail(input_csv_path, seq_len + 10)
if len(init_records) < seq_len: if len(init_records) < seq_len:
self._set_error( self._set_error(
exp_id, exp_id,
f'数据包记录不足 {seq_len} 条(当前 {len(init_records)} 条),无法初始化', f'输入CSV记录不足 {seq_len} 条(当前 {len(init_records)} 条),无法初始化',
) )
return return
...@@ -358,8 +275,8 @@ class MonitorService: ...@@ -358,8 +275,8 @@ class MonitorService:
if cancel_event.is_set(): if cancel_event.is_set():
break break
# 1. 从数据包文件重新读取最新 seq_len 条真实传感器记录 # 1. 从输入CSV重新读取最新 seq_len 条真实传感器记录
fresh_records = self._load_package_records_tail(package_id, seq_len) fresh_records = self._load_csv_records_tail(input_csv_path, seq_len)
if len(fresh_records) < seq_len: if len(fresh_records) < seq_len:
# 数据还未更新到足够条数,跳过本步等待 # 数据还未更新到足够条数,跳过本步等待
cancel_event.wait(sampling_interval) cancel_event.wait(sampling_interval)
...@@ -418,7 +335,11 @@ class MonitorService: ...@@ -418,7 +335,11 @@ class MonitorService:
).update({'total_steps': step + 1}) ).update({'total_steps': step + 1})
session.commit() session.commit()
# 10. 等待采样周期(可被 stop 信号提前中断) # 10. 写入输出 CSV(每步追加一行)
if output_csv_path:
self._append_output_row(output_csv_path, step, y_k, u_cmd, w_k, target_temp)
# 11. 等待采样周期(可被 stop 信号提前中断)
cancel_event.wait(sampling_interval) cancel_event.wait(sampling_interval)
# ── 正常结束(达到最大步数) ────────────────────────────────────── # ── 正常结束(达到最大步数) ──────────────────────────────────────
...@@ -547,6 +468,47 @@ class MonitorService: ...@@ -547,6 +468,47 @@ class MonitorService:
'[exp=%d] 程控电源指令 → %.4f A', exp_id, u_cmd '[exp=%d] 程控电源指令 → %.4f A', exp_id, u_cmd
) )
def _load_csv_records_tail(self, csv_path: str, n: int) -> list[dict[str, Any]]:
"""从 CSV 文件加载最后 n 条有效记录,模拟传感器实时写入场景。"""
from app.services.data_management_service import DataManagementService
dm = DataManagementService()
path = Path(csv_path)
if not path.exists():
return []
records, _ = dm._read_records(path, limit=None)
valid = [
r for r in records
if r.get('actual_temperature') is not None
and float(r.get('actual_temperature', 9999)) < 9000
]
return valid[-n:] if len(valid) >= n else valid
@staticmethod
def _append_output_row(
csv_path: str,
step: int,
actual_temp: float,
current_output: float,
reference_temp: float,
target_temp: float,
) -> None:
"""将每步控制结果追加写入输出 CSV 文件。"""
import csv as csv_module
path = Path(csv_path)
path.parent.mkdir(parents=True, exist_ok=True)
write_header = not path.exists() or path.stat().st_size == 0
with path.open('a', newline='', encoding='utf-8-sig') as f:
writer = csv_module.writer(f)
if write_header:
writer.writerow(['step', 'actual_temp', 'current_output', 'reference_temp', 'target_temp'])
writer.writerow([
step,
round(actual_temp, 4),
round(current_output, 4),
round(reference_temp, 4),
round(target_temp, 4),
])
@staticmethod @staticmethod
def _exp_to_dict(exp: MonitorExperiment) -> dict[str, Any]: def _exp_to_dict(exp: MonitorExperiment) -> dict[str, Any]:
return { return {
...@@ -554,8 +516,8 @@ class MonitorService: ...@@ -554,8 +516,8 @@ class MonitorService:
'name': exp.name, 'name': exp.name,
'model_id': exp.model_id, 'model_id': exp.model_id,
'model_name': exp.model_name, 'model_name': exp.model_name,
'package_id': exp.package_id, 'input_csv_path': exp.input_csv_path,
'package_name': exp.package_name, 'output_csv_path': exp.output_csv_path,
'target_temp': exp.target_temp, 'target_temp': exp.target_temp,
'sampling_interval': exp.sampling_interval, 'sampling_interval': exp.sampling_interval,
'mpc_params': exp.mpc_params, 'mpc_params': exp.mpc_params,
......
import bisect
import csv import csv
import io import io
from datetime import datetime from datetime import datetime
...@@ -117,11 +118,30 @@ class PackageManagementService: ...@@ -117,11 +118,30 @@ class PackageManagementService:
# ── data files (for selection) ─────────────────────────────────────────── # ── data files (for selection) ───────────────────────────────────────────
def list_all_data_files(self) -> list[dict[str, Any]]: def list_all_data_files(
self,
category_id: str = '',
filename: str = '',
remark: str = '',
) -> list[dict[str, Any]]:
with db_session() as session: with db_session() as session:
rows = ( query = (
session.query(DataFile, Category) session.query(DataFile, Category)
.outerjoin(Category, DataFile.category_id == Category.id) .outerjoin(Category, DataFile.category_id == Category.id)
)
if category_id not in ('', 'all', None):
db_id = self._parse_int_id(category_id, '分类ID')
query = query.filter(DataFile.category_id == db_id)
if filename:
query = query.filter(DataFile.filename.like(f'%{filename}%'))
if remark:
query = query.filter(DataFile.remark.like(f'%{remark}%'))
rows = (
query
.order_by(Category.name.asc(), DataFile.uploaded_at.desc(), DataFile.id.desc()) .order_by(Category.name.asc(), DataFile.uploaded_at.desc(), DataFile.id.desc())
.all() .all()
) )
...@@ -132,6 +152,7 @@ class PackageManagementService: ...@@ -132,6 +152,7 @@ class PackageManagementService:
'category_id': f.category_id, 'category_id': f.category_id,
'category_name': c.name if c else '', 'category_name': c.name if c else '',
'data_count': f.data_count, 'data_count': f.data_count,
'remark': f.remark or '',
'uploaded_at': f.uploaded_at.strftime('%Y-%m-%d %H:%M:%S') if f.uploaded_at else '', 'uploaded_at': f.uploaded_at.strftime('%Y-%m-%d %H:%M:%S') if f.uploaded_at else '',
} }
for f, c in rows for f, c in rows
...@@ -160,6 +181,7 @@ class PackageManagementService: ...@@ -160,6 +181,7 @@ class PackageManagementService:
remark: str | None, remark: str | None,
file_ids: list[int], file_ids: list[int],
clean_rules: dict | None = None, clean_rules: dict | None = None,
smooth: dict | None = None,
row_start: int | None = None, row_start: int | None = None,
row_end: int | None = None, row_end: int | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
...@@ -194,10 +216,15 @@ class PackageManagementService: ...@@ -194,10 +216,15 @@ class PackageManagementService:
merged = self._merge_records( merged = self._merge_records(
file_ids, file_map, limit=None, file_ids, file_map, limit=None,
clean_rules=clean_rules, clean_rules=clean_rules,
smooth=smooth,
row_start=row_start, row_end=row_end, row_start=row_start, row_end=row_end,
) )
total_count = merged['count'] total_count = merged['count']
stored_rules: dict | None = clean_rules if (clean_rules and clean_rules.get('enabled')) else None stored_rules: dict | None = clean_rules if (clean_rules and clean_rules.get('enabled')) else None
if smooth and smooth.get('enabled'):
if stored_rules is None:
stored_rules = {}
stored_rules['smooth'] = smooth
pkg = DataPackage( pkg = DataPackage(
name=name, name=name,
...@@ -228,6 +255,168 @@ class PackageManagementService: ...@@ -228,6 +255,168 @@ class PackageManagementService:
session.refresh(pkg) session.refresh(pkg)
return self._pkg_to_dict(pkg) return self._pkg_to_dict(pkg)
def create_package_split(
self,
name: str,
category_id: str | None,
remark: str | None,
file_ids: list[int],
clean_rules: dict | None = None,
smooth: dict | None = None,
row_start: int | None = None,
row_end: int | None = None,
) -> list[dict[str, Any]]:
"""按 70/15/15 比例将数据拆分为训练集/验证集/测试集三个数据包。"""
if not name:
raise ValueError('数据包名称不能为空')
if not file_ids:
raise ValueError('请至少选择一个数据文件')
cat_db_id: int | None = None
if category_id and str(category_id).strip().lower() not in {'', 'all', 'none', 'null'}:
cat_db_id = self._parse_int_id(str(category_id), '分类ID')
with db_session() as session:
if cat_db_id is not None:
cat = (
session.query(Category)
.filter(Category.id == cat_db_id, Category.type == 'data_package')
.first()
)
if not cat:
raise ValueError('数据包分类不存在')
files = session.query(DataFile).filter(DataFile.id.in_(file_ids)).all()
if len(files) != len(file_ids):
raise ValueError('部分数据文件不存在')
file_map = {f.id: f for f in files}
# 先完整合并(含清洗/平滑/截取)
merged = self._merge_records(
file_ids, file_map, limit=None,
clean_rules=clean_rules,
smooth=smooth,
row_start=row_start, row_end=row_end,
)
all_records = merged['records']
total = len(all_records)
# 按 70/15/15 计算切分点
train_end = round(total * 0.70)
val_end = train_end + round(total * 0.15)
slices = [
(f'{name}-训练集', all_records[:train_end]),
(f'{name}-验证集', all_records[train_end:val_end]),
(f'{name}-测试集', all_records[val_end:]),
]
stored_rules: dict | None = clean_rules if (clean_rules and clean_rules.get('enabled')) else None
if smooth and smooth.get('enabled'):
if stored_rules is None:
stored_rules = {}
stored_rules['smooth'] = smooth
pkg_dir = self._dm._upload_dir / 'packages'
pkg_dir.mkdir(parents=True, exist_ok=True)
results: list[dict[str, Any]] = []
for split_name, records in slices:
pkg = DataPackage(
name=split_name,
category_id=cat_db_id,
remark=remark,
data_count=len(records),
clean_rules=stored_rules,
row_start=row_start,
row_end=row_end,
)
session.add(pkg)
session.flush()
ts = datetime.now().strftime('%Y%m%d%H%M%S%f')
stored_name = f'pkg_{pkg.id}_{ts}.csv'
self._write_records_csv(records, pkg_dir / stored_name)
pkg.stored_name = stored_name
pkg.file_path = f'/app/uploads/packages/{stored_name}'
for idx, fid in enumerate(file_ids):
pf = DataPackageFile(package_id=pkg.id, file_id=fid, sort_order=idx)
session.add(pf)
results.append(self._pkg_to_dict(pkg))
session.commit()
return results
def update_package(
self,
package_id: str,
name: str,
category_id: str | None,
remark: str | None,
) -> dict[str, Any]:
db_id = self._parse_int_id(package_id, '数据包ID')
cat_db_id: int | None = None
if category_id and str(category_id).strip().lower() not in {'', 'all', 'none', 'null'}:
cat_db_id = self._parse_int_id(str(category_id), '分类ID')
with db_session() as session:
pkg = session.query(DataPackage).filter(DataPackage.id == db_id).first()
if not pkg:
raise ValueError('数据包不存在')
if cat_db_id is not None:
cat = (
session.query(Category)
.filter(Category.id == cat_db_id, Category.type == 'data_package')
.first()
)
if not cat:
raise ValueError('数据包分类不存在')
pkg.name = name
pkg.category_id = cat_db_id
pkg.remark = remark
session.commit()
session.refresh(pkg)
return self._pkg_to_dict(pkg)
def update_package(
self,
package_id: str,
name: str,
category_id: str | None,
remark: str | None,
) -> dict[str, Any]:
db_id = self._parse_int_id(package_id, '数据包ID')
cat_db_id: int | None = None
if category_id and str(category_id).strip().lower() not in {'', 'all', 'none', 'null'}:
cat_db_id = self._parse_int_id(str(category_id), '分类ID')
with db_session() as session:
pkg = session.query(DataPackage).filter(DataPackage.id == db_id).first()
if not pkg:
raise ValueError('数据包不存在')
if cat_db_id is not None:
cat = (
session.query(Category)
.filter(Category.id == cat_db_id, Category.type == 'data_package')
.first()
)
if not cat:
raise ValueError('数据包分类不存在')
pkg.name = name
pkg.category_id = cat_db_id
pkg.remark = remark
session.commit()
session.refresh(pkg)
return self._pkg_to_dict(pkg)
def delete_package(self, package_id: str) -> None: def delete_package(self, package_id: str) -> None:
db_id = self._parse_int_id(package_id, '数据包ID') db_id = self._parse_int_id(package_id, '数据包ID')
with db_session() as session: with db_session() as session:
...@@ -244,7 +433,7 @@ class PackageManagementService: ...@@ -244,7 +433,7 @@ class PackageManagementService:
if pkg_file.exists(): if pkg_file.exists():
pkg_file.unlink() pkg_file.unlink()
def get_package_records(self, package_id: str, limit: int = 500) -> dict[str, Any]: def get_package_records(self, package_id: str, limit: int | None = None) -> dict[str, Any]:
db_id = self._parse_int_id(package_id, '数据包ID') db_id = self._parse_int_id(package_id, '数据包ID')
with db_session() as session: with db_session() as session:
...@@ -269,14 +458,16 @@ class PackageManagementService: ...@@ -269,14 +458,16 @@ class PackageManagementService:
return self._merge_records( return self._merge_records(
file_ids, file_map, limit, file_ids, file_map, limit,
clean_rules=stored_clean_rules, clean_rules=stored_clean_rules,
smooth=stored_clean_rules.get('smooth') if stored_clean_rules else None,
row_start=stored_row_start, row_end=stored_row_end, row_start=stored_row_start, row_end=stored_row_end,
) )
def preview_records( def preview_records(
self, self,
file_ids: list[int], file_ids: list[int],
limit: int = 300, limit: int | None = None,
clean_rules: dict | None = None, clean_rules: dict | None = None,
smooth: dict | None = None,
row_start: int | None = None, row_start: int | None = None,
row_end: int | None = None, row_end: int | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
...@@ -290,6 +481,7 @@ class PackageManagementService: ...@@ -290,6 +481,7 @@ class PackageManagementService:
return self._merge_records( return self._merge_records(
file_ids, file_map, limit, file_ids, file_map, limit,
clean_rules=clean_rules, clean_rules=clean_rules,
smooth=smooth,
row_start=row_start, row_end=row_end, row_start=row_start, row_end=row_end,
) )
...@@ -301,12 +493,14 @@ class PackageManagementService: ...@@ -301,12 +493,14 @@ class PackageManagementService:
file_map: dict[int, Any], file_map: dict[int, Any],
limit: int | None, limit: int | None,
clean_rules: dict | None = None, clean_rules: dict | None = None,
smooth: dict | None = None,
row_start: int | None = None, row_start: int | None = None,
row_end: int | None = None, row_end: int | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
use_filter = bool(clean_rules and clean_rules.get('enabled')) use_filter = bool(clean_rules and clean_rules.get('enabled'))
use_smooth = bool(smooth and smooth.get('enabled'))
use_range = row_start is not None or row_end is not None use_range = row_start is not None or row_end is not None
need_full_read = use_filter or use_range need_full_read = use_filter or use_range or use_smooth
all_records: list[dict[str, Any]] = [] all_records: list[dict[str, Any]] = []
total_count = 0 total_count = 0
remaining = limit remaining = limit
...@@ -337,6 +531,9 @@ class PackageManagementService: ...@@ -337,6 +531,9 @@ class PackageManagementService:
end_idx = row_end if row_end else len(all_records) end_idx = row_end if row_end else len(all_records)
all_records = all_records[start_idx:end_idx] all_records = all_records[start_idx:end_idx]
if use_smooth:
all_records = self._apply_smooth(all_records, smooth.get('window', 5))
total_count = len(all_records) total_count = len(all_records)
if limit is not None and limit > 0: if limit is not None and limit > 0:
...@@ -351,30 +548,122 @@ class PackageManagementService: ...@@ -351,30 +548,122 @@ class PackageManagementService:
v_max = clean_rules.get('voltage_max') v_max = clean_rules.get('voltage_max')
t_min = clean_rules.get('temperature_min') t_min = clean_rules.get('temperature_min')
t_max = clean_rules.get('temperature_max') t_max = clean_rules.get('temperature_max')
newton_interp = bool(clean_rules.get('newton_interp', False))
numeric_cols = ['current', 'voltage', 'set_temperature', 'actual_temperature']
def _is_out_cur(v):
if v is None:
return False
return (c_min is not None and v < c_min) or (c_max is not None and v > c_max)
def _is_out_vol(v):
if v is None:
return False
return (v_min is not None and v < v_min) or (v_max is not None and v > v_max)
def _is_out_tmp(v):
if v is None:
return False
return (t_min is not None and v < t_min) or (t_max is not None and v > t_max)
result = [] result: list[dict[str, Any]] = []
for r in records: for r in records:
current = r.get('current') cur_out = _is_out_cur(r.get('current'))
voltage = r.get('voltage') vol_out = _is_out_vol(r.get('voltage'))
temp = r.get('actual_temperature') tmp_out = _is_out_tmp(r.get('actual_temperature'))
if current is not None: if not (cur_out or vol_out or tmp_out):
if c_min is not None and current < c_min: result.append(r)
continue
if c_max is not None and current > c_max:
continue continue
if voltage is not None:
if v_min is not None and voltage < v_min: if not newton_interp:
# 直接剔除行
continue continue
if v_max is not None and voltage > v_max:
# 牛顿插值:保留行但将野值字段置 None,后续插值填补
row = dict(r)
if cur_out:
row['current'] = None
if vol_out:
row['voltage'] = None
if tmp_out:
row['actual_temperature'] = None
result.append(row)
if newton_interp:
# 对各数值列独立做牛顿插值
for col in numeric_cols:
col_vals = [row.get(col) for row in result]
filled = PackageManagementService._interpolate_column(col_vals)
for i, row in enumerate(result):
row[col] = filled[i]
return result
@staticmethod
def _interpolate_column(col: list) -> list:
"""用牛顿差商插值填补列中的 None 值(bisect 二分查找邻居,O(n log n))。"""
valid_pts = [(i, v) for i, v in enumerate(col) if v is not None]
if not valid_pts:
return col
valid_indices = [p[0] for p in valid_pts] # 有序索引列表,供二分查找
result = list(col)
for i, v in enumerate(col):
if v is not None:
continue continue
if temp is not None: # 二分定位:pos 是 i 在 valid_indices 中的插入位置
if t_min is not None and temp < t_min: pos = bisect.bisect_left(valid_indices, i)
before = valid_pts[max(0, pos - 3):pos]
after = valid_pts[pos:pos + 3]
points = before + after
if not points:
continue continue
if t_max is not None and temp > t_max: if len(points) == 1:
result[i] = points[0][1]
continue continue
xs = [float(p[0]) for p in points]
ys = [float(p[1]) for p in points]
result[i] = PackageManagementService._newton_eval(xs, ys, float(i))
return result
result.append(r) @staticmethod
def _newton_eval(xs: list[float], ys: list[float], x: float) -> float:
"""牛顿差商表求插值多项式在 x 处的值。"""
n = len(xs)
dd = list(ys)
coeffs = [dd[0]]
for j in range(1, n):
dd = [
(dd[k + 1] - dd[k]) / (xs[k + j] - xs[k])
for k in range(n - j)
]
coeffs.append(dd[0])
# 霍纳法求值
val = coeffs[-1]
for j in range(n - 2, -1, -1):
val = val * (x - xs[j]) + coeffs[j]
return round(val, 6)
@staticmethod
def _apply_smooth(records: list[dict[str, Any]], window: int) -> list[dict[str, Any]]:
"""滑动均值平滑(居中窗口),对数值列原地均值化。"""
if window < 2 or not records:
return records
numeric_cols = ['current', 'voltage', 'set_temperature', 'actual_temperature']
n = len(records)
result = [dict(r) for r in records]
half = window // 2
for col in numeric_cols:
vals = [r.get(col) for r in records]
for i in range(n):
start = max(0, i - half)
end = min(n, start + window)
if end - start < window:
start = max(0, end - window)
window_vals = [v for v in vals[start:end] if v is not None]
if window_vals:
result[i][col] = round(sum(window_vals) / len(window_vals), 6)
return result return result
def _pkg_to_dict(self, pkg: DataPackage) -> dict[str, Any]: def _pkg_to_dict(self, pkg: DataPackage) -> dict[str, Any]:
......
...@@ -29,14 +29,19 @@ class TrainService: ...@@ -29,14 +29,19 @@ class TrainService:
# ── packages (for dropdown) ────────────────────────────────────────────── # ── packages (for dropdown) ──────────────────────────────────────────────
def list_packages(self) -> list[dict[str, Any]]: def list_packages(self, category_id: str = '', name: str = '') -> list[dict[str, Any]]:
with db_session() as session: with db_session() as session:
rows = ( query = session.query(DataPackage)
session.query(DataPackage) if category_id not in ('', 'all', None):
.order_by(DataPackage.created_at.desc()) try:
.all() db_id = int(category_id)
) query = query.filter(DataPackage.category_id == db_id)
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count} for p in rows] except (ValueError, TypeError):
pass
if name:
query = query.filter(DataPackage.name.like(f'%{name}%'))
rows = query.order_by(DataPackage.created_at.desc()).all()
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count, 'category_id': p.category_id} for p in rows]
# ── tasks ──────────────────────────────────────────────────────────────── # ── tasks ────────────────────────────────────────────────────────────────
...@@ -52,16 +57,21 @@ class TrainService: ...@@ -52,16 +57,21 @@ class TrainService:
raise ValueError('训练任务不存在') raise ValueError('训练任务不存在')
return self._task_to_dict(task) return self._task_to_dict(task)
def create_task(self, model_name: str, package_id: int, params: dict) -> dict[str, Any]: def create_task(self, model_name: str, train_package_id: int, val_package_id: int, params: dict) -> dict[str, Any]:
with db_session() as session: with db_session() as session:
pkg = session.query(DataPackage).filter(DataPackage.id == package_id).first() train_pkg = session.query(DataPackage).filter(DataPackage.id == train_package_id).first()
if not pkg: if not train_pkg:
raise ValueError('数据包不存在') raise ValueError('训练集数据包不存在')
val_pkg = session.query(DataPackage).filter(DataPackage.id == val_package_id).first()
if not val_pkg:
raise ValueError('验证集数据包不存在')
task = TrainTask( task = TrainTask(
model_name=model_name, model_name=model_name,
package_id=package_id, package_id=train_package_id,
package_name=pkg.name, package_name=train_pkg.name,
val_package_id=val_package_id,
val_package_name=val_pkg.name,
params=params, params=params,
status='pending', status='pending',
progress=0, progress=0,
...@@ -71,7 +81,7 @@ class TrainService: ...@@ -71,7 +81,7 @@ class TrainService:
session.refresh(task) session.refresh(task)
task_dict = self._task_to_dict(task) task_dict = self._task_to_dict(task)
self._launch_thread(task_dict['id'], package_id, params) self._launch_thread(task_dict['id'], train_package_id, val_package_id, params)
return task_dict return task_dict
def cancel_task(self, task_id: int) -> None: def cancel_task(self, task_id: int) -> None:
...@@ -100,6 +110,8 @@ class TrainService: ...@@ -100,6 +110,8 @@ class TrainService:
model_name=task.model_name, model_name=task.model_name,
package_id=task.package_id, package_id=task.package_id,
package_name=task.package_name, package_name=task.package_name,
val_package_id=task.val_package_id,
val_package_name=task.val_package_name,
params=task.params, params=task.params,
status='pending', status='pending',
progress=0, progress=0,
...@@ -109,7 +121,7 @@ class TrainService: ...@@ -109,7 +121,7 @@ class TrainService:
session.refresh(new_task) session.refresh(new_task)
task_dict = self._task_to_dict(new_task) task_dict = self._task_to_dict(new_task)
self._launch_thread(task_dict['id'], task_dict['package_id'], task_dict['params']) self._launch_thread(task_dict['id'], task_dict['package_id'], task_dict['val_package_id'], task_dict['params'])
return task_dict return task_dict
def delete_task(self, task_id: int) -> None: def delete_task(self, task_id: int) -> None:
...@@ -128,7 +140,7 @@ class TrainService: ...@@ -128,7 +140,7 @@ class TrainService:
session.delete(task) session.delete(task)
session.commit() session.commit()
def save_model(self, task_id: int) -> dict[str, Any]: def save_model(self, task_id: int, model_name: str | None = None, description: str | None = None) -> dict[str, Any]:
with db_session() as session: with db_session() as session:
task = session.query(TrainTask).filter(TrainTask.id == task_id).first() task = session.query(TrainTask).filter(TrainTask.id == task_id).first()
if not task: if not task:
...@@ -144,7 +156,8 @@ class TrainService: ...@@ -144,7 +156,8 @@ class TrainService:
saved = SavedModel( saved = SavedModel(
task_id=task_id, task_id=task_id,
model_name=task.model_name, model_name=(model_name.strip() if model_name and model_name.strip() else task.model_name),
description=description,
package_id=task.package_id, package_id=task.package_id,
package_name=task.package_name, package_name=task.package_name,
params=task.params, params=task.params,
...@@ -158,7 +171,18 @@ class TrainService: ...@@ -158,7 +171,18 @@ class TrainService:
session.commit() session.commit()
session.refresh(saved) session.refresh(saved)
return self._saved_to_dict(saved) return self._saved_to_dict(saved)
def update_saved_model(self, model_id: int, model_name: str | None = None, description: str | None = None) -> dict[str, Any]:
with db_session() as session:
model = session.query(SavedModel).filter(SavedModel.id == model_id).first()
if not model:
raise ValueError('模型不存在')
if model_name is not None and model_name.strip():
model.model_name = model_name.strip()
if description is not None:
model.description = description
session.commit()
session.refresh(model)
return self._saved_to_dict(model)
# ── saved models ───────────────────────────────────────────────────────── # ── saved models ─────────────────────────────────────────────────────────
def list_saved_models(self) -> list[dict[str, Any]]: def list_saved_models(self) -> list[dict[str, Any]]:
...@@ -193,14 +217,15 @@ class TrainService: ...@@ -193,14 +217,15 @@ class TrainService:
if p.is_absolute(): if p.is_absolute():
return p return p
return self._base_dir / p return self._base_dir / p
def _launch_thread(self, task_id: int, package_id: int, params: dict) -> None:
def _launch_thread(self, task_id: int, train_package_id: int, val_package_id: int | None, params: dict) -> None:
cancel_event = threading.Event() cancel_event = threading.Event()
with _registry_lock: with _registry_lock:
_cancel_events[task_id] = cancel_event _cancel_events[task_id] = cancel_event
thread = threading.Thread( thread = threading.Thread(
target=self._training_worker, target=self._training_worker,
args=(task_id, package_id, params, cancel_event), args=(task_id, train_package_id, val_package_id, params, cancel_event),
daemon=True, daemon=True,
name=f'train-task-{task_id}', name=f'train-task-{task_id}',
) )
...@@ -209,35 +234,44 @@ class TrainService: ...@@ -209,35 +234,44 @@ class TrainService:
def _training_worker( def _training_worker(
self, self,
task_id: int, task_id: int,
package_id: int, train_package_id: int,
val_package_id: int | None,
params: dict, params: dict,
cancel_event: threading.Event, cancel_event: threading.Event,
) -> None: ) -> None:
try: try:
self._update_task(task_id, status='running', progress=0) self._update_task(task_id, status='running', progress=0)
records = self._load_package_records(package_id) train_records = self._load_package_records(train_package_id)
if not records: if not train_records:
raise ValueError('数据包没有有效数据,请检查关联文件') raise ValueError('训练集数据包没有有效数据,请检查关联文件')
val_records: list[dict[str, Any]] = []
if val_package_id:
val_records = self._load_package_records(val_package_id)
save_path = self._models_dir / f'task_{task_id}.pt' save_path = self._models_dir / f'task_{task_id}.pt'
last_pct = [0] epoch_log_buffer: list[dict[str, Any]] = []
def on_progress(pct: int, train_loss: float, val_loss: float | None) -> None: def on_progress(pct: int, epoch: int, train_loss: float, val_loss: float | None) -> None:
if cancel_event.is_set(): if cancel_event.is_set():
return return
# Throttle: persist at most every 2 % to reduce DB writes epoch_log_buffer.append({
if pct - last_pct[0] >= 2 or pct == 100: 'epoch': epoch,
last_pct[0] = pct 'train_loss': round(float(train_loss), 6),
'val_loss': round(float(val_loss), 6) if val_loss is not None else None,
})
self._update_task( self._update_task(
task_id, task_id,
progress=pct, progress=pct,
train_loss=round(float(train_loss), 6), train_loss=round(float(train_loss), 6),
val_loss=round(float(val_loss), 6) if val_loss is not None else None, val_loss=round(float(val_loss), 6) if val_loss is not None else None,
epoch_logs=list(epoch_log_buffer),
) )
result = train_lstm( result = train_lstm(
records=records, train_records=train_records,
val_records=val_records,
params=params, params=params,
save_path=save_path, save_path=save_path,
on_progress=on_progress, on_progress=on_progress,
...@@ -304,6 +338,9 @@ class TrainService: ...@@ -304,6 +338,9 @@ class TrainService:
'model_name': task.model_name, 'model_name': task.model_name,
'package_id': task.package_id, 'package_id': task.package_id,
'package_name': task.package_name, 'package_name': task.package_name,
'val_package_id': task.val_package_id,
'val_package_name': task.val_package_name,
'epoch_logs': task.epoch_logs or [],
'params': task.params, 'params': task.params,
'status': task.status, 'status': task.status,
'progress': task.progress, 'progress': task.progress,
...@@ -321,6 +358,7 @@ class TrainService: ...@@ -321,6 +358,7 @@ class TrainService:
'id': model.id, 'id': model.id,
'task_id': model.task_id, 'task_id': model.task_id,
'model_name': model.model_name, 'model_name': model.model_name,
'description': model.description or '',
'package_id': model.package_id, 'package_id': model.package_id,
'package_name': model.package_name, 'package_name': model.package_name,
'params': model.params, 'params': model.params,
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
step,actual_temp,current_output,reference_temp,target_temp
0,20.534,0.0,23.4272,35.0
1,20.534,0.1796,23.4272,35.0
2,20.534,0.1796,23.4272,35.0
3,20.534,0.1949,23.4272,35.0
4,20.534,0.1949,23.4272,35.0
5,20.534,0.1949,23.4272,35.0
6,20.534,0.1949,23.4272,35.0
7,20.534,0.1949,23.4272,35.0
8,20.534,0.1354,23.4272,35.0
9,20.534,0.1354,23.4272,35.0
10,20.534,0.1354,23.4272,35.0
11,20.534,0.1354,23.4272,35.0
12,20.534,0.0,23.4272,35.0
13,20.534,0.1796,23.4272,35.0
14,20.534,0.1796,23.4272,35.0
15,20.534,0.1949,23.4272,35.0
16,20.534,0.1949,23.4272,35.0
17,20.534,0.1949,23.4272,35.0
18,20.534,0.1949,23.4272,35.0
19,20.534,0.1949,23.4272,35.0
20,20.534,0.1354,23.4272,35.0
21,20.534,0.1354,23.4272,35.0
22,20.534,0.1354,23.4272,35.0
23,20.534,0.1354,23.4272,35.0
24,20.534,0.0,23.4272,35.0
25,20.534,0.1796,23.4272,35.0
26,20.534,0.1796,23.4272,35.0
27,20.534,0.1949,23.4272,35.0
...@@ -11,10 +11,8 @@ const tabs = [ ...@@ -11,10 +11,8 @@ const tabs = [
{ label: '模型训练', path: '/model-training' }, { label: '模型训练', path: '/model-training' },
{ label: '模型库', path: '/model-list' }, { label: '模型库', path: '/model-list' },
{ label: '模型评估', path: '/model-evaluation' }, { label: '模型评估', path: '/model-evaluation' },
{ label: '实时监控', path: '/realtime-monitor' }, { label: '试验管理', path: '/realtime-monitor' },
{ label: '历史数据', path: '/history-data' }, { label: '实时监控', path: '/live-monitor' },
] ]
const activeTab = computed(() => { const activeTab = computed(() => {
......
...@@ -32,6 +32,10 @@ export function deleteDataFile(fileId) { ...@@ -32,6 +32,10 @@ export function deleteDataFile(fileId) {
return request.delete(`/data/files/${fileId}`) return request.delete(`/data/files/${fileId}`)
} }
export function updateDataFile(fileId, payload) {
return request.put(`/data/files/${fileId}`, payload)
}
export function getFileRecords(fileId, params) { export function getFileRecords(fileId, params) {
return request.get(`/data/files/${fileId}/records`, { params }) return request.get(`/data/files/${fileId}/records`, { params })
} }
......
import request from '@/utils/request' import request from '@/utils/request'
export function getEvalPackages() { export function getEvalPackages(params = {}) {
return request.get('/eval/packages') return request.get('/eval/packages', { params })
} }
export function getEvalModels() { export function getEvalModels() {
......
import request from '@/utils/request'
export function getHistoryList() {
return request.get('/monitor/history')
}
export function getHistoryDetail(expId) {
return request.get(`/monitor/experiments/${expId}/report`)
}
...@@ -20,8 +20,8 @@ export function deletePkgCategory(categoryId) { ...@@ -20,8 +20,8 @@ export function deletePkgCategory(categoryId) {
// ── data files (for selection) ─────────────────────────────────────────────── // ── data files (for selection) ───────────────────────────────────────────────
export function getAllDataFiles() { export function getAllDataFiles(params) {
return request.get('/packages/data-files') return request.get('/packages/data-files', { params })
} }
// ── packages ───────────────────────────────────────────────────────────────── // ── packages ─────────────────────────────────────────────────────────────────
...@@ -34,6 +34,10 @@ export function createPackage(payload) { ...@@ -34,6 +34,10 @@ export function createPackage(payload) {
return request.post('/packages', payload) return request.post('/packages', payload)
} }
export function updatePackage(packageId, payload) {
return request.put(`/packages/${packageId}`, payload)
}
export function deletePackage(packageId) { export function deletePackage(packageId) {
return request.delete(`/packages/${packageId}`) return request.delete(`/packages/${packageId}`)
} }
......
...@@ -40,10 +40,4 @@ export function getDataPoints(expId, fromStep = 0) { ...@@ -40,10 +40,4 @@ export function getDataPoints(expId, fromStep = 0) {
return request.get(`/monitor/experiments/${expId}/data`, { params: { from_step: fromStep } }) return request.get(`/monitor/experiments/${expId}/data`, { params: { from_step: fromStep } })
} }
export function getReport(expId) {
return request.get(`/monitor/experiments/${expId}/report`)
}
export function exportToHistory(expId) {
return request.post(`/monitor/experiments/${expId}/export`)
}
...@@ -28,8 +28,8 @@ export function deleteTrainTask(taskId) { ...@@ -28,8 +28,8 @@ export function deleteTrainTask(taskId) {
return request.delete(`/train/tasks/${taskId}`) return request.delete(`/train/tasks/${taskId}`)
} }
export function saveTrainModel(taskId) { export function saveTrainModel(taskId, payload = {}) {
return request.post(`/train/tasks/${taskId}/save`) return request.post(`/train/tasks/${taskId}/save`, payload)
} }
export function getSavedModels() { export function getSavedModels() {
...@@ -39,3 +39,7 @@ export function getSavedModels() { ...@@ -39,3 +39,7 @@ export function getSavedModels() {
export function deleteSavedModel(modelId) { export function deleteSavedModel(modelId) {
return request.delete(`/train/models/${modelId}`) return request.delete(`/train/models/${modelId}`)
} }
export function updateSavedModel(modelId, payload) {
return request.patch(`/train/models/${modelId}`, payload)
}
...@@ -23,9 +23,9 @@ const router = createRouter({ ...@@ -23,9 +23,9 @@ const router = createRouter({
component: () => import('@/views/RealtimeMonitor/components/ExperimentDetail.vue'), component: () => import('@/views/RealtimeMonitor/components/ExperimentDetail.vue'),
}, },
{ {
path: '/history-data', path: '/live-monitor',
name: 'history-data', name: 'live-monitor',
component: () => import('@/views/HistoryData/index.vue'), component: () => import('@/views/LiveMonitor/index.vue'),
}, },
{ {
path: '/model-training', path: '/model-training',
......
...@@ -7,6 +7,10 @@ const props = defineProps({ ...@@ -7,6 +7,10 @@ const props = defineProps({
type: Array, type: Array,
default: () => [], default: () => [],
}, },
totalCount: {
type: Number,
default: null,
},
}) })
const chartRef = ref(null) const chartRef = ref(null)
...@@ -225,7 +229,7 @@ onBeforeUnmount(() => { ...@@ -225,7 +229,7 @@ onBeforeUnmount(() => {
<div class="stats-row"> <div class="stats-row">
<div class="stat-card"> <div class="stat-card">
<div class="stat-label">总条数</div> <div class="stat-label">总条数</div>
<div class="stat-value">{{ stats.count }}</div> <div class="stat-value">{{ props.totalCount ?? stats.count }}</div>
</div> </div>
<div class="stat-card"> <div class="stat-card">
<div class="stat-label">最高温度</div> <div class="stat-label">最高温度</div>
......
<script setup> <script setup>
import { Delete, Document, Edit, Folder, FolderOpened, Upload } from '@element-plus/icons-vue' import { DataAnalysis, Delete, Document, Edit, Folder, FolderOpened, Upload } from '@element-plus/icons-vue'
import { computed, onBeforeUnmount, onMounted, reactive, ref } from 'vue' import { computed, onBeforeUnmount, onMounted, reactive, ref } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElAutoResizer, ElMessage, ElMessageBox, ElTableV2 } from 'element-plus'
import { import {
createCategory, createCategory,
deleteCategory, deleteCategory,
...@@ -12,6 +12,7 @@ import { ...@@ -12,6 +12,7 @@ import {
getFileQuality, getFileQuality,
getFileRecords, getFileRecords,
updateCategory, updateCategory,
updateDataFile,
uploadDataFile, uploadDataFile,
} from '@/api/dataManagement' } from '@/api/dataManagement'
import DataCurve from './components/DataCurve.vue' import DataCurve from './components/DataCurve.vue'
...@@ -31,6 +32,21 @@ const recordList = ref([]) ...@@ -31,6 +32,21 @@ const recordList = ref([])
const contentLoading = ref(false) const contentLoading = ref(false)
const contentMode = ref('table') const contentMode = ref('table')
const recordColumns = [
{ key: 'time', dataKey: 'time', title: '时间', width: 170 },
{ key: 'current', dataKey: 'current', title: '电流', width: 110 },
{ key: 'voltage', dataKey: 'voltage', title: '电压', width: 110 },
{ key: 'set_temperature', dataKey: 'set_temperature', title: '设定温度', width: 120 },
{ key: 'actual_temperature', dataKey: 'actual_temperature', title: '实际温度', width: 120 },
]
const curveRecords = computed(() => {
const r = recordList.value
if (r.length <= 2000) return r
const step = Math.ceil(r.length / 2000)
return r.filter((_, i) => i % step === 0)
})
const categoryDialogVisible = ref(false) const categoryDialogVisible = ref(false)
const categoryDialogMode = ref('create') const categoryDialogMode = ref('create')
const editingCategoryId = ref('') const editingCategoryId = ref('')
...@@ -48,6 +64,66 @@ const qualityLoading = ref(false) ...@@ -48,6 +64,66 @@ const qualityLoading = ref(false)
const qualityFile = ref(null) const qualityFile = ref(null)
const qualityResult = ref(null) const qualityResult = ref(null)
const fileEditDialogVisible = ref(false)
const fileEditSubmitting = ref(false)
const editingFileId = ref(null)
const fileEditForm = reactive({
filename: '',
remark: '',
categoryId: '',
})
const openFileEditDialog = (row) => {
editingFileId.value = row.id
fileEditForm.filename = row.filename
fileEditForm.remark = row.remark || ''
fileEditForm.categoryId = row.category_id ? String(row.category_id) : ''
fileEditDialogVisible.value = true
}
const submitFileEdit = async () => {
const name = fileEditForm.filename.trim()
if (!name) {
ElMessage.warning('文件名不能为空')
return
}
if (!fileEditForm.categoryId) {
ElMessage.warning('请选择分类')
return
}
fileEditSubmitting.value = true
const prevCategoryId = fileList.value.find((f) => f.id === editingFileId.value)?.category_id
try {
const result = await updateDataFile(editingFileId.value, {
filename: name,
remark: fileEditForm.remark.trim() || null,
category_id: fileEditForm.categoryId,
})
ElMessage.success('保存成功')
fileEditDialogVisible.value = false
const categoryChanged = String(result.category_id) !== String(prevCategoryId)
if (categoryChanged) {
await loadFileList()
if (currentFile.value?.id === editingFileId.value) {
currentFile.value = null
recordList.value = []
}
} else {
const idx = fileList.value.findIndex((f) => f.id === editingFileId.value)
if (idx !== -1) {
fileList.value[idx].filename = result.filename
fileList.value[idx].remark = result.remark || ''
}
if (currentFile.value?.id === editingFileId.value) {
currentFile.value = { ...currentFile.value, filename: result.filename, remark: result.remark || '' }
}
}
} finally {
fileEditSubmitting.value = false
}
}
const qualityColorClass = (percent) => { const qualityColorClass = (percent) => {
if (percent >= 90) return 'quality-good' if (percent >= 90) return 'quality-good'
if (percent >= 70) return 'quality-warn' if (percent >= 70) return 'quality-warn'
...@@ -407,7 +483,7 @@ const handleViewFile = async (row) => { ...@@ -407,7 +483,7 @@ const handleViewFile = async (row) => {
currentFile.value = row currentFile.value = row
contentLoading.value = true contentLoading.value = true
try { try {
const result = await getFileRecords(row.id, { limit: 500 }) const result = await getFileRecords(row.id)
recordList.value = result.records recordList.value = result.records
} finally { } finally {
contentLoading.value = false contentLoading.value = false
...@@ -506,15 +582,16 @@ onBeforeUnmount(() => { ...@@ -506,15 +582,16 @@ onBeforeUnmount(() => {
</el-form-item> </el-form-item>
</el-form> </el-form>
<el-table :data="fileList" border stripe v-loading="loadingFiles" height="calc(100vh - 250px)"> <el-table :data="fileList" border stripe v-loading="loadingFiles" height="calc(100vh - 250px)" row-class-name="file-row" @row-click="handleViewFile">
<el-table-column prop="filename" label="文件名" min-width="160" /> <el-table-column prop="filename" label="文件名" min-width="160" />
<el-table-column prop="remark" label="备注" min-width="120" show-overflow-tooltip />
<el-table-column prop="uploaded_at" label="上传时间" min-width="170" /> <el-table-column prop="uploaded_at" label="上传时间" min-width="170" />
<el-table-column prop="data_count" label="数据量" width="90" /> <el-table-column prop="data_count" label="数据量" width="90" />
<el-table-column label="操作" width="200" fixed="right"> <el-table-column label="操作" width="220" fixed="right">
<template #default="{ row }"> <template #default="{ row }">
<el-button link type="primary" @click="handleViewFile(row)">查看</el-button> <el-button link type="primary" :icon="Edit" @click.stop="openFileEditDialog(row)" />
<el-button link type="success" @click="handleQualityCheck(row)">质量判定</el-button> <el-button link type="success" :icon="DataAnalysis" @click.stop="handleQualityCheck(row)" title="质量判定" />
<el-button link type="danger" @click="handleDeleteFile(row)">删除</el-button> <el-button link type="danger" :icon="Delete" @click.stop="handleDeleteFile(row)" />
</template> </template>
</el-table-column> </el-table-column>
</el-table> </el-table>
...@@ -528,6 +605,7 @@ onBeforeUnmount(() => { ...@@ -528,6 +605,7 @@ onBeforeUnmount(() => {
<span> <span>
文件内容 文件内容
<span v-if="currentFile" class="file-title">- {{ currentFile.filename }}</span> <span v-if="currentFile" class="file-title">- {{ currentFile.filename }}</span>
<span v-if="currentFile && currentFile.remark" class="file-remark">{{ currentFile.remark }}</span>
</span> </span>
<el-radio-group v-model="contentMode" size="small"> <el-radio-group v-model="contentMode" size="small">
<el-radio-button value="table">表格</el-radio-button> <el-radio-button value="table">表格</el-radio-button>
...@@ -539,21 +617,21 @@ onBeforeUnmount(() => { ...@@ -539,21 +617,21 @@ onBeforeUnmount(() => {
<div class="content-wrap" v-loading="contentLoading"> <div class="content-wrap" v-loading="contentLoading">
<el-empty v-if="!currentFile" description="请选择并查看一个文件" /> <el-empty v-if="!currentFile" description="请选择并查看一个文件" />
<el-table <el-auto-resizer v-else-if="contentMode === 'table'">
v-else-if="contentMode === 'table'" <template #default="{ height, width }">
<el-table-v2
:columns="recordColumns"
:data="recordList" :data="recordList"
border :width="width"
stripe :height="height"
height="calc(100vh - 250px)" :row-height="36"
> :header-height="40"
<el-table-column prop="time" label="时间" min-width="140" /> fixed
<el-table-column prop="current" label="电流" min-width="100" /> />
<el-table-column prop="voltage" label="电压" min-width="100" /> </template>
<el-table-column prop="set_temperature" label="设定温度" min-width="100" /> </el-auto-resizer>
<el-table-column prop="actual_temperature" label="实际温度" min-width="100" />
</el-table>
<DataCurve v-else :records="recordList" /> <DataCurve v-else :records="curveRecords" :total-count="recordList.length" />
</div> </div>
</el-card> </el-card>
</div> </div>
...@@ -699,6 +777,39 @@ onBeforeUnmount(() => { ...@@ -699,6 +777,39 @@ onBeforeUnmount(() => {
<el-button @click="qualityDialogVisible = false">关闭</el-button> <el-button @click="qualityDialogVisible = false">关闭</el-button>
</template> </template>
</el-dialog> </el-dialog>
<!-- 编辑文件对话框 -->
<el-dialog v-model="fileEditDialogVisible" title="编辑文件信息" width="460px" destroy-on-close>
<el-form label-position="top">
<el-form-item label="文件名" required>
<el-input v-model="fileEditForm.filename" maxlength="255" show-word-limit clearable />
</el-form-item>
<el-form-item label="所属分类" required>
<el-select v-model="fileEditForm.categoryId" placeholder="请选择分类" style="width: 100%">
<el-option
v-for="item in flatCategoryOptions"
:key="item.value"
:label="item.label"
:value="item.value"
/>
</el-select>
</el-form-item>
<el-form-item label="备注">
<el-input
v-model="fileEditForm.remark"
type="textarea"
:rows="3"
maxlength="500"
show-word-limit
placeholder="可选备注"
/>
</el-form-item>
</el-form>
<template #footer>
<el-button @click="fileEditDialogVisible = false">取消</el-button>
<el-button type="primary" :loading="fileEditSubmitting" @click="submitFileEdit">保存</el-button>
</template>
</el-dialog>
</section> </section>
</template> </template>
...@@ -794,6 +905,13 @@ onBeforeUnmount(() => { ...@@ -794,6 +905,13 @@ onBeforeUnmount(() => {
margin-left: 4px; margin-left: 4px;
} }
.file-remark {
font-size: 12px;
color: var(--text-tertiary);
font-weight: 400;
margin-left: 2px;
}
.tree-node { .tree-node {
width: 100%; width: 100%;
display: flex; display: flex;
...@@ -948,4 +1066,8 @@ onBeforeUnmount(() => { ...@@ -948,4 +1066,8 @@ onBeforeUnmount(() => {
.quality-no-config { .quality-no-config {
color: #e6a23c; color: #e6a23c;
} }
:deep(.file-row) {
cursor: pointer;
}
</style> </style>
<script setup>
import * as echarts from 'echarts'
import { Refresh } from '@element-plus/icons-vue'
import { ElMessage } from 'element-plus'
import { onMounted, onBeforeUnmount, ref, nextTick } from 'vue'
import { getHistoryList, getHistoryDetail } from '@/api/historyData'
// ── 列表 ───────────────────────────────────────────────────────────────────────
const records = ref([])
const loadingList = ref(false)
const selectedId = ref(null)
const loadList = async () => {
loadingList.value = true
try {
records.value = await getHistoryList()
} catch {
ElMessage.error('加载历史数据失败')
} finally {
loadingList.value = false
}
}
// ── 详情 ───────────────────────────────────────────────────────────────────────
const detail = ref(null)
const loadingDetail = ref(false)
let tempChart = null
let currChart = null
const tempChartRef = ref(null)
const currChartRef = ref(null)
const disposeCharts = () => {
tempChart?.dispose()
currChart?.dispose()
tempChart = null
currChart = null
}
const initCharts = (points, targetTemp) => {
disposeCharts()
if (!tempChartRef.value || !currChartRef.value) return
tempChart = echarts.init(tempChartRef.value)
currChart = echarts.init(currChartRef.value)
const xData = points.map((d) => `步${d.step_idx}`)
const actuals = points.map((d) => d.actual_temp)
const refs = points.map((d) => d.reference_temp)
const currents = points.map((d) => d.current_output)
const targetLine = points.map(() => targetTemp)
const intervalVal = Math.max(0, Math.floor(xData.length / 10) - 1)
tempChart.setOption({
animation: false,
color: ['#409EFF', '#67C23A', '#F56C6C'],
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)',
borderColor: '#e2e8f0',
borderWidth: 1,
formatter(params) {
if (!params?.length) return ''
const lines = [`<div style="margin-bottom:4px;font-weight:600">${params[0].axisValue}</div>`]
params.forEach((p) => {
const v = p.data != null ? Number(p.data).toFixed(3) + ' °C' : '--'
lines.push(
`<div style="display:flex;justify-content:space-between;gap:16px">
<span>${p.marker}${p.seriesName}</span><strong>${v}</strong>
</div>`,
)
})
return lines.join('')
},
},
legend: { bottom: 4, itemWidth: 18, itemHeight: 8, textStyle: { color: '#475569', fontSize: 12 } },
grid: { top: 16, left: 16, right: 16, bottom: 52, containLabel: true },
xAxis: {
type: 'category',
boundaryGap: false,
data: xData,
axisLabel: { color: '#64748b', fontSize: 11, interval: intervalVal },
axisLine: { lineStyle: { color: '#cbd5e1' } },
},
yAxis: {
type: 'value',
name: '温度 (°C)',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
splitLine: { lineStyle: { color: '#e2e8f0' } },
},
series: [
{ name: '实际温度', type: 'line', data: actuals, smooth: true, symbol: 'none', lineStyle: { width: 2 } },
{ name: '参考轨迹', type: 'line', data: refs, smooth: true, symbol: 'none', lineStyle: { width: 1.5, type: 'dashed' } },
{ name: '目标温度', type: 'line', data: targetLine, symbol: 'none', lineStyle: { width: 1.5, type: 'dotted', color: '#F56C6C' } },
],
})
currChart.setOption({
animation: false,
color: ['#E6A23C'],
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)',
borderColor: '#e2e8f0',
borderWidth: 1,
formatter(params) {
if (!params?.length) return ''
const v = params[0].data != null ? Number(params[0].data).toFixed(3) + ' A' : '--'
return `<div style="font-weight:600">${params[0].axisValue}</div>
<div>${params[0].marker}电流输出:<strong>${v}</strong></div>`
},
},
legend: { bottom: 4, itemWidth: 18, itemHeight: 8, textStyle: { color: '#475569', fontSize: 12 } },
grid: { top: 16, left: 16, right: 16, bottom: 52, containLabel: true },
xAxis: {
type: 'category',
boundaryGap: false,
data: xData,
axisLabel: { color: '#64748b', fontSize: 11, interval: intervalVal },
axisLine: { lineStyle: { color: '#cbd5e1' } },
},
yAxis: {
type: 'value',
name: '电流 (A)',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
splitLine: { lineStyle: { color: '#e2e8f0' } },
min: 0,
},
series: [
{ name: '电流输出', type: 'line', data: currents, smooth: false, symbol: 'none', lineStyle: { width: 2 }, areaStyle: { opacity: 0.08 } },
],
})
}
const selectRecord = async (row) => {
if (selectedId.value === row.id) return
selectedId.value = row.id
detail.value = null
loadingDetail.value = true
try {
detail.value = await getHistoryDetail(row.id)
await nextTick()
if (detail.value?.points?.length) {
initCharts(detail.value.points, detail.value.experiment?.target_temp)
}
} catch {
ElMessage.error('加载详情失败')
} finally {
loadingDetail.value = false
}
}
const onResize = () => {
tempChart?.resize()
currChart?.resize()
}
onMounted(async () => {
await loadList()
if (records.value.length) {
await selectRecord(records.value[0])
}
window.addEventListener('resize', onResize)
})
onBeforeUnmount(() => {
window.removeEventListener('resize', onResize)
disposeCharts()
})
const fmtNum = (v, digits = 4) => (v != null ? Number(v).toFixed(digits) : '--')
</script>
<template>
<div class="history-page">
<!-- ── 左侧:记录列表 ────────────────────────────────────────────────── -->
<div class="left-panel">
<div class="panel-header">
<span class="panel-title">历史数据</span>
<el-button :icon="Refresh" size="small" plain :loading="loadingList" @click="loadList">
刷新
</el-button>
</div>
<div v-loading="loadingList" class="record-list">
<el-empty v-if="!loadingList && !records.length" description="暂无已导出的历史数据" :image-size="80" />
<div
v-for="row in records"
:key="row.id"
class="record-item"
:class="{ active: selectedId === row.id }"
@click="selectRecord(row)"
>
<div class="record-name">{{ row.name }}</div>
<div class="record-meta">
<span>{{ row.model_name }}</span>
<el-tag size="small" type="info" class="steps-tag">{{ row.total_steps }}</el-tag>
</div>
<div class="record-time">{{ row.stop_time ?? row.created_at }}</div>
</div>
</div>
</div>
<!-- ── 右侧:详情 ────────────────────────────────────────────────────── -->
<div class="right-panel" v-loading="loadingDetail">
<el-empty v-if="!detail && !loadingDetail" description="请在左侧选择一条历史记录" :image-size="100" />
<template v-if="detail">
<!-- 基本信息 -->
<el-card shadow="hover" class="info-card">
<div class="info-grid">
<div class="info-item">
<span class="label">试验名称</span>
<span class="value">{{ detail.experiment?.name }}</span>
</div>
<div class="info-item">
<span class="label">预测模型</span>
<span class="value">{{ detail.experiment?.model_name }}</span>
</div>
<div class="info-item">
<span class="label">初始数据包</span>
<span class="value">{{ detail.experiment?.package_name }}</span>
</div>
<div class="info-item">
<span class="label">目标温度</span>
<span class="value highlight">{{ detail.experiment?.target_temp?.toFixed(1) }} °C</span>
</div>
<div class="info-item">
<span class="label">开始时间</span>
<span class="value">{{ detail.experiment?.start_time ?? '--' }}</span>
</div>
<div class="info-item">
<span class="label">停止时间</span>
<span class="value">{{ detail.experiment?.stop_time ?? '--' }}</span>
</div>
</div>
</el-card>
<!-- 指标摘要 -->
<el-card v-if="detail.summary" shadow="hover" class="summary-card">
<template #header><span class="card-title">控制性能指标</span></template>
<div class="metrics-row">
<div class="metric-chip">
<span class="metric-label">采集步数</span>
<span class="metric-value">{{ detail.summary.total_steps }}</span>
</div>
<div class="metric-chip">
<span class="metric-label">仿真时长</span>
<span class="metric-value">{{ detail.summary.duration_s }} s</span>
</div>
<div class="metric-chip">
<span class="metric-label">初始温度</span>
<span class="metric-value">{{ fmtNum(detail.summary.initial_temp, 2) }} °C</span>
</div>
<div class="metric-chip">
<span class="metric-label">最终温度</span>
<span class="metric-value">{{ fmtNum(detail.summary.final_temp, 2) }} °C</span>
</div>
<div class="metric-chip">
<span class="metric-label">MAE</span>
<span class="metric-value">{{ fmtNum(detail.summary.mae) }} °C</span>
</div>
<div class="metric-chip">
<span class="metric-label">RMSE</span>
<span class="metric-value">{{ fmtNum(detail.summary.rmse) }} °C</span>
</div>
<div class="metric-chip">
<span class="metric-label">最大超调</span>
<span class="metric-value">{{ fmtNum(detail.summary.overshoot) }} °C</span>
</div>
<div class="metric-chip">
<span class="metric-label">调节时间</span>
<span class="metric-value">{{ detail.summary.settling_step ?? '未稳定' }} 步</span>
</div>
<div class="metric-chip">
<span class="metric-label">平均电流</span>
<span class="metric-value">{{ fmtNum(detail.summary.avg_current) }} A</span>
</div>
<div class="metric-chip">
<span class="metric-label">最大电流</span>
<span class="metric-value">{{ fmtNum(detail.summary.max_current) }} A</span>
</div>
</div>
</el-card>
<!-- 温度曲线 -->
<el-card shadow="hover" class="chart-card">
<template #header><span class="card-title">温度曲线</span></template>
<el-empty v-if="!detail.points?.length" description="暂无数据点" :image-size="60" />
<div v-else ref="tempChartRef" class="chart-body" />
</el-card>
<!-- 电流曲线 -->
<el-card shadow="hover" class="chart-card">
<template #header><span class="card-title">电流输出曲线</span></template>
<el-empty v-if="!detail.points?.length" description="暂无数据点" :image-size="60" />
<div v-else ref="currChartRef" class="chart-body" />
</el-card>
</template>
</div>
</div>
</template>
<style scoped>
.history-page {
display: flex;
gap: 16px;
height: calc(100vh - 60px);
padding: 16px;
box-sizing: border-box;
overflow: hidden;
}
/* ── 左侧 ── */
.left-panel {
width: 280px;
flex-shrink: 0;
display: flex;
flex-direction: column;
border: 1px solid #e2e8f0;
border-radius: 8px;
background: #fff;
overflow: hidden;
}
.panel-header {
display: flex;
align-items: center;
justify-content: space-between;
padding: 12px 16px;
border-bottom: 1px solid #e2e8f0;
flex-shrink: 0;
}
.panel-title {
font-size: 14px;
font-weight: 600;
color: #1e293b;
}
.record-list {
flex: 1;
overflow-y: auto;
padding: 8px;
}
.record-item {
padding: 10px 12px;
border-radius: 6px;
cursor: pointer;
transition: background 0.15s;
margin-bottom: 4px;
border: 1px solid transparent;
}
.record-item:hover {
background: #f1f5f9;
}
.record-item.active {
background: #eff6ff;
border-color: #bfdbfe;
}
.record-name {
font-size: 13px;
font-weight: 600;
color: #1e293b;
margin-bottom: 4px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.record-meta {
display: flex;
align-items: center;
justify-content: space-between;
font-size: 12px;
color: #64748b;
margin-bottom: 2px;
}
.steps-tag {
flex-shrink: 0;
}
.record-time {
font-size: 11px;
color: #94a3b8;
}
/* ── 右侧 ── */
.right-panel {
flex: 1;
min-width: 0;
overflow-y: auto;
display: flex;
flex-direction: column;
gap: 14px;
}
.info-card,
.summary-card,
.chart-card {
flex-shrink: 0;
}
.info-grid {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 12px 16px;
}
.info-item {
display: flex;
flex-direction: column;
gap: 2px;
}
.label {
font-size: 11px;
color: #94a3b8;
}
.value {
font-size: 13px;
color: #1e293b;
font-weight: 500;
}
.value.highlight {
color: #2563eb;
font-weight: 700;
}
.metrics-row {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
.metric-chip {
background: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 8px;
padding: 8px 14px;
display: flex;
flex-direction: column;
gap: 2px;
min-width: 100px;
}
.metric-label {
font-size: 11px;
color: #94a3b8;
}
.metric-value {
font-size: 14px;
font-weight: 600;
color: #1e293b;
}
.card-title {
font-size: 14px;
font-weight: 600;
color: #1e293b;
}
.chart-body {
height: 260px;
width: 100%;
}
</style>
<script setup>
import * as echarts from 'echarts'
import { Refresh } from '@element-plus/icons-vue'
import { ref, onMounted, onBeforeUnmount, nextTick } from 'vue'
import { useRouter } from 'vue-router'
import { getExperiments, getDataPoints } from '@/api/realtimeMonitor'
const router = useRouter()
const runningExps = ref([])
const loadingList = ref(false)
// echarts instances & data state, keyed by experiment id
const chartsMap = {} // id -> echarts instance
const dataMap = {} // id -> data points array
const fromStepMap = {} // id -> next from_step
const MAX_DISPLAY = 300
let listTimer = null
let dataTimer = null
// ── 数据加载 ──────────────────────────────────────────────────────────────────
const loadList = async () => {
loadingList.value = true
try {
const all = await getExperiments()
const running = all.filter((e) => e.status === 'running')
// 清理已停止试验的图表实例
const runningIds = new Set(running.map((e) => e.id))
for (const id of Object.keys(chartsMap)) {
if (!runningIds.has(Number(id))) {
chartsMap[id]?.dispose()
delete chartsMap[id]
delete dataMap[id]
delete fromStepMap[id]
}
}
runningExps.value = running
} finally {
loadingList.value = false
}
}
const fetchAllData = async () => {
for (const exp of runningExps.value) {
try {
const from = fromStepMap[exp.id] ?? 0
const newPts = await getDataPoints(exp.id, from)
if (newPts.length) {
dataMap[exp.id] = [...(dataMap[exp.id] ?? []), ...newPts]
fromStepMap[exp.id] = dataMap[exp.id][dataMap[exp.id].length - 1].step_idx + 1
await nextTick()
updateChart(exp.id, exp.target_temp)
}
} catch {
// ignore per-experiment errors
}
}
}
// ── 图表 ──────────────────────────────────────────────────────────────────────
const getOrInitChart = (expId) => {
if (!chartsMap[expId]) {
const el = document.getElementById(`live-chart-${expId}`)
if (el) chartsMap[expId] = echarts.init(el)
}
return chartsMap[expId]
}
const updateChart = (expId, targetTemp) => {
const c = getOrInitChart(expId)
if (!c) return
const pts = dataMap[expId] ?? []
if (!pts.length) return
const step = pts.length > MAX_DISPLAY ? Math.ceil(pts.length / MAX_DISPLAY) : 1
const display = pts.filter((_, i) => i % step === 0)
const xData = display.map((d) => `${d.step_idx}`)
const actuals = display.map((d) => d.actual_temp)
const currents = display.map((d) => d.current_output)
const targetLine = display.map(() => targetTemp ?? null)
c.setOption(
{
animation: false,
color: ['#409EFF', '#F56C6C', '#E6A23C'],
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)',
borderColor: '#e2e8f0',
borderWidth: 1,
formatter(params) {
if (!params?.length) return ''
const lines = [`<b>步 ${params[0].axisValue}</b>`]
params.forEach((p) => {
const unit = p.seriesName === '电流输出' ? ' A' : ' °C'
const v = p.data != null ? Number(p.data).toFixed(2) + unit : '--'
lines.push(`${p.marker}${p.seriesName}: <b>${v}</b>`)
})
return lines.join('<br/>')
},
},
legend: { bottom: 2, itemWidth: 14, itemHeight: 7, textStyle: { fontSize: 11, color: '#475569' } },
grid: { top: 12, left: 8, right: 50, bottom: 40, containLabel: true },
xAxis: {
type: 'category',
boundaryGap: false,
data: xData,
axisLabel: { color: '#94a3b8', fontSize: 10, interval: Math.max(0, Math.floor(xData.length / 6) - 1) },
axisLine: { lineStyle: { color: '#e2e8f0' } },
},
yAxis: [
{
type: 'value',
name: '°C',
nameTextStyle: { color: '#64748b', fontSize: 10 },
axisLabel: { color: '#64748b', fontSize: 10 },
splitLine: { lineStyle: { color: '#e2e8f0' } },
},
{
type: 'value',
name: 'A',
position: 'right',
nameTextStyle: { color: '#E6A23C', fontSize: 10 },
axisLabel: { color: '#E6A23C', fontSize: 10 },
splitLine: { show: false },
min: 0,
},
],
series: [
{
name: '实际温度',
type: 'line',
yAxisIndex: 0,
data: actuals,
smooth: true,
symbol: 'none',
lineStyle: { width: 1.5 },
},
{
name: '目标温度',
type: 'line',
yAxisIndex: 0,
data: targetLine,
symbol: 'none',
lineStyle: { width: 1, type: 'dotted', color: '#F56C6C' },
},
{
name: '电流输出',
type: 'line',
yAxisIndex: 1,
data: currents,
smooth: false,
symbol: 'none',
lineStyle: { width: 1.5, color: '#E6A23C' },
areaStyle: { color: '#E6A23C', opacity: 0.05 },
},
],
},
true,
)
}
const onResize = () => {
Object.values(chartsMap).forEach((c) => c?.resize())
}
// ── 生命周期 ──────────────────────────────────────────────────────────────────
onMounted(async () => {
await loadList()
await nextTick()
// 加载各试验初始数据
await fetchAllData()
listTimer = setInterval(loadList, 5000)
dataTimer = setInterval(fetchAllData, 2000)
window.addEventListener('resize', onResize)
})
onBeforeUnmount(() => {
clearInterval(listTimer)
clearInterval(dataTimer)
window.removeEventListener('resize', onResize)
Object.values(chartsMap).forEach((c) => c?.dispose())
})
</script>
<template>
<div class="live-page">
<div class="live-header">
<span class="live-title">实时监控</span>
<el-tag type="success" size="small" style="margin-left:10px">
{{ runningExps.length }} 个试验运行中
</el-tag>
<el-button
:icon="Refresh"
size="small"
plain
:loading="loadingList"
style="margin-left:auto"
@click="loadList"
>
刷新
</el-button>
<el-button size="small" plain @click="router.push('/realtime-monitor')">
前往试验管理
</el-button>
</div>
<div v-if="!runningExps.length && !loadingList" class="empty-state">
<el-empty description="当前没有正在运行的试验" :image-size="120">
<el-button type="primary" @click="router.push('/realtime-monitor')">
前往试验管理
</el-button>
</el-empty>
</div>
<div v-else class="exp-grid">
<el-card
v-for="exp in runningExps"
:key="exp.id"
shadow="hover"
class="exp-card"
>
<template #header>
<div class="exp-card-header">
<span class="exp-name">{{ exp.name }}</span>
<el-tag type="success" size="small" effect="plain">运行中</el-tag>
</div>
<div class="exp-meta">
<span>目标 {{ exp.target_temp?.toFixed(1) }} °C</span>
<span>步数 {{ exp.total_steps }}</span>
<span>{{ exp.model_name }}</span>
</div>
</template>
<div
:id="`live-chart-${exp.id}`"
class="mini-chart"
>
<div v-if="!(dataMap[exp.id]?.length)" class="chart-placeholder">
<el-icon class="is-loading" style="font-size:20px;color:#409EFF"><Refresh /></el-icon>
<span>等待数据…</span>
</div>
</div>
</el-card>
</div>
</div>
</template>
<style lang="scss" scoped>
.live-page {
padding: 20px;
display: flex;
flex-direction: column;
gap: 16px;
height: 100%;
}
.live-header {
display: flex;
align-items: center;
gap: 8px;
flex-shrink: 0;
}
.live-title {
font-size: 16px;
font-weight: 700;
color: var(--text-primary);
}
.empty-state {
flex: 1;
display: flex;
align-items: center;
justify-content: center;
}
.exp-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(480px, 1fr));
gap: 16px;
align-content: start;
}
.exp-card {
:deep(.el-card__header) {
padding: 10px 16px 6px;
}
:deep(.el-card__body) {
padding: 0;
}
}
.exp-card-header {
display: flex;
align-items: center;
gap: 8px;
}
.exp-name {
font-size: 14px;
font-weight: 600;
color: var(--text-primary);
flex: 1;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.exp-meta {
display: flex;
gap: 12px;
font-size: 12px;
color: var(--text-tertiary);
margin-top: 4px;
}
.mini-chart {
height: 240px;
width: 100%;
position: relative;
}
.chart-placeholder {
position: absolute;
inset: 0;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 8px;
color: #94a3b8;
font-size: 13px;
}
</style>
...@@ -23,6 +23,12 @@ const validData = computed(() => ...@@ -23,6 +23,12 @@ const validData = computed(() =>
const xLabels = computed(() => validData.value.map((d) => d.time || String(d.index ?? ''))) const xLabels = computed(() => validData.value.map((d) => d.time || String(d.index ?? '')))
const actualSeries = computed(() => validData.value.map((d) => d.actual ?? null)) const actualSeries = computed(() => validData.value.map((d) => d.actual ?? null))
const predictedSeries = computed(() => validData.value.map((d) => d.predicted ?? null)) const predictedSeries = computed(() => validData.value.map((d) => d.predicted ?? null))
const currentSeries = computed(() => validData.value.map((d) => d.current ?? null))
const voltageSeries = computed(() => validData.value.map((d) => d.voltage ?? null))
const hasCurrentVoltage = computed(() =>
validData.value.some((d) => d.current != null || d.voltage != null),
)
const renderChart = () => { const renderChart = () => {
if (!chartRef.value) return if (!chartRef.value) return
...@@ -30,10 +36,82 @@ const renderChart = () => { ...@@ -30,10 +36,82 @@ const renderChart = () => {
const hasData = validData.value.length > 0 const hasData = validData.value.length > 0
const yAxes = [
{
type: 'value',
name: '温度 (℃)',
position: 'left',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
axisLine: { show: true, lineStyle: { color: '#409EFF' } },
splitLine: { lineStyle: { type: 'dashed', color: 'rgba(148,163,184,0.4)' } },
},
]
if (hasCurrentVoltage.value) {
yAxes.push({
type: 'value',
name: '电流 (A) / 电压 (V)',
position: 'right',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
axisLine: { show: true, lineStyle: { color: '#E6A23C' } },
splitLine: { show: false },
})
}
const series = [
{
name: '实际温度 (℃)',
type: 'line',
yAxisIndex: 0,
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'solid' },
data: actualSeries.value,
},
{
name: '预测温度 (℃)',
type: 'line',
yAxisIndex: 0,
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'dashed' },
data: predictedSeries.value,
},
]
if (hasCurrentVoltage.value) {
series.push(
{
name: '电流 (A)',
type: 'line',
yAxisIndex: 1,
smooth: false,
symbol: 'none',
lineStyle: { width: 1.5, type: 'solid' },
data: currentSeries.value,
},
{
name: '电压 (V)',
type: 'line',
yAxisIndex: 1,
smooth: false,
symbol: 'none',
lineStyle: { width: 1.5, type: 'dotted' },
data: voltageSeries.value,
},
)
}
const legendData = hasCurrentVoltage.value
? ['实际温度 (℃)', '预测温度 (℃)', '电流 (A)', '电压 (V)']
: ['实际温度 (℃)', '预测温度 (℃)']
chartInstance.setOption( chartInstance.setOption(
{ {
animation: false, animation: false,
color: ['#409EFF', '#F56C6C'], color: ['#409EFF', '#F56C6C', '#E6A23C', '#67C23A'],
tooltip: { tooltip: {
trigger: 'axis', trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)', backgroundColor: 'rgba(255,255,255,0.96)',
...@@ -48,10 +126,13 @@ const renderChart = () => { ...@@ -48,10 +126,13 @@ const renderChart = () => {
] ]
params.forEach((item) => { params.forEach((item) => {
const val = item.data != null ? Number(item.data).toFixed(4) : '--' const val = item.data != null ? Number(item.data).toFixed(4) : '--'
const unit = item.seriesName.includes('℃') ? ' ℃' :
item.seriesName.includes('(A)') ? ' A' :
item.seriesName.includes('(V)') ? ' V' : ''
lines.push( lines.push(
`<div style="display:flex;align-items:center;gap:8px;min-width:180px;justify-content:space-between;"> `<div style="display:flex;align-items:center;gap:8px;min-width:200px;justify-content:space-between;">
<span>${item.marker}${item.seriesName}</span> <span>${item.marker}${item.seriesName}</span>
<strong>${val}</strong> <strong>${val}${unit}</strong>
</div>`, </div>`,
) )
}) })
...@@ -63,9 +144,9 @@ const renderChart = () => { ...@@ -63,9 +144,9 @@ const renderChart = () => {
itemWidth: 20, itemWidth: 20,
itemHeight: 10, itemHeight: 10,
textStyle: { color: '#475569', fontSize: 12 }, textStyle: { color: '#475569', fontSize: 12 },
data: ['实际温度 (℃)', '预测温度 (℃)'], data: legendData,
}, },
grid: { top: 16, left: 16, right: 20, bottom: 52, containLabel: true }, grid: { top: 16, left: 16, right: hasCurrentVoltage.value ? 80 : 20, bottom: 52, containLabel: true },
xAxis: { xAxis: {
type: 'category', type: 'category',
boundaryGap: false, boundaryGap: false,
...@@ -79,31 +160,8 @@ const renderChart = () => { ...@@ -79,31 +160,8 @@ const renderChart = () => {
}, },
axisLine: { lineStyle: { color: '#cbd5e1' } }, axisLine: { lineStyle: { color: '#cbd5e1' } },
}, },
yAxis: { yAxis: yAxes,
type: 'value', series,
name: '温度 (℃)',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
splitLine: { lineStyle: { type: 'dashed', color: 'rgba(148,163,184,0.4)' } },
},
series: [
{
name: '实际温度 (℃)',
type: 'line',
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'solid' },
data: actualSeries.value,
},
{
name: '预测温度 (℃)',
type: 'line',
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'dashed' },
data: predictedSeries.value,
},
],
graphic: hasData graphic: hasData
? [] ? []
: [ : [
...@@ -145,3 +203,4 @@ onBeforeUnmount(() => { ...@@ -145,3 +203,4 @@ onBeforeUnmount(() => {
<template> <template>
<div ref="chartRef" :style="{ width: '100%', height: props.height }" /> <div ref="chartRef" :style="{ width: '100%', height: props.height }" />
</template> </template>
<script setup> <script setup>
import { Refresh } from '@element-plus/icons-vue' import { Refresh, Search } from '@element-plus/icons-vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import { onMounted, reactive, ref } from 'vue' import { onMounted, reactive, ref } from 'vue'
import { useRoute } from 'vue-router'
import { import {
deleteEvalRecord, deleteEvalRecord,
getEvalModels, getEvalModels,
...@@ -10,16 +11,18 @@ import { ...@@ -10,16 +11,18 @@ import {
getEvalRecords, getEvalRecords,
runEvaluation, runEvaluation,
} from '@/api/evalManagement' } from '@/api/evalManagement'
import { getPkgCategoryTree } from '@/api/packageManagement'
import EvalChart from './components/EvalChart.vue' import EvalChart from './components/EvalChart.vue'
const route = useRoute()
// ── form ────────────────────────────────────────────────────────────────────── // ── form ──────────────────────────────────────────────────────────────────────
const packages = ref([])
const models = ref([]) const models = ref([])
const form = reactive({ model_id: '', package_id: '' }) const form = reactive({ model_id: '', package_id: '', package_name: '' })
const evaluating = ref(false) const evaluating = ref(false)
// Current evaluation result (shown directly above the records table) // Current evaluation result
const currentResult = ref(null) // { model_name, package_name, mae, rmse, chart_data } const currentResult = ref(null)
// ── records ─────────────────────────────────────────────────────────────────── // ── records ───────────────────────────────────────────────────────────────────
const records = ref([]) const records = ref([])
...@@ -30,11 +33,56 @@ const dialogVisible = ref(false) ...@@ -30,11 +33,56 @@ const dialogVisible = ref(false)
const dialogRecord = ref(null) const dialogRecord = ref(null)
const dialogLoading = ref(false) const dialogLoading = ref(false)
// ── package picker modal ──────────────────────────────────────────────────────
const pickerVisible = ref(false)
const pickerSearch = ref('')
const pickerCategoryId = ref('')
const pickerCategories = ref([])
const pickerPackages = ref([])
const pickerLoading = ref(false)
const openPicker = async () => {
pickerSearch.value = ''
pickerCategoryId.value = ''
pickerVisible.value = true
if (!pickerCategories.value.length) {
try {
const tree = await getPkgCategoryTree()
const flatten = (nodes, result = []) => {
for (const n of nodes) {
result.push({ value: String(n.id), label: n.name })
if (n.children?.length) flatten(n.children, result)
}
return result
}
pickerCategories.value = flatten(tree)
} catch {}
}
await loadPickerPackages()
}
const loadPickerPackages = async () => {
pickerLoading.value = true
try {
const data = await getEvalPackages({
category_id: pickerCategoryId.value || '',
name: pickerSearch.value.trim(),
})
pickerPackages.value = data
} finally {
pickerLoading.value = false
}
}
const selectPackage = (pkg) => {
form.package_id = pkg.id
form.package_name = pkg.name
pickerVisible.value = false
}
// ── actions ─────────────────────────────────────────────────────────────────── // ── actions ───────────────────────────────────────────────────────────────────
const loadDropdowns = async () => { const loadModels = async () => {
const [pkgs, mdls] = await Promise.all([getEvalPackages(), getEvalModels()]) models.value = await getEvalModels()
packages.value = pkgs
models.value = mdls
} }
const loadRecords = async () => { const loadRecords = async () => {
...@@ -92,9 +140,16 @@ const handleDelete = async (row) => { ...@@ -92,9 +140,16 @@ const handleDelete = async (row) => {
} }
const fmtMetric = (v) => (v != null ? Number(v).toFixed(5) : '-') const fmtMetric = (v) => (v != null ? Number(v).toFixed(5) : '-')
const fmtMape = (v) => (v != null ? Number(v).toFixed(3) + ' %' : '-')
const fmtR2 = (v) => (v != null ? Number(v).toFixed(6) : '-')
onMounted(async () => { onMounted(async () => {
await Promise.all([loadDropdowns(), loadRecords()]) await Promise.all([loadModels(), loadRecords()])
// Pre-select model from query param (e.g. from ModelList 评估 button)
const qModelId = route.query.model_id
if (qModelId) {
form.model_id = Number(qModelId)
}
}) })
</script> </script>
...@@ -112,7 +167,7 @@ onMounted(async () => { ...@@ -112,7 +167,7 @@ onMounted(async () => {
v-model="form.model_id" v-model="form.model_id"
placeholder="请选择已保存模型" placeholder="请选择已保存模型"
filterable filterable
style="width: 260px" style="width: 280px"
> >
<el-option <el-option
v-for="m in models" v-for="m in models"
...@@ -124,19 +179,15 @@ onMounted(async () => { ...@@ -124,19 +179,15 @@ onMounted(async () => {
</el-form-item> </el-form-item>
<el-form-item label="选择数据包" required> <el-form-item label="选择数据包" required>
<el-select <div class="pkg-picker-row">
v-model="form.package_id" <el-input
placeholder="请选择评估数据包" :model-value="form.package_name || ''"
filterable placeholder="点击右侧按钮选择数据包"
style="width: 240px" readonly
> style="width: 220px"
<el-option
v-for="p in packages"
:key="p.id"
:label="`${p.name}(${p.data_count} 条)`"
:value="p.id"
/> />
</el-select> <el-button @click="openPicker">选择</el-button>
</div>
</el-form-item> </el-form-item>
<el-form-item> <el-form-item>
...@@ -155,26 +206,42 @@ onMounted(async () => { ...@@ -155,26 +206,42 @@ onMounted(async () => {
</template> </template>
<template v-else-if="currentResult"> <template v-else-if="currentResult">
<div class="metrics-row"> <!-- 样本 / 模型信息行 -->
<div class="metrics-row meta-row">
<div class="metric-chip"> <div class="metric-chip">
<span class="metric-label">评估样本</span> <span class="metric-label">评估样本</span>
<span class="metric-value">{{ currentResult.total_count }}</span> <span class="metric-value">{{ currentResult.total_count }}</span>
</div> </div>
<div class="metric-chip"> <div class="metric-chip">
<span class="metric-label">MAE</span> <span class="metric-label">模型</span>
<span class="metric-value">{{ fmtMetric(currentResult.mae) }}</span> <span class="metric-value">{{ currentResult.model_name }}</span>
</div>
<div class="metric-chip">
<span class="metric-label">数据包</span>
<span class="metric-value">{{ currentResult.package_name }}</span>
</div>
</div>
<!-- 评估指标行 -->
<div class="metrics-row">
<div class="metric-chip">
<span class="metric-label">MSE(均方误差)</span>
<span class="metric-value">{{ fmtMetric(currentResult.mse) }}</span>
</div> </div>
<div class="metric-chip"> <div class="metric-chip">
<span class="metric-label">RMSE</span> <span class="metric-label">RMSE(均方根误差)</span>
<span class="metric-value">{{ fmtMetric(currentResult.rmse) }}</span> <span class="metric-value">{{ fmtMetric(currentResult.rmse) }}</span>
</div> </div>
<div class="metric-chip"> <div class="metric-chip">
<span class="metric-label">模型</span> <span class="metric-label">MAE(平均绝对误差)</span>
<span class="metric-value">{{ currentResult.model_name }}</span> <span class="metric-value">{{ fmtMetric(currentResult.mae) }}</span>
</div> </div>
<div class="metric-chip"> <div class="metric-chip">
<span class="metric-label">数据包</span> <span class="metric-label">MAPE(平均绝对百分比误差)</span>
<span class="metric-value">{{ currentResult.package_name }}</span> <span class="metric-value">{{ fmtMape(currentResult.mape) }}</span>
</div>
<div class="metric-chip">
<span class="metric-label">R²(决定系数)</span>
<span class="metric-value">{{ fmtR2(currentResult.r2) }}</span>
</div> </div>
</div> </div>
<EvalChart :chart-data="currentResult.chart_data" height="340px" /> <EvalChart :chart-data="currentResult.chart_data" height="340px" />
...@@ -201,19 +268,34 @@ onMounted(async () => { ...@@ -201,19 +268,34 @@ onMounted(async () => {
> >
<el-table-column prop="model_name" label="模型名称" min-width="140" show-overflow-tooltip /> <el-table-column prop="model_name" label="模型名称" min-width="140" show-overflow-tooltip />
<el-table-column prop="package_name" label="评估数据包" min-width="130" show-overflow-tooltip /> <el-table-column prop="package_name" label="评估数据包" min-width="130" show-overflow-tooltip />
<el-table-column label="样本数" width="90" align="center"> <el-table-column label="样本数" width="88" align="center">
<template #default="{ row }">{{ row.total_count }}</template> <template #default="{ row }">{{ row.total_count }}</template>
</el-table-column> </el-table-column>
<el-table-column label="MAE (℃)" width="110" align="center"> <el-table-column label="MSE" width="100" align="center">
<template #default="{ row }"> <template #default="{ row }">
<span class="metric-cell">{{ fmtMetric(row.mae) }}</span> <span class="metric-cell">{{ fmtMetric(row.mse) }}</span>
</template> </template>
</el-table-column> </el-table-column>
<el-table-column label="RMSE (℃)" width="110" align="center"> <el-table-column label="RMSE (℃)" width="100" align="center">
<template #default="{ row }"> <template #default="{ row }">
<span class="metric-cell">{{ fmtMetric(row.rmse) }}</span> <span class="metric-cell">{{ fmtMetric(row.rmse) }}</span>
</template> </template>
</el-table-column> </el-table-column>
<el-table-column label="MAE (℃)" width="100" align="center">
<template #default="{ row }">
<span class="metric-cell">{{ fmtMetric(row.mae) }}</span>
</template>
</el-table-column>
<el-table-column label="MAPE" width="90" align="center">
<template #default="{ row }">
<span class="metric-cell">{{ fmtMape(row.mape) }}</span>
</template>
</el-table-column>
<el-table-column label="R²" width="90" align="center">
<template #default="{ row }">
<span class="metric-cell">{{ fmtR2(row.r2) }}</span>
</template>
</el-table-column>
<el-table-column prop="created_at" label="评估时间" width="165" /> <el-table-column prop="created_at" label="评估时间" width="165" />
<el-table-column label="操作" width="120" fixed="right" align="center"> <el-table-column label="操作" width="120" fixed="right" align="center">
<template #default="{ row }"> <template #default="{ row }">
...@@ -228,30 +310,46 @@ onMounted(async () => { ...@@ -228,30 +310,46 @@ onMounted(async () => {
<el-dialog <el-dialog
v-model="dialogVisible" v-model="dialogVisible"
:title="dialogRecord ? `${dialogRecord.model_name} — ${dialogRecord.package_name}` : '评估详情'" :title="dialogRecord ? `${dialogRecord.model_name} — ${dialogRecord.package_name}` : '评估详情'"
width="860px" width="900px"
destroy-on-close destroy-on-close
> >
<div v-loading="dialogLoading" style="min-height: 120px"> <div v-loading="dialogLoading" style="min-height: 120px">
<template v-if="dialogRecord"> <template v-if="dialogRecord">
<div class="metrics-row" style="margin-bottom: 12px"> <!-- 样本数 / 评估时间行 -->
<div class="metrics-row meta-row" style="margin-bottom: 8px">
<div class="metric-chip"> <div class="metric-chip">
<span class="metric-label">样本数</span> <span class="metric-label">样本数</span>
<span class="metric-value">{{ dialogRecord.total_count }}</span> <span class="metric-value">{{ dialogRecord.total_count }}</span>
</div> </div>
<div class="metric-chip"> <div class="metric-chip">
<span class="metric-label">MAE</span> <span class="metric-label">评估时间</span>
<span class="metric-value">{{ fmtMetric(dialogRecord.mae) }}</span> <span class="metric-value">{{ dialogRecord.created_at }}</span>
</div>
</div>
<!-- 评估指标行 -->
<div class="metrics-row" style="margin-bottom: 12px">
<div class="metric-chip">
<span class="metric-label">MSE(均方误差)</span>
<span class="metric-value">{{ fmtMetric(dialogRecord.mse) }}</span>
</div> </div>
<div class="metric-chip"> <div class="metric-chip">
<span class="metric-label">RMSE</span> <span class="metric-label">RMSE(均方根误差)</span>
<span class="metric-value">{{ fmtMetric(dialogRecord.rmse) }}</span> <span class="metric-value">{{ fmtMetric(dialogRecord.rmse) }}</span>
</div> </div>
<div class="metric-chip"> <div class="metric-chip">
<span class="metric-label">评估时间</span> <span class="metric-label">MAE(平均绝对误差)</span>
<span class="metric-value">{{ dialogRecord.created_at }}</span> <span class="metric-value">{{ fmtMetric(dialogRecord.mae) }}</span>
</div>
<div class="metric-chip">
<span class="metric-label">MAPE(平均绝对百分比误差)</span>
<span class="metric-value">{{ fmtMape(dialogRecord.mape) }}</span>
</div>
<div class="metric-chip">
<span class="metric-label">R²(决定系数)</span>
<span class="metric-value">{{ fmtR2(dialogRecord.r2) }}</span>
</div> </div>
</div> </div>
<EvalChart :chart-data="dialogRecord.chart_data || []" height="380px" /> <EvalChart :chart-data="dialogRecord.chart_data || []" height="400px" />
</template> </template>
</div> </div>
...@@ -259,6 +357,55 @@ onMounted(async () => { ...@@ -259,6 +357,55 @@ onMounted(async () => {
<el-button @click="dialogVisible = false">关闭</el-button> <el-button @click="dialogVisible = false">关闭</el-button>
</template> </template>
</el-dialog> </el-dialog>
<!-- ── package picker dialog ─────────────────────────────────────────── -->
<el-dialog
v-model="pickerVisible"
title="选择评估数据包"
width="640px"
:close-on-click-modal="false"
>
<div class="picker-filters">
<el-select
v-model="pickerCategoryId"
placeholder="所有分类"
clearable
style="width:180px"
@change="loadPickerPackages"
>
<el-option
v-for="opt in pickerCategories"
:key="opt.value"
:label="opt.label"
:value="opt.value"
/>
</el-select>
<el-input
v-model="pickerSearch"
placeholder="数据包名称 / 备注"
:prefix-icon="Search"
clearable
style="flex:1"
@keyup.enter="loadPickerPackages"
@clear="loadPickerPackages"
/>
<el-button type="primary" @click="loadPickerPackages">搜索</el-button>
</div>
<el-table
:data="pickerPackages"
v-loading="pickerLoading"
border
stripe
highlight-current-row
height="360px"
style="cursor:pointer;margin-top:10px"
@row-click="selectPackage"
>
<el-table-column prop="name" label="数据包名称" show-overflow-tooltip />
<el-table-column prop="data_count" label="数据量" width="90" align="center" />
</el-table>
</el-dialog>
</div> </div>
</template> </template>
...@@ -309,6 +456,18 @@ onMounted(async () => { ...@@ -309,6 +456,18 @@ onMounted(async () => {
} }
} }
.pkg-picker-row {
display: flex;
gap: 8px;
align-items: center;
}
.picker-filters {
display: flex;
gap: 8px;
align-items: center;
}
.eval-loading { .eval-loading {
display: flex; display: flex;
align-items: center; align-items: center;
...@@ -325,6 +484,15 @@ onMounted(async () => { ...@@ -325,6 +484,15 @@ onMounted(async () => {
padding: 16px 0 12px; padding: 16px 0 12px;
} }
.meta-row {
padding-bottom: 4px;
.metric-chip {
background: #f1f5f9;
border-color: #cbd5e1;
}
}
.metric-chip { .metric-chip {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
...@@ -357,3 +525,4 @@ onMounted(async () => { ...@@ -357,3 +525,4 @@ onMounted(async () => {
font-weight: 500; font-weight: 500;
} }
</style> </style>
<script setup> <script setup>
import { Refresh } from '@element-plus/icons-vue' import { Refresh } from '@element-plus/icons-vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import { onMounted, ref } from 'vue' import { onMounted, reactive, ref } from 'vue'
import { deleteSavedModel, getSavedModels } from '@/api/trainManagement' import { useRouter } from 'vue-router'
import { deleteSavedModel, getSavedModels, updateSavedModel } from '@/api/trainManagement'
const router = useRouter()
const models = ref([]) const models = ref([])
const loading = ref(false) const loading = ref(false)
...@@ -16,6 +19,43 @@ const loadModels = async () => { ...@@ -16,6 +19,43 @@ const loadModels = async () => {
} }
} }
// ── edit dialog ───────────────────────────────────────────────────────────────
const editVisible = ref(false)
const editRow = ref(null)
const editForm = reactive({ model_name: '', description: '' })
const editSaving = ref(false)
const handleEdit = (row) => {
editRow.value = row
editForm.model_name = row.model_name
editForm.description = row.description || ''
editVisible.value = true
}
const confirmEdit = async () => {
if (!editForm.model_name.trim()) return ElMessage.warning('模型名称不能为空')
editSaving.value = true
try {
await updateSavedModel(editRow.value.id, {
model_name: editForm.model_name.trim(),
description: editForm.description.trim(),
})
ElMessage.success('已更新')
editVisible.value = false
await loadModels()
} catch (e) {
ElMessage.error(e?.message || '更新失败')
} finally {
editSaving.value = false
}
}
// ── evaluate shortcut ─────────────────────────────────────────────────────────
const handleEvaluate = (row) => {
router.push({ name: 'model-evaluation', query: { model_id: row.id } })
}
// ── delete ────────────────────────────────────────────────────────────────────
const handleDelete = async (row) => { const handleDelete = async (row) => {
try { try {
await ElMessageBox.confirm(`确定删除模型"${row.model_name}"吗?删除后无法恢复。`, '提示', { await ElMessageBox.confirm(`确定删除模型"${row.model_name}"吗?删除后无法恢复。`, '提示', {
...@@ -29,14 +69,16 @@ const handleDelete = async (row) => { ...@@ -29,14 +69,16 @@ const handleDelete = async (row) => {
} }
} }
// ── display helpers ───────────────────────────────────────────────────────────
const formatParams = (params) => { const formatParams = (params) => {
if (!params) return '-' if (!params) return '-'
return [ return [
`seq=${params.seq_len}`, `序列长度 ${params.seq_len}`,
`hidden=${params.hidden_size}`, `隐藏层 ${params.hidden_size}`,
`layers=${params.num_layers}`, `层数 ${params.num_layers}`,
`epochs=${params.epochs}`, `轮数 ${params.epochs}`,
`lr=${params.learning_rate}`, `批次 ${params.batch_size}`,
`学习率 ${params.learning_rate}`,
].join(' / ') ].join(' / ')
} }
...@@ -65,10 +107,15 @@ onMounted(loadModels) ...@@ -65,10 +107,15 @@ onMounted(loadModels)
<el-table :data="models" v-loading="loading" border stripe height="calc(100vh - 160px)"> <el-table :data="models" v-loading="loading" border stripe height="calc(100vh - 160px)">
<el-table-column type="index" width="55" label="#" align="center" /> <el-table-column type="index" width="55" label="#" align="center" />
<el-table-column prop="model_name" label="模型名称" min-width="150" show-overflow-tooltip /> <el-table-column prop="model_name" label="模型名称" min-width="140" show-overflow-tooltip />
<el-table-column prop="description" label="说明" min-width="160" show-overflow-tooltip>
<template #default="{ row }">
<span class="desc-text">{{ row.description || '—' }}</span>
</template>
</el-table-column>
<el-table-column prop="package_name" label="训练数据包" min-width="140" show-overflow-tooltip /> <el-table-column prop="package_name" label="训练数据包" min-width="140" show-overflow-tooltip />
<el-table-column label="LSTM 参数" min-width="280" show-overflow-tooltip> <el-table-column label="LSTM 参数" min-width="310" show-overflow-tooltip>
<template #default="{ row }"> <template #default="{ row }">
<el-tooltip :content="formatParams(row.params)" placement="top"> <el-tooltip :content="formatParams(row.params)" placement="top">
<span class="params-text">{{ formatParams(row.params) }}</span> <span class="params-text">{{ formatParams(row.params) }}</span>
...@@ -84,8 +131,10 @@ onMounted(loadModels) ...@@ -84,8 +131,10 @@ onMounted(loadModels)
<el-table-column prop="created_at" label="保存时间" width="170" /> <el-table-column prop="created_at" label="保存时间" width="170" />
<el-table-column label="操作" width="90" fixed="right" align="center"> <el-table-column label="操作" width="160" fixed="right" align="center">
<template #default="{ row }"> <template #default="{ row }">
<el-button link type="primary" @click="handleEdit(row)">编辑</el-button>
<el-button link type="success" @click="handleEvaluate(row)">评估</el-button>
<el-button link type="danger" @click="handleDelete(row)">删除</el-button> <el-button link type="danger" @click="handleDelete(row)">删除</el-button>
</template> </template>
</el-table-column> </el-table-column>
...@@ -98,6 +147,29 @@ onMounted(loadModels) ...@@ -98,6 +147,29 @@ onMounted(loadModels)
/> />
</el-card> </el-card>
</div> </div>
<!-- ── edit dialog ──────────────────────────────────────────────── -->
<el-dialog v-model="editVisible" title="编辑模型信息" width="480px" :close-on-click-modal="false">
<el-form :model="editForm" label-position="top">
<el-form-item label="模型名称" required>
<el-input v-model="editForm.model_name" maxlength="100" show-word-limit />
</el-form-item>
<el-form-item label="说明">
<el-input
v-model="editForm.description"
type="textarea"
:rows="3"
placeholder="模型用途、训练说明等"
maxlength="500"
show-word-limit
/>
</el-form-item>
</el-form>
<template #footer>
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" :loading="editSaving" @click="confirmEdit">保存</el-button>
</template>
</el-dialog>
</template> </template>
<style lang="scss" scoped> <style lang="scss" scoped>
...@@ -144,7 +216,8 @@ onMounted(loadModels) ...@@ -144,7 +216,8 @@ onMounted(loadModels)
} }
.params-text, .params-text,
.loss-text { .loss-text,
.desc-text {
font-size: 12px; font-size: 12px;
color: var(--text-secondary); color: var(--text-secondary);
white-space: nowrap; white-space: nowrap;
...@@ -158,3 +231,5 @@ onMounted(loadModels) ...@@ -158,3 +231,5 @@ onMounted(loadModels)
font-weight: 500; font-weight: 500;
} }
</style> </style>
<script setup> <script setup>
import { Refresh } from '@element-plus/icons-vue' import * as echarts from 'echarts'
import { Refresh, Search } from '@element-plus/icons-vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import { computed, onBeforeUnmount, onMounted, reactive, ref } from 'vue' import { computed, nextTick, onBeforeUnmount, onMounted, reactive, ref, watch } from 'vue'
import { import {
cancelTrainTask, cancelTrainTask,
createTrainTask, createTrainTask,
deleteTrainTask, deleteTrainTask,
getTrainPackages,
getTrainTasks, getTrainTasks,
restartTrainTask, restartTrainTask,
saveTrainModel, saveTrainModel,
} from '@/api/trainManagement' } from '@/api/trainManagement'
import { getPackages, getPkgCategoryTree } from '@/api/packageManagement'
// ── form ────────────────────────────────────────────────────────────────────── // ── form ──────────────────────────────────────────────────────────────────────
const form = reactive({ const form = reactive({
model_name: '', model_name: '',
package_id: '', train_package_id: null,
train_package_name: '',
val_package_id: null,
val_package_name: '',
params: { params: {
seq_len: 20, seq_len: 20,
hidden_size: 64, hidden_size: 64,
...@@ -26,9 +30,62 @@ const form = reactive({ ...@@ -26,9 +30,62 @@ const form = reactive({
}, },
}) })
const packages = ref([])
const submitting = ref(false) const submitting = ref(false)
// ── package picker modal ──────────────────────────────────────────────────────
const pickerVisible = ref(false)
const pickerMode = ref('train') // 'train' | 'val'
const pickerSearch = ref('')
const pickerCategoryId = ref('')
const pickerCategories = ref([])
const pickerPackages = ref([])
const pickerLoading = ref(false)
const openPicker = async (mode) => {
pickerMode.value = mode
pickerSearch.value = ''
pickerCategoryId.value = ''
pickerVisible.value = true
if (!pickerCategories.value.length) {
try {
const tree = await getPkgCategoryTree()
const flatten = (nodes, result = []) => {
for (const n of nodes) {
result.push({ value: String(n.id), label: n.name })
if (n.children?.length) flatten(n.children, result)
}
return result
}
pickerCategories.value = flatten(tree)
} catch {}
}
await loadPickerPackages()
}
const loadPickerPackages = async () => {
pickerLoading.value = true
try {
const data = await getPackages({
category_id: pickerCategoryId.value || '',
name: pickerSearch.value.trim(),
})
pickerPackages.value = data
} finally {
pickerLoading.value = false
}
}
const selectPackage = (pkg) => {
if (pickerMode.value === 'train') {
form.train_package_id = pkg.id
form.train_package_name = pkg.name
} else {
form.val_package_id = pkg.id
form.val_package_name = pkg.name
}
pickerVisible.value = false
}
// ── tasks table ─────────────────────────────────────────────────────────────── // ── tasks table ───────────────────────────────────────────────────────────────
const tasks = ref([]) const tasks = ref([])
const loadingTasks = ref(false) const loadingTasks = ref(false)
...@@ -55,6 +112,11 @@ const loadTasks = async () => { ...@@ -55,6 +112,11 @@ const loadTasks = async () => {
tasks.value = await getTrainTasks() tasks.value = await getTrainTasks()
if (hasActiveTasks.value) { if (hasActiveTasks.value) {
startPolling() startPolling()
// refresh selected task detail if still active
if (selectedTask.value && ['pending', 'running'].includes(selectedTask.value.status)) {
const fresh = tasks.value.find((t) => t.id === selectedTask.value.id)
if (fresh) selectedTask.value = fresh
}
} else { } else {
stopPolling() stopPolling()
} }
...@@ -64,20 +126,16 @@ const loadTasks = async () => { ...@@ -64,20 +126,16 @@ const loadTasks = async () => {
} }
const handleStartTraining = async () => { const handleStartTraining = async () => {
if (!form.model_name.trim()) { if (!form.model_name.trim()) return ElMessage.warning('请输入模型名称')
ElMessage.warning('请输入模型名称') if (!form.train_package_id) return ElMessage.warning('请选择训练集数据包')
return if (!form.val_package_id) return ElMessage.warning('请选择验证集数据包')
}
if (!form.package_id) {
ElMessage.warning('请选择数据包')
return
}
submitting.value = true submitting.value = true
try { try {
await createTrainTask({ await createTrainTask({
model_name: form.model_name.trim(), model_name: form.model_name.trim(),
package_id: form.package_id, train_package_id: form.train_package_id,
val_package_id: form.val_package_id,
params: { ...form.params }, params: { ...form.params },
}) })
ElMessage.success('训练任务已提交') ElMessage.success('训练任务已提交')
...@@ -100,21 +158,47 @@ const handleCancel = async (task) => { ...@@ -100,21 +158,47 @@ const handleCancel = async (task) => {
const handleRestart = async (task) => { const handleRestart = async (task) => {
try { try {
await ElMessageBox.confirm(
`确定重新训练"${task.model_name}"?将使用相同配置和数据包新建一个训练任务。`,
'确认重新训练',
{ type: 'warning', confirmButtonText: '确认', cancelButtonText: '取消' },
)
await restartTrainTask(task.id) await restartTrainTask(task.id)
ElMessage.success('重新训练已启动') ElMessage.success('重新训练已启动')
await loadTasks() await loadTasks()
} catch (e) { } catch {
ElMessage.error(e?.message || '重启失败') // user cancelled
} }
} }
const handleSave = async (task) => { // ── save model modal ──────────────────────────────────────────────────────────
const saveModalVisible = ref(false)
const saveModalTask = ref(null)
const saveForm = reactive({ model_name: '', description: '' })
const saving = ref(false)
const handleSave = (task) => {
saveModalTask.value = task
saveForm.model_name = task.model_name
saveForm.description = ''
saveModalVisible.value = true
}
const confirmSave = async () => {
if (!saveForm.model_name.trim()) return ElMessage.warning('请输入模型名称')
saving.value = true
try { try {
await saveTrainModel(task.id) await saveTrainModel(saveModalTask.value.id, {
model_name: saveForm.model_name.trim(),
description: saveForm.description.trim(),
})
ElMessage.success('模型已保存,可在模型列表中查看') ElMessage.success('模型已保存,可在模型列表中查看')
saveModalVisible.value = false
await loadTasks() await loadTasks()
} catch (e) { } catch (e) {
ElMessage.error(e?.message || '保存失败') ElMessage.error(e?.message || '保存失败')
} finally {
saving.value = false
} }
} }
...@@ -125,12 +209,72 @@ const handleDelete = async (task) => { ...@@ -125,12 +209,72 @@ const handleDelete = async (task) => {
}) })
await deleteTrainTask(task.id) await deleteTrainTask(task.id)
ElMessage.success('已删除') ElMessage.success('已删除')
if (selectedTask.value?.id === task.id) selectedTask.value = null
await loadTasks() await loadTasks()
} catch { } catch {
// user cancelled // user cancelled
} }
} }
// ── task detail (epoch logs) ──────────────────────────────────────────────────
const selectedTask = ref(null)
const detailMode = ref('table') // 'table' | 'chart'
const chartRef = ref(null)
let chartInstance = null
const epochLogs = computed(() => selectedTask.value?.epoch_logs || [])
const handleRowClick = (row) => {
selectedTask.value = row
detailMode.value = 'table'
}
const renderEpochChart = () => {
if (!chartRef.value || detailMode.value !== 'chart') return
if (!chartInstance) chartInstance = echarts.init(chartRef.value)
const logs = epochLogs.value
const epochs = logs.map((l) => l.epoch)
const trainLoss = logs.map((l) => l.train_loss)
const valLoss = logs.map((l) => l.val_loss)
chartInstance.setOption(
{
animation: false,
color: ['#409EFF', '#F56C6C'],
tooltip: {
trigger: 'axis',
formatter(params) {
const e = params[0]?.axisValue
const lines = [`<b>Epoch ${e}</b>`]
params.forEach((p) => {
if (p.data != null)
lines.push(
`<span style="color:${p.color}">● </span>${p.seriesName}: ${Number(p.data).toExponential(4)}`,
)
})
return lines.join('<br/>')
},
},
legend: { data: ['训练损失', '验证损失'], top: 4 },
grid: { left: 70, right: 20, top: 36, bottom: 36 },
xAxis: { type: 'category', data: epochs, name: 'Epoch', nameLocation: 'end' },
yAxis: { type: 'value', name: 'Loss', scale: true },
series: [
{ name: '训练损失', type: 'line', data: trainLoss, smooth: true, symbol: 'none' },
{ name: '验证损失', type: 'line', data: valLoss, smooth: true, symbol: 'none' },
],
},
true,
)
}
watch(detailMode, (val) => {
if (val === 'chart') nextTick(() => { renderEpochChart(); chartInstance?.resize() })
})
watch(epochLogs, () => {
if (detailMode.value === 'chart') nextTick(renderEpochChart)
})
// ── display helpers ─────────────────────────────────────────────────────────── // ── display helpers ───────────────────────────────────────────────────────────
const STATUS_MAP = { const STATUS_MAP = {
pending: { type: 'info', text: '等待中' }, pending: { type: 'info', text: '等待中' },
...@@ -158,23 +302,22 @@ const formatLoss = (task) => { ...@@ -158,23 +302,22 @@ const formatLoss = (task) => {
if (task.train_loss == null) return '-' if (task.train_loss == null) return '-'
const parts = [`训练: ${Number(task.train_loss).toFixed(5)}`] const parts = [`训练: ${Number(task.train_loss).toFixed(5)}`]
if (task.val_loss != null) parts.push(`验证: ${Number(task.val_loss).toFixed(5)}`) if (task.val_loss != null) parts.push(`验证: ${Number(task.val_loss).toFixed(5)}`)
if (task.test_loss != null) parts.push(`测试: ${Number(task.test_loss).toFixed(5)}`)
return parts.join(' / ') return parts.join(' / ')
} }
onMounted(async () => { onMounted(async () => {
loadingTasks.value = true loadingTasks.value = true
try { try {
await Promise.all([ await loadTasks()
getTrainPackages().then((d) => (packages.value = d)),
loadTasks(),
])
} finally { } finally {
loadingTasks.value = false loadingTasks.value = false
} }
}) })
onBeforeUnmount(stopPolling) onBeforeUnmount(() => {
stopPolling()
chartInstance?.dispose()
})
</script> </script>
<template> <template>
...@@ -186,24 +329,8 @@ onBeforeUnmount(stopPolling) ...@@ -186,24 +329,8 @@ onBeforeUnmount(stopPolling)
</template> </template>
<el-form :model="form" label-position="top" class="train-form"> <el-form :model="form" label-position="top" class="train-form">
<!-- row 1: package + model name --> <!-- row 1: model name -->
<div class="form-row"> <div class="form-row">
<el-form-item label="选择数据包" required class="form-item-wide">
<el-select
v-model="form.package_id"
placeholder="请选择数据包"
filterable
style="width: 100%"
>
<el-option
v-for="pkg in packages"
:key="pkg.id"
:label="`${pkg.name}(${pkg.data_count} 条)`"
:value="pkg.id"
/>
</el-select>
</el-form-item>
<el-form-item label="模型名称" required class="form-item-wide"> <el-form-item label="模型名称" required class="form-item-wide">
<el-input <el-input
v-model="form.model_name" v-model="form.model_name"
...@@ -214,101 +341,74 @@ onBeforeUnmount(stopPolling) ...@@ -214,101 +341,74 @@ onBeforeUnmount(stopPolling)
</el-form-item> </el-form-item>
</div> </div>
<!-- row 2: LSTM params --> <!-- row 2: package pickers -->
<div class="form-row">
<el-form-item label="训练集数据包" required class="form-item-wide">
<div class="pkg-picker-row">
<el-input
:model-value="form.train_package_name || ''"
placeholder="点击右侧按钮选择训练集"
readonly
class="pkg-input"
/>
<el-button @click="openPicker('train')">选择</el-button>
</div>
</el-form-item>
<el-form-item label="验证集数据包" required class="form-item-wide">
<div class="pkg-picker-row">
<el-input
:model-value="form.val_package_name || ''"
placeholder="点击右侧按钮选择验证集"
readonly
class="pkg-input"
/>
<el-button @click="openPicker('val')">选择</el-button>
</div>
</el-form-item>
</div>
<!-- row 3: LSTM params -->
<div class="params-section"> <div class="params-section">
<span class="params-label">LSTM 超参数</span> <span class="params-label">LSTM 超参数</span>
<div class="params-grid"> <div class="params-grid">
<el-form-item label="序列长度"> <el-form-item label="序列长度">
<el-input-number <el-input-number v-model="form.params.seq_len" :min="5" :max="500" :step="5" controls-position="right" style="width:100%" />
v-model="form.params.seq_len"
:min="5"
:max="500"
:step="5"
controls-position="right"
style="width: 100%"
/>
</el-form-item> </el-form-item>
<el-form-item label="隐藏层大小"> <el-form-item label="隐藏层大小">
<el-input-number <el-input-number v-model="form.params.hidden_size" :min="8" :max="1024" :step="8" controls-position="right" style="width:100%" />
v-model="form.params.hidden_size"
:min="8"
:max="1024"
:step="8"
controls-position="right"
style="width: 100%"
/>
</el-form-item> </el-form-item>
<el-form-item label="LSTM 层数"> <el-form-item label="LSTM 层数">
<el-input-number <el-input-number v-model="form.params.num_layers" :min="1" :max="8" controls-position="right" style="width:100%" />
v-model="form.params.num_layers"
:min="1"
:max="8"
controls-position="right"
style="width: 100%"
/>
</el-form-item> </el-form-item>
<el-form-item label="训练轮数 (Epochs)"> <el-form-item label="训练轮数 (Epochs)">
<el-input-number <el-input-number v-model="form.params.epochs" :min="1" :max="2000" :step="10" controls-position="right" style="width:100%" />
v-model="form.params.epochs"
:min="1"
:max="2000"
:step="10"
controls-position="right"
style="width: 100%"
/>
</el-form-item> </el-form-item>
<el-form-item label="批次大小 (Batch)"> <el-form-item label="批次大小 (Batch)">
<el-input-number <el-input-number v-model="form.params.batch_size" :min="1" :max="512" :step="8" controls-position="right" style="width:100%" />
v-model="form.params.batch_size"
:min="1"
:max="512"
:step="8"
controls-position="right"
style="width: 100%"
/>
</el-form-item> </el-form-item>
<el-form-item label="学习率"> <el-form-item label="学习率">
<el-input-number <el-input-number v-model="form.params.learning_rate" :min="0.00001" :max="1" :step="0.0001" :precision="5" controls-position="right" style="width:100%" />
v-model="form.params.learning_rate"
:min="0.00001"
:max="1"
:step="0.0001"
:precision="5"
controls-position="right"
style="width: 100%"
/>
</el-form-item> </el-form-item>
</div> </div>
</div> </div>
<!-- action --> <!-- action -->
<div class="form-action"> <div class="form-action">
<el-button <el-button type="primary" size="large" :loading="submitting" @click="handleStartTraining">
type="primary"
size="large"
:loading="submitting"
@click="handleStartTraining"
>
开始训练 开始训练
</el-button> </el-button>
</div> </div>
</el-form> </el-form>
</el-card> </el-card>
<!-- ── tasks table ──────────────────────────────────────────────────── --> <!-- ── bottom: tasks list + epoch detail ────────────────────────────── -->
<div class="bottom-area">
<!-- tasks card -->
<el-card class="tasks-card" shadow="hover"> <el-card class="tasks-card" shadow="hover">
<template #header> <template #header>
<div class="card-header-row"> <div class="card-header-row">
<span class="card-title">训练记录</span> <span class="card-title">训练记录</span>
<el-button <el-button :icon="Refresh" size="small" plain :loading="loadingTasks" @click="loadTasks">刷新</el-button>
:icon="Refresh"
size="small"
plain
:loading="loadingTasks"
@click="loadTasks"
>
刷新
</el-button>
</div> </div>
</template> </template>
...@@ -317,99 +417,179 @@ onBeforeUnmount(stopPolling) ...@@ -317,99 +417,179 @@ onBeforeUnmount(stopPolling)
v-loading="loadingTasks" v-loading="loadingTasks"
border border
stripe stripe
:height="tableHeight" highlight-current-row
height="100%"
style="cursor:pointer"
@row-click="handleRowClick"
> >
<el-table-column prop="model_name" label="模型名称" min-width="140" show-overflow-tooltip /> <el-table-column prop="model_name" label="模型名称" min-width="130" show-overflow-tooltip />
<el-table-column prop="package_name" label="数据包" min-width="130" show-overflow-tooltip /> <el-table-column label="训练集" min-width="120" show-overflow-tooltip>
<template #default="{ row }">{{ row.package_name }}</template>
<el-table-column label="参数" min-width="180" show-overflow-tooltip>
<template #default="{ row }">
<el-tooltip :content="formatParams(row.params)" placement="top">
<span class="params-cell">{{ formatParams(row.params) }}</span>
</el-tooltip>
</template>
</el-table-column> </el-table-column>
<el-table-column label="验证集" min-width="120" show-overflow-tooltip>
<el-table-column label="状态" width="100" align="center"> <template #default="{ row }">{{ row.val_package_name || '-' }}</template>
</el-table-column>
<el-table-column label="状态" width="90" align="center">
<template #default="{ row }"> <template #default="{ row }">
<el-tag :type="getStatusTag(row.status).type" size="small"> <el-tag :type="getStatusTag(row.status).type" size="small">
{{ getStatusTag(row.status).text }} {{ getStatusTag(row.status).text }}
</el-tag> </el-tag>
</template> </template>
</el-table-column> </el-table-column>
<el-table-column label="进度" width="140" align="center">
<el-table-column label="进度" width="150" align="center">
<template #default="{ row }"> <template #default="{ row }">
<template v-if="row.status === 'running'"> <el-progress v-if="row.status === 'running'" :percentage="row.progress" :stroke-width="6" style="width:110px" />
<el-progress <el-progress v-else-if="row.status === 'completed'" :percentage="100" status="success" :stroke-width="6" style="width:110px" />
:percentage="row.progress"
:stroke-width="6"
:show-text="true"
style="width: 120px"
/>
</template>
<template v-else-if="row.status === 'completed'">
<el-progress
:percentage="100"
status="success"
:stroke-width="6"
style="width: 120px"
/>
</template>
<span v-else class="muted"></span> <span v-else class="muted"></span>
</template> </template>
</el-table-column> </el-table-column>
<el-table-column label="损失" min-width="180" show-overflow-tooltip>
<el-table-column label="损失" min-width="200" show-overflow-tooltip>
<template #default="{ row }"> <template #default="{ row }">
<span v-if="row.status === 'failed'" class="err-text"> <span v-if="row.status === 'failed'" class="err-text">{{ row.error_msg || '未知错误' }}</span>
{{ row.error_msg || '未知错误' }}
</span>
<span v-else>{{ formatLoss(row) }}</span> <span v-else>{{ formatLoss(row) }}</span>
</template> </template>
</el-table-column> </el-table-column>
<el-table-column prop="created_at" label="创建时间" width="155" />
<el-table-column prop="created_at" label="创建时间" width="160" /> <el-table-column label="操作" width="190" fixed="right">
<el-table-column label="操作" width="210" fixed="right">
<template #default="{ row }"> <template #default="{ row }">
<!-- running / pending -->
<template v-if="row.status === 'running' || row.status === 'pending'"> <template v-if="row.status === 'running' || row.status === 'pending'">
<el-button link type="warning" @click="handleCancel(row)">取消</el-button> <el-button link type="warning" @click.stop="handleCancel(row)">取消</el-button>
</template> </template>
<!-- completed -->
<template v-else-if="row.status === 'completed'"> <template v-else-if="row.status === 'completed'">
<el-button <el-button v-if="!row.is_saved" link type="primary" @click.stop="handleSave(row)">保存模型</el-button>
v-if="!row.is_saved" <el-tag v-else type="success" size="small" style="margin-right:6px">已保存</el-tag>
link <el-button link type="info" @click.stop="handleRestart(row)">重新训练</el-button>
type="primary" <el-button link type="danger" @click.stop="handleDelete(row)">删除</el-button>
@click="handleSave(row)"
>
保存模型
</el-button>
<el-tag v-else type="success" size="small" style="margin-right: 6px">已保存</el-tag>
<el-button link type="info" @click="handleRestart(row)">重新训练</el-button>
<el-button link type="danger" @click="handleDelete(row)">删除</el-button>
</template> </template>
<!-- failed / cancelled -->
<template v-else> <template v-else>
<el-button link type="info" @click="handleRestart(row)">重新训练</el-button> <el-button link type="info" @click.stop="handleRestart(row)">重新训练</el-button>
<el-button link type="danger" @click="handleDelete(row)">删除</el-button> <el-button link type="danger" @click.stop="handleDelete(row)">删除</el-button>
</template> </template>
</template> </template>
</el-table-column> </el-table-column>
</el-table> </el-table>
</el-card> </el-card>
<!-- epoch detail panel -->
<el-card v-if="selectedTask" class="detail-card" shadow="hover">
<template #header>
<div class="card-header-row">
<span class="card-title">训练过程 — {{ selectedTask.model_name }}</span>
<el-radio-group v-model="detailMode" size="small">
<el-radio-button value="table">表格</el-radio-button>
<el-radio-button value="chart">曲线</el-radio-button>
</el-radio-group>
</div> </div>
</template> </template>
<script> <div class="detail-body">
// tableHeight is a non-reactive calculation; compute once <!-- table mode -->
const tableHeight = 'calc(100vh - 520px)' <el-table
export default {} v-if="detailMode === 'table'"
</script> :data="epochLogs"
border
stripe
height="100%"
size="small"
>
<el-table-column prop="epoch" label="轮次" width="70" align="center" />
<el-table-column label="训练损失" align="right">
<template #default="{ row }">
{{ row.train_loss != null ? Number(row.train_loss).toExponential(4) : '-' }}
</template>
</el-table-column>
<el-table-column label="验证损失" align="right">
<template #default="{ row }">
{{ row.val_loss != null ? Number(row.val_loss).toExponential(4) : '-' }}
</template>
</el-table-column>
</el-table>
<!-- chart mode -->
<div v-show="detailMode === 'chart'" ref="chartRef" class="epoch-chart" />
</div>
</el-card>
</div>
</div>
<!-- ── package picker dialog ────────────────────────────────────────── -->
<el-dialog
v-model="pickerVisible"
:title="pickerMode === 'train' ? '选择训练集数据包' : '选择验证集数据包'"
width="640px"
:close-on-click-modal="false"
>
<div class="picker-filters">
<el-select
v-model="pickerCategoryId"
placeholder="所有分类"
clearable
style="width:180px"
@change="loadPickerPackages"
>
<el-option
v-for="opt in pickerCategories"
:key="opt.value"
:label="opt.label"
:value="opt.value"
/>
</el-select>
<el-input
v-model="pickerSearch"
placeholder="数据包名称"
:prefix-icon="Search"
clearable
style="flex:1"
@keyup.enter="loadPickerPackages"
@clear="loadPickerPackages"
/>
<el-button type="primary" @click="loadPickerPackages">搜索</el-button>
</div>
<el-table
:data="pickerPackages"
v-loading="pickerLoading"
border
stripe
highlight-current-row
height="360px"
style="cursor:pointer;margin-top:10px"
@row-click="selectPackage"
>
<el-table-column prop="name" label="数据包名称" show-overflow-tooltip />
<el-table-column prop="data_count" label="数据量" width="90" align="center" />
<el-table-column prop="created_at" label="创建时间" width="160" />
</el-table>
</el-dialog>
<!-- ── save model dialog ────────────────────────────────────────── -->
<el-dialog
v-model="saveModalVisible"
title="保存模型"
width="480px"
:close-on-click-modal="false"
>
<el-form :model="saveForm" label-position="top">
<el-form-item label="模型名称" required>
<el-input v-model="saveForm.model_name" placeholder="请输入模型名称" maxlength="100" show-word-limit />
</el-form-item>
<el-form-item label="说明">
<el-input
v-model="saveForm.description"
type="textarea"
:rows="3"
placeholder="可选,填写模型用途、训练说明等"
maxlength="500"
show-word-limit
/>
</el-form-item>
</el-form>
<template #footer>
<el-button @click="saveModalVisible = false">取消</el-button>
<el-button type="primary" :loading="saving" @click="confirmSave">保存</el-button>
</template>
</el-dialog>
</template>
<style lang="scss" scoped> <style lang="scss" scoped>
.train-page { .train-page {
...@@ -417,13 +597,13 @@ export default {} ...@@ -417,13 +597,13 @@ export default {}
display: flex; display: flex;
flex-direction: column; flex-direction: column;
gap: 12px; gap: 12px;
overflow-y: auto; overflow: hidden;
background: var(--bg-page); background: var(--bg-page);
} }
.config-card, .config-card,
.tasks-card { .tasks-card,
flex-shrink: 0; .detail-card {
border: 1px solid var(--border-color); border: 1px solid var(--border-color);
box-shadow: var(--shadow-card) !important; box-shadow: var(--shadow-card) !important;
border-radius: 4px; border-radius: 4px;
...@@ -439,15 +619,45 @@ export default {} ...@@ -439,15 +619,45 @@ export default {}
:deep(.el-card__body) { :deep(.el-card__body) {
padding: 16px 20px; padding: 16px 20px;
height: calc(100% - 44px);
overflow: hidden;
} }
} }
.tasks-card { .config-card {
flex-shrink: 0;
}
.bottom-area {
flex: 1; flex: 1;
min-height: 0; min-height: 0;
display: flex;
gap: 12px;
}
.tasks-card {
flex: 1.4;
min-width: 0;
overflow: hidden; overflow: hidden;
} }
.detail-card {
flex: 1;
min-width: 320px;
overflow: hidden;
}
.detail-body {
height: 100%;
display: flex;
flex-direction: column;
}
.epoch-chart {
flex: 1;
min-height: 0;
}
.card-title { .card-title {
font-size: 14px; font-size: 14px;
font-weight: 600; font-weight: 600;
...@@ -479,6 +689,16 @@ export default {} ...@@ -479,6 +689,16 @@ export default {}
} }
} }
.pkg-picker-row {
display: flex;
gap: 8px;
align-items: center;
.pkg-input {
flex: 1;
}
}
.params-section { .params-section {
border-top: 1px solid var(--border-color); border-top: 1px solid var(--border-color);
padding-top: 12px; padding-top: 12px;
...@@ -509,13 +729,10 @@ export default {} ...@@ -509,13 +729,10 @@ export default {}
text-align: right; text-align: right;
} }
.params-cell { .picker-filters {
font-size: 12px; display: flex;
color: var(--text-secondary); gap: 8px;
white-space: nowrap; align-items: center;
overflow: hidden;
text-overflow: ellipsis;
display: block;
} }
.muted { .muted {
......
<script setup> <script setup>
import { ArrowLeft } from '@element-plus/icons-vue' import { ArrowLeft } from '@element-plus/icons-vue'
import { ElMessage } from 'element-plus' import { ElAutoResizer, ElMessage, ElTableV2 } from 'element-plus'
import { computed, onMounted, reactive, ref, watch } from 'vue' import { computed, onMounted, reactive, ref, watch } from 'vue'
import { createPackage, getAllDataFiles, getPkgCategoryTree, previewPackage } from '@/api/packageManagement' import { createPackage, getAllDataFiles, getPkgCategoryTree, previewPackage } from '@/api/packageManagement'
import { getQualityConfig } from '@/api/dataManagement' import { getCategoryTree, getQualityConfig } from '@/api/dataManagement'
import DataCurve from '@/views/DataManagement/components/DataCurve.vue' import DataCurve from '@/views/DataManagement/components/DataCurve.vue'
const emit = defineEmits(['cancel', 'saved']) const emit = defineEmits(['cancel', 'saved'])
...@@ -13,14 +13,19 @@ const allFiles = ref([]) ...@@ -13,14 +13,19 @@ const allFiles = ref([])
const selectedFileIds = ref([]) const selectedFileIds = ref([])
const fileTableRef = ref(null) const fileTableRef = ref(null)
const fileSearchText = ref('') const fileSearchText = ref('')
const fileFilterCategoryId = ref('')
const filteredFiles = computed(() => { const filteredFiles = computed(() => {
let list = allFiles.value
if (fileFilterCategoryId.value) {
list = list.filter((f) => String(f.category_id) === fileFilterCategoryId.value)
}
const kw = fileSearchText.value.trim().toLowerCase() const kw = fileSearchText.value.trim().toLowerCase()
if (!kw) return allFiles.value if (!kw) return list
return allFiles.value.filter( return list.filter(
(f) => (f) =>
f.filename.toLowerCase().includes(kw) || f.filename.toLowerCase().includes(kw) ||
(f.category_name || '').toLowerCase().includes(kw), (f.remark || '').toLowerCase().includes(kw),
) )
}) })
...@@ -30,6 +35,7 @@ const handleFileSelectionChange = (rows) => { ...@@ -30,6 +35,7 @@ const handleFileSelectionChange = (rows) => {
// ── categories ──────────────────────────────────────────────────────────────── // ── categories ────────────────────────────────────────────────────────────────
const categoryOptions = ref([]) const categoryOptions = ref([])
const dataCategoryOptions = ref([])
const loadCategories = async () => { const loadCategories = async () => {
const data = await getPkgCategoryTree() const data = await getPkgCategoryTree()
...@@ -39,6 +45,14 @@ const loadCategories = async () => { ...@@ -39,6 +45,14 @@ const loadCategories = async () => {
.map((item) => ({ label: item.name, value: String(item.id) })) .map((item) => ({ label: item.name, value: String(item.id) }))
} }
const loadDataCategories = async () => {
const data = await getCategoryTree()
const source = Array.isArray(data) ? data : []
dataCategoryOptions.value = source
.filter((item) => String(item?.id) !== 'all')
.map((item) => ({ label: item.name, value: String(item.id) }))
}
// ── form ────────────────────────────────────────────────────────────────────── // ── form ──────────────────────────────────────────────────────────────────────
const form = reactive({ const form = reactive({
categoryId: '', categoryId: '',
...@@ -48,8 +62,7 @@ const form = reactive({ ...@@ -48,8 +62,7 @@ const form = reactive({
// ── clean rules ─────────────────────────────────────────────────────────────── // ── clean rules ───────────────────────────────────────────────────────────────
const cleanRules = reactive({ const cleanRules = reactive({
enabled: false, enabled: false, newton_interp: false, current_min: null,
current_min: null,
current_max: null, current_max: null,
voltage_min: null, voltage_min: null,
voltage_max: null, voltage_max: null,
...@@ -111,6 +124,7 @@ const cleanRulesPayload = computed(() => { ...@@ -111,6 +124,7 @@ const cleanRulesPayload = computed(() => {
if (!cleanRules.enabled) return null if (!cleanRules.enabled) return null
return { return {
enabled: true, enabled: true,
newton_interp: cleanRules.newton_interp,
current_min: cleanRules.current_min ?? null, current_min: cleanRules.current_min ?? null,
current_max: cleanRules.current_max ?? null, current_max: cleanRules.current_max ?? null,
voltage_min: cleanRules.voltage_min ?? null, voltage_min: cleanRules.voltage_min ?? null,
...@@ -120,7 +134,19 @@ const cleanRulesPayload = computed(() => { ...@@ -120,7 +134,19 @@ const cleanRulesPayload = computed(() => {
} }
}) })
// ── row range ───────────────────────────────────────────────────────────────── // ── smooth ────────────────────────────────────────────────────────────────────
const smooth = reactive({
enabled: false,
window: 5,
})
const smoothPayload = computed(() => {
if (!smooth.enabled) return null
return { enabled: true, window: smooth.window ?? 5 }
})
// ── auto split ───────────────────────────────────────────────────────────────
const autoSplit = ref(false)
const rowRange = reactive({ const rowRange = reactive({
enabled: false, enabled: false,
start: null, start: null,
...@@ -149,6 +175,24 @@ const previewLoading = ref(false) ...@@ -149,6 +175,24 @@ const previewLoading = ref(false)
const previewRecords = ref([]) const previewRecords = ref([])
const previewTotal = ref(0) const previewTotal = ref(0)
const previewMode = ref('table') const previewMode = ref('table')
// 虚拟表格列定义
const previewColumns = [
{ key: 'time', dataKey: 'time', title: '时间', width: 160 },
{ key: 'current', dataKey: 'current', title: '电流', width: 100 },
{ key: 'voltage', dataKey: 'voltage', title: '电压', width: 100 },
{ key: 'set_temperature', dataKey: 'set_temperature', title: '设定温度', width: 110 },
{ key: 'actual_temperature', dataKey: 'actual_temperature', title: '实际温度', width: 110 },
]
// 曲线模式下最多采样 2000 个点,避免 canvas 过载
const curveRecords = computed(() => {
const records = previewRecords.value
if (records.length <= 2000) return records
const step = Math.ceil(records.length / 2000)
return records.filter((_, i) => i % step === 0)
})
let previewDebounceTimer = null let previewDebounceTimer = null
const triggerPreview = () => { const triggerPreview = () => {
...@@ -165,9 +209,9 @@ const triggerPreview = () => { ...@@ -165,9 +209,9 @@ const triggerPreview = () => {
{ {
file_ids: selectedFileIds.value, file_ids: selectedFileIds.value,
clean_rules: cleanRulesPayload.value, clean_rules: cleanRulesPayload.value,
smooth: smoothPayload.value,
...rowRangePayload.value, ...rowRangePayload.value,
}, },
{ limit: 300 },
) )
previewRecords.value = result.records previewRecords.value = result.records
previewTotal.value = result.count previewTotal.value = result.count
...@@ -179,6 +223,7 @@ const triggerPreview = () => { ...@@ -179,6 +223,7 @@ const triggerPreview = () => {
watch(selectedFileIds, triggerPreview, { deep: true }) watch(selectedFileIds, triggerPreview, { deep: true })
watch(cleanRules, triggerPreview, { deep: true }) watch(cleanRules, triggerPreview, { deep: true })
watch(smooth, triggerPreview, { deep: true })
watch(rowRange, triggerPreview, { deep: true }) watch(rowRange, triggerPreview, { deep: true })
// ── save ────────────────────────────────────────────────────────────────────── // ── save ──────────────────────────────────────────────────────────────────────
...@@ -206,9 +251,11 @@ const handleGenerate = async () => { ...@@ -206,9 +251,11 @@ const handleGenerate = async () => {
remark: form.remark.trim() || null, remark: form.remark.trim() || null,
file_ids: selectedFileIds.value, file_ids: selectedFileIds.value,
clean_rules: cleanRulesPayload.value, clean_rules: cleanRulesPayload.value,
smooth: smoothPayload.value,
auto_split: autoSplit.value,
...rowRangePayload.value, ...rowRangePayload.value,
}) })
ElMessage.success('数据包创建成功') ElMessage.success(autoSplit.value ? '数据包已自动划分为训练集/验证集/测试集' : '数据包创建成功')
emit('saved') emit('saved')
} finally { } finally {
saving.value = false saving.value = false
...@@ -217,7 +264,7 @@ const handleGenerate = async () => { ...@@ -217,7 +264,7 @@ const handleGenerate = async () => {
// ── init ────────────────────────────────────────────────────────────────────── // ── init ──────────────────────────────────────────────────────────────────────
onMounted(async () => { onMounted(async () => {
const [filesData] = await Promise.all([getAllDataFiles(), loadCategories()]) const [filesData] = await Promise.all([getAllDataFiles(), loadCategories(), loadDataCategories()])
allFiles.value = filesData allFiles.value = filesData
}) })
</script> </script>
...@@ -237,11 +284,24 @@ onMounted(async () => { ...@@ -237,11 +284,24 @@ onMounted(async () => {
<span class="section-hint">已选 {{ selectedFileIds.length }} 个文件</span> <span class="section-hint">已选 {{ selectedFileIds.length }} 个文件</span>
</div> </div>
<div class="file-search-bar"> <div class="file-search-bar">
<el-select
v-model="fileFilterCategoryId"
placeholder="全部分类"
clearable
style="width: 160px"
>
<el-option
v-for="item in dataCategoryOptions"
:key="item.value"
:label="item.label"
:value="item.value"
/>
</el-select>
<el-input <el-input
v-model="fileSearchText" v-model="fileSearchText"
placeholder="按文件名或分类搜索" placeholder="按文件名或备注搜索"
clearable clearable
style="width: 280px" style="width: 240px"
/> />
</div> </div>
<el-table <el-table
...@@ -253,10 +313,11 @@ onMounted(async () => { ...@@ -253,10 +313,11 @@ onMounted(async () => {
@selection-change="handleFileSelectionChange" @selection-change="handleFileSelectionChange"
> >
<el-table-column type="selection" width="46" /> <el-table-column type="selection" width="46" />
<el-table-column prop="filename" label="文件名" min-width="200" show-overflow-tooltip /> <el-table-column prop="filename" label="文件名" min-width="180" show-overflow-tooltip />
<el-table-column prop="category_name" label="所属分类" width="130" show-overflow-tooltip /> <el-table-column prop="category_name" label="所属分类" width="110" show-overflow-tooltip />
<el-table-column prop="remark" label="备注" min-width="120" show-overflow-tooltip />
<el-table-column prop="data_count" label="数据量" width="80" align="center" /> <el-table-column prop="data_count" label="数据量" width="80" align="center" />
<el-table-column prop="uploaded_at" label="上传时间" width="160" /> <el-table-column prop="uploaded_at" label="上传时间" width="150" />
</el-table> </el-table>
</div> </div>
...@@ -281,7 +342,7 @@ onMounted(async () => { ...@@ -281,7 +342,7 @@ onMounted(async () => {
</el-form-item> </el-form-item>
<el-form-item label="清洗规则"> <el-form-item label="清洗规则">
<div class="clean-rules-wrap"> <div class="clean-rules-wrap">
<el-checkbox v-model="cleanRules.enabled">野值清理</el-checkbox> <el-checkbox v-model="cleanRules.enabled">剔除野值</el-checkbox>
<div v-if="cleanRules.enabled" class="clean-range-grid"> <div v-if="cleanRules.enabled" class="clean-range-grid">
<div class="clean-range-row"> <div class="clean-range-row">
<span class="clean-range-label">电流范围 (A)</span> <span class="clean-range-label">电流范围 (A)</span>
...@@ -337,10 +398,30 @@ onMounted(async () => { ...@@ -337,10 +398,30 @@ onMounted(async () => {
class="range-input" class="range-input"
/> />
</div> </div>
<div class="clean-range-row newton-row">
<el-checkbox v-model="cleanRules.newton_interp">牛顿插值填补(剔除野值后自动插值)</el-checkbox>
</div>
</div> </div>
</div> </div>
</el-form-item> </el-form-item>
<el-form-item label="数据行范围"> <el-form-item label="数据转换">
<div class="smooth-wrap">
<el-checkbox v-model="smooth.enabled">滑动均值法</el-checkbox>
<div v-if="smooth.enabled" class="smooth-inputs">
<span class="smooth-label">窗口长度</span>
<el-input-number
v-model="smooth.window"
:min="2"
:max="500"
:precision="0"
:value-on-clear="5"
style="width: 90px"
/>
<span class="smooth-unit">个点</span>
</div>
</div>
</el-form-item>
<el-form-item label="截取范围">
<div class="row-range-wrap"> <div class="row-range-wrap">
<el-checkbox v-model="rowRange.enabled">启用行范围截取</el-checkbox> <el-checkbox v-model="rowRange.enabled">启用行范围截取</el-checkbox>
<div v-if="rowRange.enabled" class="row-range-inputs"> <div v-if="rowRange.enabled" class="row-range-inputs">
...@@ -368,6 +449,14 @@ onMounted(async () => { ...@@ -368,6 +449,14 @@ onMounted(async () => {
</div> </div>
</div> </div>
</el-form-item> </el-form-item>
<el-form-item label="数据集划分">
<div class="auto-split-wrap">
<el-checkbox v-model="autoSplit">自动划分数据集(70/15/15)</el-checkbox>
<div v-if="autoSplit" class="auto-split-hint">
将生成三个数据包:《名称-训练集》《名称-验证集》《名称-测试集》
</div>
</div>
</el-form-item>
<el-form-item label="备注"> <el-form-item label="备注">
<el-input <el-input
v-model="form.remark" v-model="form.remark"
...@@ -385,7 +474,7 @@ onMounted(async () => { ...@@ -385,7 +474,7 @@ onMounted(async () => {
style="width: 100%" style="width: 100%"
@click="handleGenerate" @click="handleGenerate"
> >
生成数据包 {{ autoSplit ? '划分生成数据包 (70/15/15)' : '生成数据包' }}
</el-button> </el-button>
</el-form-item> </el-form-item>
</el-form> </el-form>
...@@ -395,7 +484,7 @@ onMounted(async () => { ...@@ -395,7 +484,7 @@ onMounted(async () => {
<div class="section-card preview-section"> <div class="section-card preview-section">
<div class="section-title"> <div class="section-title">
生成结果预览 生成结果预览
<span v-if="previewTotal" class="section-hint">{{ previewTotal }}(展示前 300 条)</span> <span v-if="previewTotal" class="section-hint">{{ previewTotal }}</span>
<el-radio-group v-model="previewMode" size="small" style="margin-left: auto"> <el-radio-group v-model="previewMode" size="small" style="margin-left: auto">
<el-radio-button value="table">表格</el-radio-button> <el-radio-button value="table">表格</el-radio-button>
<el-radio-button value="curve">曲线</el-radio-button> <el-radio-button value="curve">曲线</el-radio-button>
...@@ -409,21 +498,21 @@ onMounted(async () => { ...@@ -409,21 +498,21 @@ onMounted(async () => {
:image-size="80" :image-size="80"
/> />
<el-table <el-auto-resizer v-else-if="previewMode === 'table'">
v-else-if="previewMode === 'table'" <template #default="{ height, width }">
<el-table-v2
:columns="previewColumns"
:data="previewRecords" :data="previewRecords"
border :width="width"
stripe :height="height"
height="100%" :row-height="36"
> :header-height="40"
<el-table-column prop="time" label="时间" min-width="140" /> fixed
<el-table-column prop="current" label="电流" min-width="80" /> />
<el-table-column prop="voltage" label="电压" min-width="80" /> </template>
<el-table-column prop="set_temperature" label="设定温度" min-width="100" /> </el-auto-resizer>
<el-table-column prop="actual_temperature" label="实际温度" min-width="100" />
</el-table>
<DataCurve v-else :records="previewRecords" /> <DataCurve v-else :records="curveRecords" :total-count="previewTotal" />
</div> </div>
</div> </div>
</div> </div>
...@@ -480,6 +569,9 @@ onMounted(async () => { ...@@ -480,6 +569,9 @@ onMounted(async () => {
} }
.file-search-bar { .file-search-bar {
display: flex;
gap: 8px;
align-items: center;
margin-bottom: 10px; margin-bottom: 10px;
} }
...@@ -564,6 +656,42 @@ onMounted(async () => { ...@@ -564,6 +656,42 @@ onMounted(async () => {
color: #94a3b8; color: #94a3b8;
} }
.newton-row {
margin-top: 4px;
}
.smooth-wrap {
width: 100%;
}
.smooth-inputs {
display: flex;
align-items: center;
gap: 8px;
margin-top: 8px;
}
.smooth-label {
font-size: 12px;
color: #475569;
}
.smooth-unit {
font-size: 12px;
color: #94a3b8;
}
.auto-split-wrap {
width: 100%;
}
.auto-split-hint {
margin-top: 6px;
font-size: 12px;
color: #64748b;
line-height: 1.6;
}
.row-range-wrap { .row-range-wrap {
width: 100%; width: 100%;
} }
......
<script setup> <script setup>
import { ref, watch } from 'vue' import { computed, ref, watch } from 'vue'
import { ElAutoResizer, ElTableV2 } from 'element-plus'
import { getPackageRecords } from '@/api/packageManagement' import { getPackageRecords } from '@/api/packageManagement'
import DataCurve from '@/views/DataManagement/components/DataCurve.vue' import DataCurve from '@/views/DataManagement/components/DataCurve.vue'
...@@ -18,6 +19,21 @@ const loading = ref(false) ...@@ -18,6 +19,21 @@ const loading = ref(false)
const records = ref([]) const records = ref([])
const contentMode = ref('table') const contentMode = ref('table')
const detailColumns = [
{ key: 'time', dataKey: 'time', title: '时间', width: 170 },
{ key: 'current', dataKey: 'current', title: '电流', width: 110 },
{ key: 'voltage', dataKey: 'voltage', title: '电压', width: 110 },
{ key: 'set_temperature', dataKey: 'set_temperature', title: '设定温度', width: 120 },
{ key: 'actual_temperature', dataKey: 'actual_temperature', title: '实际温度', width: 120 },
]
const curveRecords = computed(() => {
const r = records.value
if (r.length <= 2000) return r
const step = Math.ceil(r.length / 2000)
return r.filter((_, i) => i % step === 0)
})
const loadRecords = async (pkg) => { const loadRecords = async (pkg) => {
if (!pkg) { if (!pkg) {
records.value = [] records.value = []
...@@ -25,7 +41,7 @@ const loadRecords = async (pkg) => { ...@@ -25,7 +41,7 @@ const loadRecords = async (pkg) => {
} }
loading.value = true loading.value = true
try { try {
const result = await getPackageRecords(pkg.id, { limit: 500 }) const result = await getPackageRecords(pkg.id)
records.value = result.records records.value = result.records
} finally { } finally {
loading.value = false loading.value = false
...@@ -56,21 +72,21 @@ watch(() => props.package, (pkg) => { ...@@ -56,21 +72,21 @@ watch(() => props.package, (pkg) => {
<div class="content-wrap" v-loading="loading"> <div class="content-wrap" v-loading="loading">
<el-empty v-if="!props.package" description="请选择一个数据包查看" /> <el-empty v-if="!props.package" description="请选择一个数据包查看" />
<el-table <el-auto-resizer v-else-if="contentMode === 'table'">
v-else-if="contentMode === 'table'" <template #default="{ height, width }">
<el-table-v2
:columns="detailColumns"
:data="records" :data="records"
border :width="width"
stripe :height="height"
height="calc(100vh - 250px)" :row-height="36"
> :header-height="40"
<el-table-column prop="time" label="时间" min-width="140" /> fixed
<el-table-column prop="current" label="电流" min-width="90" /> />
<el-table-column prop="voltage" label="电压" min-width="90" /> </template>
<el-table-column prop="set_temperature" label="设定温度" min-width="100" /> </el-auto-resizer>
<el-table-column prop="actual_temperature" label="实际温度" min-width="100" />
</el-table> <DataCurve v-else :records="curveRecords" :total-count="records.length" />
<DataCurve v-else :records="records" />
</div> </div>
</el-card> </el-card>
</template> </template>
......
<script setup> <script setup>
import { Plus, Search } from '@element-plus/icons-vue' import { Delete, Edit, Plus, Search } from '@element-plus/icons-vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import { onMounted, reactive, ref, watch } from 'vue' import { onMounted, reactive, ref, watch } from 'vue'
import { deletePackage, getPackages } from '@/api/packageManagement' import { deletePackage, getPackages, updatePackage } from '@/api/packageManagement'
import { getPkgCategoryTree } from '@/api/packageManagement'
const props = defineProps({ const props = defineProps({
categoryId: { categoryId: {
...@@ -52,6 +53,61 @@ const handleView = (row) => { ...@@ -52,6 +53,61 @@ const handleView = (row) => {
emit('view', row) emit('view', row)
} }
// ── edit ──────────────────────────────────────────────────────────────────
const editVisible = ref(false)
const editLoading = ref(false)
const editForm = reactive({ id: null, name: '', category_id: null, remark: '' })
const categoryOptions = ref([])
const loadCategoryOptions = async () => {
if (categoryOptions.value.length) return
try {
const tree = await getPkgCategoryTree()
const flatten = (nodes, result = []) => {
for (const n of nodes) {
result.push({ value: n.id, label: n.name })
if (n.children?.length) flatten(n.children, result)
}
return result
}
categoryOptions.value = flatten(tree)
} catch {}
}
const handleEdit = async (row) => {
await loadCategoryOptions()
editForm.id = row.id
editForm.name = row.name
editForm.category_id = row.category_id ?? null
editForm.remark = row.remark ?? ''
editVisible.value = true
}
const submitEdit = async () => {
if (!editForm.name.trim()) {
ElMessage.warning('请输入数据包名称')
return
}
editLoading.value = true
try {
const updated = await updatePackage(editForm.id, {
name: editForm.name.trim(),
category_id: editForm.category_id || null,
remark: editForm.remark.trim() || null,
})
ElMessage.success('修改成功')
editVisible.value = false
// 更新列表中对应行
const idx = packageList.value.findIndex(p => p.id === editForm.id)
if (idx !== -1) Object.assign(packageList.value[idx], updated)
if (currentPackage.value?.id === editForm.id) Object.assign(currentPackage.value, updated)
} catch (e) {
ElMessage.error(e?.response?.data?.detail || '修改失败')
} finally {
editLoading.value = false
}
}
const handleDelete = async (row) => { const handleDelete = async (row) => {
try { try {
await ElMessageBox.confirm(`确定删除数据包"${row.name}"吗?`, '提示', { type: 'warning' }) await ElMessageBox.confirm(`确定删除数据包"${row.name}"吗?`, '提示', { type: 'warning' })
...@@ -105,18 +161,48 @@ onMounted(loadPackages) ...@@ -105,18 +161,48 @@ onMounted(loadPackages)
highlight-current-row highlight-current-row
height="calc(100vh - 265px)" height="calc(100vh - 265px)"
:row-class-name="({ row }) => (currentPackage?.id === row.id ? 'current-row' : '')" :row-class-name="({ row }) => (currentPackage?.id === row.id ? 'current-row' : '')"
style="cursor: pointer"
@row-click="handleView"
> >
<el-table-column prop="name" label="数据包名称" min-width="150" show-overflow-tooltip /> <el-table-column prop="name" label="数据包名称" min-width="100" show-overflow-tooltip />
<el-table-column prop="created_at" label="创建时间" min-width="160" /> <el-table-column prop="created_at" label="创建时间" min-width="100" />
<el-table-column prop="data_count" label="数据量" width="80" align="center" /> <el-table-column prop="data_count" label="数据量" width="80" align="center" />
<el-table-column label="操作" width="130" fixed="right"> <el-table-column label="操作" width="150" fixed="right" align="center">
<template #default="{ row }"> <template #default="{ row }">
<el-button link type="primary" @click="handleView(row)">查看</el-button> <div>
<el-button link type="danger" @click="handleDelete(row)">删除</el-button> <el-button link type="primary" :icon="Edit" @click.stop="handleEdit(row)" />
<el-button link type="danger" :icon="Delete" @click.stop="handleDelete(row)" />
</div>
</template> </template>
</el-table-column> </el-table-column>
</el-table> </el-table>
</el-card> </el-card>
<!-- 编辑数据包 -->
<el-dialog v-model="editVisible" title="编辑数据包" width="480px" :close-on-click-modal="false">
<el-form label-width="90px">
<el-form-item label="名称" required>
<el-input v-model="editForm.name" placeholder="数据包名称" clearable />
</el-form-item>
<el-form-item label="分类">
<el-select v-model="editForm.category_id" placeholder="选择分类(可选)" clearable style="width:100%">
<el-option
v-for="opt in categoryOptions"
:key="opt.value"
:label="opt.label"
:value="opt.value"
/>
</el-select>
</el-form-item>
<el-form-item label="备注">
<el-input v-model="editForm.remark" type="textarea" :rows="3" placeholder="备注(可选)" />
</el-form-item>
</el-form>
<template #footer>
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" :loading="editLoading" @click="submitEdit">保存</el-button>
</template>
</el-dialog>
</template> </template>
<style lang="scss" scoped> <style lang="scss" scoped>
......
<script setup> <script setup>
import * as echarts from 'echarts' import * as echarts from 'echarts'
import { ArrowLeft, VideoPlay, VideoPause, Download, DataAnalysis } from '@element-plus/icons-vue' import { ArrowLeft, VideoPlay, VideoPause, Download } from '@element-plus/icons-vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage } from 'element-plus'
import { import {
onMounted, onMounted,
onBeforeUnmount, onBeforeUnmount,
ref, ref,
computed, computed,
nextTick, nextTick,
watch,
} from 'vue' } from 'vue'
import { useRouter, useRoute } from 'vue-router' import { useRouter, useRoute } from 'vue-router'
import { import {
...@@ -15,8 +16,6 @@ import { ...@@ -15,8 +16,6 @@ import {
startExperiment, startExperiment,
stopExperiment, stopExperiment,
getDataPoints, getDataPoints,
getReport,
exportToHistory,
} from '@/api/realtimeMonitor' } from '@/api/realtimeMonitor'
const router = useRouter() const router = useRouter()
...@@ -106,43 +105,9 @@ const handleStop = async () => { ...@@ -106,43 +105,9 @@ const handleStop = async () => {
} }
} }
// ── 报告 ───────────────────────────────────────────────────────────────────────
const reportVisible = ref(false)
const reportData = ref(null)
const loadingReport = ref(false)
const handleReport = async () => {
loadingReport.value = true
try {
reportData.value = await getReport(expId)
reportVisible.value = true
} finally {
loadingReport.value = false
}
}
// ── 导出 ───────────────────────────────────────────────────────────────────────
const exporting = ref(false)
const handleExport = async () => {
try {
await ElMessageBox.confirm('确定将本次试验数据导出到历史数据吗?', '导出确认', { type: 'info' })
exporting.value = true
await exportToHistory(expId)
experiment.value = await getExperiment(expId)
ElMessage.success('已导出到历史数据')
} catch {
// 取消
} finally {
exporting.value = false
}
}
// ── 图表 ─────────────────────────────────────────────────────────────────────── // ── 图表 ───────────────────────────────────────────────────────────────────────
const tempChartRef = ref(null) const chartRef = ref(null)
const currChartRef = ref(null) let chart = null
let tempChart = null
let currChart = null
const MAX_CHART_POINTS = 300 // 超过时抽样显示 const MAX_CHART_POINTS = 300 // 超过时抽样显示
...@@ -154,11 +119,8 @@ const buildDisplayData = () => { ...@@ -154,11 +119,8 @@ const buildDisplayData = () => {
} }
const initCharts = () => { const initCharts = () => {
if (tempChartRef.value && !tempChart) { if (chartRef.value && !chart) {
tempChart = echarts.init(tempChartRef.value) chart = echarts.init(chartRef.value)
}
if (currChartRef.value && !currChart) {
currChart = echarts.init(currChartRef.value)
} }
updateCharts() updateCharts()
} }
...@@ -167,16 +129,14 @@ const updateCharts = () => { ...@@ -167,16 +129,14 @@ const updateCharts = () => {
const display = buildDisplayData() const display = buildDisplayData()
const xData = display.map((d) => `步${d.step_idx}`) const xData = display.map((d) => `步${d.step_idx}`)
const actuals = display.map((d) => d.actual_temp) const actuals = display.map((d) => d.actual_temp)
const refs = display.map((d) => d.reference_temp)
const currents = display.map((d) => d.current_output) const currents = display.map((d) => d.current_output)
const targetLine = display.map(() => experiment.value?.target_temp ?? null) const targetLine = display.map(() => experiment.value?.target_temp ?? null)
// ── 温度图 ──────────────────────────────────────────────────────────────── if (chart) {
if (tempChart) { chart.setOption(
tempChart.setOption(
{ {
animation: false, animation: false,
color: ['#409EFF', '#67C23A', '#F56C6C'], color: ['#409EFF', '#F56C6C', '#E6A23C'],
tooltip: { tooltip: {
trigger: 'axis', trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)', backgroundColor: 'rgba(255,255,255,0.96)',
...@@ -186,7 +146,8 @@ const updateCharts = () => { ...@@ -186,7 +146,8 @@ const updateCharts = () => {
if (!params?.length) return '' if (!params?.length) return ''
const lines = [`<div style="margin-bottom:4px;font-weight:600">${params[0].axisValue}</div>`] const lines = [`<div style="margin-bottom:4px;font-weight:600">${params[0].axisValue}</div>`]
params.forEach((p) => { params.forEach((p) => {
const v = p.data != null ? Number(p.data).toFixed(3) + ' °C' : '--' const unit = p.seriesName === '电流输出' ? ' A' : ' °C'
const v = p.data != null ? Number(p.data).toFixed(3) + unit : '--'
lines.push( lines.push(
`<div style="display:flex;justify-content:space-between;gap:16px"> `<div style="display:flex;justify-content:space-between;gap:16px">
<span>${p.marker}${p.seriesName}</span><strong>${v}</strong> <span>${p.marker}${p.seriesName}</span><strong>${v}</strong>
...@@ -202,7 +163,7 @@ const updateCharts = () => { ...@@ -202,7 +163,7 @@ const updateCharts = () => {
itemHeight: 8, itemHeight: 8,
textStyle: { color: '#475569', fontSize: 12 }, textStyle: { color: '#475569', fontSize: 12 },
}, },
grid: { top: 16, left: 16, right: 16, bottom: 52, containLabel: true }, grid: { top: 16, left: 16, right: 60, bottom: 52, containLabel: true },
xAxis: { xAxis: {
type: 'category', type: 'category',
boundaryGap: false, boundaryGap: false,
...@@ -210,92 +171,52 @@ const updateCharts = () => { ...@@ -210,92 +171,52 @@ const updateCharts = () => {
axisLabel: { color: '#64748b', fontSize: 11, interval: Math.max(0, Math.floor(xData.length / 10) - 1) }, axisLabel: { color: '#64748b', fontSize: 11, interval: Math.max(0, Math.floor(xData.length / 10) - 1) },
axisLine: { lineStyle: { color: '#cbd5e1' } }, axisLine: { lineStyle: { color: '#cbd5e1' } },
}, },
yAxis: { yAxis: [
{
type: 'value', type: 'value',
name: '温度 (°C)', name: '温度 (°C)',
position: 'left',
nameTextStyle: { color: '#64748b', fontSize: 11 }, nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 }, axisLabel: { color: '#64748b', fontSize: 11 },
splitLine: { lineStyle: { color: '#e2e8f0' } }, splitLine: { lineStyle: { color: '#e2e8f0' } },
}, },
{
type: 'value',
name: '电流 (A)',
position: 'right',
nameTextStyle: { color: '#E6A23C', fontSize: 11 },
axisLabel: { color: '#E6A23C', fontSize: 11 },
splitLine: { show: false },
min: 0,
},
],
series: [ series: [
{ {
name: '实际温度', name: '实际温度',
type: 'line', type: 'line',
yAxisIndex: 0,
data: actuals, data: actuals,
smooth: true, smooth: true,
symbol: 'none', symbol: 'none',
lineStyle: { width: 2 }, lineStyle: { width: 2 },
}, },
{
name: '参考轨迹',
type: 'line',
data: refs,
smooth: true,
symbol: 'none',
lineStyle: { width: 1.5, type: 'dashed' },
},
{ {
name: '目标温度', name: '目标温度',
type: 'line', type: 'line',
yAxisIndex: 0,
data: targetLine, data: targetLine,
symbol: 'none', symbol: 'none',
lineStyle: { width: 1.5, type: 'dotted', color: '#F56C6C' }, lineStyle: { width: 1.5, type: 'dotted', color: '#F56C6C' },
}, },
],
},
true,
)
}
// ── 电流图 ────────────────────────────────────────────────────────────────
if (currChart) {
currChart.setOption(
{
animation: false,
color: ['#E6A23C'],
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)',
borderColor: '#e2e8f0',
borderWidth: 1,
formatter(params) {
if (!params?.length) return ''
const v = params[0].data != null ? Number(params[0].data).toFixed(3) + ' A' : '--'
return `<div style="font-weight:600">${params[0].axisValue}</div>
<div>${params[0].marker}电流输出:<strong>${v}</strong></div>`
},
},
legend: {
bottom: 4,
itemWidth: 18,
itemHeight: 8,
textStyle: { color: '#475569', fontSize: 12 },
},
grid: { top: 16, left: 16, right: 16, bottom: 52, containLabel: true },
xAxis: {
type: 'category',
boundaryGap: false,
data: xData,
axisLabel: { color: '#64748b', fontSize: 11, interval: Math.max(0, Math.floor(xData.length / 10) - 1) },
axisLine: { lineStyle: { color: '#cbd5e1' } },
},
yAxis: {
type: 'value',
name: '电流 (A)',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
splitLine: { lineStyle: { color: '#e2e8f0' } },
min: 0,
},
series: [
{ {
name: '电流输出', name: '电流输出',
type: 'line', type: 'line',
yAxisIndex: 1,
data: currents, data: currents,
smooth: false, smooth: false,
symbol: 'none', symbol: 'none',
lineStyle: { width: 2 }, lineStyle: { width: 2, color: '#E6A23C' },
areaStyle: { opacity: 0.08 }, areaStyle: { color: '#E6A23C', opacity: 0.06 },
}, },
], ],
}, },
...@@ -306,8 +227,7 @@ const updateCharts = () => { ...@@ -306,8 +227,7 @@ const updateCharts = () => {
// ── 窗口 resize ──────────────────────────────────────────────────────────────── // ── 窗口 resize ────────────────────────────────────────────────────────────────
const onResize = () => { const onResize = () => {
tempChart?.resize() chart?.resize()
currChart?.resize()
} }
// ── 生命周期 ─────────────────────────────────────────────────────────────────── // ── 生命周期 ───────────────────────────────────────────────────────────────────
...@@ -331,8 +251,7 @@ onMounted(async () => { ...@@ -331,8 +251,7 @@ onMounted(async () => {
onBeforeUnmount(() => { onBeforeUnmount(() => {
stopPolling() stopPolling()
window.removeEventListener('resize', onResize) window.removeEventListener('resize', onResize)
tempChart?.dispose() chart?.dispose()
currChart?.dispose()
}) })
// ── 状态辅助 ─────────────────────────────────────────────────────────────────── // ── 状态辅助 ───────────────────────────────────────────────────────────────────
...@@ -343,6 +262,37 @@ const latestPoint = computed(() => { ...@@ -343,6 +262,37 @@ const latestPoint = computed(() => {
const pts = dataPoints.value const pts = dataPoints.value
return pts.length ? pts[pts.length - 1] : null return pts.length ? pts[pts.length - 1] : null
}) })
// ── 曲线 / 表格 切换 ────────────────────────────────────────────────────────────
const viewMode = ref('chart')
watch(viewMode, async (val) => {
if (val === 'chart') {
await nextTick()
chart?.dispose()
chart = null
initCharts()
}
})
const downloadCSV = () => {
const target = experiment.value?.target_temp ?? ''
const header = ['步骤', '实际温度(°C)', '目标温度(°C)', '电流输出(A)']
const rows = dataPoints.value.map((pt) => [
pt.step_idx,
pt.actual_temp,
target,
pt.current_output,
])
const csv = [header, ...rows].map((r) => r.join(',')).join('\n')
const blob = new Blob(['\uFEFF' + csv], { type: 'text/csv;charset=utf-8;' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `${experiment.value?.name ?? 'data'}_数据.csv`
a.click()
URL.revokeObjectURL(url)
}
</script> </script>
<template> <template>
...@@ -367,10 +317,6 @@ const latestPoint = computed(() => { ...@@ -367,10 +317,6 @@ const latestPoint = computed(() => {
<span class="label">预测模型</span> <span class="label">预测模型</span>
<span class="value">{{ experiment.model_name }}</span> <span class="value">{{ experiment.model_name }}</span>
</div> </div>
<div class="info-item">
<span class="label">初始数据包</span>
<span class="value">{{ experiment.package_name }}</span>
</div>
<div class="info-item"> <div class="info-item">
<span class="label">目标温度</span> <span class="label">目标温度</span>
<span class="value highlight">{{ experiment.target_temp?.toFixed(1) }} °C</span> <span class="value highlight">{{ experiment.target_temp?.toFixed(1) }} °C</span>
...@@ -383,6 +329,14 @@ const latestPoint = computed(() => { ...@@ -383,6 +329,14 @@ const latestPoint = computed(() => {
<span class="label">采样周期</span> <span class="label">采样周期</span>
<span class="value">{{ experiment.sampling_interval ?? 1.0 }} s</span> <span class="value">{{ experiment.sampling_interval ?? 1.0 }} s</span>
</div> </div>
<div class="info-item" style="grid-column: span 2">
<span class="label">输入CSV路径</span>
<span class="value" style="font-size:12px;word-break:break-all">{{ experiment.input_csv_path ?? '--' }}</span>
</div>
<div class="info-item" style="grid-column: span 2">
<span class="label">输出CSV路径</span>
<span class="value" style="font-size:12px;word-break:break-all">{{ experiment.output_csv_path ?? '--' }}</span>
</div>
<div class="info-item"> <div class="info-item">
<span class="label">开始时间</span> <span class="label">开始时间</span>
<span class="value">{{ experiment.start_time ?? '--' }}</span> <span class="value">{{ experiment.start_time ?? '--' }}</span>
...@@ -409,9 +363,9 @@ const latestPoint = computed(() => { ...@@ -409,9 +363,9 @@ const latestPoint = computed(() => {
</div> </div>
</div> </div>
<div class="stat-item"> <div class="stat-item">
<div class="stat-label">参考轨迹</div> <div class="stat-label">目标温度</div>
<div class="stat-value ref"> <div class="stat-value target">
{{ latestPoint ? latestPoint.reference_temp.toFixed(3) : '--' }} °C {{ experiment.target_temp?.toFixed(1) }} °C
</div> </div>
</div> </div>
<div class="stat-item"> <div class="stat-item">
...@@ -420,12 +374,6 @@ const latestPoint = computed(() => { ...@@ -420,12 +374,6 @@ const latestPoint = computed(() => {
{{ latestPoint ? latestPoint.current_output.toFixed(3) : '--' }} A {{ latestPoint ? latestPoint.current_output.toFixed(3) : '--' }} A
</div> </div>
</div> </div>
<div class="stat-item">
<div class="stat-label">目标温度</div>
<div class="stat-value target">
{{ experiment.target_temp?.toFixed(1) }} °C
</div>
</div>
</div> </div>
<div class="control-btns"> <div class="control-btns">
...@@ -451,28 +399,6 @@ const latestPoint = computed(() => { ...@@ -451,28 +399,6 @@ const latestPoint = computed(() => {
停止控制 停止控制
</el-button> </el-button>
<el-button
type="primary"
:icon="DataAnalysis"
:loading="loadingReport"
:disabled="!experiment.total_steps"
style="width:100%;margin-bottom:10px"
plain
@click="handleReport"
>
查看报告
</el-button>
<el-button
:icon="Download"
:loading="exporting"
:disabled="isRunning || !experiment.total_steps || experiment.exported"
style="width:100%"
plain
@click="handleExport"
>
{{ experiment.exported ? '已导出' : '导出到历史数据' }}
</el-button>
</div> </div>
</el-card> </el-card>
</el-col> </el-col>
...@@ -491,78 +417,49 @@ const latestPoint = computed(() => { ...@@ -491,78 +417,49 @@ const latestPoint = computed(() => {
</div> </div>
</el-card> </el-card>
<!-- 温度曲线图 --> <!-- 温度与电流曲线 -->
<el-card shadow="hover" class="chart-card"> <el-card shadow="hover" class="chart-card">
<template #header> <template #header>
<div class="chart-header"> <div class="chart-header">
<span class="card-title">温度曲线</span> <span class="card-title">温度与电流曲线</span>
<span v-if="isRunning" class="live-badge">● 实时更新中</span> <span v-if="isRunning" class="live-badge">● 实时更新中</span>
<div class="chart-toolbar">
<el-radio-group v-model="viewMode" size="small">
<el-radio-button value="chart">曲线</el-radio-button>
<el-radio-button value="table">表格</el-radio-button>
</el-radio-group>
<el-button
v-if="viewMode === 'table' && dataPoints.length"
size="small"
:icon="Download"
@click="downloadCSV"
>下载 CSV</el-button>
</div>
</div> </div>
</template> </template>
<div v-if="!dataPoints.length" class="empty-chart"> <div v-if="!dataPoints.length" class="empty-chart">
<el-empty description="暂无数据,请启动 MPC 控制" :image-size="80" /> <el-empty description="暂无数据,请启动 MPC 控制" :image-size="80" />
</div> </div>
<div v-else ref="tempChartRef" class="chart-body" /> <template v-else>
</el-card> <div v-if="viewMode === 'chart'" ref="chartRef" class="chart-body" />
<div v-else class="table-body">
<!-- 电流曲线图 --> <el-table :data="dataPoints" size="small" height="320" border stripe>
<el-card shadow="hover" class="chart-card"> <el-table-column prop="step_idx" label="步骤" width="80" align="center" />
<template #header> <el-table-column label="实际温度 (°C)" align="right">
<span class="card-title">电流输出曲线</span> <template #default="{ row }">{{ row.actual_temp.toFixed(3) }}</template>
</template> </el-table-column>
<div v-if="!dataPoints.length" class="empty-chart"> <el-table-column label="目标温度 (°C)" align="right">
<el-empty description="暂无数据" :image-size="80" /> <template #default>{{ experiment.target_temp?.toFixed(3) }}</template>
</el-table-column>
<el-table-column label="电流输出 (A)" align="right">
<template #default="{ row }">{{ row.current_output.toFixed(3) }}</template>
</el-table-column>
</el-table>
</div> </div>
<div v-else ref="currChartRef" class="chart-body" /> </template>
</el-card> </el-card>
</template> </template>
<!-- 报告对话框 -->
<el-dialog v-model="reportVisible" title="试验报告" width="620px">
<template v-if="reportData">
<el-descriptions :column="2" border size="small" class="report-desc">
<el-descriptions-item label="试验名称" :span="2">
{{ reportData.experiment?.name }}
</el-descriptions-item>
<el-descriptions-item label="目标温度">
{{ reportData.summary?.target_temp?.toFixed(1) }} °C
</el-descriptions-item>
<el-descriptions-item label="最终温度">
{{ reportData.summary?.final_temp?.toFixed(3) }} °C
</el-descriptions-item>
<el-descriptions-item label="初始温度">
{{ reportData.summary?.initial_temp?.toFixed(3) }} °C
</el-descriptions-item>
<el-descriptions-item label="采集步数">
{{ reportData.summary?.total_steps }}
</el-descriptions-item>
<el-descriptions-item label="仿真时长">
{{ reportData.summary?.duration_s }} s
</el-descriptions-item>
<el-descriptions-item label="调节时间(步)">
{{ reportData.summary?.settling_step ?? '未稳定' }}
</el-descriptions-item>
<el-descriptions-item label="MAE (°C)">
{{ reportData.summary?.mae }}
</el-descriptions-item>
<el-descriptions-item label="RMSE (°C)">
{{ reportData.summary?.rmse }}
</el-descriptions-item>
<el-descriptions-item label="最大超调 (°C)">
{{ reportData.summary?.overshoot }}
</el-descriptions-item>
<el-descriptions-item label="平均电流 (A)">
{{ reportData.summary?.avg_current }}
</el-descriptions-item>
<el-descriptions-item label="最大电流 (A)">
{{ reportData.summary?.max_current }}
</el-descriptions-item>
</el-descriptions>
</template>
<template #footer>
<el-button @click="reportVisible = false">关闭</el-button>
</template>
</el-dialog>
</div> </div>
</template> </template>
...@@ -697,6 +594,13 @@ const latestPoint = computed(() => { ...@@ -697,6 +594,13 @@ const latestPoint = computed(() => {
display: flex; display: flex;
align-items: center; align-items: center;
gap: 12px; gap: 12px;
.chart-toolbar {
margin-left: auto;
display: flex;
align-items: center;
gap: 8px;
}
} }
.live-badge { .live-badge {
...@@ -716,6 +620,10 @@ const latestPoint = computed(() => { ...@@ -716,6 +620,10 @@ const latestPoint = computed(() => {
height: 280px; height: 280px;
width: 100%; width: 100%;
} }
.table-body {
width: 100%;
}
} }
.card-title { .card-title {
......
...@@ -8,7 +8,6 @@ import { ...@@ -8,7 +8,6 @@ import {
createExperiment, createExperiment,
deleteExperiment, deleteExperiment,
getMonitorModels, getMonitorModels,
getMonitorPackages,
} from '@/api/realtimeMonitor' } from '@/api/realtimeMonitor'
const router = useRouter() const router = useRouter()
...@@ -30,7 +29,6 @@ const loadExperiments = async () => { ...@@ -30,7 +29,6 @@ const loadExperiments = async () => {
const dialogVisible = ref(false) const dialogVisible = ref(false)
const submitting = ref(false) const submitting = ref(false)
const models = ref([]) const models = ref([])
const packages = ref([])
const defaultMpcParams = () => ({ const defaultMpcParams = () => ({
P: 20, P: 20,
...@@ -50,7 +48,8 @@ const defaultMpcParams = () => ({ ...@@ -50,7 +48,8 @@ const defaultMpcParams = () => ({
const form = reactive({ const form = reactive({
name: '', name: '',
model_id: '', model_id: '',
package_id: '', input_csv_path: '',
output_csv_path: '',
target_temp: 35.0, target_temp: 35.0,
sampling_interval: 1.0, sampling_interval: 1.0,
mpc_params: defaultMpcParams(), mpc_params: defaultMpcParams(),
...@@ -62,7 +61,8 @@ const openDialog = async () => { ...@@ -62,7 +61,8 @@ const openDialog = async () => {
Object.assign(form, { Object.assign(form, {
name: '', name: '',
model_id: '', model_id: '',
package_id: '', input_csv_path: '',
output_csv_path: '',
target_temp: 35.0, target_temp: 35.0,
sampling_interval: 1.0, sampling_interval: 1.0,
mpc_params: defaultMpcParams(), mpc_params: defaultMpcParams(),
...@@ -70,23 +70,23 @@ const openDialog = async () => { ...@@ -70,23 +70,23 @@ const openDialog = async () => {
showAdvanced.value = false showAdvanced.value = false
dialogVisible.value = true dialogVisible.value = true
if (!models.value.length) { if (!models.value.length) {
const [m, p] = await Promise.all([getMonitorModels(), getMonitorPackages()]) models.value = await getMonitorModels()
models.value = m
packages.value = p
} }
} }
const handleCreate = async () => { const handleCreate = async () => {
if (!form.name.trim()) { ElMessage.warning('请输入试验名称'); return } if (!form.name.trim()) { ElMessage.warning('请输入试验名称'); return }
if (!form.model_id) { ElMessage.warning('请选择预测模型'); return } if (!form.model_id) { ElMessage.warning('请选择预测模型'); return }
if (!form.package_id) { ElMessage.warning('请选择初始数据包'); return } if (!form.input_csv_path.trim()) { ElMessage.warning('请输入输入CSV文件路径'); return }
if (!form.output_csv_path.trim()) { ElMessage.warning('请输入输出CSV文件路径'); return }
submitting.value = true submitting.value = true
try { try {
await createExperiment({ await createExperiment({
name: form.name.trim(), name: form.name.trim(),
model_id: form.model_id, model_id: form.model_id,
package_id: form.package_id, input_csv_path: form.input_csv_path.trim(),
output_csv_path: form.output_csv_path.trim(),
target_temp: form.target_temp, target_temp: form.target_temp,
sampling_interval: form.sampling_interval, sampling_interval: form.sampling_interval,
mpc_params: { ...form.mpc_params }, mpc_params: { ...form.mpc_params },
...@@ -133,7 +133,7 @@ onMounted(loadExperiments) ...@@ -133,7 +133,7 @@ onMounted(loadExperiments)
<el-card shadow="hover" class="page-card"> <el-card shadow="hover" class="page-card">
<template #header> <template #header>
<div class="card-header-row"> <div class="card-header-row">
<span class="card-title">实时监控试验</span> <span class="card-title">试验管理</span>
<div class="header-actions"> <div class="header-actions">
<el-button :icon="Refresh" size="small" plain :loading="loading" @click="loadExperiments"> <el-button :icon="Refresh" size="small" plain :loading="loading" @click="loadExperiments">
刷新 刷新
...@@ -148,7 +148,7 @@ onMounted(loadExperiments) ...@@ -148,7 +148,7 @@ onMounted(loadExperiments)
<el-table :data="experiments" v-loading="loading" stripe> <el-table :data="experiments" v-loading="loading" stripe>
<el-table-column prop="name" label="试验名称" min-width="160" /> <el-table-column prop="name" label="试验名称" min-width="160" />
<el-table-column prop="model_name" label="预测模型" min-width="140" /> <el-table-column prop="model_name" label="预测模型" min-width="140" />
<el-table-column prop="package_name" label="数据包" min-width="140" /> <el-table-column prop="input_csv_path" label="输入CSV" min-width="180" show-overflow-tooltip />
<el-table-column prop="target_temp" label="目标温度(°C)" width="120" align="right"> <el-table-column prop="target_temp" label="目标温度(°C)" width="120" align="right">
<template #default="{ row }">{{ row.target_temp?.toFixed(1) }}</template> <template #default="{ row }">{{ row.target_temp?.toFixed(1) }}</template>
</el-table-column> </el-table-column>
...@@ -180,8 +180,8 @@ onMounted(loadExperiments) ...@@ -180,8 +180,8 @@ onMounted(loadExperiments)
</el-card> </el-card>
<!-- 创建试验对话框 --> <!-- 创建试验对话框 -->
<el-dialog v-model="dialogVisible" title="创建试验" width="600px" :close-on-click-modal="false"> <el-dialog v-model="dialogVisible" title="创建试验" width="640px" :close-on-click-modal="false">
<el-form label-width="130px" @submit.prevent> <el-form label-width="140px" @submit.prevent>
<el-form-item label="试验名称" required> <el-form-item label="试验名称" required>
<el-input v-model="form.name" placeholder="请输入试验名称" /> <el-input v-model="form.name" placeholder="请输入试验名称" />
</el-form-item> </el-form-item>
...@@ -193,10 +193,19 @@ onMounted(loadExperiments) ...@@ -193,10 +193,19 @@ onMounted(loadExperiments)
</el-option> </el-option>
</el-select> </el-select>
</el-form-item> </el-form-item>
<el-form-item label="初始数据包" required> <el-form-item label="输入CSV路径" required>
<el-select v-model="form.package_id" placeholder="选择数据包" style="width:100%"> <el-input
<el-option v-for="p in packages" :key="p.id" :label="p.name" :value="p.id" /> v-model="form.input_csv_path"
</el-select> placeholder="传感器数据源 CSV 文件的绝对路径"
/>
<div class="field-hint">每步控制将从该文件末尾读取最新传感器数据</div>
</el-form-item>
<el-form-item label="输出CSV路径" required>
<el-input
v-model="form.output_csv_path"
placeholder="生成曲线写入的 CSV 文件绝对路径"
/>
<div class="field-hint">MPC 每步控制结果(温度、电流)将追加写入该文件</div>
</el-form-item> </el-form-item>
<el-form-item label="目标温度(°C)" required> <el-form-item label="目标温度(°C)" required>
<el-input-number v-model="form.target_temp" :step="0.5" :precision="1" style="width:160px" /> <el-input-number v-model="form.target_temp" :step="0.5" :precision="1" style="width:160px" />
...@@ -210,7 +219,7 @@ onMounted(loadExperiments) ...@@ -210,7 +219,7 @@ onMounted(loadExperiments)
:precision="1" :precision="1"
style="width:160px" style="width:160px"
/> />
<span style="margin-left:8px;color:#94a3b8;font-size:12px">每步等待时长,需与传感器采集频率一致</span> <span style="margin-left:8px;color:#94a3b8;font-size:12px">需与传感器采集频率一致</span>
</el-form-item> </el-form-item>
<!-- MPC 参数(高级) --> <!-- MPC 参数(高级) -->
...@@ -290,9 +299,17 @@ onMounted(loadExperiments) ...@@ -290,9 +299,17 @@ onMounted(loadExperiments)
gap: 8px; gap: 8px;
} }
.field-hint {
font-size: 12px;
color: #94a3b8;
margin-top: 4px;
line-height: 1.4;
}
.params-grid { .params-grid {
display: grid; display: grid;
grid-template-columns: 1fr 1fr; grid-template-columns: 1fr 1fr;
gap: 0 16px; gap: 0 16px;
} }
</style> </style>
...@@ -12,7 +12,7 @@ export default defineConfig({ ...@@ -12,7 +12,7 @@ export default defineConfig({
server: { server: {
proxy: { proxy: {
'/api': { '/api': {
target: 'http://127.0.0.1:8000', target: 'http://127.0.0.1:8002',
changeOrigin: true, changeOrigin: true,
}, },
}, },
......
...@@ -66,6 +66,9 @@ ALTER TABLE train_tasks ...@@ -66,6 +66,9 @@ ALTER TABLE train_tasks
ALTER TABLE saved_models ALTER TABLE saved_models
ADD COLUMN IF NOT EXISTS test_loss FLOAT NULL COMMENT '测试集损失'; ADD COLUMN IF NOT EXISTS test_loss FLOAT NULL COMMENT '测试集损失';
ALTER TABLE saved_models
ADD COLUMN IF NOT EXISTS description TEXT NULL COMMENT '模型说明';
CREATE TABLE IF NOT EXISTS data_quality_config ( CREATE TABLE IF NOT EXISTS data_quality_config (
id BIGINT PRIMARY KEY AUTO_INCREMENT, id BIGINT PRIMARY KEY AUTO_INCREMENT,
field_name VARCHAR(50) NOT NULL COMMENT '字段名: current/voltage/set_temperature/actual_temperature', field_name VARCHAR(50) NOT NULL COMMENT '字段名: current/voltage/set_temperature/actual_temperature',
...@@ -83,3 +86,25 @@ INSERT IGNORE INTO data_quality_config (field_name, label, unit) VALUES ...@@ -83,3 +86,25 @@ INSERT IGNORE INTO data_quality_config (field_name, label, unit) VALUES
('voltage', '电压', 'V'), ('voltage', '电压', 'V'),
('set_temperature', '设定温度', '℃'), ('set_temperature', '设定温度', '℃'),
('actual_temperature', '实际温度', '℃'); ('actual_temperature', '实际温度', '℃');
ALTER TABLE eval_records
ADD COLUMN IF NOT EXISTS mse FLOAT NULL COMMENT '均方误差';
ALTER TABLE eval_records
ADD COLUMN IF NOT EXISTS mape FLOAT NULL COMMENT '平均绝对百分比误差(%)';
ALTER TABLE eval_records
ADD COLUMN IF NOT EXISTS r2 FLOAT NULL COMMENT '决定系数';
ALTER TABLE monitor_experiments
ADD COLUMN IF NOT EXISTS input_csv_path TEXT NULL COMMENT '输入CSV路径(传感器数据源)';
ALTER TABLE monitor_experiments
ADD COLUMN IF NOT EXISTS output_csv_path TEXT NULL COMMENT '输出CSV路径(生成曲线写入)';
ALTER TABLE monitor_experiments
MODIFY COLUMN package_id BIGINT NULL COMMENT '初始数据包ID(旧版兼容)';
ALTER TABLE monitor_experiments
MODIFY COLUMN package_name VARCHAR(255) NULL COMMENT '数据包名称(旧版兼容)';
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment