From a4ca19b61af71488044336d6c54ed5d265848826 Mon Sep 17 00:00:00 2001 From: ClovertaTheTrilobita Date: Wed, 22 Oct 2025 14:42:14 +0000 Subject: [PATCH 1/2] =?UTF-8?q?chore:=20=E5=A2=9E=E5=8A=A0=E6=B3=A8?= =?UTF-8?q?=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.yaml | 2 +- backend/cp_run.py | 19 +++++++++-- backend/cp_train.py | 39 +++++++++++++++++++++++ backend/flaskApp.py | 78 ++++++++++++++++++++++++++++++++------------- backend/main.py | 2 +- frontend/api.js | 8 +++-- 6 files changed, 118 insertions(+), 30 deletions(-) diff --git a/backend/config.yaml b/backend/config.yaml index 19b138f..9ce8471 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -1,5 +1,5 @@ backend: - ip: 10.10.25.240 + ip: 192.168.193.141 port: 5000 model: diff --git a/backend/cp_run.py b/backend/cp_run.py index dfcb4ab..474bac1 100644 --- a/backend/cp_run.py +++ b/backend/cp_run.py @@ -62,6 +62,20 @@ class Cprun: diameter: float | None = None, flow_threshold: float = 0.4, cellprob_threshold: float = 0.0, ): + """ + 运行 cellpose 分割 + + Args: + images: [list] 图片存储路径 + time: [str] 开始运行的时间,相当于本次运行的ID,用于存储运行结果 + model: [str] 图像分割所使用的模型 + diameter: [float] diameters are used to rescale the image to 30 pix cell diameter. + flow_threshold: [float] flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4. + cellprob_threshold: [float] all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0. + + Returns: + + """ if time is None: return [False, "No time received"] @@ -71,9 +85,10 @@ class Cprun: message = [f"Using {model} model"] + # 设定模型参数 model = models.CellposeModel(gpu=True, model_type=model) files = images - imgs = [imread(f) for f in files] + imgs = [imread(f) for f in files] # 获取目录中的每一个文件 masks, flows, styles = model.eval( imgs, flow_threshold=flow_threshold, @@ -90,7 +105,7 @@ class Cprun: out = base + "_output" save_masks(imgs, mask, flow, out, tif=True) - # 用 plot 生成彩色叠加图(不依赖 skimage) + # 用 plot 生成彩色叠加图 rgb = plot.image_to_rgb(img, channels=[0, 0]) # 原图转 RGB over = plot.mask_overlay(rgb, masks=mask, colors=None) # 叠加彩色实例 Image.fromarray(over).save(base + "_overlay.png") diff --git a/backend/cp_train.py b/backend/cp_train.py index 58594ac..f9ea094 100644 --- a/backend/cp_train.py +++ b/backend/cp_train.py @@ -19,6 +19,19 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR r = redis.Redis(host="127.0.0.1", port=6379, db=0) def set_status(task_id, status, train_losses, test_losses, **extra): + """ + 修改redis数据库中某一任务的运行状态 + + Args: + task_id: 这一任务的时间戳 + status: 任务状态 + train_losses: 此次任务的训练loss + test_losses: 此次任务的测试loss + **extra: + + Returns: + + """ payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), "train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses, @@ -55,6 +68,32 @@ class Cptrain: scale_range=None, channel_axis: int = None, ): + """ + 开始训练 + + Args: + time: 此次任务的时间戳(即任务ID) + model_name: 训练结果命名 + image_filter: + mask_filter: + base_model: + train_probs: + test_probs: + batch_size: + learning_rate: + n_epochs: + weight_decay: + normalize: + compute_flows: + min_train_masks: + nimg_per_epoch: + rescale: + scale_range: + channel_axis: + + Returns: + + """ train_dir = Path(TRAIN_DIR) / time test_dir = Path(TEST_DIR) / time diff --git a/backend/flaskApp.py b/backend/flaskApp.py index 40109f5..29619cb 100644 --- a/backend/flaskApp.py +++ b/backend/flaskApp.py @@ -93,15 +93,10 @@ def run_upload(): return default flow_threshold = _to_float(request.args.get("flow_threshold") or request.form.get("flow_threshold"), 0.4) - cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"), - 0.0) + cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),0.0) diameter_raw = request.args.get("diameter") or request.form.get("diameter") diameter = _to_float(diameter_raw, None) if diameter_raw not in (None, "") else None - print("cpt:" + str(cellprob_threshold)) - print("flow:" + str(flow_threshold)) - print("diameter:" + str(diameter)) - # 将文件保存在本地目录中 ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}" os.makedirs(Path(UPLOAD_DIR) / ts, exist_ok=True) @@ -140,19 +135,39 @@ def run_upload(): @app.post("/train_upload") def train_upload(): + """ + 从前端获取训练数据和测试数据,并开始训练 + + Returns: + + """ def _to_float(x, default): + """ + 将变量转为float类型 + Args: + x: 变量 + default: 默认值 + """ + try: return float(x) except (TypeError, ValueError): return default def _to_int(x, default): + """ + 将变量转为int类型 + Args: + x: 变量 + default: 默认值 + """ try: return int(x) except (TypeError, ValueError): return default + # 获取从前端传来的参数 ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}" model_name = request.args.get("model_name") or f"custom_model-{ts}" image_filter = request.args.get("image_filter") or "_img" @@ -182,6 +197,7 @@ def train_upload(): scale_range = _to_float(request.args.get("scale_range"), None) channel_axis = _to_int(request.args.get("channel_axis"), None) + # 创建工作目录 train_files = request.files.getlist("train_files") test_files = request.files.getlist("test_files") os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True) @@ -203,34 +219,30 @@ def train_upload(): saved.append(os.path.join(TEST_DIR, ts, name)) def job(): + """ + 子线程方法 + """ return asyncio.run(Cptrain.start_train( - time=ts, - model_name=model_name, - image_filter=image_filter, - mask_filter=mask_filter, - base_model=base_model, - batch_size=batch_size, - learning_rate=learning_rate, - n_epochs=n_epochs, - weight_decay=weight_decay, - normalize=normalize, - compute_flows=compute_flows, - min_train_masks=min_train_masks, - nimg_per_epoch=nimg_per_epoch, - rescale=rescale, - scale_range=scale_range, - channel_axis=channel_axis, + time=ts, model_name=model_name, image_filter=image_filter, mask_filter=mask_filter, base_model=base_model, + batch_size=batch_size, learning_rate=learning_rate, n_epochs=n_epochs, weight_decay=weight_decay, + normalize=normalize, compute_flows=compute_flows, min_train_masks=min_train_masks, nimg_per_epoch=nimg_per_epoch, + rescale=rescale, scale_range=scale_range, channel_axis=channel_axis, )) + # 创建一个子线程,防止阻塞主线程 fut = executor.submit(job) def done_cb(f): + """ + 获取训练结果,并存入redis数据库 + """ try: train_losses, test_losses = f.result() set_train_status(ts, "success", train_losses, test_losses) except Exception as e: set_status(ts, "failed", error=str(e)) + # 添加回调,在子线程执行完后更新redis中任务状态 fut.add_done_callback(done_cb) return jsonify({"ok": True, "count": len(saved), "id": ts}) @@ -251,6 +263,13 @@ def status(): @app.get("/preview") def preview(): + """ + 获取本次分割结果的预览 + + Returns: + + """ + task_id = request.args.get('id') task_dir = Path(OUTPUT_DIR) / task_id if not task_dir.exists(): @@ -275,11 +294,24 @@ def preview(): @app.get("/models") def list_models(): - models_list = os.listdir(MODELS_DIR) + """ + 获取现有模型列表 + + Returns: + + """ + + models_list = os.listdir(MODELS_DIR) # 查询模型列表中有哪些文件 return jsonify({"ok": True, "models": models_list}) @app.get("/result") def list_results(): + """ + 获取运行结果 + + Returns: + + """ task_id = request.args.get('id') st = get_status(task_id) if not st: diff --git a/backend/main.py b/backend/main.py index be2bcba..3a087f1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -4,7 +4,7 @@ from multiprocessing import Process if __name__ == "__main__": - # Cprun.run_test() + # 启动测试服务器 p = Process(target=run_dev) p.start() print(f"Flask running in PID {p.pid}") \ No newline at end of file diff --git a/frontend/api.js b/frontend/api.js index a4df8f5..e3a5aab 100644 --- a/frontend/api.js +++ b/frontend/api.js @@ -1,9 +1,11 @@ const config = { + /* 请根据需求修改下列IP和端口 */ server: { - protocol: 'http', - host: '10.10.25.240', - port: 5000 + protocol: 'http', // 网络协议 + host: '192.168.193.141', // 主机IP + port: 5000 // 后端运行端口 } }; +/* 生成API链接 */ const API_BASE = `${config.server.protocol}://${config.server.host}:${config.server.port}/`; \ No newline at end of file From 09d07561843f7c00741791e0f07f6c11395c5b86 Mon Sep 17 00:00:00 2001 From: ClovertaTheTrilobita Date: Wed, 22 Oct 2025 14:47:22 +0000 Subject: [PATCH 2/2] =?UTF-8?q?fix(api):=20=E5=88=A0=E9=99=A4=E6=B3=A8?= =?UTF-8?q?=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/api.js | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/frontend/api.js b/frontend/api.js index e3a5aab..b84607a 100644 --- a/frontend/api.js +++ b/frontend/api.js @@ -1,11 +1,9 @@ const config = { - /* 请根据需求修改下列IP和端口 */ server: { - protocol: 'http', // 网络协议 - host: '192.168.193.141', // 主机IP - port: 5000 // 后端运行端口 + protocol: 'http', + host: '192.168.193.141', + port: 5000 } }; -/* 生成API链接 */ const API_BASE = `${config.server.protocol}://${config.server.host}:${config.server.port}/`; \ No newline at end of file