mirror of
https://github.com/ClovertaTheTrilobita/cellpose-web.git
synced 2026-04-01 23:14:50 +00:00
feature(train): 新增预览
TODO: 更改曲线值输出
This commit is contained in:
parent
f8edafb9c0
commit
147c9c0bc4
4 changed files with 36 additions and 11 deletions
|
|
@ -16,8 +16,12 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
|
||||||
|
|
||||||
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
||||||
|
|
||||||
def set_status(task_id, status, **extra):
|
def set_status(task_id, status, train_losses, test_losses, **extra):
|
||||||
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
|
payload = {"status": status,
|
||||||
|
"updated_at": datetime.datetime.utcnow().isoformat(),
|
||||||
|
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
|
||||||
|
"test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
|
||||||
|
**extra}
|
||||||
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
||||||
|
|
||||||
def get_status(task_id):
|
def get_status(task_id):
|
||||||
|
|
@ -47,13 +51,15 @@ class Cptrain:
|
||||||
|
|
||||||
model = models.CellposeModel(gpu=True, pretrained_model=base_model)
|
model = models.CellposeModel(gpu=True, pretrained_model=base_model)
|
||||||
|
|
||||||
set_status(time, "running")
|
set_status(time, "running", None, None)
|
||||||
|
|
||||||
model_path, train_losses, test_losses = train.train_seg(model.net,
|
model_path, train_losses, test_losses = train.train_seg(model.net,
|
||||||
train_data=images, train_labels=labels,
|
train_data=images, train_labels=labels,
|
||||||
test_data=test_images, test_labels=test_labels,
|
test_data=test_images, test_labels=test_labels,
|
||||||
weight_decay=0.1, learning_rate=1e-5,
|
weight_decay=0.1, learning_rate=1e-5,
|
||||||
n_epochs=100, model_name=model_name,
|
n_epochs=100, model_name=model_name,
|
||||||
save_path=MODELS_DIR)
|
save_path=BASE_DIR)
|
||||||
|
|
||||||
|
set_status(time, "done", train_losses, test_losses)
|
||||||
print("模型已保存到:", model_path)
|
print("模型已保存到:", model_path)
|
||||||
|
return train_losses, test_losses
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,14 @@ def set_status(task_id, status, **extra):
|
||||||
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
|
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
|
||||||
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
||||||
|
|
||||||
|
def set_train_status(task_id, status, train_losses, test_losses, **extra):
|
||||||
|
payload = {"status": status,
|
||||||
|
"updated_at": datetime.datetime.utcnow().isoformat(),
|
||||||
|
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
|
||||||
|
"test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
|
||||||
|
**extra}
|
||||||
|
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
||||||
|
|
||||||
def get_status(task_id):
|
def get_status(task_id):
|
||||||
raw = r.get(f"task:{task_id}")
|
raw = r.get(f"task:{task_id}")
|
||||||
return json.loads(raw) if raw else None
|
return json.loads(raw) if raw else None
|
||||||
|
|
@ -138,6 +146,8 @@ def train_upload():
|
||||||
|
|
||||||
train_files = request.files.getlist("train_files")
|
train_files = request.files.getlist("train_files")
|
||||||
test_files = request.files.getlist("test_files")
|
test_files = request.files.getlist("test_files")
|
||||||
|
os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
|
||||||
|
os.makedirs(Path(TEST_DIR) / ts, exist_ok=True)
|
||||||
set_status(ts, "pending")
|
set_status(ts, "pending")
|
||||||
saved = []
|
saved = []
|
||||||
for f in train_files:
|
for f in train_files:
|
||||||
|
|
@ -167,8 +177,8 @@ def train_upload():
|
||||||
|
|
||||||
def done_cb(f):
|
def done_cb(f):
|
||||||
try:
|
try:
|
||||||
f.result()
|
train_losses, test_losses = f.result()
|
||||||
set_status(ts, "success")
|
set_train_status(ts, "success", train_losses, test_losses)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
set_status(ts, "failed", error=str(e))
|
set_status(ts, "failed", error=str(e))
|
||||||
|
|
||||||
|
|
@ -216,4 +226,12 @@ def preview():
|
||||||
@app.get("/models")
|
@app.get("/models")
|
||||||
def list_models():
|
def list_models():
|
||||||
models_list = os.listdir(MODELS_DIR)
|
models_list = os.listdir(MODELS_DIR)
|
||||||
return jsonify({"ok": True, "models": models_list})
|
return jsonify({"ok": True, "models": models_list})
|
||||||
|
|
||||||
|
@app.get("/result")
|
||||||
|
def list_results():
|
||||||
|
task_id = request.args.get('id')
|
||||||
|
st = get_status(task_id)
|
||||||
|
if not st:
|
||||||
|
return jsonify({"ok": True, "exists": False, "status": "not_found"}), 200
|
||||||
|
return jsonify({"ok": True, "exists": True, **st}), 200
|
||||||
|
|
@ -59,7 +59,7 @@
|
||||||
crossorigin="anonymous"></script>
|
crossorigin="anonymous"></script>
|
||||||
<script>
|
<script>
|
||||||
const API = "http://10.147.18.141:5000/";
|
const API = "http://10.147.18.141:5000/";
|
||||||
const APT_UPLOAD = API + "run_upload";
|
const API_UPLOAD = API + "run_upload";
|
||||||
const API_MODEL = API + "models";
|
const API_MODEL = API + "models";
|
||||||
|
|
||||||
async function loadModels() {
|
async function loadModels() {
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@
|
||||||
crossorigin="anonymous"></script>
|
crossorigin="anonymous"></script>
|
||||||
<script>
|
<script>
|
||||||
const API = "http://10.147.18.141:5000/";
|
const API = "http://10.147.18.141:5000/";
|
||||||
const APT_UPLOAD = API + "train_upload";
|
const API_UPLOAD = API + "train_upload";
|
||||||
const API_MODEL = API + "models";
|
const API_MODEL = API + "models";
|
||||||
|
|
||||||
async function loadModels() {
|
async function loadModels() {
|
||||||
|
|
@ -140,7 +140,8 @@
|
||||||
document.getElementById("uploadBtn").addEventListener("click", async () => {
|
document.getElementById("uploadBtn").addEventListener("click", async () => {
|
||||||
const input_train = document.getElementById("trainFileInput");
|
const input_train = document.getElementById("trainFileInput");
|
||||||
const input_test = document.getElementById("testFileInput");
|
const input_test = document.getElementById("testFileInput");
|
||||||
if (!input.files.length) return alert("请选择文件");
|
if (!input_train.files.length) return alert("请选择训练文件");
|
||||||
|
if (!input_test.files.length) return alert("请选择训练文件");
|
||||||
|
|
||||||
const fd = new FormData();
|
const fd = new FormData();
|
||||||
for (const f of input_train.files) fd.append("train_files", f);
|
for (const f of input_train.files) fd.append("train_files", f);
|
||||||
|
|
@ -180,7 +181,7 @@
|
||||||
} else {
|
} else {
|
||||||
clearInterval(timer);
|
clearInterval(timer);
|
||||||
document.body.removeChild(notice);
|
document.body.removeChild(notice);
|
||||||
window.location.href = `preview.html?id=${encodeURIComponent(res.data['id'])}`;
|
window.location.href = `train_result.html?id=${encodeURIComponent(res.data['id'])}`;
|
||||||
}
|
}
|
||||||
}, 1000);
|
}, 1000);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue