mirror of
https://github.com/ClovertaTheTrilobita/cellpose-web.git
synced 2026-04-01 23:14:50 +00:00
feature(train): 新增预览
TODO: 更改曲线值输出
This commit is contained in:
parent
f8edafb9c0
commit
147c9c0bc4
4 changed files with 36 additions and 11 deletions
|
|
@ -16,8 +16,12 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
|
|||
|
||||
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
||||
|
||||
def set_status(task_id, status, **extra):
|
||||
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
|
||||
def set_status(task_id, status, train_losses, test_losses, **extra):
|
||||
payload = {"status": status,
|
||||
"updated_at": datetime.datetime.utcnow().isoformat(),
|
||||
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
|
||||
"test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
|
||||
**extra}
|
||||
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
||||
|
||||
def get_status(task_id):
|
||||
|
|
@ -47,13 +51,15 @@ class Cptrain:
|
|||
|
||||
model = models.CellposeModel(gpu=True, pretrained_model=base_model)
|
||||
|
||||
set_status(time, "running")
|
||||
set_status(time, "running", None, None)
|
||||
|
||||
model_path, train_losses, test_losses = train.train_seg(model.net,
|
||||
train_data=images, train_labels=labels,
|
||||
test_data=test_images, test_labels=test_labels,
|
||||
weight_decay=0.1, learning_rate=1e-5,
|
||||
n_epochs=100, model_name=model_name,
|
||||
save_path=MODELS_DIR)
|
||||
save_path=BASE_DIR)
|
||||
|
||||
set_status(time, "done", train_losses, test_losses)
|
||||
print("模型已保存到:", model_path)
|
||||
return train_losses, test_losses
|
||||
|
|
|
|||
|
|
@ -43,6 +43,14 @@ def set_status(task_id, status, **extra):
|
|||
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
|
||||
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
||||
|
||||
def set_train_status(task_id, status, train_losses, test_losses, **extra):
|
||||
payload = {"status": status,
|
||||
"updated_at": datetime.datetime.utcnow().isoformat(),
|
||||
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
|
||||
"test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
|
||||
**extra}
|
||||
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
||||
|
||||
def get_status(task_id):
|
||||
raw = r.get(f"task:{task_id}")
|
||||
return json.loads(raw) if raw else None
|
||||
|
|
@ -138,6 +146,8 @@ def train_upload():
|
|||
|
||||
train_files = request.files.getlist("train_files")
|
||||
test_files = request.files.getlist("test_files")
|
||||
os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
|
||||
os.makedirs(Path(TEST_DIR) / ts, exist_ok=True)
|
||||
set_status(ts, "pending")
|
||||
saved = []
|
||||
for f in train_files:
|
||||
|
|
@ -167,8 +177,8 @@ def train_upload():
|
|||
|
||||
def done_cb(f):
|
||||
try:
|
||||
f.result()
|
||||
set_status(ts, "success")
|
||||
train_losses, test_losses = f.result()
|
||||
set_train_status(ts, "success", train_losses, test_losses)
|
||||
except Exception as e:
|
||||
set_status(ts, "failed", error=str(e))
|
||||
|
||||
|
|
@ -216,4 +226,12 @@ def preview():
|
|||
@app.get("/models")
|
||||
def list_models():
|
||||
models_list = os.listdir(MODELS_DIR)
|
||||
return jsonify({"ok": True, "models": models_list})
|
||||
return jsonify({"ok": True, "models": models_list})
|
||||
|
||||
@app.get("/result")
|
||||
def list_results():
|
||||
task_id = request.args.get('id')
|
||||
st = get_status(task_id)
|
||||
if not st:
|
||||
return jsonify({"ok": True, "exists": False, "status": "not_found"}), 200
|
||||
return jsonify({"ok": True, "exists": True, **st}), 200
|
||||
|
|
@ -59,7 +59,7 @@
|
|||
crossorigin="anonymous"></script>
|
||||
<script>
|
||||
const API = "http://10.147.18.141:5000/";
|
||||
const APT_UPLOAD = API + "run_upload";
|
||||
const API_UPLOAD = API + "run_upload";
|
||||
const API_MODEL = API + "models";
|
||||
|
||||
async function loadModels() {
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@
|
|||
crossorigin="anonymous"></script>
|
||||
<script>
|
||||
const API = "http://10.147.18.141:5000/";
|
||||
const APT_UPLOAD = API + "train_upload";
|
||||
const API_UPLOAD = API + "train_upload";
|
||||
const API_MODEL = API + "models";
|
||||
|
||||
async function loadModels() {
|
||||
|
|
@ -140,7 +140,8 @@
|
|||
document.getElementById("uploadBtn").addEventListener("click", async () => {
|
||||
const input_train = document.getElementById("trainFileInput");
|
||||
const input_test = document.getElementById("testFileInput");
|
||||
if (!input.files.length) return alert("请选择文件");
|
||||
if (!input_train.files.length) return alert("请选择训练文件");
|
||||
if (!input_test.files.length) return alert("请选择训练文件");
|
||||
|
||||
const fd = new FormData();
|
||||
for (const f of input_train.files) fd.append("train_files", f);
|
||||
|
|
@ -180,7 +181,7 @@
|
|||
} else {
|
||||
clearInterval(timer);
|
||||
document.body.removeChild(notice);
|
||||
window.location.href = `preview.html?id=${encodeURIComponent(res.data['id'])}`;
|
||||
window.location.href = `train_result.html?id=${encodeURIComponent(res.data['id'])}`;
|
||||
}
|
||||
}, 1000);
|
||||
} catch (e) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue