Commit d8561934 authored by luwei's avatar luwei

修改

parent d7a073ab
......@@ -91,6 +91,26 @@ def download_template():
)
class FileUpdateRequest(BaseModel):
filename: str | None = Field(default=None, min_length=1, max_length=255)
remark: str | None = Field(default=None, max_length=500)
category_id: str | None = Field(default=None)
@router.put('/files/{file_id}')
def update_file(file_id: str, request: FileUpdateRequest):
try:
result = service.update_file(
file_id=file_id,
filename=request.filename,
remark=request.remark,
category_id=request.category_id,
)
return success_response(data=result, message='文件更新成功')
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
@router.delete('/files/{file_id}')
def delete_file(file_id: str):
try:
......@@ -115,9 +135,9 @@ def get_file_quality(file_id: str):
@router.get('/files/{file_id}/records')
def get_file_records(file_id: str, limit: int = Query(default=500, ge=1, le=5000)):
def get_file_records(file_id: str):
try:
result = service.get_file_records(file_id=file_id, limit=limit)
result = service.get_file_records(file_id=file_id, limit=None)
return success_response(data=result)
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from app.services.eval_service import EvalService
......@@ -14,8 +14,11 @@ class EvalRequest(BaseModel):
@router.get('/packages')
def list_packages():
return success_response(data=service.list_packages())
def list_packages(
category_id: str = Query(default=''),
name: str = Query(default=''),
):
return success_response(data=service.list_packages(category_id=category_id, name=name.strip()))
@router.get('/models')
......
......@@ -28,7 +28,8 @@ class MPCParamsSchema(BaseModel):
class CreateExperimentRequest(BaseModel):
name: str = Field(min_length=1, max_length=255)
model_id: int
package_id: int
input_csv_path: str = Field(min_length=1, description='读取传感器数据的CSV文件路径')
output_csv_path: str = Field(min_length=1, description='输出曲线数据的CSV文件路径')
target_temp: float = Field(description='目标温度(°C)')
sampling_interval: float = Field(default=1.0, gt=0, le=3600, description='采样周期(秒)')
mpc_params: MPCParamsSchema = Field(default_factory=MPCParamsSchema)
......@@ -59,7 +60,8 @@ def create_experiment(req: CreateExperimentRequest):
exp = service.create_experiment(
name=req.name.strip(),
model_id=req.model_id,
package_id=req.package_id,
input_csv_path=req.input_csv_path.strip(),
output_csv_path=req.output_csv_path.strip(),
target_temp=req.target_temp,
sampling_interval=req.sampling_interval,
mpc_params=req.mpc_params.model_dump(),
......@@ -113,25 +115,5 @@ def get_data_points(exp_id: int, from_step: int = Query(default=0, ge=0)):
return success_response(data=service.get_data_points(exp_id, from_step))
@router.get('/experiments/{exp_id}/report')
def get_report(exp_id: int):
try:
return success_response(data=service.get_report(exp_id))
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
@router.post('/experiments/{exp_id}/export')
def export_to_history(exp_id: int):
try:
result = service.export_to_history(exp_id)
return success_response(data=result, message='已导出到历史数据')
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
# ── 历史数据 ──────────────────────────────────────────────────────────────────
@router.get('/history')
def list_history():
return success_response(data=service.list_history_experiments())
......@@ -20,6 +20,7 @@ class CategoryUpdateRequest(BaseModel):
class CleanRules(BaseModel):
enabled: bool = False
newton_interp: bool = False
current_min: float | None = None
current_max: float | None = None
voltage_min: float | None = None
......@@ -28,6 +29,11 @@ class CleanRules(BaseModel):
temperature_max: float | None = None
class SmoothConfig(BaseModel):
enabled: bool = False
window: int = Field(default=5, ge=2, le=500)
@router.get('/categories')
def get_categories():
return success_response(data=service.get_category_tree())
......@@ -63,8 +69,22 @@ def delete_category(category_id: str):
# Must be declared before /{package_id} routes to avoid path conflict
@router.get('/data-files')
def list_all_data_files():
return success_response(data=service.list_all_data_files())
def list_all_data_files(
category_id: str = Query(default=''),
filename: str = Query(default=''),
remark: str = Query(default=''),
):
return success_response(data=service.list_all_data_files(
category_id=category_id,
filename=filename.strip(),
remark=remark.strip(),
))
class PackageUpdateRequest(BaseModel):
name: str = Field(min_length=1, max_length=255)
category_id: str | int | None = Field(default=None)
remark: str | None = Field(default=None)
class PackageCreateRequest(BaseModel):
......@@ -73,6 +93,8 @@ class PackageCreateRequest(BaseModel):
remark: str | None = Field(default=None)
file_ids: list[int] = Field(default_factory=list)
clean_rules: CleanRules | None = Field(default=None)
smooth: SmoothConfig | None = Field(default=None)
auto_split: bool = Field(default=False)
row_start: int | None = Field(default=None, ge=1)
row_end: int | None = Field(default=None, ge=1)
......@@ -80,20 +102,23 @@ class PackageCreateRequest(BaseModel):
class PreviewRequest(BaseModel):
file_ids: list[int] = Field(default_factory=list)
clean_rules: CleanRules | None = Field(default=None)
smooth: SmoothConfig | None = Field(default=None)
row_start: int | None = Field(default=None, ge=1)
row_end: int | None = Field(default=None, ge=1)
@router.post('/preview')
def preview_package(request: PreviewRequest, limit: int = Query(default=300, ge=1, le=2000)):
def preview_package(request: PreviewRequest):
try:
if request.row_start and request.row_end and request.row_start > request.row_end:
raise ValueError('起始行不能大于结束行')
clean_rules = request.clean_rules.model_dump() if request.clean_rules else None
smooth = request.smooth.model_dump() if request.smooth else None
result = service.preview_records(
file_ids=request.file_ids,
limit=limit,
limit=None,
clean_rules=clean_rules,
smooth=smooth,
row_start=request.row_start,
row_end=request.row_end,
)
......@@ -117,12 +142,27 @@ def create_package(request: PackageCreateRequest):
raise ValueError('起始行不能大于结束行')
category_id = None if request.category_id in (None, '', 'all') else str(request.category_id)
clean_rules = request.clean_rules.model_dump() if request.clean_rules else None
smooth = request.smooth.model_dump() if request.smooth else None
base_name = request.name.strip()
if request.auto_split:
pkgs = service.create_package_split(
name=base_name,
category_id=category_id,
remark=request.remark,
file_ids=request.file_ids,
clean_rules=clean_rules,
smooth=smooth,
row_start=request.row_start,
row_end=request.row_end,
)
return success_response(data=pkgs, message='数据包创建成功,已自动划分训练集/验证集/测试集')
pkg = service.create_package(
name=request.name.strip(),
name=base_name,
category_id=category_id,
remark=request.remark,
file_ids=request.file_ids,
clean_rules=clean_rules,
smooth=smooth,
row_start=request.row_start,
row_end=request.row_end,
)
......@@ -132,14 +172,29 @@ def create_package(request: PackageCreateRequest):
@router.get('/{package_id}/records')
def get_package_records(package_id: str, limit: int = Query(default=500, ge=1, le=5000)):
def get_package_records(package_id: str):
try:
result = service.get_package_records(package_id=package_id, limit=limit)
result = service.get_package_records(package_id=package_id, limit=None)
return success_response(data=result)
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
@router.put('/{package_id}')
def update_package(package_id: str, request: PackageUpdateRequest):
try:
category_id = None if request.category_id in (None, '', 'all') else str(request.category_id)
result = service.update_package(
package_id=package_id,
name=request.name.strip(),
category_id=category_id,
remark=request.remark,
)
return success_response(data=result, message='数据包更新成功')
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
@router.delete('/{package_id}')
def delete_package(package_id: str):
try:
......
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from app.services.train_service import TrainService
......@@ -21,15 +21,29 @@ class LSTMParams(BaseModel):
class CreateTaskRequest(BaseModel):
model_name: str = Field(min_length=1, max_length=255)
package_id: int
train_package_id: int
val_package_id: int
params: LSTMParams = Field(default_factory=LSTMParams)
class SaveModelRequest(BaseModel):
model_name: str | None = Field(default=None, max_length=255)
description: str | None = None
class UpdateModelRequest(BaseModel):
model_name: str | None = Field(default=None, max_length=255)
description: str | None = None
# ── packages ──────────────────────────────────────────────────────────────────
@router.get('/packages')
def list_packages():
return success_response(data=service.list_packages())
def list_packages(
category_id: str = Query(default=''),
name: str = Query(default=''),
):
return success_response(data=service.list_packages(category_id=category_id, name=name.strip()))
# ── tasks ─────────────────────────────────────────────────────────────────────
......@@ -52,7 +66,8 @@ def create_task(request: CreateTaskRequest):
try:
task = service.create_task(
model_name=request.model_name.strip(),
package_id=request.package_id,
train_package_id=request.train_package_id,
val_package_id=request.val_package_id,
params=request.params.model_dump(),
)
return success_response(data=task, message='训练任务已启动')
......@@ -88,9 +103,9 @@ def delete_task(task_id: int):
@router.post('/tasks/{task_id}/save')
def save_model(task_id: int):
def save_model(task_id: int, request: SaveModelRequest):
try:
model = service.save_model(task_id)
model = service.save_model(task_id, model_name=request.model_name, description=request.description)
return success_response(data=model, message='模型已保存')
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
......@@ -110,3 +125,15 @@ def delete_model(model_id: int):
return success_response(data=True, message='模型已删除')
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
@router.patch('/models/{model_id}')
def update_model(model_id: int, request: UpdateModelRequest):
try:
model = service.update_saved_model(
model_id,
model_name=request.model_name,
description=request.description,
)
return success_response(data=model, message='已更新')
except ValueError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
\ No newline at end of file
......@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy import text
from app.api.data_management import router as data_management_router
from app.api.eval_management import router as eval_management_router
......@@ -16,10 +17,29 @@ from app.models.train_management import SavedModel, TrainTask # noqa: F401
from app.utils.response import error_response, success_response
def _run_migrations() -> None:
"""Add new columns to existing tables without dropping data."""
migrations = [
"ALTER TABLE train_tasks ADD COLUMN val_package_id BIGINT NULL COMMENT '\u9a8c\u8bc1\u96c6\u6570\u636e\u5305ID' AFTER package_name",
"ALTER TABLE train_tasks ADD COLUMN val_package_name VARCHAR(255) NULL COMMENT '\u9a8c\u8bc1\u96c6\u6570\u636e\u5305\u540d\u79f0' AFTER val_package_id",
"ALTER TABLE train_tasks ADD COLUMN epoch_logs JSON NULL COMMENT '\u6bcf\u8f6e\u8bad\u7ec3\u65e5\u5fd7' AFTER val_package_name",
]
with engine.connect() as conn:
for stmt in migrations:
try:
conn.execute(text(stmt))
conn.commit()
except Exception:
pass # Column already exists
def create_app() -> FastAPI:
# Auto-create any missing tables (safe: uses CREATE TABLE IF NOT EXISTS internally)
Base.metadata.create_all(bind=engine)
# Safe column migrations for existing tables
_run_migrations()
app = FastAPI(
title='Thermal Control System API',
version='0.1.0',
......
......@@ -96,7 +96,22 @@ def predict_lstm(
# ── metrics ───────────────────────────────────────────────────────────────
errors = preds_real - actuals_real
mae = float(np.mean(np.abs(errors)))
rmse = float(math.sqrt(float(np.mean(errors ** 2))))
mse = float(np.mean(errors ** 2))
rmse = float(math.sqrt(mse))
# MAPE — skip points where actual == 0 to avoid division by zero
nonzero_mask = actuals_real != 0
if nonzero_mask.any():
mape: float | None = float(
np.mean(np.abs(errors[nonzero_mask] / actuals_real[nonzero_mask])) * 100
)
else:
mape = None
# R² (coefficient of determination)
ss_res = float(np.sum(errors ** 2))
ss_tot = float(np.sum((actuals_real - float(np.mean(actuals_real))) ** 2))
r2: float | None = (1.0 - ss_res / ss_tot) if ss_tot != 0 else None
# ── build result points ───────────────────────────────────────────────────
times = [str(r.get('time', '')) for r in records]
......@@ -115,6 +130,8 @@ def predict_lstm(
'time': times[seq_len + i] if (seq_len + i) < len(times) else str(seq_len + i),
'actual': round(float(actuals_real[i]), 4),
'predicted': round(float(preds_real[i]), 4),
'current': round(float(raw[seq_len + i, 0]), 4),
'voltage': round(float(raw[seq_len + i, 1]), 4),
}
for i in sample_idx
]
......@@ -122,6 +139,9 @@ def predict_lstm(
return {
'total_count': total,
'mae': round(mae, 6),
'mse': round(mse, 6),
'rmse': round(rmse, 6),
'mape': round(mape, 4) if mape is not None else None,
'r2': round(r2, 6) if r2 is not None else None,
'chart_data': chart_data,
}
......@@ -44,11 +44,6 @@ FEATURE_COLS = ['current', 'voltage', 'set_temperature', 'actual_temperature']
TARGET_COL = 'actual_temperature'
TARGET_IDX = FEATURE_COLS.index(TARGET_COL)
# Fixed dataset split ratios (train / val / test)
_TRAIN_RATIO = 0.70
_VAL_RATIO = 0.15
# test = 1 - _TRAIN_RATIO - _VAL_RATIO (≈ 0.15)
def _check_torch() -> None:
if not _TORCH_AVAILABLE:
......@@ -88,12 +83,137 @@ def _make_sequences(data: np.ndarray, seq_len: int) -> tuple[np.ndarray, np.ndar
# ── public training entry point ───────────────────────────────────────────────
def train_lstm(
records: list[dict],
train_records: list[dict],
val_records: list[dict],
params: dict,
save_path: Path,
on_progress: Callable[[int, float, float | None], None],
on_progress: Callable[[int, int, float, float | None], None],
cancel_event: threading.Event,
) -> dict[str, float | None]:
"""
Train an LSTM model on *train_records* validated against *val_records*.
Args:
train_records: list of dicts for the training set.
val_records: list of dicts for the validation set.
params: hyper-parameter dict (seq_len, hidden_size, num_layers,
epochs, batch_size, learning_rate).
save_path: destination .pt file.
on_progress: callback(pct, epoch, train_loss, val_loss) called after each epoch.
cancel_event: when set, training stops with InterruptedError.
Returns:
{'train_loss': float, 'val_loss': float|None}
"""
_check_torch()
seq_len = max(1, int(params.get('seq_len', 20)))
hidden_size = max(1, int(params.get('hidden_size', 64)))
num_layers = max(1, int(params.get('num_layers', 2)))
epochs = max(1, int(params.get('epochs', 50)))
batch_size = max(1, int(params.get('batch_size', 32)))
lr = float(params.get('learning_rate', 0.001))
# ── data preparation ────────────────────────────────────────────────────
train_data = _extract_features(train_records)
val_data = _extract_features(val_records)
min_required = seq_len + 10
if len(train_data) < min_required:
raise ValueError(
f'\u8bad\u7ec3\u96c6\u6709\u6548\u6570\u636e\u91cf\u4e0d\u8db3\uff1a\u9700\u81f3\u5c11 {min_required} \u6761\uff0c\u5f53\u524d\u4ec5 {len(train_data)} \u6761\u3002'
'\u8bf7\u68c0\u67e5\u6570\u636e\u5305\u5185\u5bb9\u6216\u51cf\u5c0f\u5e8f\u5217\u957f\u5ea6\u3002'
)
# min-max normalisation fitted on train set only
data_min = train_data.min(axis=0)
data_max = train_data.max(axis=0)
data_range = data_max - data_min
data_range[data_range == 0] = 1.0
train_norm = (train_data - data_min) / data_range
X_train, y_train = _make_sequences(train_norm, seq_len)
has_val = len(val_data) >= seq_len + 1
X_val_t = y_val_t = None
if has_val:
val_norm = (val_data - data_min) / data_range
X_val, y_val = _make_sequences(val_norm, seq_len)
if len(X_val) == 0:
has_val = False
else:
X_val_t = torch.tensor(X_val)
y_val_t = torch.tensor(y_val)
device = torch.device('cpu')
X_train_t = torch.tensor(X_train).to(device)
y_train_t = torch.tensor(y_train).to(device)
train_loader = DataLoader(
TensorDataset(X_train_t, y_train_t),
batch_size=batch_size,
shuffle=True,
)
# ── model ────────────────────────────────────────────────────────────────
input_size = len(FEATURE_COLS)
model = _LSTMModel(input_size, hidden_size, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
train_loss = 0.0
val_loss: float | None = None
for epoch in range(epochs):
if cancel_event.is_set():
raise InterruptedError('\u8bad\u7ec3\u5df2\u53d6\u6d88')
# ── train step ───────────────────────────────────────────────────────
model.train()
epoch_loss = 0.0
for xb, yb in train_loader:
optimizer.zero_grad()
pred = model(xb)
loss = criterion(pred, yb)
loss.backward()
optimizer.step()
epoch_loss += loss.item() * len(xb)
train_loss = epoch_loss / len(X_train)
# ── val step ─────────────────────────────────────────────────────────
if has_val:
model.eval()
with torch.no_grad():
val_pred = model(X_val_t)
val_loss = criterion(val_pred, y_val_t).item()
pct = int((epoch + 1) / epochs * 100)
on_progress(pct, epoch + 1, train_loss, val_loss)
# ── persist ──────────────────────────────────────────────────────────────
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
'model_state': model.state_dict(),
'params': params,
'data_min': data_min.tolist(),
'data_max': data_max.tolist(),
'feature_cols': FEATURE_COLS,
'target_col': TARGET_COL,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'seq_len': seq_len,
},
save_path,
)
return {
'train_loss': round(float(train_loss), 6),
'val_loss': round(float(val_loss), 6) if val_loss is not None else None,
}
"""
Train an LSTM model on *records* and persist it to *save_path*.
......
......@@ -19,7 +19,10 @@ class EvalRecord(Base):
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称')
total_count: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text('0'), comment='评估数据点总数')
mae: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='平均绝对误差')
mse: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='均方误差')
rmse: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='均方根误差')
mape: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='平均绝对百分比误差(%)')
r2: Mapped[float | None] = mapped_column(FLOAT, nullable=True, comment='决定系数')
# Store up to ~2000 sampled points for chart rendering
chart_data: Mapped[list | None] = mapped_column(JSON, nullable=True, comment='图表数据(采样)')
created_at: Mapped[str] = mapped_column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'))
......@@ -15,8 +15,10 @@ class MonitorExperiment(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False, comment='试验名称')
model_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='模型ID')
model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称')
package_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='初始数据包ID')
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称')
package_id: Mapped[int | None] = mapped_column(BIGINT, nullable=True, comment='初始数据包ID(旧版兼容)')
package_name: Mapped[str | None] = mapped_column(String(255), nullable=True, comment='数据包名称(旧版兼容)')
input_csv_path: Mapped[str | None] = mapped_column(Text, nullable=True, comment='输入CSV路径(传感器数据源)')
output_csv_path: Mapped[str | None] = mapped_column(Text, nullable=True, comment='输出CSV路径(生成曲线写入)')
target_temp: Mapped[float] = mapped_column(FLOAT, nullable=False, comment='目标温度(°C)')
mpc_params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='MPC参数')
status: Mapped[str] = mapped_column(
......
......@@ -13,8 +13,11 @@ class TrainTask(Base):
id: Mapped[int] = mapped_column(BIGINT, primary_key=True, autoincrement=True)
model_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='模型名称')
package_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='数据包ID')
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称')
package_id: Mapped[int] = mapped_column(BIGINT, nullable=False, comment='训练集数据包ID')
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='训练集数据包名称')
val_package_id: Mapped[int | None] = mapped_column(BIGINT, nullable=True, comment='验证集数据包ID')
val_package_name: Mapped[str | None] = mapped_column(String(255), nullable=True, comment='验证集数据包名称')
epoch_logs: Mapped[list | None] = mapped_column(JSON, nullable=True, comment='每轮训练日志')
params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数')
status: Mapped[str] = mapped_column(
Enum('pending', 'running', 'completed', 'failed', 'cancelled', name='train_status_enum'),
......@@ -50,6 +53,7 @@ class SavedModel(Base):
package_name: Mapped[str] = mapped_column(String(255), nullable=False, comment='数据包名称')
params: Mapped[dict] = mapped_column(JSON, nullable=False, comment='LSTM超参数')
file_path: Mapped[str] = mapped_column(String(500), nullable=False, comment='模型文件路径')
description: Mapped[str | None] = mapped_column(Text, nullable=True, comment='模型说明')
train_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True)
val_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True)
test_loss: Mapped[float | None] = mapped_column(FLOAT, nullable=True)
......
......@@ -152,6 +152,7 @@ class DataManagementService:
'category_id': item.category_id,
'uploaded_at': item.uploaded_at.strftime('%Y-%m-%d %H:%M:%S') if item.uploaded_at else '',
'data_count': item.data_count,
'remark': item.remark or '',
}
for item in rows
]
......@@ -212,6 +213,46 @@ class DataManagementService:
target_path.unlink()
raise
def update_file(self, file_id: str, filename: str | None = None, remark: str | None = None, category_id: str | None = None) -> dict[str, Any]:
file_db_id = self._parse_int_id(file_id, '文件ID')
if category_id is not None:
cat_db_id = self._parse_int_id(category_id, '分类ID')
with db_session() as session:
matched = session.query(DataFile).filter(DataFile.id == file_db_id).first()
if not matched:
raise ValueError('文件不存在')
if filename is not None:
stripped = filename.strip()
if not stripped:
raise ValueError('文件名不能为空')
matched.filename = stripped
if remark is not None:
matched.remark = remark.strip() or None
if category_id is not None:
category = (
session.query(Category)
.filter(Category.id == cat_db_id, Category.type == 'data_file')
.first()
)
if not category:
raise ValueError('分类不存在')
matched.category_id = cat_db_id
session.commit()
session.refresh(matched)
return {
'id': matched.id,
'filename': matched.filename,
'remark': matched.remark or '',
'category_id': matched.category_id,
}
def delete_file(self, file_id: str) -> None:
file_db_id = self._parse_int_id(file_id, '文件ID')
......@@ -227,7 +268,7 @@ class DataManagementService:
session.delete(matched)
session.commit()
def get_file_records(self, file_id: str, limit: int = 500) -> dict[str, Any]:
def get_file_records(self, file_id: str, limit: int | None = None) -> dict[str, Any]:
file_db_id = self._parse_int_id(file_id, '文件ID')
with db_session() as session:
......
......@@ -5,7 +5,7 @@ from typing import Any
from app.database import db_session
from app.ml.lstm_predictor import predict_lstm
from app.models import DataFile, DataPackage, DataPackageFile
from app.models import DataPackage
from app.models.eval_management import EvalRecord
from app.models.train_management import SavedModel
from app.services.data_management_service import DataManagementService
......@@ -24,10 +24,19 @@ class EvalService:
# ── dropdown data ─────────────────────────────────────────────────────────
def list_packages(self) -> list[dict[str, Any]]:
def list_packages(self, category_id: str = '', name: str = '') -> list[dict[str, Any]]:
with db_session() as session:
rows = session.query(DataPackage).order_by(DataPackage.created_at.desc()).all()
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count} for p in rows]
query = session.query(DataPackage)
if category_id not in ('', 'all', None):
try:
db_id = int(category_id)
query = query.filter(DataPackage.category_id == db_id)
except (ValueError, TypeError):
pass
if name:
query = query.filter(DataPackage.name.like(f'%{name}%'))
rows = query.order_by(DataPackage.created_at.desc()).all()
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count, 'category_id': p.category_id} for p in rows]
def list_saved_models(self) -> list[dict[str, Any]]:
with db_session() as session:
......@@ -67,7 +76,10 @@ class EvalService:
package_name=package_name,
total_count=result['total_count'],
mae=result['mae'],
mse=result['mse'],
rmse=result['rmse'],
mape=result['mape'],
r2=result['r2'],
chart_data=result['chart_data'],
)
session.add(record)
......@@ -106,32 +118,13 @@ class EvalService:
def _load_package_records(self, package_id: int) -> list[dict[str, Any]]:
with db_session() as session:
pkg = session.query(DataPackage).filter(DataPackage.id == package_id).first()
clean_rules = pkg.clean_rules if pkg else None
pf_rows = (
session.query(DataPackageFile)
.filter(DataPackageFile.package_id == package_id)
.order_by(DataPackageFile.sort_order.asc())
.all()
)
file_ids = [pf.file_id for pf in pf_rows]
if not file_ids:
if not pkg or not pkg.stored_name:
return []
files = session.query(DataFile).filter(DataFile.id.in_(file_ids)).all()
file_map = {f.id: f for f in files}
all_records: list[dict[str, Any]] = []
for fid in file_ids:
if fid not in file_map:
continue
fmeta = file_map[fid]
path = self._dm._resolve_local_file_path(fmeta.file_path, fmeta.stored_name)
recs, _ = self._dm._read_records(path, limit=None)
all_records.extend(recs)
if clean_rules and clean_rules.get('enabled'):
all_records = self._apply_clean_rules(all_records, clean_rules)
stored_name = pkg.stored_name
return all_records
pkg_path = self._base_dir / 'uploads' / 'packages' / stored_name
recs, _ = self._dm._read_records(pkg_path, limit=None)
return recs
@staticmethod
def _apply_clean_rules(records: list[dict[str, Any]], clean_rules: dict) -> list[dict[str, Any]]:
......@@ -177,7 +170,10 @@ class EvalService:
'package_name': row.package_name,
'total_count': row.total_count,
'mae': row.mae,
'mse': row.mse,
'rmse': row.rmse,
'mape': row.mape,
'r2': row.r2,
'created_at': row.created_at.strftime('%Y-%m-%d %H:%M:%S') if row.created_at else '',
}
if include_chart:
......
This diff is collapsed.
......@@ -29,14 +29,19 @@ class TrainService:
# ── packages (for dropdown) ──────────────────────────────────────────────
def list_packages(self) -> list[dict[str, Any]]:
def list_packages(self, category_id: str = '', name: str = '') -> list[dict[str, Any]]:
with db_session() as session:
rows = (
session.query(DataPackage)
.order_by(DataPackage.created_at.desc())
.all()
)
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count} for p in rows]
query = session.query(DataPackage)
if category_id not in ('', 'all', None):
try:
db_id = int(category_id)
query = query.filter(DataPackage.category_id == db_id)
except (ValueError, TypeError):
pass
if name:
query = query.filter(DataPackage.name.like(f'%{name}%'))
rows = query.order_by(DataPackage.created_at.desc()).all()
return [{'id': p.id, 'name': p.name, 'data_count': p.data_count, 'category_id': p.category_id} for p in rows]
# ── tasks ────────────────────────────────────────────────────────────────
......@@ -52,16 +57,21 @@ class TrainService:
raise ValueError('训练任务不存在')
return self._task_to_dict(task)
def create_task(self, model_name: str, package_id: int, params: dict) -> dict[str, Any]:
def create_task(self, model_name: str, train_package_id: int, val_package_id: int, params: dict) -> dict[str, Any]:
with db_session() as session:
pkg = session.query(DataPackage).filter(DataPackage.id == package_id).first()
if not pkg:
raise ValueError('数据包不存在')
train_pkg = session.query(DataPackage).filter(DataPackage.id == train_package_id).first()
if not train_pkg:
raise ValueError('训练集数据包不存在')
val_pkg = session.query(DataPackage).filter(DataPackage.id == val_package_id).first()
if not val_pkg:
raise ValueError('验证集数据包不存在')
task = TrainTask(
model_name=model_name,
package_id=package_id,
package_name=pkg.name,
package_id=train_package_id,
package_name=train_pkg.name,
val_package_id=val_package_id,
val_package_name=val_pkg.name,
params=params,
status='pending',
progress=0,
......@@ -71,7 +81,7 @@ class TrainService:
session.refresh(task)
task_dict = self._task_to_dict(task)
self._launch_thread(task_dict['id'], package_id, params)
self._launch_thread(task_dict['id'], train_package_id, val_package_id, params)
return task_dict
def cancel_task(self, task_id: int) -> None:
......@@ -100,6 +110,8 @@ class TrainService:
model_name=task.model_name,
package_id=task.package_id,
package_name=task.package_name,
val_package_id=task.val_package_id,
val_package_name=task.val_package_name,
params=task.params,
status='pending',
progress=0,
......@@ -109,7 +121,7 @@ class TrainService:
session.refresh(new_task)
task_dict = self._task_to_dict(new_task)
self._launch_thread(task_dict['id'], task_dict['package_id'], task_dict['params'])
self._launch_thread(task_dict['id'], task_dict['package_id'], task_dict['val_package_id'], task_dict['params'])
return task_dict
def delete_task(self, task_id: int) -> None:
......@@ -128,7 +140,7 @@ class TrainService:
session.delete(task)
session.commit()
def save_model(self, task_id: int) -> dict[str, Any]:
def save_model(self, task_id: int, model_name: str | None = None, description: str | None = None) -> dict[str, Any]:
with db_session() as session:
task = session.query(TrainTask).filter(TrainTask.id == task_id).first()
if not task:
......@@ -144,7 +156,8 @@ class TrainService:
saved = SavedModel(
task_id=task_id,
model_name=task.model_name,
model_name=(model_name.strip() if model_name and model_name.strip() else task.model_name),
description=description,
package_id=task.package_id,
package_name=task.package_name,
params=task.params,
......@@ -158,7 +171,18 @@ class TrainService:
session.commit()
session.refresh(saved)
return self._saved_to_dict(saved)
def update_saved_model(self, model_id: int, model_name: str | None = None, description: str | None = None) -> dict[str, Any]:
with db_session() as session:
model = session.query(SavedModel).filter(SavedModel.id == model_id).first()
if not model:
raise ValueError('模型不存在')
if model_name is not None and model_name.strip():
model.model_name = model_name.strip()
if description is not None:
model.description = description
session.commit()
session.refresh(model)
return self._saved_to_dict(model)
# ── saved models ─────────────────────────────────────────────────────────
def list_saved_models(self) -> list[dict[str, Any]]:
......@@ -193,14 +217,15 @@ class TrainService:
if p.is_absolute():
return p
return self._base_dir / p
def _launch_thread(self, task_id: int, package_id: int, params: dict) -> None:
def _launch_thread(self, task_id: int, train_package_id: int, val_package_id: int | None, params: dict) -> None:
cancel_event = threading.Event()
with _registry_lock:
_cancel_events[task_id] = cancel_event
thread = threading.Thread(
target=self._training_worker,
args=(task_id, package_id, params, cancel_event),
args=(task_id, train_package_id, val_package_id, params, cancel_event),
daemon=True,
name=f'train-task-{task_id}',
)
......@@ -209,35 +234,44 @@ class TrainService:
def _training_worker(
self,
task_id: int,
package_id: int,
train_package_id: int,
val_package_id: int | None,
params: dict,
cancel_event: threading.Event,
) -> None:
try:
self._update_task(task_id, status='running', progress=0)
records = self._load_package_records(package_id)
if not records:
raise ValueError('数据包没有有效数据,请检查关联文件')
train_records = self._load_package_records(train_package_id)
if not train_records:
raise ValueError('训练集数据包没有有效数据,请检查关联文件')
val_records: list[dict[str, Any]] = []
if val_package_id:
val_records = self._load_package_records(val_package_id)
save_path = self._models_dir / f'task_{task_id}.pt'
last_pct = [0]
epoch_log_buffer: list[dict[str, Any]] = []
def on_progress(pct: int, train_loss: float, val_loss: float | None) -> None:
def on_progress(pct: int, epoch: int, train_loss: float, val_loss: float | None) -> None:
if cancel_event.is_set():
return
# Throttle: persist at most every 2 % to reduce DB writes
if pct - last_pct[0] >= 2 or pct == 100:
last_pct[0] = pct
self._update_task(
task_id,
progress=pct,
train_loss=round(float(train_loss), 6),
val_loss=round(float(val_loss), 6) if val_loss is not None else None,
)
epoch_log_buffer.append({
'epoch': epoch,
'train_loss': round(float(train_loss), 6),
'val_loss': round(float(val_loss), 6) if val_loss is not None else None,
})
self._update_task(
task_id,
progress=pct,
train_loss=round(float(train_loss), 6),
val_loss=round(float(val_loss), 6) if val_loss is not None else None,
epoch_logs=list(epoch_log_buffer),
)
result = train_lstm(
records=records,
train_records=train_records,
val_records=val_records,
params=params,
save_path=save_path,
on_progress=on_progress,
......@@ -304,6 +338,9 @@ class TrainService:
'model_name': task.model_name,
'package_id': task.package_id,
'package_name': task.package_name,
'val_package_id': task.val_package_id,
'val_package_name': task.val_package_name,
'epoch_logs': task.epoch_logs or [],
'params': task.params,
'status': task.status,
'progress': task.progress,
......@@ -321,6 +358,7 @@ class TrainService:
'id': model.id,
'task_id': model.task_id,
'model_name': model.model_name,
'description': model.description or '',
'package_id': model.package_id,
'package_name': model.package_name,
'params': model.params,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
step,actual_temp,current_output,reference_temp,target_temp
0,20.534,0.0,23.4272,35.0
1,20.534,0.1796,23.4272,35.0
2,20.534,0.1796,23.4272,35.0
3,20.534,0.1949,23.4272,35.0
4,20.534,0.1949,23.4272,35.0
5,20.534,0.1949,23.4272,35.0
6,20.534,0.1949,23.4272,35.0
7,20.534,0.1949,23.4272,35.0
8,20.534,0.1354,23.4272,35.0
9,20.534,0.1354,23.4272,35.0
10,20.534,0.1354,23.4272,35.0
11,20.534,0.1354,23.4272,35.0
12,20.534,0.0,23.4272,35.0
13,20.534,0.1796,23.4272,35.0
14,20.534,0.1796,23.4272,35.0
15,20.534,0.1949,23.4272,35.0
16,20.534,0.1949,23.4272,35.0
17,20.534,0.1949,23.4272,35.0
18,20.534,0.1949,23.4272,35.0
19,20.534,0.1949,23.4272,35.0
20,20.534,0.1354,23.4272,35.0
21,20.534,0.1354,23.4272,35.0
22,20.534,0.1354,23.4272,35.0
23,20.534,0.1354,23.4272,35.0
24,20.534,0.0,23.4272,35.0
25,20.534,0.1796,23.4272,35.0
26,20.534,0.1796,23.4272,35.0
27,20.534,0.1949,23.4272,35.0
......@@ -11,10 +11,8 @@ const tabs = [
{ label: '模型训练', path: '/model-training' },
{ label: '模型库', path: '/model-list' },
{ label: '模型评估', path: '/model-evaluation' },
{ label: '实时监控', path: '/realtime-monitor' },
{ label: '历史数据', path: '/history-data' },
{ label: '试验管理', path: '/realtime-monitor' },
{ label: '实时监控', path: '/live-monitor' },
]
const activeTab = computed(() => {
......
......@@ -32,6 +32,10 @@ export function deleteDataFile(fileId) {
return request.delete(`/data/files/${fileId}`)
}
export function updateDataFile(fileId, payload) {
return request.put(`/data/files/${fileId}`, payload)
}
export function getFileRecords(fileId, params) {
return request.get(`/data/files/${fileId}/records`, { params })
}
......
import request from '@/utils/request'
export function getEvalPackages() {
return request.get('/eval/packages')
export function getEvalPackages(params = {}) {
return request.get('/eval/packages', { params })
}
export function getEvalModels() {
......
import request from '@/utils/request'
export function getHistoryList() {
return request.get('/monitor/history')
}
export function getHistoryDetail(expId) {
return request.get(`/monitor/experiments/${expId}/report`)
}
......@@ -20,8 +20,8 @@ export function deletePkgCategory(categoryId) {
// ── data files (for selection) ───────────────────────────────────────────────
export function getAllDataFiles() {
return request.get('/packages/data-files')
export function getAllDataFiles(params) {
return request.get('/packages/data-files', { params })
}
// ── packages ─────────────────────────────────────────────────────────────────
......@@ -34,6 +34,10 @@ export function createPackage(payload) {
return request.post('/packages', payload)
}
export function updatePackage(packageId, payload) {
return request.put(`/packages/${packageId}`, payload)
}
export function deletePackage(packageId) {
return request.delete(`/packages/${packageId}`)
}
......
......@@ -40,10 +40,4 @@ export function getDataPoints(expId, fromStep = 0) {
return request.get(`/monitor/experiments/${expId}/data`, { params: { from_step: fromStep } })
}
export function getReport(expId) {
return request.get(`/monitor/experiments/${expId}/report`)
}
export function exportToHistory(expId) {
return request.post(`/monitor/experiments/${expId}/export`)
}
......@@ -28,8 +28,8 @@ export function deleteTrainTask(taskId) {
return request.delete(`/train/tasks/${taskId}`)
}
export function saveTrainModel(taskId) {
return request.post(`/train/tasks/${taskId}/save`)
export function saveTrainModel(taskId, payload = {}) {
return request.post(`/train/tasks/${taskId}/save`, payload)
}
export function getSavedModels() {
......@@ -39,3 +39,7 @@ export function getSavedModels() {
export function deleteSavedModel(modelId) {
return request.delete(`/train/models/${modelId}`)
}
export function updateSavedModel(modelId, payload) {
return request.patch(`/train/models/${modelId}`, payload)
}
......@@ -23,9 +23,9 @@ const router = createRouter({
component: () => import('@/views/RealtimeMonitor/components/ExperimentDetail.vue'),
},
{
path: '/history-data',
name: 'history-data',
component: () => import('@/views/HistoryData/index.vue'),
path: '/live-monitor',
name: 'live-monitor',
component: () => import('@/views/LiveMonitor/index.vue'),
},
{
path: '/model-training',
......
......@@ -7,6 +7,10 @@ const props = defineProps({
type: Array,
default: () => [],
},
totalCount: {
type: Number,
default: null,
},
})
const chartRef = ref(null)
......@@ -225,7 +229,7 @@ onBeforeUnmount(() => {
<div class="stats-row">
<div class="stat-card">
<div class="stat-label">总条数</div>
<div class="stat-value">{{ stats.count }}</div>
<div class="stat-value">{{ props.totalCount ?? stats.count }}</div>
</div>
<div class="stat-card">
<div class="stat-label">最高温度</div>
......
<script setup>
import { Delete, Document, Edit, Folder, FolderOpened, Upload } from '@element-plus/icons-vue'
import { DataAnalysis, Delete, Document, Edit, Folder, FolderOpened, Upload } from '@element-plus/icons-vue'
import { computed, onBeforeUnmount, onMounted, reactive, ref } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { ElAutoResizer, ElMessage, ElMessageBox, ElTableV2 } from 'element-plus'
import {
createCategory,
deleteCategory,
......@@ -12,6 +12,7 @@ import {
getFileQuality,
getFileRecords,
updateCategory,
updateDataFile,
uploadDataFile,
} from '@/api/dataManagement'
import DataCurve from './components/DataCurve.vue'
......@@ -31,6 +32,21 @@ const recordList = ref([])
const contentLoading = ref(false)
const contentMode = ref('table')
const recordColumns = [
{ key: 'time', dataKey: 'time', title: '时间', width: 170 },
{ key: 'current', dataKey: 'current', title: '电流', width: 110 },
{ key: 'voltage', dataKey: 'voltage', title: '电压', width: 110 },
{ key: 'set_temperature', dataKey: 'set_temperature', title: '设定温度', width: 120 },
{ key: 'actual_temperature', dataKey: 'actual_temperature', title: '实际温度', width: 120 },
]
const curveRecords = computed(() => {
const r = recordList.value
if (r.length <= 2000) return r
const step = Math.ceil(r.length / 2000)
return r.filter((_, i) => i % step === 0)
})
const categoryDialogVisible = ref(false)
const categoryDialogMode = ref('create')
const editingCategoryId = ref('')
......@@ -48,6 +64,66 @@ const qualityLoading = ref(false)
const qualityFile = ref(null)
const qualityResult = ref(null)
const fileEditDialogVisible = ref(false)
const fileEditSubmitting = ref(false)
const editingFileId = ref(null)
const fileEditForm = reactive({
filename: '',
remark: '',
categoryId: '',
})
const openFileEditDialog = (row) => {
editingFileId.value = row.id
fileEditForm.filename = row.filename
fileEditForm.remark = row.remark || ''
fileEditForm.categoryId = row.category_id ? String(row.category_id) : ''
fileEditDialogVisible.value = true
}
const submitFileEdit = async () => {
const name = fileEditForm.filename.trim()
if (!name) {
ElMessage.warning('文件名不能为空')
return
}
if (!fileEditForm.categoryId) {
ElMessage.warning('请选择分类')
return
}
fileEditSubmitting.value = true
const prevCategoryId = fileList.value.find((f) => f.id === editingFileId.value)?.category_id
try {
const result = await updateDataFile(editingFileId.value, {
filename: name,
remark: fileEditForm.remark.trim() || null,
category_id: fileEditForm.categoryId,
})
ElMessage.success('保存成功')
fileEditDialogVisible.value = false
const categoryChanged = String(result.category_id) !== String(prevCategoryId)
if (categoryChanged) {
await loadFileList()
if (currentFile.value?.id === editingFileId.value) {
currentFile.value = null
recordList.value = []
}
} else {
const idx = fileList.value.findIndex((f) => f.id === editingFileId.value)
if (idx !== -1) {
fileList.value[idx].filename = result.filename
fileList.value[idx].remark = result.remark || ''
}
if (currentFile.value?.id === editingFileId.value) {
currentFile.value = { ...currentFile.value, filename: result.filename, remark: result.remark || '' }
}
}
} finally {
fileEditSubmitting.value = false
}
}
const qualityColorClass = (percent) => {
if (percent >= 90) return 'quality-good'
if (percent >= 70) return 'quality-warn'
......@@ -407,7 +483,7 @@ const handleViewFile = async (row) => {
currentFile.value = row
contentLoading.value = true
try {
const result = await getFileRecords(row.id, { limit: 500 })
const result = await getFileRecords(row.id)
recordList.value = result.records
} finally {
contentLoading.value = false
......@@ -506,15 +582,16 @@ onBeforeUnmount(() => {
</el-form-item>
</el-form>
<el-table :data="fileList" border stripe v-loading="loadingFiles" height="calc(100vh - 250px)">
<el-table :data="fileList" border stripe v-loading="loadingFiles" height="calc(100vh - 250px)" row-class-name="file-row" @row-click="handleViewFile">
<el-table-column prop="filename" label="文件名" min-width="160" />
<el-table-column prop="remark" label="备注" min-width="120" show-overflow-tooltip />
<el-table-column prop="uploaded_at" label="上传时间" min-width="170" />
<el-table-column prop="data_count" label="数据量" width="90" />
<el-table-column label="操作" width="200" fixed="right">
<el-table-column label="操作" width="220" fixed="right">
<template #default="{ row }">
<el-button link type="primary" @click="handleViewFile(row)">查看</el-button>
<el-button link type="success" @click="handleQualityCheck(row)">质量判定</el-button>
<el-button link type="danger" @click="handleDeleteFile(row)">删除</el-button>
<el-button link type="primary" :icon="Edit" @click.stop="openFileEditDialog(row)" />
<el-button link type="success" :icon="DataAnalysis" @click.stop="handleQualityCheck(row)" title="质量判定" />
<el-button link type="danger" :icon="Delete" @click.stop="handleDeleteFile(row)" />
</template>
</el-table-column>
</el-table>
......@@ -528,6 +605,7 @@ onBeforeUnmount(() => {
<span>
文件内容
<span v-if="currentFile" class="file-title">- {{ currentFile.filename }}</span>
<span v-if="currentFile && currentFile.remark" class="file-remark">{{ currentFile.remark }}</span>
</span>
<el-radio-group v-model="contentMode" size="small">
<el-radio-button value="table">表格</el-radio-button>
......@@ -539,21 +617,21 @@ onBeforeUnmount(() => {
<div class="content-wrap" v-loading="contentLoading">
<el-empty v-if="!currentFile" description="请选择并查看一个文件" />
<el-table
v-else-if="contentMode === 'table'"
:data="recordList"
border
stripe
height="calc(100vh - 250px)"
>
<el-table-column prop="time" label="时间" min-width="140" />
<el-table-column prop="current" label="电流" min-width="100" />
<el-table-column prop="voltage" label="电压" min-width="100" />
<el-table-column prop="set_temperature" label="设定温度" min-width="100" />
<el-table-column prop="actual_temperature" label="实际温度" min-width="100" />
</el-table>
<DataCurve v-else :records="recordList" />
<el-auto-resizer v-else-if="contentMode === 'table'">
<template #default="{ height, width }">
<el-table-v2
:columns="recordColumns"
:data="recordList"
:width="width"
:height="height"
:row-height="36"
:header-height="40"
fixed
/>
</template>
</el-auto-resizer>
<DataCurve v-else :records="curveRecords" :total-count="recordList.length" />
</div>
</el-card>
</div>
......@@ -699,6 +777,39 @@ onBeforeUnmount(() => {
<el-button @click="qualityDialogVisible = false">关闭</el-button>
</template>
</el-dialog>
<!-- 编辑文件对话框 -->
<el-dialog v-model="fileEditDialogVisible" title="编辑文件信息" width="460px" destroy-on-close>
<el-form label-position="top">
<el-form-item label="文件名" required>
<el-input v-model="fileEditForm.filename" maxlength="255" show-word-limit clearable />
</el-form-item>
<el-form-item label="所属分类" required>
<el-select v-model="fileEditForm.categoryId" placeholder="请选择分类" style="width: 100%">
<el-option
v-for="item in flatCategoryOptions"
:key="item.value"
:label="item.label"
:value="item.value"
/>
</el-select>
</el-form-item>
<el-form-item label="备注">
<el-input
v-model="fileEditForm.remark"
type="textarea"
:rows="3"
maxlength="500"
show-word-limit
placeholder="可选备注"
/>
</el-form-item>
</el-form>
<template #footer>
<el-button @click="fileEditDialogVisible = false">取消</el-button>
<el-button type="primary" :loading="fileEditSubmitting" @click="submitFileEdit">保存</el-button>
</template>
</el-dialog>
</section>
</template>
......@@ -794,6 +905,13 @@ onBeforeUnmount(() => {
margin-left: 4px;
}
.file-remark {
font-size: 12px;
color: var(--text-tertiary);
font-weight: 400;
margin-left: 2px;
}
.tree-node {
width: 100%;
display: flex;
......@@ -948,4 +1066,8 @@ onBeforeUnmount(() => {
.quality-no-config {
color: #e6a23c;
}
:deep(.file-row) {
cursor: pointer;
}
</style>
This diff is collapsed.
<script setup>
import * as echarts from 'echarts'
import { Refresh } from '@element-plus/icons-vue'
import { ref, onMounted, onBeforeUnmount, nextTick } from 'vue'
import { useRouter } from 'vue-router'
import { getExperiments, getDataPoints } from '@/api/realtimeMonitor'
const router = useRouter()
const runningExps = ref([])
const loadingList = ref(false)
// echarts instances & data state, keyed by experiment id
const chartsMap = {} // id -> echarts instance
const dataMap = {} // id -> data points array
const fromStepMap = {} // id -> next from_step
const MAX_DISPLAY = 300
let listTimer = null
let dataTimer = null
// ── 数据加载 ──────────────────────────────────────────────────────────────────
const loadList = async () => {
loadingList.value = true
try {
const all = await getExperiments()
const running = all.filter((e) => e.status === 'running')
// 清理已停止试验的图表实例
const runningIds = new Set(running.map((e) => e.id))
for (const id of Object.keys(chartsMap)) {
if (!runningIds.has(Number(id))) {
chartsMap[id]?.dispose()
delete chartsMap[id]
delete dataMap[id]
delete fromStepMap[id]
}
}
runningExps.value = running
} finally {
loadingList.value = false
}
}
const fetchAllData = async () => {
for (const exp of runningExps.value) {
try {
const from = fromStepMap[exp.id] ?? 0
const newPts = await getDataPoints(exp.id, from)
if (newPts.length) {
dataMap[exp.id] = [...(dataMap[exp.id] ?? []), ...newPts]
fromStepMap[exp.id] = dataMap[exp.id][dataMap[exp.id].length - 1].step_idx + 1
await nextTick()
updateChart(exp.id, exp.target_temp)
}
} catch {
// ignore per-experiment errors
}
}
}
// ── 图表 ──────────────────────────────────────────────────────────────────────
const getOrInitChart = (expId) => {
if (!chartsMap[expId]) {
const el = document.getElementById(`live-chart-${expId}`)
if (el) chartsMap[expId] = echarts.init(el)
}
return chartsMap[expId]
}
const updateChart = (expId, targetTemp) => {
const c = getOrInitChart(expId)
if (!c) return
const pts = dataMap[expId] ?? []
if (!pts.length) return
const step = pts.length > MAX_DISPLAY ? Math.ceil(pts.length / MAX_DISPLAY) : 1
const display = pts.filter((_, i) => i % step === 0)
const xData = display.map((d) => `${d.step_idx}`)
const actuals = display.map((d) => d.actual_temp)
const currents = display.map((d) => d.current_output)
const targetLine = display.map(() => targetTemp ?? null)
c.setOption(
{
animation: false,
color: ['#409EFF', '#F56C6C', '#E6A23C'],
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)',
borderColor: '#e2e8f0',
borderWidth: 1,
formatter(params) {
if (!params?.length) return ''
const lines = [`<b>步 ${params[0].axisValue}</b>`]
params.forEach((p) => {
const unit = p.seriesName === '电流输出' ? ' A' : ' °C'
const v = p.data != null ? Number(p.data).toFixed(2) + unit : '--'
lines.push(`${p.marker}${p.seriesName}: <b>${v}</b>`)
})
return lines.join('<br/>')
},
},
legend: { bottom: 2, itemWidth: 14, itemHeight: 7, textStyle: { fontSize: 11, color: '#475569' } },
grid: { top: 12, left: 8, right: 50, bottom: 40, containLabel: true },
xAxis: {
type: 'category',
boundaryGap: false,
data: xData,
axisLabel: { color: '#94a3b8', fontSize: 10, interval: Math.max(0, Math.floor(xData.length / 6) - 1) },
axisLine: { lineStyle: { color: '#e2e8f0' } },
},
yAxis: [
{
type: 'value',
name: '°C',
nameTextStyle: { color: '#64748b', fontSize: 10 },
axisLabel: { color: '#64748b', fontSize: 10 },
splitLine: { lineStyle: { color: '#e2e8f0' } },
},
{
type: 'value',
name: 'A',
position: 'right',
nameTextStyle: { color: '#E6A23C', fontSize: 10 },
axisLabel: { color: '#E6A23C', fontSize: 10 },
splitLine: { show: false },
min: 0,
},
],
series: [
{
name: '实际温度',
type: 'line',
yAxisIndex: 0,
data: actuals,
smooth: true,
symbol: 'none',
lineStyle: { width: 1.5 },
},
{
name: '目标温度',
type: 'line',
yAxisIndex: 0,
data: targetLine,
symbol: 'none',
lineStyle: { width: 1, type: 'dotted', color: '#F56C6C' },
},
{
name: '电流输出',
type: 'line',
yAxisIndex: 1,
data: currents,
smooth: false,
symbol: 'none',
lineStyle: { width: 1.5, color: '#E6A23C' },
areaStyle: { color: '#E6A23C', opacity: 0.05 },
},
],
},
true,
)
}
const onResize = () => {
Object.values(chartsMap).forEach((c) => c?.resize())
}
// ── 生命周期 ──────────────────────────────────────────────────────────────────
onMounted(async () => {
await loadList()
await nextTick()
// 加载各试验初始数据
await fetchAllData()
listTimer = setInterval(loadList, 5000)
dataTimer = setInterval(fetchAllData, 2000)
window.addEventListener('resize', onResize)
})
onBeforeUnmount(() => {
clearInterval(listTimer)
clearInterval(dataTimer)
window.removeEventListener('resize', onResize)
Object.values(chartsMap).forEach((c) => c?.dispose())
})
</script>
<template>
<div class="live-page">
<div class="live-header">
<span class="live-title">实时监控</span>
<el-tag type="success" size="small" style="margin-left:10px">
{{ runningExps.length }} 个试验运行中
</el-tag>
<el-button
:icon="Refresh"
size="small"
plain
:loading="loadingList"
style="margin-left:auto"
@click="loadList"
>
刷新
</el-button>
<el-button size="small" plain @click="router.push('/realtime-monitor')">
前往试验管理
</el-button>
</div>
<div v-if="!runningExps.length && !loadingList" class="empty-state">
<el-empty description="当前没有正在运行的试验" :image-size="120">
<el-button type="primary" @click="router.push('/realtime-monitor')">
前往试验管理
</el-button>
</el-empty>
</div>
<div v-else class="exp-grid">
<el-card
v-for="exp in runningExps"
:key="exp.id"
shadow="hover"
class="exp-card"
>
<template #header>
<div class="exp-card-header">
<span class="exp-name">{{ exp.name }}</span>
<el-tag type="success" size="small" effect="plain">运行中</el-tag>
</div>
<div class="exp-meta">
<span>目标 {{ exp.target_temp?.toFixed(1) }} °C</span>
<span>步数 {{ exp.total_steps }}</span>
<span>{{ exp.model_name }}</span>
</div>
</template>
<div
:id="`live-chart-${exp.id}`"
class="mini-chart"
>
<div v-if="!(dataMap[exp.id]?.length)" class="chart-placeholder">
<el-icon class="is-loading" style="font-size:20px;color:#409EFF"><Refresh /></el-icon>
<span>等待数据…</span>
</div>
</div>
</el-card>
</div>
</div>
</template>
<style lang="scss" scoped>
.live-page {
padding: 20px;
display: flex;
flex-direction: column;
gap: 16px;
height: 100%;
}
.live-header {
display: flex;
align-items: center;
gap: 8px;
flex-shrink: 0;
}
.live-title {
font-size: 16px;
font-weight: 700;
color: var(--text-primary);
}
.empty-state {
flex: 1;
display: flex;
align-items: center;
justify-content: center;
}
.exp-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(480px, 1fr));
gap: 16px;
align-content: start;
}
.exp-card {
:deep(.el-card__header) {
padding: 10px 16px 6px;
}
:deep(.el-card__body) {
padding: 0;
}
}
.exp-card-header {
display: flex;
align-items: center;
gap: 8px;
}
.exp-name {
font-size: 14px;
font-weight: 600;
color: var(--text-primary);
flex: 1;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.exp-meta {
display: flex;
gap: 12px;
font-size: 12px;
color: var(--text-tertiary);
margin-top: 4px;
}
.mini-chart {
height: 240px;
width: 100%;
position: relative;
}
.chart-placeholder {
position: absolute;
inset: 0;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 8px;
color: #94a3b8;
font-size: 13px;
}
</style>
......@@ -23,6 +23,12 @@ const validData = computed(() =>
const xLabels = computed(() => validData.value.map((d) => d.time || String(d.index ?? '')))
const actualSeries = computed(() => validData.value.map((d) => d.actual ?? null))
const predictedSeries = computed(() => validData.value.map((d) => d.predicted ?? null))
const currentSeries = computed(() => validData.value.map((d) => d.current ?? null))
const voltageSeries = computed(() => validData.value.map((d) => d.voltage ?? null))
const hasCurrentVoltage = computed(() =>
validData.value.some((d) => d.current != null || d.voltage != null),
)
const renderChart = () => {
if (!chartRef.value) return
......@@ -30,10 +36,82 @@ const renderChart = () => {
const hasData = validData.value.length > 0
const yAxes = [
{
type: 'value',
name: '温度 (℃)',
position: 'left',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
axisLine: { show: true, lineStyle: { color: '#409EFF' } },
splitLine: { lineStyle: { type: 'dashed', color: 'rgba(148,163,184,0.4)' } },
},
]
if (hasCurrentVoltage.value) {
yAxes.push({
type: 'value',
name: '电流 (A) / 电压 (V)',
position: 'right',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
axisLine: { show: true, lineStyle: { color: '#E6A23C' } },
splitLine: { show: false },
})
}
const series = [
{
name: '实际温度 (℃)',
type: 'line',
yAxisIndex: 0,
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'solid' },
data: actualSeries.value,
},
{
name: '预测温度 (℃)',
type: 'line',
yAxisIndex: 0,
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'dashed' },
data: predictedSeries.value,
},
]
if (hasCurrentVoltage.value) {
series.push(
{
name: '电流 (A)',
type: 'line',
yAxisIndex: 1,
smooth: false,
symbol: 'none',
lineStyle: { width: 1.5, type: 'solid' },
data: currentSeries.value,
},
{
name: '电压 (V)',
type: 'line',
yAxisIndex: 1,
smooth: false,
symbol: 'none',
lineStyle: { width: 1.5, type: 'dotted' },
data: voltageSeries.value,
},
)
}
const legendData = hasCurrentVoltage.value
? ['实际温度 (℃)', '预测温度 (℃)', '电流 (A)', '电压 (V)']
: ['实际温度 (℃)', '预测温度 (℃)']
chartInstance.setOption(
{
animation: false,
color: ['#409EFF', '#F56C6C'],
color: ['#409EFF', '#F56C6C', '#E6A23C', '#67C23A'],
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(255,255,255,0.96)',
......@@ -48,10 +126,13 @@ const renderChart = () => {
]
params.forEach((item) => {
const val = item.data != null ? Number(item.data).toFixed(4) : '--'
const unit = item.seriesName.includes('℃') ? ' ℃' :
item.seriesName.includes('(A)') ? ' A' :
item.seriesName.includes('(V)') ? ' V' : ''
lines.push(
`<div style="display:flex;align-items:center;gap:8px;min-width:180px;justify-content:space-between;">
`<div style="display:flex;align-items:center;gap:8px;min-width:200px;justify-content:space-between;">
<span>${item.marker}${item.seriesName}</span>
<strong>${val}</strong>
<strong>${val}${unit}</strong>
</div>`,
)
})
......@@ -63,9 +144,9 @@ const renderChart = () => {
itemWidth: 20,
itemHeight: 10,
textStyle: { color: '#475569', fontSize: 12 },
data: ['实际温度 (℃)', '预测温度 (℃)'],
data: legendData,
},
grid: { top: 16, left: 16, right: 20, bottom: 52, containLabel: true },
grid: { top: 16, left: 16, right: hasCurrentVoltage.value ? 80 : 20, bottom: 52, containLabel: true },
xAxis: {
type: 'category',
boundaryGap: false,
......@@ -79,31 +160,8 @@ const renderChart = () => {
},
axisLine: { lineStyle: { color: '#cbd5e1' } },
},
yAxis: {
type: 'value',
name: '温度 (℃)',
nameTextStyle: { color: '#64748b', fontSize: 11 },
axisLabel: { color: '#64748b', fontSize: 11 },
splitLine: { lineStyle: { type: 'dashed', color: 'rgba(148,163,184,0.4)' } },
},
series: [
{
name: '实际温度 (℃)',
type: 'line',
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'solid' },
data: actualSeries.value,
},
{
name: '预测温度 (℃)',
type: 'line',
smooth: false,
symbol: 'none',
lineStyle: { width: 2, type: 'dashed' },
data: predictedSeries.value,
},
],
yAxis: yAxes,
series,
graphic: hasData
? []
: [
......@@ -145,3 +203,4 @@ onBeforeUnmount(() => {
<template>
<div ref="chartRef" :style="{ width: '100%', height: props.height }" />
</template>
<script setup>
import { Refresh } from '@element-plus/icons-vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { onMounted, ref } from 'vue'
import { deleteSavedModel, getSavedModels } from '@/api/trainManagement'
import { onMounted, reactive, ref } from 'vue'
import { useRouter } from 'vue-router'
import { deleteSavedModel, getSavedModels, updateSavedModel } from '@/api/trainManagement'
const router = useRouter()
const models = ref([])
const loading = ref(false)
......@@ -16,6 +19,43 @@ const loadModels = async () => {
}
}
// ── edit dialog ───────────────────────────────────────────────────────────────
const editVisible = ref(false)
const editRow = ref(null)
const editForm = reactive({ model_name: '', description: '' })
const editSaving = ref(false)
const handleEdit = (row) => {
editRow.value = row
editForm.model_name = row.model_name
editForm.description = row.description || ''
editVisible.value = true
}
const confirmEdit = async () => {
if (!editForm.model_name.trim()) return ElMessage.warning('模型名称不能为空')
editSaving.value = true
try {
await updateSavedModel(editRow.value.id, {
model_name: editForm.model_name.trim(),
description: editForm.description.trim(),
})
ElMessage.success('已更新')
editVisible.value = false
await loadModels()
} catch (e) {
ElMessage.error(e?.message || '更新失败')
} finally {
editSaving.value = false
}
}
// ── evaluate shortcut ─────────────────────────────────────────────────────────
const handleEvaluate = (row) => {
router.push({ name: 'model-evaluation', query: { model_id: row.id } })
}
// ── delete ────────────────────────────────────────────────────────────────────
const handleDelete = async (row) => {
try {
await ElMessageBox.confirm(`确定删除模型"${row.model_name}"吗?删除后无法恢复。`, '提示', {
......@@ -29,14 +69,16 @@ const handleDelete = async (row) => {
}
}
// ── display helpers ───────────────────────────────────────────────────────────
const formatParams = (params) => {
if (!params) return '-'
return [
`seq=${params.seq_len}`,
`hidden=${params.hidden_size}`,
`layers=${params.num_layers}`,
`epochs=${params.epochs}`,
`lr=${params.learning_rate}`,
`序列长度 ${params.seq_len}`,
`隐藏层 ${params.hidden_size}`,
`层数 ${params.num_layers}`,
`轮数 ${params.epochs}`,
`批次 ${params.batch_size}`,
`学习率 ${params.learning_rate}`,
].join(' / ')
}
......@@ -65,10 +107,15 @@ onMounted(loadModels)
<el-table :data="models" v-loading="loading" border stripe height="calc(100vh - 160px)">
<el-table-column type="index" width="55" label="#" align="center" />
<el-table-column prop="model_name" label="模型名称" min-width="150" show-overflow-tooltip />
<el-table-column prop="model_name" label="模型名称" min-width="140" show-overflow-tooltip />
<el-table-column prop="description" label="说明" min-width="160" show-overflow-tooltip>
<template #default="{ row }">
<span class="desc-text">{{ row.description || '—' }}</span>
</template>
</el-table-column>
<el-table-column prop="package_name" label="训练数据包" min-width="140" show-overflow-tooltip />
<el-table-column label="LSTM 参数" min-width="280" show-overflow-tooltip>
<el-table-column label="LSTM 参数" min-width="310" show-overflow-tooltip>
<template #default="{ row }">
<el-tooltip :content="formatParams(row.params)" placement="top">
<span class="params-text">{{ formatParams(row.params) }}</span>
......@@ -84,8 +131,10 @@ onMounted(loadModels)
<el-table-column prop="created_at" label="保存时间" width="170" />
<el-table-column label="操作" width="90" fixed="right" align="center">
<el-table-column label="操作" width="160" fixed="right" align="center">
<template #default="{ row }">
<el-button link type="primary" @click="handleEdit(row)">编辑</el-button>
<el-button link type="success" @click="handleEvaluate(row)">评估</el-button>
<el-button link type="danger" @click="handleDelete(row)">删除</el-button>
</template>
</el-table-column>
......@@ -98,6 +147,29 @@ onMounted(loadModels)
/>
</el-card>
</div>
<!-- ── edit dialog ──────────────────────────────────────────────── -->
<el-dialog v-model="editVisible" title="编辑模型信息" width="480px" :close-on-click-modal="false">
<el-form :model="editForm" label-position="top">
<el-form-item label="模型名称" required>
<el-input v-model="editForm.model_name" maxlength="100" show-word-limit />
</el-form-item>
<el-form-item label="说明">
<el-input
v-model="editForm.description"
type="textarea"
:rows="3"
placeholder="模型用途、训练说明等"
maxlength="500"
show-word-limit
/>
</el-form-item>
</el-form>
<template #footer>
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" :loading="editSaving" @click="confirmEdit">保存</el-button>
</template>
</el-dialog>
</template>
<style lang="scss" scoped>
......@@ -144,7 +216,8 @@ onMounted(loadModels)
}
.params-text,
.loss-text {
.loss-text,
.desc-text {
font-size: 12px;
color: var(--text-secondary);
white-space: nowrap;
......@@ -158,3 +231,5 @@ onMounted(loadModels)
font-weight: 500;
}
</style>
This diff is collapsed.
<script setup>
import { ref, watch } from 'vue'
import { computed, ref, watch } from 'vue'
import { ElAutoResizer, ElTableV2 } from 'element-plus'
import { getPackageRecords } from '@/api/packageManagement'
import DataCurve from '@/views/DataManagement/components/DataCurve.vue'
......@@ -18,6 +19,21 @@ const loading = ref(false)
const records = ref([])
const contentMode = ref('table')
const detailColumns = [
{ key: 'time', dataKey: 'time', title: '时间', width: 170 },
{ key: 'current', dataKey: 'current', title: '电流', width: 110 },
{ key: 'voltage', dataKey: 'voltage', title: '电压', width: 110 },
{ key: 'set_temperature', dataKey: 'set_temperature', title: '设定温度', width: 120 },
{ key: 'actual_temperature', dataKey: 'actual_temperature', title: '实际温度', width: 120 },
]
const curveRecords = computed(() => {
const r = records.value
if (r.length <= 2000) return r
const step = Math.ceil(r.length / 2000)
return r.filter((_, i) => i % step === 0)
})
const loadRecords = async (pkg) => {
if (!pkg) {
records.value = []
......@@ -25,7 +41,7 @@ const loadRecords = async (pkg) => {
}
loading.value = true
try {
const result = await getPackageRecords(pkg.id, { limit: 500 })
const result = await getPackageRecords(pkg.id)
records.value = result.records
} finally {
loading.value = false
......@@ -56,21 +72,21 @@ watch(() => props.package, (pkg) => {
<div class="content-wrap" v-loading="loading">
<el-empty v-if="!props.package" description="请选择一个数据包查看" />
<el-table
v-else-if="contentMode === 'table'"
:data="records"
border
stripe
height="calc(100vh - 250px)"
>
<el-table-column prop="time" label="时间" min-width="140" />
<el-table-column prop="current" label="电流" min-width="90" />
<el-table-column prop="voltage" label="电压" min-width="90" />
<el-table-column prop="set_temperature" label="设定温度" min-width="100" />
<el-table-column prop="actual_temperature" label="实际温度" min-width="100" />
</el-table>
<DataCurve v-else :records="records" />
<el-auto-resizer v-else-if="contentMode === 'table'">
<template #default="{ height, width }">
<el-table-v2
:columns="detailColumns"
:data="records"
:width="width"
:height="height"
:row-height="36"
:header-height="40"
fixed
/>
</template>
</el-auto-resizer>
<DataCurve v-else :records="curveRecords" :total-count="records.length" />
</div>
</el-card>
</template>
......
<script setup>
import { Plus, Search } from '@element-plus/icons-vue'
import { Delete, Edit, Plus, Search } from '@element-plus/icons-vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { onMounted, reactive, ref, watch } from 'vue'
import { deletePackage, getPackages } from '@/api/packageManagement'
import { deletePackage, getPackages, updatePackage } from '@/api/packageManagement'
import { getPkgCategoryTree } from '@/api/packageManagement'
const props = defineProps({
categoryId: {
......@@ -52,6 +53,61 @@ const handleView = (row) => {
emit('view', row)
}
// ── edit ──────────────────────────────────────────────────────────────────
const editVisible = ref(false)
const editLoading = ref(false)
const editForm = reactive({ id: null, name: '', category_id: null, remark: '' })
const categoryOptions = ref([])
const loadCategoryOptions = async () => {
if (categoryOptions.value.length) return
try {
const tree = await getPkgCategoryTree()
const flatten = (nodes, result = []) => {
for (const n of nodes) {
result.push({ value: n.id, label: n.name })
if (n.children?.length) flatten(n.children, result)
}
return result
}
categoryOptions.value = flatten(tree)
} catch {}
}
const handleEdit = async (row) => {
await loadCategoryOptions()
editForm.id = row.id
editForm.name = row.name
editForm.category_id = row.category_id ?? null
editForm.remark = row.remark ?? ''
editVisible.value = true
}
const submitEdit = async () => {
if (!editForm.name.trim()) {
ElMessage.warning('请输入数据包名称')
return
}
editLoading.value = true
try {
const updated = await updatePackage(editForm.id, {
name: editForm.name.trim(),
category_id: editForm.category_id || null,
remark: editForm.remark.trim() || null,
})
ElMessage.success('修改成功')
editVisible.value = false
// 更新列表中对应行
const idx = packageList.value.findIndex(p => p.id === editForm.id)
if (idx !== -1) Object.assign(packageList.value[idx], updated)
if (currentPackage.value?.id === editForm.id) Object.assign(currentPackage.value, updated)
} catch (e) {
ElMessage.error(e?.response?.data?.detail || '修改失败')
} finally {
editLoading.value = false
}
}
const handleDelete = async (row) => {
try {
await ElMessageBox.confirm(`确定删除数据包"${row.name}"吗?`, '提示', { type: 'warning' })
......@@ -105,18 +161,48 @@ onMounted(loadPackages)
highlight-current-row
height="calc(100vh - 265px)"
:row-class-name="({ row }) => (currentPackage?.id === row.id ? 'current-row' : '')"
style="cursor: pointer"
@row-click="handleView"
>
<el-table-column prop="name" label="数据包名称" min-width="150" show-overflow-tooltip />
<el-table-column prop="created_at" label="创建时间" min-width="160" />
<el-table-column prop="name" label="数据包名称" min-width="100" show-overflow-tooltip />
<el-table-column prop="created_at" label="创建时间" min-width="100" />
<el-table-column prop="data_count" label="数据量" width="80" align="center" />
<el-table-column label="操作" width="130" fixed="right">
<el-table-column label="操作" width="150" fixed="right" align="center">
<template #default="{ row }">
<el-button link type="primary" @click="handleView(row)">查看</el-button>
<el-button link type="danger" @click="handleDelete(row)">删除</el-button>
<div>
<el-button link type="primary" :icon="Edit" @click.stop="handleEdit(row)" />
<el-button link type="danger" :icon="Delete" @click.stop="handleDelete(row)" />
</div>
</template>
</el-table-column>
</el-table>
</el-card>
<!-- 编辑数据包 -->
<el-dialog v-model="editVisible" title="编辑数据包" width="480px" :close-on-click-modal="false">
<el-form label-width="90px">
<el-form-item label="名称" required>
<el-input v-model="editForm.name" placeholder="数据包名称" clearable />
</el-form-item>
<el-form-item label="分类">
<el-select v-model="editForm.category_id" placeholder="选择分类(可选)" clearable style="width:100%">
<el-option
v-for="opt in categoryOptions"
:key="opt.value"
:label="opt.label"
:value="opt.value"
/>
</el-select>
</el-form-item>
<el-form-item label="备注">
<el-input v-model="editForm.remark" type="textarea" :rows="3" placeholder="备注(可选)" />
</el-form-item>
</el-form>
<template #footer>
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" :loading="editLoading" @click="submitEdit">保存</el-button>
</template>
</el-dialog>
</template>
<style lang="scss" scoped>
......
......@@ -8,7 +8,6 @@ import {
createExperiment,
deleteExperiment,
getMonitorModels,
getMonitorPackages,
} from '@/api/realtimeMonitor'
const router = useRouter()
......@@ -30,7 +29,6 @@ const loadExperiments = async () => {
const dialogVisible = ref(false)
const submitting = ref(false)
const models = ref([])
const packages = ref([])
const defaultMpcParams = () => ({
P: 20,
......@@ -50,7 +48,8 @@ const defaultMpcParams = () => ({
const form = reactive({
name: '',
model_id: '',
package_id: '',
input_csv_path: '',
output_csv_path: '',
target_temp: 35.0,
sampling_interval: 1.0,
mpc_params: defaultMpcParams(),
......@@ -62,7 +61,8 @@ const openDialog = async () => {
Object.assign(form, {
name: '',
model_id: '',
package_id: '',
input_csv_path: '',
output_csv_path: '',
target_temp: 35.0,
sampling_interval: 1.0,
mpc_params: defaultMpcParams(),
......@@ -70,23 +70,23 @@ const openDialog = async () => {
showAdvanced.value = false
dialogVisible.value = true
if (!models.value.length) {
const [m, p] = await Promise.all([getMonitorModels(), getMonitorPackages()])
models.value = m
packages.value = p
models.value = await getMonitorModels()
}
}
const handleCreate = async () => {
if (!form.name.trim()) { ElMessage.warning('请输入试验名称'); return }
if (!form.model_id) { ElMessage.warning('请选择预测模型'); return }
if (!form.package_id) { ElMessage.warning('请选择初始数据包'); return }
if (!form.input_csv_path.trim()) { ElMessage.warning('请输入输入CSV文件路径'); return }
if (!form.output_csv_path.trim()) { ElMessage.warning('请输入输出CSV文件路径'); return }
submitting.value = true
try {
await createExperiment({
name: form.name.trim(),
model_id: form.model_id,
package_id: form.package_id,
input_csv_path: form.input_csv_path.trim(),
output_csv_path: form.output_csv_path.trim(),
target_temp: form.target_temp,
sampling_interval: form.sampling_interval,
mpc_params: { ...form.mpc_params },
......@@ -133,7 +133,7 @@ onMounted(loadExperiments)
<el-card shadow="hover" class="page-card">
<template #header>
<div class="card-header-row">
<span class="card-title">实时监控试验</span>
<span class="card-title">试验管理</span>
<div class="header-actions">
<el-button :icon="Refresh" size="small" plain :loading="loading" @click="loadExperiments">
刷新
......@@ -148,7 +148,7 @@ onMounted(loadExperiments)
<el-table :data="experiments" v-loading="loading" stripe>
<el-table-column prop="name" label="试验名称" min-width="160" />
<el-table-column prop="model_name" label="预测模型" min-width="140" />
<el-table-column prop="package_name" label="数据包" min-width="140" />
<el-table-column prop="input_csv_path" label="输入CSV" min-width="180" show-overflow-tooltip />
<el-table-column prop="target_temp" label="目标温度(°C)" width="120" align="right">
<template #default="{ row }">{{ row.target_temp?.toFixed(1) }}</template>
</el-table-column>
......@@ -180,8 +180,8 @@ onMounted(loadExperiments)
</el-card>
<!-- 创建试验对话框 -->
<el-dialog v-model="dialogVisible" title="创建试验" width="600px" :close-on-click-modal="false">
<el-form label-width="130px" @submit.prevent>
<el-dialog v-model="dialogVisible" title="创建试验" width="640px" :close-on-click-modal="false">
<el-form label-width="140px" @submit.prevent>
<el-form-item label="试验名称" required>
<el-input v-model="form.name" placeholder="请输入试验名称" />
</el-form-item>
......@@ -193,10 +193,19 @@ onMounted(loadExperiments)
</el-option>
</el-select>
</el-form-item>
<el-form-item label="初始数据包" required>
<el-select v-model="form.package_id" placeholder="选择数据包" style="width:100%">
<el-option v-for="p in packages" :key="p.id" :label="p.name" :value="p.id" />
</el-select>
<el-form-item label="输入CSV路径" required>
<el-input
v-model="form.input_csv_path"
placeholder="传感器数据源 CSV 文件的绝对路径"
/>
<div class="field-hint">每步控制将从该文件末尾读取最新传感器数据</div>
</el-form-item>
<el-form-item label="输出CSV路径" required>
<el-input
v-model="form.output_csv_path"
placeholder="生成曲线写入的 CSV 文件绝对路径"
/>
<div class="field-hint">MPC 每步控制结果(温度、电流)将追加写入该文件</div>
</el-form-item>
<el-form-item label="目标温度(°C)" required>
<el-input-number v-model="form.target_temp" :step="0.5" :precision="1" style="width:160px" />
......@@ -210,7 +219,7 @@ onMounted(loadExperiments)
:precision="1"
style="width:160px"
/>
<span style="margin-left:8px;color:#94a3b8;font-size:12px">每步等待时长,需与传感器采集频率一致</span>
<span style="margin-left:8px;color:#94a3b8;font-size:12px">需与传感器采集频率一致</span>
</el-form-item>
<!-- MPC 参数(高级) -->
......@@ -290,9 +299,17 @@ onMounted(loadExperiments)
gap: 8px;
}
.field-hint {
font-size: 12px;
color: #94a3b8;
margin-top: 4px;
line-height: 1.4;
}
.params-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 0 16px;
}
</style>
......@@ -12,7 +12,7 @@ export default defineConfig({
server: {
proxy: {
'/api': {
target: 'http://127.0.0.1:8000',
target: 'http://127.0.0.1:8002',
changeOrigin: true,
},
},
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment