diff --git a/backend/cp_train.py b/backend/cp_train.py index e01521b..d6d9aa4 100644 --- a/backend/cp_train.py +++ b/backend/cp_train.py @@ -16,8 +16,12 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR r = redis.Redis(host="127.0.0.1", port=6379, db=0) -def set_status(task_id, status, **extra): - payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra} +def set_status(task_id, status, train_losses, test_losses, **extra): + payload = {"status": status, + "updated_at": datetime.datetime.utcnow().isoformat(), + "train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses, + "test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses, + **extra} r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期 def get_status(task_id): @@ -47,13 +51,15 @@ class Cptrain: model = models.CellposeModel(gpu=True, pretrained_model=base_model) - set_status(time, "running") + set_status(time, "running", None, None) model_path, train_losses, test_losses = train.train_seg(model.net, train_data=images, train_labels=labels, test_data=test_images, test_labels=test_labels, weight_decay=0.1, learning_rate=1e-5, n_epochs=100, model_name=model_name, - save_path=MODELS_DIR) + save_path=BASE_DIR) + set_status(time, "done", train_losses, test_losses) print("模型已保存到:", model_path) + return train_losses, test_losses diff --git a/backend/flaskApp.py b/backend/flaskApp.py index d86e2ce..2037bae 100644 --- a/backend/flaskApp.py +++ b/backend/flaskApp.py @@ -43,6 +43,14 @@ def set_status(task_id, status, **extra): payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra} r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期 +def set_train_status(task_id, status, train_losses, test_losses, **extra): + payload = {"status": status, + "updated_at": datetime.datetime.utcnow().isoformat(), + "train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses, + "test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses, + **extra} + r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期 + def get_status(task_id): raw = r.get(f"task:{task_id}") return json.loads(raw) if raw else None @@ -138,6 +146,8 @@ def train_upload(): train_files = request.files.getlist("train_files") test_files = request.files.getlist("test_files") + os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True) + os.makedirs(Path(TEST_DIR) / ts, exist_ok=True) set_status(ts, "pending") saved = [] for f in train_files: @@ -167,8 +177,8 @@ def train_upload(): def done_cb(f): try: - f.result() - set_status(ts, "success") + train_losses, test_losses = f.result() + set_train_status(ts, "success", train_losses, test_losses) except Exception as e: set_status(ts, "failed", error=str(e)) @@ -216,4 +226,12 @@ def preview(): @app.get("/models") def list_models(): models_list = os.listdir(MODELS_DIR) - return jsonify({"ok": True, "models": models_list}) \ No newline at end of file + return jsonify({"ok": True, "models": models_list}) + +@app.get("/result") +def list_results(): + task_id = request.args.get('id') + st = get_status(task_id) + if not st: + return jsonify({"ok": True, "exists": False, "status": "not_found"}), 200 + return jsonify({"ok": True, "exists": True, **st}), 200 \ No newline at end of file diff --git a/frontend/run.html b/frontend/run.html index 2ae0aeb..5a4aff3 100644 --- a/frontend/run.html +++ b/frontend/run.html @@ -59,7 +59,7 @@ crossorigin="anonymous">