chore: 增加注释

This commit is contained in:
ClovertaTheTrilobita 2025-10-22 14:42:14 +00:00
parent 185b01c467
commit a4ca19b61a
6 changed files with 118 additions and 30 deletions

View file

@ -1,5 +1,5 @@
backend:
ip: 10.10.25.240
ip: 192.168.193.141
port: 5000
model:

View file

@ -62,6 +62,20 @@ class Cprun:
diameter: float | None = None,
flow_threshold: float = 0.4,
cellprob_threshold: float = 0.0, ):
"""
运行 cellpose 分割
Args:
images: [list] 图片存储路径
time: [str] 开始运行的时间相当于本次运行的ID用于存储运行结果
model: [str] 图像分割所使用的模型
diameter: [float] diameters are used to rescale the image to 30 pix cell diameter.
flow_threshold: [float] flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
cellprob_threshold: [float] all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
Returns:
"""
if time is None:
return [False, "No time received"]
@ -71,9 +85,10 @@ class Cprun:
message = [f"Using {model} model"]
# 设定模型参数
model = models.CellposeModel(gpu=True, model_type=model)
files = images
imgs = [imread(f) for f in files]
imgs = [imread(f) for f in files] # 获取目录中的每一个文件
masks, flows, styles = model.eval(
imgs,
flow_threshold=flow_threshold,
@ -90,7 +105,7 @@ class Cprun:
out = base + "_output"
save_masks(imgs, mask, flow, out, tif=True)
# 用 plot 生成彩色叠加图(不依赖 skimage
# 用 plot 生成彩色叠加图
rgb = plot.image_to_rgb(img, channels=[0, 0]) # 原图转 RGB
over = plot.mask_overlay(rgb, masks=mask, colors=None) # 叠加彩色实例
Image.fromarray(over).save(base + "_overlay.png")

View file

@ -19,6 +19,19 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
def set_status(task_id, status, train_losses, test_losses, **extra):
"""
修改redis数据库中某一任务的运行状态
Args:
task_id: 这一任务的时间戳
status: 任务状态
train_losses: 此次任务的训练loss
test_losses: 此次任务的测试loss
**extra:
Returns:
"""
payload = {"status": status,
"updated_at": datetime.datetime.utcnow().isoformat(),
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
@ -55,6 +68,32 @@ class Cptrain:
scale_range=None,
channel_axis: int = None,
):
"""
开始训练
Args:
time: 此次任务的时间戳即任务ID
model_name: 训练结果命名
image_filter:
mask_filter:
base_model:
train_probs:
test_probs:
batch_size:
learning_rate:
n_epochs:
weight_decay:
normalize:
compute_flows:
min_train_masks:
nimg_per_epoch:
rescale:
scale_range:
channel_axis:
Returns:
"""
train_dir = Path(TRAIN_DIR) / time
test_dir = Path(TEST_DIR) / time

View file

@ -93,15 +93,10 @@ def run_upload():
return default
flow_threshold = _to_float(request.args.get("flow_threshold") or request.form.get("flow_threshold"), 0.4)
cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),
0.0)
cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),0.0)
diameter_raw = request.args.get("diameter") or request.form.get("diameter")
diameter = _to_float(diameter_raw, None) if diameter_raw not in (None, "") else None
print("cpt:" + str(cellprob_threshold))
print("flow:" + str(flow_threshold))
print("diameter:" + str(diameter))
# 将文件保存在本地目录中
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
os.makedirs(Path(UPLOAD_DIR) / ts, exist_ok=True)
@ -140,19 +135,39 @@ def run_upload():
@app.post("/train_upload")
def train_upload():
"""
从前端获取训练数据和测试数据并开始训练
Returns:
"""
def _to_float(x, default):
"""
将变量转为float类型
Args:
x: 变量
default: 默认值
"""
try:
return float(x)
except (TypeError, ValueError):
return default
def _to_int(x, default):
"""
将变量转为int类型
Args:
x: 变量
default: 默认值
"""
try:
return int(x)
except (TypeError, ValueError):
return default
# 获取从前端传来的参数
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
model_name = request.args.get("model_name") or f"custom_model-{ts}"
image_filter = request.args.get("image_filter") or "_img"
@ -182,6 +197,7 @@ def train_upload():
scale_range = _to_float(request.args.get("scale_range"), None)
channel_axis = _to_int(request.args.get("channel_axis"), None)
# 创建工作目录
train_files = request.files.getlist("train_files")
test_files = request.files.getlist("test_files")
os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
@ -203,34 +219,30 @@ def train_upload():
saved.append(os.path.join(TEST_DIR, ts, name))
def job():
"""
子线程方法
"""
return asyncio.run(Cptrain.start_train(
time=ts,
model_name=model_name,
image_filter=image_filter,
mask_filter=mask_filter,
base_model=base_model,
batch_size=batch_size,
learning_rate=learning_rate,
n_epochs=n_epochs,
weight_decay=weight_decay,
normalize=normalize,
compute_flows=compute_flows,
min_train_masks=min_train_masks,
nimg_per_epoch=nimg_per_epoch,
rescale=rescale,
scale_range=scale_range,
channel_axis=channel_axis,
time=ts, model_name=model_name, image_filter=image_filter, mask_filter=mask_filter, base_model=base_model,
batch_size=batch_size, learning_rate=learning_rate, n_epochs=n_epochs, weight_decay=weight_decay,
normalize=normalize, compute_flows=compute_flows, min_train_masks=min_train_masks, nimg_per_epoch=nimg_per_epoch,
rescale=rescale, scale_range=scale_range, channel_axis=channel_axis,
))
# 创建一个子线程,防止阻塞主线程
fut = executor.submit(job)
def done_cb(f):
"""
获取训练结果并存入redis数据库
"""
try:
train_losses, test_losses = f.result()
set_train_status(ts, "success", train_losses, test_losses)
except Exception as e:
set_status(ts, "failed", error=str(e))
# 添加回调在子线程执行完后更新redis中任务状态
fut.add_done_callback(done_cb)
return jsonify({"ok": True, "count": len(saved), "id": ts})
@ -251,6 +263,13 @@ def status():
@app.get("/preview")
def preview():
"""
获取本次分割结果的预览
Returns:
"""
task_id = request.args.get('id')
task_dir = Path(OUTPUT_DIR) / task_id
if not task_dir.exists():
@ -275,11 +294,24 @@ def preview():
@app.get("/models")
def list_models():
models_list = os.listdir(MODELS_DIR)
"""
获取现有模型列表
Returns:
"""
models_list = os.listdir(MODELS_DIR) # 查询模型列表中有哪些文件
return jsonify({"ok": True, "models": models_list})
@app.get("/result")
def list_results():
"""
获取运行结果
Returns:
"""
task_id = request.args.get('id')
st = get_status(task_id)
if not st:

View file

@ -4,7 +4,7 @@ from multiprocessing import Process
if __name__ == "__main__":
# Cprun.run_test()
# 启动测试服务器
p = Process(target=run_dev)
p.start()
print(f"Flask running in PID {p.pid}")

View file

@ -1,9 +1,11 @@
const config = {
/* 请根据需求修改下列IP和端口 */
server: {
protocol: 'http',
host: '10.10.25.240',
port: 5000
protocol: 'http', // 网络协议
host: '192.168.193.141', // 主机IP
port: 5000 // 后端运行端口
}
};
/* 生成API链接 */
const API_BASE = `${config.server.protocol}://${config.server.host}:${config.server.port}/`;