mirror of
https://github.com/ClovertaTheTrilobita/cellpose-web.git
synced 2026-04-01 15:04:51 +00:00
chore: 增加注释
This commit is contained in:
parent
185b01c467
commit
a4ca19b61a
6 changed files with 118 additions and 30 deletions
|
|
@ -1,5 +1,5 @@
|
|||
backend:
|
||||
ip: 10.10.25.240
|
||||
ip: 192.168.193.141
|
||||
port: 5000
|
||||
|
||||
model:
|
||||
|
|
|
|||
|
|
@ -62,6 +62,20 @@ class Cprun:
|
|||
diameter: float | None = None,
|
||||
flow_threshold: float = 0.4,
|
||||
cellprob_threshold: float = 0.0, ):
|
||||
"""
|
||||
运行 cellpose 分割
|
||||
|
||||
Args:
|
||||
images: [list] 图片存储路径
|
||||
time: [str] 开始运行的时间,相当于本次运行的ID,用于存储运行结果
|
||||
model: [str] 图像分割所使用的模型
|
||||
diameter: [float] diameters are used to rescale the image to 30 pix cell diameter.
|
||||
flow_threshold: [float] flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
|
||||
cellprob_threshold: [float] all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
if time is None:
|
||||
return [False, "No time received"]
|
||||
|
|
@ -71,9 +85,10 @@ class Cprun:
|
|||
|
||||
message = [f"Using {model} model"]
|
||||
|
||||
# 设定模型参数
|
||||
model = models.CellposeModel(gpu=True, model_type=model)
|
||||
files = images
|
||||
imgs = [imread(f) for f in files]
|
||||
imgs = [imread(f) for f in files] # 获取目录中的每一个文件
|
||||
masks, flows, styles = model.eval(
|
||||
imgs,
|
||||
flow_threshold=flow_threshold,
|
||||
|
|
@ -90,7 +105,7 @@ class Cprun:
|
|||
out = base + "_output"
|
||||
save_masks(imgs, mask, flow, out, tif=True)
|
||||
|
||||
# 用 plot 生成彩色叠加图(不依赖 skimage)
|
||||
# 用 plot 生成彩色叠加图
|
||||
rgb = plot.image_to_rgb(img, channels=[0, 0]) # 原图转 RGB
|
||||
over = plot.mask_overlay(rgb, masks=mask, colors=None) # 叠加彩色实例
|
||||
Image.fromarray(over).save(base + "_overlay.png")
|
||||
|
|
|
|||
|
|
@ -19,6 +19,19 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
|
|||
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
||||
|
||||
def set_status(task_id, status, train_losses, test_losses, **extra):
|
||||
"""
|
||||
修改redis数据库中某一任务的运行状态
|
||||
|
||||
Args:
|
||||
task_id: 这一任务的时间戳
|
||||
status: 任务状态
|
||||
train_losses: 此次任务的训练loss
|
||||
test_losses: 此次任务的测试loss
|
||||
**extra:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
payload = {"status": status,
|
||||
"updated_at": datetime.datetime.utcnow().isoformat(),
|
||||
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
|
||||
|
|
@ -55,6 +68,32 @@ class Cptrain:
|
|||
scale_range=None,
|
||||
channel_axis: int = None,
|
||||
):
|
||||
"""
|
||||
开始训练
|
||||
|
||||
Args:
|
||||
time: 此次任务的时间戳(即任务ID)
|
||||
model_name: 训练结果命名
|
||||
image_filter:
|
||||
mask_filter:
|
||||
base_model:
|
||||
train_probs:
|
||||
test_probs:
|
||||
batch_size:
|
||||
learning_rate:
|
||||
n_epochs:
|
||||
weight_decay:
|
||||
normalize:
|
||||
compute_flows:
|
||||
min_train_masks:
|
||||
nimg_per_epoch:
|
||||
rescale:
|
||||
scale_range:
|
||||
channel_axis:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
train_dir = Path(TRAIN_DIR) / time
|
||||
test_dir = Path(TEST_DIR) / time
|
||||
|
|
|
|||
|
|
@ -93,15 +93,10 @@ def run_upload():
|
|||
return default
|
||||
|
||||
flow_threshold = _to_float(request.args.get("flow_threshold") or request.form.get("flow_threshold"), 0.4)
|
||||
cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),
|
||||
0.0)
|
||||
cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),0.0)
|
||||
diameter_raw = request.args.get("diameter") or request.form.get("diameter")
|
||||
diameter = _to_float(diameter_raw, None) if diameter_raw not in (None, "") else None
|
||||
|
||||
print("cpt:" + str(cellprob_threshold))
|
||||
print("flow:" + str(flow_threshold))
|
||||
print("diameter:" + str(diameter))
|
||||
|
||||
# 将文件保存在本地目录中
|
||||
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
|
||||
os.makedirs(Path(UPLOAD_DIR) / ts, exist_ok=True)
|
||||
|
|
@ -140,19 +135,39 @@ def run_upload():
|
|||
|
||||
@app.post("/train_upload")
|
||||
def train_upload():
|
||||
"""
|
||||
从前端获取训练数据和测试数据,并开始训练
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
def _to_float(x, default):
|
||||
"""
|
||||
将变量转为float类型
|
||||
Args:
|
||||
x: 变量
|
||||
default: 默认值
|
||||
"""
|
||||
|
||||
try:
|
||||
return float(x)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _to_int(x, default):
|
||||
"""
|
||||
将变量转为int类型
|
||||
Args:
|
||||
x: 变量
|
||||
default: 默认值
|
||||
"""
|
||||
try:
|
||||
return int(x)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
# 获取从前端传来的参数
|
||||
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
|
||||
model_name = request.args.get("model_name") or f"custom_model-{ts}"
|
||||
image_filter = request.args.get("image_filter") or "_img"
|
||||
|
|
@ -182,6 +197,7 @@ def train_upload():
|
|||
scale_range = _to_float(request.args.get("scale_range"), None)
|
||||
channel_axis = _to_int(request.args.get("channel_axis"), None)
|
||||
|
||||
# 创建工作目录
|
||||
train_files = request.files.getlist("train_files")
|
||||
test_files = request.files.getlist("test_files")
|
||||
os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
|
||||
|
|
@ -203,34 +219,30 @@ def train_upload():
|
|||
saved.append(os.path.join(TEST_DIR, ts, name))
|
||||
|
||||
def job():
|
||||
"""
|
||||
子线程方法
|
||||
"""
|
||||
return asyncio.run(Cptrain.start_train(
|
||||
time=ts,
|
||||
model_name=model_name,
|
||||
image_filter=image_filter,
|
||||
mask_filter=mask_filter,
|
||||
base_model=base_model,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
n_epochs=n_epochs,
|
||||
weight_decay=weight_decay,
|
||||
normalize=normalize,
|
||||
compute_flows=compute_flows,
|
||||
min_train_masks=min_train_masks,
|
||||
nimg_per_epoch=nimg_per_epoch,
|
||||
rescale=rescale,
|
||||
scale_range=scale_range,
|
||||
channel_axis=channel_axis,
|
||||
time=ts, model_name=model_name, image_filter=image_filter, mask_filter=mask_filter, base_model=base_model,
|
||||
batch_size=batch_size, learning_rate=learning_rate, n_epochs=n_epochs, weight_decay=weight_decay,
|
||||
normalize=normalize, compute_flows=compute_flows, min_train_masks=min_train_masks, nimg_per_epoch=nimg_per_epoch,
|
||||
rescale=rescale, scale_range=scale_range, channel_axis=channel_axis,
|
||||
))
|
||||
|
||||
# 创建一个子线程,防止阻塞主线程
|
||||
fut = executor.submit(job)
|
||||
|
||||
def done_cb(f):
|
||||
"""
|
||||
获取训练结果,并存入redis数据库
|
||||
"""
|
||||
try:
|
||||
train_losses, test_losses = f.result()
|
||||
set_train_status(ts, "success", train_losses, test_losses)
|
||||
except Exception as e:
|
||||
set_status(ts, "failed", error=str(e))
|
||||
|
||||
# 添加回调,在子线程执行完后更新redis中任务状态
|
||||
fut.add_done_callback(done_cb)
|
||||
|
||||
return jsonify({"ok": True, "count": len(saved), "id": ts})
|
||||
|
|
@ -251,6 +263,13 @@ def status():
|
|||
|
||||
@app.get("/preview")
|
||||
def preview():
|
||||
"""
|
||||
获取本次分割结果的预览
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
task_id = request.args.get('id')
|
||||
task_dir = Path(OUTPUT_DIR) / task_id
|
||||
if not task_dir.exists():
|
||||
|
|
@ -275,11 +294,24 @@ def preview():
|
|||
|
||||
@app.get("/models")
|
||||
def list_models():
|
||||
models_list = os.listdir(MODELS_DIR)
|
||||
"""
|
||||
获取现有模型列表
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
models_list = os.listdir(MODELS_DIR) # 查询模型列表中有哪些文件
|
||||
return jsonify({"ok": True, "models": models_list})
|
||||
|
||||
@app.get("/result")
|
||||
def list_results():
|
||||
"""
|
||||
获取运行结果
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
task_id = request.args.get('id')
|
||||
st = get_status(task_id)
|
||||
if not st:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from multiprocessing import Process
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Cprun.run_test()
|
||||
# 启动测试服务器
|
||||
p = Process(target=run_dev)
|
||||
p.start()
|
||||
print(f"Flask running in PID {p.pid}")
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
const config = {
|
||||
/* 请根据需求修改下列IP和端口 */
|
||||
server: {
|
||||
protocol: 'http',
|
||||
host: '10.10.25.240',
|
||||
port: 5000
|
||||
protocol: 'http', // 网络协议
|
||||
host: '192.168.193.141', // 主机IP
|
||||
port: 5000 // 后端运行端口
|
||||
}
|
||||
};
|
||||
|
||||
/* 生成API链接 */
|
||||
const API_BASE = `${config.server.protocol}://${config.server.host}:${config.server.port}/`;
|
||||
Loading…
Reference in a new issue