mirror of
https://github.com/ClovertaTheTrilobita/cellpose-web.git
synced 2026-04-01 23:14:50 +00:00
chore: 增加注释
This commit is contained in:
parent
185b01c467
commit
a4ca19b61a
6 changed files with 118 additions and 30 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
backend:
|
backend:
|
||||||
ip: 10.10.25.240
|
ip: 192.168.193.141
|
||||||
port: 5000
|
port: 5000
|
||||||
|
|
||||||
model:
|
model:
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,20 @@ class Cprun:
|
||||||
diameter: float | None = None,
|
diameter: float | None = None,
|
||||||
flow_threshold: float = 0.4,
|
flow_threshold: float = 0.4,
|
||||||
cellprob_threshold: float = 0.0, ):
|
cellprob_threshold: float = 0.0, ):
|
||||||
|
"""
|
||||||
|
运行 cellpose 分割
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: [list] 图片存储路径
|
||||||
|
time: [str] 开始运行的时间,相当于本次运行的ID,用于存储运行结果
|
||||||
|
model: [str] 图像分割所使用的模型
|
||||||
|
diameter: [float] diameters are used to rescale the image to 30 pix cell diameter.
|
||||||
|
flow_threshold: [float] flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
|
||||||
|
cellprob_threshold: [float] all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
if time is None:
|
if time is None:
|
||||||
return [False, "No time received"]
|
return [False, "No time received"]
|
||||||
|
|
@ -71,9 +85,10 @@ class Cprun:
|
||||||
|
|
||||||
message = [f"Using {model} model"]
|
message = [f"Using {model} model"]
|
||||||
|
|
||||||
|
# 设定模型参数
|
||||||
model = models.CellposeModel(gpu=True, model_type=model)
|
model = models.CellposeModel(gpu=True, model_type=model)
|
||||||
files = images
|
files = images
|
||||||
imgs = [imread(f) for f in files]
|
imgs = [imread(f) for f in files] # 获取目录中的每一个文件
|
||||||
masks, flows, styles = model.eval(
|
masks, flows, styles = model.eval(
|
||||||
imgs,
|
imgs,
|
||||||
flow_threshold=flow_threshold,
|
flow_threshold=flow_threshold,
|
||||||
|
|
@ -90,7 +105,7 @@ class Cprun:
|
||||||
out = base + "_output"
|
out = base + "_output"
|
||||||
save_masks(imgs, mask, flow, out, tif=True)
|
save_masks(imgs, mask, flow, out, tif=True)
|
||||||
|
|
||||||
# 用 plot 生成彩色叠加图(不依赖 skimage)
|
# 用 plot 生成彩色叠加图
|
||||||
rgb = plot.image_to_rgb(img, channels=[0, 0]) # 原图转 RGB
|
rgb = plot.image_to_rgb(img, channels=[0, 0]) # 原图转 RGB
|
||||||
over = plot.mask_overlay(rgb, masks=mask, colors=None) # 叠加彩色实例
|
over = plot.mask_overlay(rgb, masks=mask, colors=None) # 叠加彩色实例
|
||||||
Image.fromarray(over).save(base + "_overlay.png")
|
Image.fromarray(over).save(base + "_overlay.png")
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,19 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
|
||||||
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
||||||
|
|
||||||
def set_status(task_id, status, train_losses, test_losses, **extra):
|
def set_status(task_id, status, train_losses, test_losses, **extra):
|
||||||
|
"""
|
||||||
|
修改redis数据库中某一任务的运行状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 这一任务的时间戳
|
||||||
|
status: 任务状态
|
||||||
|
train_losses: 此次任务的训练loss
|
||||||
|
test_losses: 此次任务的测试loss
|
||||||
|
**extra:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
payload = {"status": status,
|
payload = {"status": status,
|
||||||
"updated_at": datetime.datetime.utcnow().isoformat(),
|
"updated_at": datetime.datetime.utcnow().isoformat(),
|
||||||
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
|
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
|
||||||
|
|
@ -55,6 +68,32 @@ class Cptrain:
|
||||||
scale_range=None,
|
scale_range=None,
|
||||||
channel_axis: int = None,
|
channel_axis: int = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
开始训练
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time: 此次任务的时间戳(即任务ID)
|
||||||
|
model_name: 训练结果命名
|
||||||
|
image_filter:
|
||||||
|
mask_filter:
|
||||||
|
base_model:
|
||||||
|
train_probs:
|
||||||
|
test_probs:
|
||||||
|
batch_size:
|
||||||
|
learning_rate:
|
||||||
|
n_epochs:
|
||||||
|
weight_decay:
|
||||||
|
normalize:
|
||||||
|
compute_flows:
|
||||||
|
min_train_masks:
|
||||||
|
nimg_per_epoch:
|
||||||
|
rescale:
|
||||||
|
scale_range:
|
||||||
|
channel_axis:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
train_dir = Path(TRAIN_DIR) / time
|
train_dir = Path(TRAIN_DIR) / time
|
||||||
test_dir = Path(TEST_DIR) / time
|
test_dir = Path(TEST_DIR) / time
|
||||||
|
|
|
||||||
|
|
@ -93,15 +93,10 @@ def run_upload():
|
||||||
return default
|
return default
|
||||||
|
|
||||||
flow_threshold = _to_float(request.args.get("flow_threshold") or request.form.get("flow_threshold"), 0.4)
|
flow_threshold = _to_float(request.args.get("flow_threshold") or request.form.get("flow_threshold"), 0.4)
|
||||||
cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),
|
cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),0.0)
|
||||||
0.0)
|
|
||||||
diameter_raw = request.args.get("diameter") or request.form.get("diameter")
|
diameter_raw = request.args.get("diameter") or request.form.get("diameter")
|
||||||
diameter = _to_float(diameter_raw, None) if diameter_raw not in (None, "") else None
|
diameter = _to_float(diameter_raw, None) if diameter_raw not in (None, "") else None
|
||||||
|
|
||||||
print("cpt:" + str(cellprob_threshold))
|
|
||||||
print("flow:" + str(flow_threshold))
|
|
||||||
print("diameter:" + str(diameter))
|
|
||||||
|
|
||||||
# 将文件保存在本地目录中
|
# 将文件保存在本地目录中
|
||||||
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
|
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
|
||||||
os.makedirs(Path(UPLOAD_DIR) / ts, exist_ok=True)
|
os.makedirs(Path(UPLOAD_DIR) / ts, exist_ok=True)
|
||||||
|
|
@ -140,19 +135,39 @@ def run_upload():
|
||||||
|
|
||||||
@app.post("/train_upload")
|
@app.post("/train_upload")
|
||||||
def train_upload():
|
def train_upload():
|
||||||
|
"""
|
||||||
|
从前端获取训练数据和测试数据,并开始训练
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def _to_float(x, default):
|
def _to_float(x, default):
|
||||||
|
"""
|
||||||
|
将变量转为float类型
|
||||||
|
Args:
|
||||||
|
x: 变量
|
||||||
|
default: 默认值
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return float(x)
|
return float(x)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
def _to_int(x, default):
|
def _to_int(x, default):
|
||||||
|
"""
|
||||||
|
将变量转为int类型
|
||||||
|
Args:
|
||||||
|
x: 变量
|
||||||
|
default: 默认值
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return int(x)
|
return int(x)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
# 获取从前端传来的参数
|
||||||
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
|
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
|
||||||
model_name = request.args.get("model_name") or f"custom_model-{ts}"
|
model_name = request.args.get("model_name") or f"custom_model-{ts}"
|
||||||
image_filter = request.args.get("image_filter") or "_img"
|
image_filter = request.args.get("image_filter") or "_img"
|
||||||
|
|
@ -182,6 +197,7 @@ def train_upload():
|
||||||
scale_range = _to_float(request.args.get("scale_range"), None)
|
scale_range = _to_float(request.args.get("scale_range"), None)
|
||||||
channel_axis = _to_int(request.args.get("channel_axis"), None)
|
channel_axis = _to_int(request.args.get("channel_axis"), None)
|
||||||
|
|
||||||
|
# 创建工作目录
|
||||||
train_files = request.files.getlist("train_files")
|
train_files = request.files.getlist("train_files")
|
||||||
test_files = request.files.getlist("test_files")
|
test_files = request.files.getlist("test_files")
|
||||||
os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
|
os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
|
||||||
|
|
@ -203,34 +219,30 @@ def train_upload():
|
||||||
saved.append(os.path.join(TEST_DIR, ts, name))
|
saved.append(os.path.join(TEST_DIR, ts, name))
|
||||||
|
|
||||||
def job():
|
def job():
|
||||||
|
"""
|
||||||
|
子线程方法
|
||||||
|
"""
|
||||||
return asyncio.run(Cptrain.start_train(
|
return asyncio.run(Cptrain.start_train(
|
||||||
time=ts,
|
time=ts, model_name=model_name, image_filter=image_filter, mask_filter=mask_filter, base_model=base_model,
|
||||||
model_name=model_name,
|
batch_size=batch_size, learning_rate=learning_rate, n_epochs=n_epochs, weight_decay=weight_decay,
|
||||||
image_filter=image_filter,
|
normalize=normalize, compute_flows=compute_flows, min_train_masks=min_train_masks, nimg_per_epoch=nimg_per_epoch,
|
||||||
mask_filter=mask_filter,
|
rescale=rescale, scale_range=scale_range, channel_axis=channel_axis,
|
||||||
base_model=base_model,
|
|
||||||
batch_size=batch_size,
|
|
||||||
learning_rate=learning_rate,
|
|
||||||
n_epochs=n_epochs,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
normalize=normalize,
|
|
||||||
compute_flows=compute_flows,
|
|
||||||
min_train_masks=min_train_masks,
|
|
||||||
nimg_per_epoch=nimg_per_epoch,
|
|
||||||
rescale=rescale,
|
|
||||||
scale_range=scale_range,
|
|
||||||
channel_axis=channel_axis,
|
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# 创建一个子线程,防止阻塞主线程
|
||||||
fut = executor.submit(job)
|
fut = executor.submit(job)
|
||||||
|
|
||||||
def done_cb(f):
|
def done_cb(f):
|
||||||
|
"""
|
||||||
|
获取训练结果,并存入redis数据库
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
train_losses, test_losses = f.result()
|
train_losses, test_losses = f.result()
|
||||||
set_train_status(ts, "success", train_losses, test_losses)
|
set_train_status(ts, "success", train_losses, test_losses)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
set_status(ts, "failed", error=str(e))
|
set_status(ts, "failed", error=str(e))
|
||||||
|
|
||||||
|
# 添加回调,在子线程执行完后更新redis中任务状态
|
||||||
fut.add_done_callback(done_cb)
|
fut.add_done_callback(done_cb)
|
||||||
|
|
||||||
return jsonify({"ok": True, "count": len(saved), "id": ts})
|
return jsonify({"ok": True, "count": len(saved), "id": ts})
|
||||||
|
|
@ -251,6 +263,13 @@ def status():
|
||||||
|
|
||||||
@app.get("/preview")
|
@app.get("/preview")
|
||||||
def preview():
|
def preview():
|
||||||
|
"""
|
||||||
|
获取本次分割结果的预览
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
task_id = request.args.get('id')
|
task_id = request.args.get('id')
|
||||||
task_dir = Path(OUTPUT_DIR) / task_id
|
task_dir = Path(OUTPUT_DIR) / task_id
|
||||||
if not task_dir.exists():
|
if not task_dir.exists():
|
||||||
|
|
@ -275,11 +294,24 @@ def preview():
|
||||||
|
|
||||||
@app.get("/models")
|
@app.get("/models")
|
||||||
def list_models():
|
def list_models():
|
||||||
models_list = os.listdir(MODELS_DIR)
|
"""
|
||||||
|
获取现有模型列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
models_list = os.listdir(MODELS_DIR) # 查询模型列表中有哪些文件
|
||||||
return jsonify({"ok": True, "models": models_list})
|
return jsonify({"ok": True, "models": models_list})
|
||||||
|
|
||||||
@app.get("/result")
|
@app.get("/result")
|
||||||
def list_results():
|
def list_results():
|
||||||
|
"""
|
||||||
|
获取运行结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
task_id = request.args.get('id')
|
task_id = request.args.get('id')
|
||||||
st = get_status(task_id)
|
st = get_status(task_id)
|
||||||
if not st:
|
if not st:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from multiprocessing import Process
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Cprun.run_test()
|
# 启动测试服务器
|
||||||
p = Process(target=run_dev)
|
p = Process(target=run_dev)
|
||||||
p.start()
|
p.start()
|
||||||
print(f"Flask running in PID {p.pid}")
|
print(f"Flask running in PID {p.pid}")
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
const config = {
|
const config = {
|
||||||
|
/* 请根据需求修改下列IP和端口 */
|
||||||
server: {
|
server: {
|
||||||
protocol: 'http',
|
protocol: 'http', // 网络协议
|
||||||
host: '10.10.25.240',
|
host: '192.168.193.141', // 主机IP
|
||||||
port: 5000
|
port: 5000 // 后端运行端口
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* 生成API链接 */
|
||||||
const API_BASE = `${config.server.protocol}://${config.server.host}:${config.server.port}/`;
|
const API_BASE = `${config.server.protocol}://${config.server.host}:${config.server.port}/`;
|
||||||
Loading…
Reference in a new issue