2025-09-17 19:43:36 +00:00
|
|
|
|
import asyncio
|
2025-09-17 20:46:25 +00:00
|
|
|
|
import base64
|
2025-09-22 17:17:33 +00:00
|
|
|
|
import datetime
|
|
|
|
|
|
import json
|
|
|
|
|
|
import os
|
|
|
|
|
|
import redis
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
import time
|
|
|
|
|
|
from omegaconf import OmegaConf
|
2025-09-17 19:43:36 +00:00
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
2025-09-22 17:17:33 +00:00
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
2025-09-16 18:53:58 +00:00
|
|
|
|
from flask import Flask, send_from_directory, request, jsonify
|
2025-09-16 19:14:17 +00:00
|
|
|
|
from flask_cors import CORS
|
2025-09-22 17:17:33 +00:00
|
|
|
|
from werkzeug.utils import secure_filename
|
|
|
|
|
|
|
2025-10-16 17:31:03 +00:00
|
|
|
|
from cp_train import Cptrain
|
2025-09-17 19:43:36 +00:00
|
|
|
|
from cp_run import Cprun
|
2025-09-16 18:53:58 +00:00
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__)
|
2025-09-16 19:14:17 +00:00
|
|
|
|
CORS(app)
|
2025-09-22 17:17:33 +00:00
|
|
|
|
|
|
|
|
|
|
CONFIG_PATH = Path(__file__).parent / "config.yaml"
|
|
|
|
|
|
cfg = OmegaConf.load(CONFIG_PATH)
|
|
|
|
|
|
cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
|
|
|
|
|
|
BASE_DIR = cfg.data.root_dir
|
|
|
|
|
|
UPLOAD_DIR = cfg.data.upload_dir
|
|
|
|
|
|
OUTPUT_DIR = cfg.data.run.output_dir
|
2025-09-22 18:00:00 +00:00
|
|
|
|
MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve())
|
2025-09-23 16:27:02 +00:00
|
|
|
|
TRAIN_DIR = cfg.data.train.train_dir
|
|
|
|
|
|
TEST_DIR = cfg.data.train.test_dir
|
2025-09-26 16:01:42 +00:00
|
|
|
|
BACKEND_IP = cfg.backend.ip
|
|
|
|
|
|
BACKEND_PORT = cfg.backend.port
|
2025-09-22 17:17:33 +00:00
|
|
|
|
|
2025-09-16 18:53:58 +00:00
|
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
2025-09-17 19:43:36 +00:00
|
|
|
|
executor = ThreadPoolExecutor(max_workers=4)
|
|
|
|
|
|
TASKS = {}
|
|
|
|
|
|
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
2025-09-16 18:53:58 +00:00
|
|
|
|
|
2025-09-17 19:43:36 +00:00
|
|
|
|
# 启动测试服务器
|
|
|
|
|
|
def run_dev():
|
2025-09-26 16:01:42 +00:00
|
|
|
|
app.run(host=BACKEND_IP, port=int(BACKEND_PORT))
|
2025-09-16 18:53:58 +00:00
|
|
|
|
|
2025-09-17 19:43:36 +00:00
|
|
|
|
def set_status(task_id, status, **extra):
|
|
|
|
|
|
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
|
|
|
|
|
|
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
|
|
|
|
|
|
2025-09-24 19:03:56 +00:00
|
|
|
|
def set_train_status(task_id, status, train_losses, test_losses, **extra):
|
|
|
|
|
|
payload = {"status": status,
|
|
|
|
|
|
"updated_at": datetime.datetime.utcnow().isoformat(),
|
|
|
|
|
|
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
|
|
|
|
|
|
"test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
|
|
|
|
|
|
**extra}
|
|
|
|
|
|
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
|
|
|
|
|
|
2025-09-17 19:43:36 +00:00
|
|
|
|
def get_status(task_id):
|
|
|
|
|
|
raw = r.get(f"task:{task_id}")
|
|
|
|
|
|
return json.loads(raw) if raw else None
|
|
|
|
|
|
|
2025-09-16 18:53:58 +00:00
|
|
|
|
@app.route("/")
|
|
|
|
|
|
def index():
|
|
|
|
|
|
return "<h1>Hello</h1><p>This is the backend of our cellpose server, please visit our website.</p>"
|
|
|
|
|
|
|
2025-09-17 19:43:36 +00:00
|
|
|
|
@app.get("/testdl")
|
2025-09-16 18:53:58 +00:00
|
|
|
|
def test_download():
|
2025-09-16 19:14:17 +00:00
|
|
|
|
return send_from_directory("test_output/2025-09-16-20-03-51", "img_overlay.png", as_attachment=True)
|
2025-09-16 18:53:58 +00:00
|
|
|
|
|
2025-09-17 21:12:10 +00:00
|
|
|
|
@app.get("/dl")
|
|
|
|
|
|
def download():
|
|
|
|
|
|
timestamp = request.args.get("id")
|
2025-09-22 17:17:33 +00:00
|
|
|
|
input_dir = os.path.join(OUTPUT_DIR, timestamp)
|
|
|
|
|
|
output_dir = os.path.join(OUTPUT_DIR, "tmp", timestamp) # 不要加 .zip,make_archive 会自动加
|
|
|
|
|
|
os.makedirs(Path(OUTPUT_DIR) / "tmp", exist_ok=True) # 确保 tmp 存在
|
2025-09-16 18:53:58 +00:00
|
|
|
|
shutil.make_archive(output_dir, 'zip', input_dir)
|
|
|
|
|
|
print(f"压缩完成: {output_dir}.zip")
|
2025-09-22 17:17:33 +00:00
|
|
|
|
print(OUTPUT_DIR)
|
|
|
|
|
|
return send_from_directory(f"{OUTPUT_DIR}/tmp/", f"{timestamp}.zip", as_attachment=True)
|
2025-09-16 18:53:58 +00:00
|
|
|
|
|
2025-09-22 18:51:29 +00:00
|
|
|
|
@app.post("/run_upload")
|
|
|
|
|
|
def run_upload():
|
2025-09-17 19:43:36 +00:00
|
|
|
|
"""
|
|
|
|
|
|
接收上传的文件,并将其发送给cellpose。
|
|
|
|
|
|
:return:
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# 从请求中获取参数,若没有则设定为默认值
|
|
|
|
|
|
model = request.args.get("model") or request.form.get("model") or "cpsam"
|
|
|
|
|
|
|
|
|
|
|
|
def _to_float(x, default):
|
|
|
|
|
|
try:
|
|
|
|
|
|
return float(x)
|
|
|
|
|
|
except (TypeError, ValueError):
|
|
|
|
|
|
return default
|
|
|
|
|
|
|
|
|
|
|
|
flow_threshold = _to_float(request.args.get("flow_threshold") or request.form.get("flow_threshold"), 0.4)
|
|
|
|
|
|
cellprob_threshold = _to_float(request.args.get("cellprob_threshold") or request.form.get("cellprob_threshold"),
|
|
|
|
|
|
0.0)
|
|
|
|
|
|
diameter_raw = request.args.get("diameter") or request.form.get("diameter")
|
|
|
|
|
|
diameter = _to_float(diameter_raw, None) if diameter_raw not in (None, "") else None
|
|
|
|
|
|
|
2025-09-18 13:57:14 +00:00
|
|
|
|
print("cpt:" + str(cellprob_threshold))
|
|
|
|
|
|
print("flow:" + str(flow_threshold))
|
2025-09-18 17:07:22 +00:00
|
|
|
|
print("diameter:" + str(diameter))
|
2025-09-18 13:57:14 +00:00
|
|
|
|
|
2025-09-17 19:43:36 +00:00
|
|
|
|
# 将文件保存在本地目录中
|
|
|
|
|
|
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
|
2025-09-22 17:17:33 +00:00
|
|
|
|
os.makedirs(Path(UPLOAD_DIR) / ts, exist_ok=True)
|
2025-09-16 19:14:17 +00:00
|
|
|
|
files = request.files.getlist("files")
|
2025-09-16 18:53:58 +00:00
|
|
|
|
saved = []
|
|
|
|
|
|
for f in files:
|
|
|
|
|
|
if not f or f.filename == "":
|
|
|
|
|
|
continue
|
|
|
|
|
|
name = secure_filename(f.filename)
|
2025-09-22 17:17:33 +00:00
|
|
|
|
f.save(os.path.join(UPLOAD_DIR, ts, name))
|
2025-09-17 19:43:36 +00:00
|
|
|
|
saved.append(os.path.join(UPLOAD_DIR, ts, name))
|
|
|
|
|
|
|
|
|
|
|
|
# 新建一个线程,防止返回被阻塞
|
|
|
|
|
|
def job():
|
|
|
|
|
|
return asyncio.run(Cprun.run(
|
|
|
|
|
|
images=saved, model=model,
|
|
|
|
|
|
cellprob_threshold=cellprob_threshold,
|
|
|
|
|
|
flow_threshold=flow_threshold,
|
|
|
|
|
|
diameter=diameter, time=ts
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
|
|
# 将线程状态存入redis
|
|
|
|
|
|
set_status(ts, "running")
|
|
|
|
|
|
fut = executor.submit(job)
|
|
|
|
|
|
|
|
|
|
|
|
def done_cb(f):
|
|
|
|
|
|
try:
|
|
|
|
|
|
f.result()
|
|
|
|
|
|
set_status(ts, "success")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
set_status(ts, "failed", error=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
fut.add_done_callback(done_cb)
|
|
|
|
|
|
|
|
|
|
|
|
return jsonify({"ok": True, "count": len(saved), "id": ts})
|
|
|
|
|
|
|
2025-09-22 18:51:29 +00:00
|
|
|
|
@app.post("/train_upload")
|
|
|
|
|
|
def train_upload():
|
|
|
|
|
|
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
|
2025-09-23 16:27:02 +00:00
|
|
|
|
model_name = request.args.get("model_name") or f"custom_model-{ts}"
|
|
|
|
|
|
image_filter = request.args.get("image_filter") or "_img"
|
|
|
|
|
|
mask_filter = request.args.get("mask_filter") or "_masks"
|
|
|
|
|
|
base_model = request.args.get("base_model") or "cpsam"
|
2025-10-17 11:58:14 +00:00
|
|
|
|
batch_size = request.args.get("batch_size") or 8
|
|
|
|
|
|
learning_rate = request.args.get("learning_rate") or 5e-5
|
|
|
|
|
|
n_epochs = request.args.get("n_epochs") or 100
|
|
|
|
|
|
weight_decay = request.args.get("weight_decay") or 0.1
|
|
|
|
|
|
normalize = request.args.get("normalize") or True
|
|
|
|
|
|
compute_flows = request.args.get("compute_flows") or False
|
|
|
|
|
|
min_train_masks = request.args.get(" min_train_masks") or 5
|
|
|
|
|
|
nimg_per_epoch = request.args.get("nimg_per_epoch") or None
|
|
|
|
|
|
rescale = request.args.get("rescale") or False
|
|
|
|
|
|
scale_range = request.args.get("scale_range") or None
|
|
|
|
|
|
channel_axis = request.args.get("channel_axis") or None
|
2025-09-23 16:27:02 +00:00
|
|
|
|
|
2025-09-22 18:51:29 +00:00
|
|
|
|
train_files = request.files.getlist("train_files")
|
|
|
|
|
|
test_files = request.files.getlist("test_files")
|
2025-09-24 19:03:56 +00:00
|
|
|
|
os.makedirs(Path(TRAIN_DIR) / ts, exist_ok=True)
|
|
|
|
|
|
os.makedirs(Path(TEST_DIR) / ts, exist_ok=True)
|
2025-09-23 16:27:02 +00:00
|
|
|
|
set_status(ts, "pending")
|
2025-09-22 18:51:29 +00:00
|
|
|
|
saved = []
|
|
|
|
|
|
for f in train_files:
|
|
|
|
|
|
if not f or f.filename == "":
|
|
|
|
|
|
continue
|
|
|
|
|
|
name = secure_filename(f.filename)
|
2025-09-23 16:27:02 +00:00
|
|
|
|
f.save(os.path.join(TRAIN_DIR, ts, name))
|
|
|
|
|
|
saved.append(os.path.join(TRAIN_DIR, ts, name))
|
|
|
|
|
|
|
|
|
|
|
|
for f in test_files:
|
|
|
|
|
|
if not f or f.filename == "":
|
|
|
|
|
|
continue
|
|
|
|
|
|
name = secure_filename(f.filename)
|
|
|
|
|
|
f.save(os.path.join(TEST_DIR, ts, name))
|
|
|
|
|
|
saved.append(os.path.join(TEST_DIR, ts, name))
|
|
|
|
|
|
|
|
|
|
|
|
def job():
|
|
|
|
|
|
return asyncio.run(Cptrain.start_train(
|
|
|
|
|
|
time=ts,
|
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
|
image_filter=image_filter,
|
|
|
|
|
|
mask_filter=mask_filter,
|
2025-10-17 11:58:14 +00:00
|
|
|
|
base_model=base_model,
|
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
|
|
n_epochs=n_epochs,
|
|
|
|
|
|
weight_decay=weight_decay,
|
|
|
|
|
|
normalize=normalize,
|
|
|
|
|
|
compute_flows=compute_flows,
|
|
|
|
|
|
min_train_masks=min_train_masks,
|
|
|
|
|
|
nimg_per_epoch=nimg_per_epoch,
|
|
|
|
|
|
rescale=rescale,
|
|
|
|
|
|
scale_range=scale_range,
|
|
|
|
|
|
channel_axis=channel_axis,
|
2025-09-23 16:27:02 +00:00
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
|
|
fut = executor.submit(job)
|
2025-09-22 18:51:29 +00:00
|
|
|
|
|
2025-09-23 16:27:02 +00:00
|
|
|
|
def done_cb(f):
|
|
|
|
|
|
try:
|
2025-09-24 19:03:56 +00:00
|
|
|
|
train_losses, test_losses = f.result()
|
|
|
|
|
|
set_train_status(ts, "success", train_losses, test_losses)
|
2025-09-23 16:27:02 +00:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
set_status(ts, "failed", error=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
fut.add_done_callback(done_cb)
|
|
|
|
|
|
|
|
|
|
|
|
return jsonify({"ok": True, "count": len(saved), "id": ts})
|
2025-09-22 18:51:29 +00:00
|
|
|
|
|
2025-09-17 20:46:25 +00:00
|
|
|
|
@app.get("/status")
|
2025-09-17 19:43:36 +00:00
|
|
|
|
def status():
|
|
|
|
|
|
"""
|
|
|
|
|
|
检查某一cellpose任务是否完成
|
|
|
|
|
|
|
|
|
|
|
|
:return:
|
|
|
|
|
|
"""
|
|
|
|
|
|
task_id = request.args.get('id')
|
|
|
|
|
|
st = get_status(task_id)
|
|
|
|
|
|
if not st:
|
|
|
|
|
|
return jsonify({"ok": True, "exists": False, "status": "not_found"}), 200
|
|
|
|
|
|
return jsonify({"ok": True, "exists": True, **st}), 200
|
2025-09-17 20:46:25 +00:00
|
|
|
|
|
|
|
|
|
|
@app.get("/preview")
|
|
|
|
|
|
def preview():
|
|
|
|
|
|
task_id = request.args.get('id')
|
2025-09-22 17:17:33 +00:00
|
|
|
|
task_dir = Path(OUTPUT_DIR) / task_id
|
2025-09-17 20:46:25 +00:00
|
|
|
|
if not task_dir.exists():
|
|
|
|
|
|
return jsonify({"ok": False, "error": "task not found"}), 200
|
|
|
|
|
|
|
|
|
|
|
|
# 找出所有 *_overlay.png 文件
|
|
|
|
|
|
files = sorted(task_dir.glob("*_overlay.png"))
|
|
|
|
|
|
|
|
|
|
|
|
if not files:
|
|
|
|
|
|
return jsonify({"ok": False, "error": "no overlay images"}), 200
|
|
|
|
|
|
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for path in files:
|
|
|
|
|
|
data = path.read_bytes()
|
|
|
|
|
|
encoded = base64.b64encode(data).decode("utf-8")
|
|
|
|
|
|
result.append({
|
|
|
|
|
|
"filename": path.name,
|
|
|
|
|
|
"image": encoded
|
|
|
|
|
|
})
|
|
|
|
|
|
|
2025-09-22 18:00:00 +00:00
|
|
|
|
return jsonify({"ok": True, "count": len(result), "images": result})
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/models")
|
|
|
|
|
|
def list_models():
|
|
|
|
|
|
models_list = os.listdir(MODELS_DIR)
|
2025-09-24 19:03:56 +00:00
|
|
|
|
return jsonify({"ok": True, "models": models_list})
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/result")
|
|
|
|
|
|
def list_results():
|
|
|
|
|
|
task_id = request.args.get('id')
|
|
|
|
|
|
st = get_status(task_id)
|
|
|
|
|
|
if not st:
|
|
|
|
|
|
return jsonify({"ok": True, "exists": False, "status": "not_found"}), 200
|
2025-10-16 17:31:03 +00:00
|
|
|
|
return jsonify({"ok": True, "exists": True, **st}), 200
|