Compare commits

..

No commits in common. "master" and "1.0.0" have entirely different histories.

7 changed files with 29 additions and 122 deletions

View file

@ -6,13 +6,8 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Redis-6.4.0-red"> <img alt="Static Badge" src="https://img.shields.io/badge/Redis-6.4.0-red">
<img alt="Static Badge" src="https://img.shields.io/badge/JSDelivr-in_use-brown"> <img alt="Static Badge" src="https://img.shields.io/badge/JSDelivr-in_use-brown">
<img alt="Static Badge" src="https://img.shields.io/badge/Flask-3.1.2-8ecae6"> <img alt="Static Badge" src="https://img.shields.io/badge/Flask-3.1.2-8ecae6">
<br><br>
<img width="1920" height="1080" alt="image" src="https://github.com/user-attachments/assets/da5b891e-b4ac-484a-885c-0856f18e04fc" style="height: 70%; width: 70%"/>
</p> </p>
<br> <br>
🌈 实现功能: 🌈 实现功能:
@ -27,12 +22,10 @@
## 🚀一键安装 ## 🚀一键安装
<b>[最新Release页面](https://github.com/ClovertaTheTrilobita/cellpose-web/releases/latest)</b>中下载最新的 <b>install.sh</b> 到你的Linux/macOS机器上。 <b>[Release页面](https://github.com/ClovertaTheTrilobita/cellpose-web/releases)</b>中下载最新的 <b>install.sh</b> 到你的Linux/macOS机器上。
将它放到你希望项目存在的位置,并执行它,安装脚本会将项目自动拉取到同一目录下。 将它放到你希望项目存在的位置,并执行它,安装脚本会将项目自动拉取到同一目录下。
**NOTE:** 安装脚本设计上支持Debian系、Arch系、RHEL系Linux,但目前仅测试过Arch Linux。其他发行版暂未经过测试若出现意外错误推荐手动安装。
Windows暂时不支持通过脚本一键安装。 Windows暂时不支持通过脚本一键安装。
<br> <br>

View file

@ -1,5 +1,5 @@
backend: backend:
ip: 192.168.193.141 ip: 10.10.25.240
port: 5000 port: 5000
model: model:

View file

@ -62,20 +62,6 @@ class Cprun:
diameter: float | None = None, diameter: float | None = None,
flow_threshold: float = 0.4, flow_threshold: float = 0.4,
cellprob_threshold: float = 0.0, ): cellprob_threshold: float = 0.0, ):
"""
运行 cellpose 分割
Args:
images: [list] 图片存储路径
time: [str] 开始运行的时间相当于本次运行的ID用于存储运行结果
model: [str] 图像分割所使用的模型
diameter: [float] diameters are used to rescale the image to 30 pix cell diameter.
flow_threshold: [float] flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
cellprob_threshold: [float] all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
Returns:
"""
if time is None: if time is None:
return [False, "No time received"] return [False, "No time received"]
@ -85,10 +71,9 @@ class Cprun:
message = [f"Using {model} model"] message = [f"Using {model} model"]
# 设定模型参数
model = models.CellposeModel(gpu=True, model_type=model) model = models.CellposeModel(gpu=True, model_type=model)
files = images files = images
imgs = [imread(f) for f in files] # 获取目录中的每一个文件 imgs = [imread(f) for f in files]
masks, flows, styles = model.eval( masks, flows, styles = model.eval(
imgs, imgs,
flow_threshold=flow_threshold, flow_threshold=flow_threshold,
@ -105,7 +90,7 @@ class Cprun:
out = base + "_output" out = base + "_output"
save_masks(imgs, mask, flow, out, tif=True) save_masks(imgs, mask, flow, out, tif=True)
# 用 plot 生成彩色叠加图 # 用 plot 生成彩色叠加图(不依赖 skimage
rgb = plot.image_to_rgb(img, channels=[0, 0]) # 原图转 RGB rgb = plot.image_to_rgb(img, channels=[0, 0]) # 原图转 RGB
over = plot.mask_overlay(rgb, masks=mask, colors=None) # 叠加彩色实例 over = plot.mask_overlay(rgb, masks=mask, colors=None) # 叠加彩色实例
Image.fromarray(over).save(base + "_overlay.png") Image.fromarray(over).save(base + "_overlay.png")

View file

@ -19,19 +19,6 @@ os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
r = redis.Redis(host="127.0.0.1", port=6379, db=0) r = redis.Redis(host="127.0.0.1", port=6379, db=0)
def set_status(task_id, status, train_losses, test_losses, **extra): def set_status(task_id, status, train_losses, test_losses, **extra):
"""
修改redis数据库中某一任务的运行状态
Args:
task_id: 这一任务的时间戳
status: 任务状态
train_losses: 此次任务的训练loss
test_losses: 此次任务的测试loss
**extra:
Returns:
"""
payload = {"status": status, payload = {"status": status,
"updated_at": datetime.datetime.utcnow().isoformat(), "updated_at": datetime.datetime.utcnow().isoformat(),
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses, "train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
@ -68,32 +55,6 @@ class Cptrain:
scale_range=None, scale_range=None,
channel_axis: int = None, channel_axis: int = None,
): ):
"""
开始训练
Args:
time: 此次任务的时间戳即任务ID
model_name: 训练结果命名
image_filter:
mask_filter:
base_model:
train_probs:
test_probs:
batch_size:
learning_rate:
n_epochs:
weight_decay:
normalize:
compute_flows:
min_train_masks:
nimg_per_epoch:
rescale:
scale_range:
channel_axis:
Returns:
"""
train_dir = Path(TRAIN_DIR) / time train_dir = Path(TRAIN_DIR) / time
test_dir = Path(TEST_DIR) / time test_dir = Path(TEST_DIR) / time

View file

@ -93,10 +93,15 @@ def run_upload():
return default return default
flow_threshold = _to_float(request.args.get("flow_threshold") or request.form.get("flow_threshold"), 0.4) flow_threshold = _to_float(request.args.get("flow_threshold") or request.form.get("flow_threshold"), 0.4)
cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),0.0) cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),
0.0)
diameter_raw = request.args.get("diameter") or request.form.get("diameter") diameter_raw = request.args.get("diameter") or request.form.get("diameter")
diameter = _to_float(diameter_raw, None) if diameter_raw not in (None, "") else None diameter = _to_float(diameter_raw, None) if diameter_raw not in (None, "") else None
print("cpt:" + str(cellprob_threshold))
print("flow:" + str(flow_threshold))
print("diameter:" + str(diameter))
# 将文件保存在本地目录中 # 将文件保存在本地目录中
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}" ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
os.makedirs(Path(UPLOAD_DIR) / ts, exist_ok=True) os.makedirs(Path(UPLOAD_DIR) / ts, exist_ok=True)
@ -135,39 +140,19 @@ def run_upload():
@app.post("/train_upload") @app.post("/train_upload")
def train_upload(): def train_upload():
"""
从前端获取训练数据和测试数据并开始训练
Returns:
"""
def _to_float(x, default): def _to_float(x, default):
"""
将变量转为float类型
Args:
x: 变量
default: 默认值
"""
try: try:
return float(x) return float(x)
except (TypeError, ValueError): except (TypeError, ValueError):
return default return default
def _to_int(x, default): def _to_int(x, default):
"""
将变量转为int类型
Args:
x: 变量
default: 默认值
"""
try: try:
return int(x) return int(x)
except (TypeError, ValueError): except (TypeError, ValueError):
return default return default
# 获取从前端传来的参数
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}" ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
model_name = request.args.get("model_name") or f"custom_model-{ts}" model_name = request.args.get("model_name") or f"custom_model-{ts}"
image_filter = request.args.get("image_filter") or "_img" image_filter = request.args.get("image_filter") or "_img"
@ -197,7 +182,6 @@ def train_upload():
scale_range = _to_float(request.args.get("scale_range"), None) scale_range = _to_float(request.args.get("scale_range"), None)
channel_axis = _to_int(request.args.get("channel_axis"), None) channel_axis = _to_int(request.args.get("channel_axis"), None)
# 创建工作目录
train_files = request.files.getlist("train_files") train_files = request.files.getlist("train_files")
test_files = request.files.getlist("test_files") test_files = request.files.getlist("test_files")
os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True) os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
@ -219,30 +203,34 @@ def train_upload():
saved.append(os.path.join(TEST_DIR, ts, name)) saved.append(os.path.join(TEST_DIR, ts, name))
def job(): def job():
"""
子线程方法
"""
return asyncio.run(Cptrain.start_train( return asyncio.run(Cptrain.start_train(
time=ts, model_name=model_name, image_filter=image_filter, mask_filter=mask_filter, base_model=base_model, time=ts,
batch_size=batch_size, learning_rate=learning_rate, n_epochs=n_epochs, weight_decay=weight_decay, model_name=model_name,
normalize=normalize, compute_flows=compute_flows, min_train_masks=min_train_masks, nimg_per_epoch=nimg_per_epoch, image_filter=image_filter,
rescale=rescale, scale_range=scale_range, channel_axis=channel_axis, mask_filter=mask_filter,
base_model=base_model,
batch_size=batch_size,
learning_rate=learning_rate,
n_epochs=n_epochs,
weight_decay=weight_decay,
normalize=normalize,
compute_flows=compute_flows,
min_train_masks=min_train_masks,
nimg_per_epoch=nimg_per_epoch,
rescale=rescale,
scale_range=scale_range,
channel_axis=channel_axis,
)) ))
# 创建一个子线程,防止阻塞主线程
fut = executor.submit(job) fut = executor.submit(job)
def done_cb(f): def done_cb(f):
"""
获取训练结果并存入redis数据库
"""
try: try:
train_losses, test_losses = f.result() train_losses, test_losses = f.result()
set_train_status(ts, "success", train_losses, test_losses) set_train_status(ts, "success", train_losses, test_losses)
except Exception as e: except Exception as e:
set_status(ts, "failed", error=str(e)) set_status(ts, "failed", error=str(e))
# 添加回调在子线程执行完后更新redis中任务状态
fut.add_done_callback(done_cb) fut.add_done_callback(done_cb)
return jsonify({"ok": True, "count": len(saved), "id": ts}) return jsonify({"ok": True, "count": len(saved), "id": ts})
@ -263,13 +251,6 @@ def status():
@app.get("/preview") @app.get("/preview")
def preview(): def preview():
"""
获取本次分割结果的预览
Returns:
"""
task_id = request.args.get('id') task_id = request.args.get('id')
task_dir = Path(OUTPUT_DIR) / task_id task_dir = Path(OUTPUT_DIR) / task_id
if not task_dir.exists(): if not task_dir.exists():
@ -294,24 +275,11 @@ def preview():
@app.get("/models") @app.get("/models")
def list_models(): def list_models():
""" models_list = os.listdir(MODELS_DIR)
获取现有模型列表
Returns:
"""
models_list = os.listdir(MODELS_DIR) # 查询模型列表中有哪些文件
return jsonify({"ok": True, "models": models_list}) return jsonify({"ok": True, "models": models_list})
@app.get("/result") @app.get("/result")
def list_results(): def list_results():
"""
获取运行结果
Returns:
"""
task_id = request.args.get('id') task_id = request.args.get('id')
st = get_status(task_id) st = get_status(task_id)
if not st: if not st:

View file

@ -4,7 +4,7 @@ from multiprocessing import Process
if __name__ == "__main__": if __name__ == "__main__":
# 启动测试服务器 # Cprun.run_test()
p = Process(target=run_dev) p = Process(target=run_dev)
p.start() p.start()
print(f"Flask running in PID {p.pid}") print(f"Flask running in PID {p.pid}")

View file

@ -1,7 +1,7 @@
const config = { const config = {
server: { server: {
protocol: 'http', protocol: 'http',
host: '192.168.193.141', host: '10.10.25.240',
port: 5000 port: 5000
} }
}; };