feature(train): 新增预览

TODO: 更改曲线值输出
This commit is contained in:
ClovertaTheTrilobita 2025-09-24 19:03:56 +00:00
parent f8edafb9c0
commit 147c9c0bc4
4 changed files with 36 additions and 11 deletions

View file

@ -16,8 +16,12 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
def set_status(task_id, status, **extra):
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
def set_status(task_id, status, train_losses, test_losses, **extra):
payload = {"status": status,
"updated_at": datetime.datetime.utcnow().isoformat(),
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
"test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
**extra}
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
def get_status(task_id):
@ -47,13 +51,15 @@ class Cptrain:
model = models.CellposeModel(gpu=True, pretrained_model=base_model)
set_status(time, "running")
set_status(time, "running", None, None)
model_path, train_losses, test_losses = train.train_seg(model.net,
train_data=images, train_labels=labels,
test_data=test_images, test_labels=test_labels,
weight_decay=0.1, learning_rate=1e-5,
n_epochs=100, model_name=model_name,
save_path=MODELS_DIR)
save_path=BASE_DIR)
set_status(time, "done", train_losses, test_losses)
print("模型已保存到:", model_path)
return train_losses, test_losses

View file

@ -43,6 +43,14 @@ def set_status(task_id, status, **extra):
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
def set_train_status(task_id, status, train_losses, test_losses, **extra):
payload = {"status": status,
"updated_at": datetime.datetime.utcnow().isoformat(),
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
"test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
**extra}
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
def get_status(task_id):
raw = r.get(f"task:{task_id}")
return json.loads(raw) if raw else None
@ -138,6 +146,8 @@ def train_upload():
train_files = request.files.getlist("train_files")
test_files = request.files.getlist("test_files")
os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
os.makedirs(Path(TEST_DIR) / ts, exist_ok=True)
set_status(ts, "pending")
saved = []
for f in train_files:
@ -167,8 +177,8 @@ def train_upload():
def done_cb(f):
try:
f.result()
set_status(ts, "success")
train_losses, test_losses = f.result()
set_train_status(ts, "success", train_losses, test_losses)
except Exception as e:
set_status(ts, "failed", error=str(e))
@ -217,3 +227,11 @@ def preview():
def list_models():
models_list = os.listdir(MODELS_DIR)
return jsonify({"ok": True, "models": models_list})
@app.get("/result")
def list_results():
task_id = request.args.get('id')
st = get_status(task_id)
if not st:
return jsonify({"ok": True, "exists": False, "status": "not_found"}), 200
return jsonify({"ok": True, "exists": True, **st}), 200

View file

@ -59,7 +59,7 @@
crossorigin="anonymous"></script>
<script>
const API = "http://10.147.18.141:5000/";
const APT_UPLOAD = API + "run_upload";
const API_UPLOAD = API + "run_upload";
const API_MODEL = API + "models";
async function loadModels() {

View file

@ -64,7 +64,7 @@
crossorigin="anonymous"></script>
<script>
const API = "http://10.147.18.141:5000/";
const APT_UPLOAD = API + "train_upload";
const API_UPLOAD = API + "train_upload";
const API_MODEL = API + "models";
async function loadModels() {
@ -140,7 +140,8 @@
document.getElementById("uploadBtn").addEventListener("click", async () => {
const input_train = document.getElementById("trainFileInput");
const input_test = document.getElementById("testFileInput");
if (!input.files.length) return alert("请选择文件");
if (!input_train.files.length) return alert("请选择训练文件");
if (!input_test.files.length) return alert("请选择训练文件");
const fd = new FormData();
for (const f of input_train.files) fd.append("train_files", f);
@ -180,7 +181,7 @@
} else {
clearInterval(timer);
document.body.removeChild(notice);
window.location.href = `preview.html?id=${encodeURIComponent(res.data['id'])}`;
window.location.href = `train_result.html?id=${encodeURIComponent(res.data['id'])}`;
}
}, 1000);
} catch (e) {