diff --git a/backend/cp_train.py b/backend/cp_train.py
index e01521b..d6d9aa4 100644
--- a/backend/cp_train.py
+++ b/backend/cp_train.py
@@ -16,8 +16,12 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
-def set_status(task_id, status, **extra):
- payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
+def set_status(task_id, status, train_losses, test_losses, **extra):
+ payload = {"status": status,
+ "updated_at": datetime.datetime.utcnow().isoformat(),
+ "train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
+ "test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
+ **extra}
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
def get_status(task_id):
@@ -47,13 +51,15 @@ class Cptrain:
model = models.CellposeModel(gpu=True, pretrained_model=base_model)
- set_status(time, "running")
+ set_status(time, "running", None, None)
model_path, train_losses, test_losses = train.train_seg(model.net,
train_data=images, train_labels=labels,
test_data=test_images, test_labels=test_labels,
weight_decay=0.1, learning_rate=1e-5,
n_epochs=100, model_name=model_name,
- save_path=MODELS_DIR)
+ save_path=BASE_DIR)
+ set_status(time, "done", train_losses, test_losses)
print("模型已保存到:", model_path)
+ return train_losses, test_losses
diff --git a/backend/flaskApp.py b/backend/flaskApp.py
index d86e2ce..2037bae 100644
--- a/backend/flaskApp.py
+++ b/backend/flaskApp.py
@@ -43,6 +43,14 @@ def set_status(task_id, status, **extra):
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
+def set_train_status(task_id, status, train_losses, test_losses, **extra):
+ payload = {"status": status,
+ "updated_at": datetime.datetime.utcnow().isoformat(),
+ "train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
+ "test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
+ **extra}
+ r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
+
def get_status(task_id):
raw = r.get(f"task:{task_id}")
return json.loads(raw) if raw else None
@@ -138,6 +146,8 @@ def train_upload():
train_files = request.files.getlist("train_files")
test_files = request.files.getlist("test_files")
+ os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
+ os.makedirs(Path(TEST_DIR) / ts, exist_ok=True)
set_status(ts, "pending")
saved = []
for f in train_files:
@@ -167,8 +177,8 @@ def train_upload():
def done_cb(f):
try:
- f.result()
- set_status(ts, "success")
+ train_losses, test_losses = f.result()
+ set_train_status(ts, "success", train_losses, test_losses)
except Exception as e:
set_status(ts, "failed", error=str(e))
@@ -216,4 +226,12 @@ def preview():
@app.get("/models")
def list_models():
models_list = os.listdir(MODELS_DIR)
- return jsonify({"ok": True, "models": models_list})
\ No newline at end of file
+ return jsonify({"ok": True, "models": models_list})
+
+@app.get("/result")
+def list_results():
+ task_id = request.args.get('id')
+ st = get_status(task_id)
+ if not st:
+ return jsonify({"ok": True, "exists": False, "status": "not_found"}), 200
+ return jsonify({"ok": True, "exists": True, **st}), 200
\ No newline at end of file
diff --git a/frontend/run.html b/frontend/run.html
index 2ae0aeb..5a4aff3 100644
--- a/frontend/run.html
+++ b/frontend/run.html
@@ -59,7 +59,7 @@
crossorigin="anonymous">