2025-09-22 17:17:33 +00:00
|
|
|
import os.path
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from omegaconf import OmegaConf
|
2025-09-23 16:27:02 +00:00
|
|
|
import redis
|
|
|
|
|
import datetime
|
|
|
|
|
import json
|
2025-09-22 17:17:33 +00:00
|
|
|
|
|
|
|
|
CONFIG_PATH = Path(__file__).parent / "config.yaml"
|
|
|
|
|
cfg = OmegaConf.load(CONFIG_PATH)
|
|
|
|
|
cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
|
|
|
|
|
BASE_DIR = cfg.data.root_dir
|
2025-09-23 16:27:02 +00:00
|
|
|
TRAIN_DIR = cfg.data.train.train_dir
|
|
|
|
|
TEST_DIR = cfg.data.train.test_dir
|
2025-09-22 18:28:19 +00:00
|
|
|
MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve())
|
2025-09-23 16:27:02 +00:00
|
|
|
os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
|
|
|
|
|
|
|
|
|
|
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
|
|
|
|
|
2025-09-24 19:03:56 +00:00
|
|
|
def set_status(task_id, status, train_losses, test_losses, **extra):
|
|
|
|
|
payload = {"status": status,
|
|
|
|
|
"updated_at": datetime.datetime.utcnow().isoformat(),
|
|
|
|
|
"train_losses": train_losses.tolist() if hasattr(train_losses, "tolist") else train_losses,
|
|
|
|
|
"test_losses": test_losses.tolist() if hasattr(test_losses, "tolist") else test_losses,
|
|
|
|
|
**extra}
|
2025-09-23 16:27:02 +00:00
|
|
|
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
|
|
|
|
|
|
|
|
|
def get_status(task_id):
|
|
|
|
|
raw = r.get(f"task:{task_id}")
|
|
|
|
|
return json.loads(raw) if raw else None
|
|
|
|
|
|
|
|
|
|
from cellpose import io, models, train
|
2025-09-22 17:17:33 +00:00
|
|
|
|
|
|
|
|
class Cptrain:
|
|
|
|
|
|
|
|
|
|
@classmethod
|
2025-09-23 16:27:02 +00:00
|
|
|
async def start_train(cls,
|
2025-09-22 18:28:19 +00:00
|
|
|
time: str | None = None,
|
2025-09-23 16:27:02 +00:00
|
|
|
model_name: str | None = None,
|
|
|
|
|
image_filter: str = "_img",
|
|
|
|
|
mask_filter: str = "_masks",
|
|
|
|
|
base_model: str = "cpsam"):
|
2025-09-22 17:17:33 +00:00
|
|
|
|
2025-09-23 16:27:02 +00:00
|
|
|
train_dir = Path(TRAIN_DIR) / time
|
|
|
|
|
test_dir = Path(TEST_DIR) / time
|
2025-09-22 17:17:33 +00:00
|
|
|
os.makedirs(train_dir, exist_ok=True)
|
|
|
|
|
os.makedirs(test_dir, exist_ok=True)
|
|
|
|
|
io.logger_setup()
|
2025-09-23 16:27:02 +00:00
|
|
|
output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter=image_filter,
|
|
|
|
|
mask_filter=mask_filter, look_one_level_down=False)
|
2025-09-22 17:17:33 +00:00
|
|
|
images, labels, image_names, test_images, test_labels, image_names_test = output
|
|
|
|
|
|
2025-09-23 16:27:02 +00:00
|
|
|
model = models.CellposeModel(gpu=True, pretrained_model=base_model)
|
|
|
|
|
|
2025-09-24 19:03:56 +00:00
|
|
|
set_status(time, "running", None, None)
|
2025-09-22 17:17:33 +00:00
|
|
|
|
|
|
|
|
model_path, train_losses, test_losses = train.train_seg(model.net,
|
|
|
|
|
train_data=images, train_labels=labels,
|
|
|
|
|
test_data=test_images, test_labels=test_labels,
|
|
|
|
|
weight_decay=0.1, learning_rate=1e-5,
|
2025-09-22 18:28:19 +00:00
|
|
|
n_epochs=100, model_name=model_name,
|
2025-09-24 19:03:56 +00:00
|
|
|
save_path=BASE_DIR)
|
2025-09-22 18:28:19 +00:00
|
|
|
|
2025-09-24 19:03:56 +00:00
|
|
|
set_status(time, "done", train_losses, test_losses)
|
2025-09-23 16:27:02 +00:00
|
|
|
print("模型已保存到:", model_path)
|
2025-09-24 19:03:56 +00:00
|
|
|
return train_losses, test_losses
|