mirror of
https://github.com/ClovertaTheTrilobita/cellpose-web.git
synced 2026-04-01 23:14:50 +00:00
feature(train): 新增训练前后端(未测试)
This commit is contained in:
parent
a6353d1c37
commit
d51439dee0
4 changed files with 628 additions and 186 deletions
|
|
@ -1,33 +1,53 @@
|
|||
import os.path
|
||||
from cellpose import io, models, train
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
import redis
|
||||
import datetime
|
||||
import json
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent / "config.yaml"
|
||||
cfg = OmegaConf.load(CONFIG_PATH)
|
||||
cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
|
||||
BASE_DIR = cfg.data.root_dir
|
||||
TEST_TRAIN_DIR = cfg.data.train.test_train_dir
|
||||
TEST_TEST_DIR = cfg.data.train.test_test_dir
|
||||
TRAIN_DIR = cfg.data.train.train_dir
|
||||
TEST_DIR = cfg.data.train.test_dir
|
||||
MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve())
|
||||
os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
|
||||
|
||||
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
||||
|
||||
def set_status(task_id, status, **extra):
|
||||
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
|
||||
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
||||
|
||||
def get_status(task_id):
|
||||
raw = r.get(f"task:{task_id}")
|
||||
return json.loads(raw) if raw else None
|
||||
|
||||
from cellpose import io, models, train
|
||||
|
||||
class Cptrain:
|
||||
|
||||
@classmethod
|
||||
def start_train(cls,
|
||||
async def start_train(cls,
|
||||
time: str | None = None,
|
||||
model_name: str | None = None,):
|
||||
model_name: str | None = None,
|
||||
image_filter: str = "_img",
|
||||
mask_filter: str = "_masks",
|
||||
base_model: str = "cpsam"):
|
||||
|
||||
train_dir = Path(TEST_TRAIN_DIR) / time
|
||||
test_dir = Path(TEST_TEST_DIR) / time
|
||||
train_dir = Path(TRAIN_DIR) / time
|
||||
test_dir = Path(TEST_DIR) / time
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
io.logger_setup()
|
||||
output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter="_img",
|
||||
mask_filter="_masks", look_one_level_down=False)
|
||||
output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter=image_filter,
|
||||
mask_filter=mask_filter, look_one_level_down=False)
|
||||
images, labels, image_names, test_images, test_labels, image_names_test = output
|
||||
|
||||
model = models.CellposeModel(gpu=True)
|
||||
model = models.CellposeModel(gpu=True, pretrained_model=base_model)
|
||||
|
||||
set_status(time, "running")
|
||||
|
||||
model_path, train_losses, test_losses = train.train_seg(model.net,
|
||||
train_data=images, train_labels=labels,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from flask import Flask, send_from_directory, request, jsonify
|
|||
from flask_cors import CORS
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from backend.cp_train import Cptrain
|
||||
from cp_run import Cprun
|
||||
|
||||
app = Flask(__name__)
|
||||
|
|
@ -26,6 +27,8 @@ BASE_DIR = cfg.data.root_dir
|
|||
UPLOAD_DIR = cfg.data.upload_dir
|
||||
OUTPUT_DIR = cfg.data.run.output_dir
|
||||
MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve())
|
||||
TRAIN_DIR = cfg.data.train.train_dir
|
||||
TEST_DIR = cfg.data.train.test_dir
|
||||
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
|
@ -128,16 +131,50 @@ def run_upload():
|
|||
@app.post("/train_upload")
|
||||
def train_upload():
|
||||
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
|
||||
model_name = request.args.get("model_name") or f"custom_model-{ts}"
|
||||
image_filter = request.args.get("image_filter") or "_img"
|
||||
mask_filter = request.args.get("mask_filter") or "_masks"
|
||||
base_model = request.args.get("base_model") or "cpsam"
|
||||
|
||||
train_files = request.files.getlist("train_files")
|
||||
test_files = request.files.getlist("test_files")
|
||||
set_status(ts, "pending")
|
||||
saved = []
|
||||
for f in train_files:
|
||||
if not f or f.filename == "":
|
||||
continue
|
||||
name = secure_filename(f.filename)
|
||||
f.save(os.path.join(UPLOAD_DIR, ts, name))
|
||||
saved.append(os.path.join(UPLOAD_DIR, ts, name))
|
||||
f.save(os.path.join(TRAIN_DIR, ts, name))
|
||||
saved.append(os.path.join(TRAIN_DIR, ts, name))
|
||||
|
||||
for f in test_files:
|
||||
if not f or f.filename == "":
|
||||
continue
|
||||
name = secure_filename(f.filename)
|
||||
f.save(os.path.join(TEST_DIR, ts, name))
|
||||
saved.append(os.path.join(TEST_DIR, ts, name))
|
||||
|
||||
def job():
|
||||
return asyncio.run(Cptrain.start_train(
|
||||
time=ts,
|
||||
model_name=model_name,
|
||||
image_filter=image_filter,
|
||||
mask_filter=mask_filter,
|
||||
base_model=base_model
|
||||
))
|
||||
|
||||
fut = executor.submit(job)
|
||||
|
||||
def done_cb(f):
|
||||
try:
|
||||
f.result()
|
||||
set_status(ts, "success")
|
||||
except Exception as e:
|
||||
set_status(ts, "failed", error=str(e))
|
||||
|
||||
fut.add_done_callback(done_cb)
|
||||
|
||||
return jsonify({"ok": True, "count": len(saved), "id": ts})
|
||||
|
||||
@app.get("/status")
|
||||
def status():
|
||||
|
|
|
|||
555
backend/train.py
Normal file
555
backend/train.py
Normal file
|
|
@ -0,0 +1,555 @@
|
|||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
from cellpose import io, utils, models, dynamics
|
||||
from cellpose.transforms import normalize_img, random_rotate_and_resize, convert_image
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import trange
|
||||
import redis
|
||||
import json
|
||||
import datetime
|
||||
|
||||
import logging
|
||||
|
||||
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
||||
|
||||
def set_status(task_id, status, **extra):
|
||||
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
|
||||
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
|
||||
|
||||
def get_status(task_id):
|
||||
raw = r.get(f"task:{task_id}")
|
||||
return json.loads(raw) if raw else None
|
||||
|
||||
train_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _loss_fn_class(lbl, y, class_weights=None):
|
||||
"""
|
||||
Calculates the loss function between true labels lbl and prediction y.
|
||||
|
||||
Args:
|
||||
lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
|
||||
y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss value.
|
||||
|
||||
"""
|
||||
|
||||
criterion3 = nn.CrossEntropyLoss(reduction="mean", weight=class_weights)
|
||||
loss3 = criterion3(y[:, :-3], lbl[:, 0].long())
|
||||
|
||||
return loss3
|
||||
|
||||
|
||||
def _loss_fn_seg(lbl, y, device):
|
||||
"""
|
||||
Calculates the loss function between true labels lbl and prediction y.
|
||||
|
||||
Args:
|
||||
lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
|
||||
y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
|
||||
device (torch.device): Device on which the tensors are located.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss value.
|
||||
|
||||
"""
|
||||
criterion = nn.MSELoss(reduction="mean")
|
||||
criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
|
||||
veci = 5. * lbl[:, -2:]
|
||||
loss = criterion(y[:, -3:-1], veci)
|
||||
loss /= 2.
|
||||
loss2 = criterion2(y[:, -1], (lbl[:, -3] > 0.5).to(y.dtype))
|
||||
loss = loss + loss2
|
||||
return loss
|
||||
|
||||
|
||||
def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False}):
|
||||
"""
|
||||
Reshapes and normalizes the input data.
|
||||
|
||||
Args:
|
||||
data (list): List of input data, with channels axis first or last.
|
||||
normalize_params (dict, optional): Dictionary of normalization parameters. Defaults to {"normalize": False}.
|
||||
|
||||
Returns:
|
||||
list: List of reshaped and normalized data.
|
||||
"""
|
||||
if (np.array([td.ndim != 3 for td in data]).sum() > 0 or
|
||||
np.array([td.shape[0] != 3 for td in data]).sum() > 0):
|
||||
data_new = []
|
||||
for td in data:
|
||||
if td.ndim == 3:
|
||||
channel_axis0 = channel_axis if channel_axis is not None else np.array(td.shape).argmin()
|
||||
# put channel axis first
|
||||
td = np.moveaxis(td, channel_axis0, 0)
|
||||
td = td[:3] # keep at most 3 channels
|
||||
if td.ndim == 2 or (td.ndim == 3 and td.shape[0] == 1):
|
||||
td = np.stack((td, 0 * td, 0 * td), axis=0)
|
||||
elif td.ndim == 3 and td.shape[0] < 3:
|
||||
td = np.concatenate((td, 0 * td[:1]), axis=0)
|
||||
data_new.append(td)
|
||||
data = data_new
|
||||
if normalize_params["normalize"]:
|
||||
data = [
|
||||
normalize_img(td, normalize=normalize_params, axis=0)
|
||||
for td in data
|
||||
]
|
||||
return data
|
||||
|
||||
|
||||
def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
|
||||
normalize_params={"normalize": False}):
|
||||
"""
|
||||
Get a batch of images and labels.
|
||||
|
||||
Args:
|
||||
inds (list): List of indices indicating which images and labels to retrieve.
|
||||
data (list or None): List of image data. If None, images will be loaded from files.
|
||||
labels (list or None): List of label data. If None, labels will be loaded from files.
|
||||
files (list or None): List of file paths for images.
|
||||
labels_files (list or None): List of file paths for labels.
|
||||
normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize).
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing two lists: the batch of images and the batch of labels.
|
||||
"""
|
||||
if data is None:
|
||||
lbls = None
|
||||
imgs = [io.imread(files[i]) for i in inds]
|
||||
imgs = _reshape_norm(imgs, normalize_params=normalize_params)
|
||||
if labels_files is not None:
|
||||
lbls = [io.imread(labels_files[i])[1:] for i in inds]
|
||||
else:
|
||||
imgs = [data[i] for i in inds]
|
||||
lbls = [labels[i][1:] for i in inds]
|
||||
return imgs, lbls
|
||||
|
||||
|
||||
def _reshape_norm_save(files, channels=None, channel_axis=None,
|
||||
normalize_params={"normalize": False}):
|
||||
""" not currently used -- normalization happening on each batch if not load_files """
|
||||
files_new = []
|
||||
for f in trange(files):
|
||||
td = io.imread(f)
|
||||
if channels is not None:
|
||||
td = convert_image(td, channels=channels,
|
||||
channel_axis=channel_axis)
|
||||
td = td.transpose(2, 0, 1)
|
||||
if normalize_params["normalize"]:
|
||||
td = normalize_img(td, normalize=normalize_params, axis=0)
|
||||
fnew = os.path.splitext(str(f))[0] + "_cpnorm.tif"
|
||||
io.imsave(fnew, td)
|
||||
files_new.append(fnew)
|
||||
return files_new
|
||||
# else:
|
||||
# train_files = reshape_norm_save(train_files, channels=channels,
|
||||
# channel_axis=channel_axis, normalize_params=normalize_params)
|
||||
# elif test_files is not None:
|
||||
# test_files = reshape_norm_save(test_files, channels=channels,
|
||||
# channel_axis=channel_axis, normalize_params=normalize_params)
|
||||
|
||||
|
||||
def _process_train_test(train_data=None, train_labels=None, train_files=None,
|
||||
train_labels_files=None, train_probs=None, test_data=None,
|
||||
test_labels=None, test_files=None, test_labels_files=None,
|
||||
test_probs=None, load_files=True, min_train_masks=5,
|
||||
compute_flows=False, normalize_params={"normalize": False},
|
||||
channel_axis=None, device=None):
|
||||
"""
|
||||
Process train and test data.
|
||||
|
||||
Args:
|
||||
train_data (list or None): List of training data arrays.
|
||||
train_labels (list or None): List of training label arrays.
|
||||
train_files (list or None): List of training file paths.
|
||||
train_labels_files (list or None): List of training label file paths.
|
||||
train_probs (ndarray or None): Array of training probabilities.
|
||||
test_data (list or None): List of test data arrays.
|
||||
test_labels (list or None): List of test label arrays.
|
||||
test_files (list or None): List of test file paths.
|
||||
test_labels_files (list or None): List of test label file paths.
|
||||
test_probs (ndarray or None): Array of test probabilities.
|
||||
load_files (bool): Whether to load data from files.
|
||||
min_train_masks (int): Minimum number of masks required for training images.
|
||||
compute_flows (bool): Whether to compute flows.
|
||||
channels (list or None): List of channel indices to use.
|
||||
channel_axis (int or None): Axis of channel dimension.
|
||||
rgb (bool): Convert training/testing images to RGB.
|
||||
normalize_params (dict): Dictionary of normalization parameters.
|
||||
device (torch.device): Device to use for computation.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the processed train and test data and sampling probabilities and diameters.
|
||||
"""
|
||||
if device == None:
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device(
|
||||
'mps') if torch.backends.mps.is_available() else None
|
||||
|
||||
if train_data is not None and train_labels is not None:
|
||||
# if data is loaded
|
||||
nimg = len(train_data)
|
||||
nimg_test = len(test_data) if test_data is not None else None
|
||||
else:
|
||||
# otherwise use files
|
||||
nimg = len(train_files)
|
||||
if train_labels_files is None:
|
||||
train_labels_files = [
|
||||
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files
|
||||
]
|
||||
train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)]
|
||||
if (test_data is not None or
|
||||
test_files is not None) and test_labels_files is None:
|
||||
test_labels_files = [
|
||||
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files
|
||||
]
|
||||
test_labels_files = [tf for tf in test_labels_files if os.path.exists(tf)]
|
||||
if not load_files:
|
||||
train_logger.info(">>> using files instead of loading dataset")
|
||||
else:
|
||||
# load all images
|
||||
train_logger.info(">>> loading images and labels")
|
||||
train_data = [io.imread(train_files[i]) for i in trange(nimg)]
|
||||
train_labels = [io.imread(train_labels_files[i]) for i in trange(nimg)]
|
||||
nimg_test = len(test_files) if test_files is not None else None
|
||||
if load_files and nimg_test:
|
||||
test_data = [io.imread(test_files[i]) for i in trange(nimg_test)]
|
||||
test_labels = [io.imread(test_labels_files[i]) for i in trange(nimg_test)]
|
||||
|
||||
### check that arrays are correct size
|
||||
if ((train_labels is not None and nimg != len(train_labels)) or
|
||||
(train_labels_files is not None and nimg != len(train_labels_files))):
|
||||
error_message = "train data and labels not same length"
|
||||
train_logger.critical(error_message)
|
||||
raise ValueError(error_message)
|
||||
if ((test_labels is not None and nimg_test != len(test_labels)) or
|
||||
(test_labels_files is not None and nimg_test != len(test_labels_files))):
|
||||
train_logger.warning("test data and labels not same length, not using")
|
||||
test_data, test_files = None, None
|
||||
if train_labels is not None:
|
||||
if train_labels[0].ndim < 2 or train_data[0].ndim < 2:
|
||||
error_message = "training data or labels are not at least two-dimensional"
|
||||
train_logger.critical(error_message)
|
||||
raise ValueError(error_message)
|
||||
if train_data[0].ndim > 3:
|
||||
error_message = "training data is more than three-dimensional (should be 2D or 3D array)"
|
||||
train_logger.critical(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
### check that flows are computed
|
||||
if train_labels is not None:
|
||||
train_labels = dynamics.labels_to_flows(train_labels, files=train_files,
|
||||
device=device)
|
||||
if test_labels is not None:
|
||||
test_labels = dynamics.labels_to_flows(test_labels, files=test_files,
|
||||
device=device)
|
||||
elif compute_flows:
|
||||
for k in trange(nimg):
|
||||
tl = dynamics.labels_to_flows(io.imread(train_labels_files),
|
||||
files=train_files, device=device)
|
||||
if test_files is not None:
|
||||
for k in trange(nimg_test):
|
||||
tl = dynamics.labels_to_flows(io.imread(test_labels_files),
|
||||
files=test_files, device=device)
|
||||
|
||||
### compute diameters
|
||||
nmasks = np.zeros(nimg)
|
||||
diam_train = np.zeros(nimg)
|
||||
train_logger.info(">>> computing diameters")
|
||||
for k in trange(nimg):
|
||||
tl = (train_labels[k][0]
|
||||
if train_labels is not None else io.imread(train_labels_files[k])[0])
|
||||
diam_train[k], dall = utils.diameters(tl)
|
||||
nmasks[k] = len(dall)
|
||||
diam_train[diam_train < 5] = 5.
|
||||
if test_data is not None:
|
||||
diam_test = np.array(
|
||||
[utils.diameters(test_labels[k][0])[0] for k in trange(len(test_labels))])
|
||||
diam_test[diam_test < 5] = 5.
|
||||
elif test_labels_files is not None:
|
||||
diam_test = np.array([
|
||||
utils.diameters(io.imread(test_labels_files[k])[0])[0]
|
||||
for k in trange(len(test_labels_files))
|
||||
])
|
||||
diam_test[diam_test < 5] = 5.
|
||||
else:
|
||||
diam_test = None
|
||||
|
||||
### check to remove training images with too few masks
|
||||
if min_train_masks > 0:
|
||||
nremove = (nmasks < min_train_masks).sum()
|
||||
if nremove > 0:
|
||||
train_logger.warning(
|
||||
f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set"
|
||||
)
|
||||
ikeep = np.nonzero(nmasks >= min_train_masks)[0]
|
||||
if train_data is not None:
|
||||
train_data = [train_data[i] for i in ikeep]
|
||||
train_labels = [train_labels[i] for i in ikeep]
|
||||
if train_files is not None:
|
||||
train_files = [train_files[i] for i in ikeep]
|
||||
if train_labels_files is not None:
|
||||
train_labels_files = [train_labels_files[i] for i in ikeep]
|
||||
if train_probs is not None:
|
||||
train_probs = train_probs[ikeep]
|
||||
diam_train = diam_train[ikeep]
|
||||
nimg = len(train_data)
|
||||
|
||||
### normalize probabilities
|
||||
train_probs = 1. / nimg * np.ones(nimg,
|
||||
"float64") if train_probs is None else train_probs
|
||||
train_probs /= train_probs.sum()
|
||||
if test_files is not None or test_data is not None:
|
||||
test_probs = 1. / nimg_test * np.ones(
|
||||
nimg_test, "float64") if test_probs is None else test_probs
|
||||
test_probs /= test_probs.sum()
|
||||
|
||||
### reshape and normalize train / test data
|
||||
normed = False
|
||||
if normalize_params["normalize"]:
|
||||
train_logger.info(f">>> normalizing {normalize_params}")
|
||||
if train_data is not None:
|
||||
train_data = _reshape_norm(train_data, channel_axis=channel_axis,
|
||||
normalize_params=normalize_params)
|
||||
normed = True
|
||||
if test_data is not None:
|
||||
test_data = _reshape_norm(test_data, channel_axis=channel_axis,
|
||||
normalize_params=normalize_params)
|
||||
|
||||
return (train_data, train_labels, train_files, train_labels_files, train_probs,
|
||||
diam_train, test_data, test_labels, test_files, test_labels_files,
|
||||
test_probs, diam_test, normed)
|
||||
|
||||
|
||||
def train_seg(net, train_data=None, train_labels=None, train_files=None,
|
||||
train_labels_files=None, train_probs=None, test_data=None,
|
||||
test_labels=None, test_files=None, test_labels_files=None,
|
||||
test_probs=None, channel_axis=None,
|
||||
load_files=True, batch_size=1, learning_rate=5e-5, SGD=False,
|
||||
n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False,
|
||||
save_path=None, save_every=100, save_each=False, nimg_per_epoch=None,
|
||||
nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256,
|
||||
min_train_masks=5, model_name=None, class_weights=None, ts=None):
|
||||
"""
|
||||
Train the network with images for segmentation.
|
||||
|
||||
Args:
|
||||
net (object): The network model to train.
|
||||
train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None.
|
||||
train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
|
||||
train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None.
|
||||
train_labels_files (list or None): List of training label file paths. Defaults to None.
|
||||
train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None.
|
||||
test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None.
|
||||
test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
|
||||
test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None.
|
||||
test_labels_files (list or None): List of test label file paths. Defaults to None.
|
||||
test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None.
|
||||
load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True.
|
||||
batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 8.
|
||||
learning_rate (float or List[float], optional): Float or list/np.ndarray - learning rate for training. Defaults to 0.005.
|
||||
n_epochs (int, optional): Integer - number of times to go through the whole training set during training. Defaults to 2000.
|
||||
weight_decay (float, optional): Float - weight decay for the optimizer. Defaults to 1e-5.
|
||||
momentum (float, optional): Float - momentum for the optimizer. Defaults to 0.9.
|
||||
SGD (bool, optional): Deprecated in v4.0.1+ - AdamW always used.
|
||||
normalize (bool or dict, optional): Boolean or dictionary - whether to normalize the data. Defaults to True.
|
||||
compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False.
|
||||
save_path (str, optional): String - where to save the trained model. Defaults to None.
|
||||
save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100.
|
||||
save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False.
|
||||
nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None.
|
||||
nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None.
|
||||
rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True.
|
||||
min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5.
|
||||
model_name (str, optional): String - name of the network. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the path to the saved model weights, training losses, and test losses.
|
||||
|
||||
"""
|
||||
if SGD:
|
||||
train_logger.warning("SGD is deprecated, using AdamW instead")
|
||||
|
||||
device = net.device
|
||||
|
||||
scale_range = 0.5 if scale_range is None else scale_range
|
||||
|
||||
if isinstance(normalize, dict):
|
||||
normalize_params = {**models.normalize_default, **normalize}
|
||||
elif not isinstance(normalize, bool):
|
||||
raise ValueError("normalize parameter must be a bool or a dict")
|
||||
else:
|
||||
normalize_params = models.normalize_default
|
||||
normalize_params["normalize"] = normalize
|
||||
|
||||
out = _process_train_test(train_data=train_data, train_labels=train_labels,
|
||||
train_files=train_files, train_labels_files=train_labels_files,
|
||||
train_probs=train_probs,
|
||||
test_data=test_data, test_labels=test_labels,
|
||||
test_files=test_files, test_labels_files=test_labels_files,
|
||||
test_probs=test_probs,
|
||||
load_files=load_files, min_train_masks=min_train_masks,
|
||||
compute_flows=compute_flows, channel_axis=channel_axis,
|
||||
normalize_params=normalize_params, device=net.device)
|
||||
(train_data, train_labels, train_files, train_labels_files, train_probs, diam_train,
|
||||
test_data, test_labels, test_files, test_labels_files, test_probs, diam_test,
|
||||
normed) = out
|
||||
# already normalized, do not normalize during training
|
||||
if normed:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"normalize_params": normalize_params, "channel_axis": channel_axis}
|
||||
|
||||
net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)
|
||||
|
||||
if class_weights is not None and isinstance(class_weights, (list, np.ndarray, tuple)):
|
||||
class_weights = torch.from_numpy(class_weights).to(device).float()
|
||||
print(class_weights)
|
||||
|
||||
nimg = len(train_data) if train_data is not None else len(train_files)
|
||||
nimg_test = len(test_data) if test_data is not None else None
|
||||
nimg_test = len(test_files) if test_files is not None else nimg_test
|
||||
nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
|
||||
nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
|
||||
|
||||
# learning rate schedule
|
||||
LR = np.linspace(0, learning_rate, 10)
|
||||
LR = np.append(LR, learning_rate * np.ones(max(0, n_epochs - 10)))
|
||||
if n_epochs > 300:
|
||||
LR = LR[:-100]
|
||||
for i in range(10):
|
||||
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
|
||||
elif n_epochs > 99:
|
||||
LR = LR[:-50]
|
||||
for i in range(10):
|
||||
LR = np.append(LR, LR[-1] / 2 * np.ones(5))
|
||||
|
||||
train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}")
|
||||
train_logger.info(
|
||||
f">>> AdamW, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}"
|
||||
)
|
||||
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate,
|
||||
weight_decay=weight_decay)
|
||||
|
||||
t0 = time.time()
|
||||
model_name = f"cellpose_{t0}" if model_name is None else model_name
|
||||
save_path = Path.cwd() if save_path is None else Path(save_path)
|
||||
filename = save_path / "models" / model_name
|
||||
(save_path / "models").mkdir(exist_ok=True)
|
||||
|
||||
train_logger.info(f">>> saving model to {filename}")
|
||||
|
||||
lavg, nsum = 0, 0
|
||||
train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs)
|
||||
for iepoch in range(n_epochs):
|
||||
np.random.seed(iepoch)
|
||||
if nimg != nimg_per_epoch:
|
||||
# choose random images for epoch with probability train_probs
|
||||
rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
|
||||
p=train_probs)
|
||||
else:
|
||||
# otherwise use all images
|
||||
rperm = np.random.permutation(np.arange(0, nimg))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = LR[iepoch] # set learning rate
|
||||
net.train()
|
||||
for k in range(0, nimg_per_epoch, batch_size):
|
||||
kend = min(k + batch_size, nimg_per_epoch)
|
||||
inds = rperm[k:kend]
|
||||
imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels,
|
||||
files=train_files, labels_files=train_labels_files,
|
||||
**kwargs)
|
||||
diams = np.array([diam_train[i] for i in inds])
|
||||
rsc = diams / net.diam_mean.item() if rescale else np.ones(
|
||||
len(diams), "float32")
|
||||
# augmentations
|
||||
imgi, lbl = random_rotate_and_resize(imgs, Y=lbls, rescale=rsc,
|
||||
scale_range=scale_range,
|
||||
xy=(bsize, bsize))[:2]
|
||||
# network and loss optimization
|
||||
X = torch.from_numpy(imgi).to(device)
|
||||
lbl = torch.from_numpy(lbl).to(device)
|
||||
|
||||
if X.dtype != net.dtype:
|
||||
X = X.to(net.dtype)
|
||||
lbl = lbl.to(net.dtype)
|
||||
|
||||
y = net(X)[0]
|
||||
loss = _loss_fn_seg(lbl, y, device)
|
||||
if y.shape[1] > 3:
|
||||
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
|
||||
loss += loss3
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss = loss.item()
|
||||
train_loss *= len(imgi)
|
||||
|
||||
# keep track of average training loss across epochs
|
||||
lavg += train_loss
|
||||
nsum += len(imgi)
|
||||
# per epoch training loss
|
||||
train_losses[iepoch] += train_loss
|
||||
train_losses[iepoch] /= nimg_per_epoch
|
||||
|
||||
if iepoch == 5 or iepoch % 10 == 0:
|
||||
lavgt = 0.
|
||||
if test_data is not None or test_files is not None:
|
||||
np.random.seed(42)
|
||||
if nimg_test != nimg_test_per_epoch:
|
||||
rperm = np.random.choice(np.arange(0, nimg_test),
|
||||
size=(nimg_test_per_epoch,), p=test_probs)
|
||||
else:
|
||||
rperm = np.random.permutation(np.arange(0, nimg_test))
|
||||
for ibatch in range(0, len(rperm), batch_size):
|
||||
with torch.no_grad():
|
||||
net.eval()
|
||||
inds = rperm[ibatch:ibatch + batch_size]
|
||||
imgs, lbls = _get_batch(inds, data=test_data,
|
||||
labels=test_labels, files=test_files,
|
||||
labels_files=test_labels_files,
|
||||
**kwargs)
|
||||
diams = np.array([diam_test[i] for i in inds])
|
||||
rsc = diams / net.diam_mean.item() if rescale else np.ones(
|
||||
len(diams), "float32")
|
||||
imgi, lbl = random_rotate_and_resize(
|
||||
imgs, Y=lbls, rescale=rsc, scale_range=scale_range,
|
||||
xy=(bsize, bsize))[:2]
|
||||
X = torch.from_numpy(imgi).to(device)
|
||||
lbl = torch.from_numpy(lbl).to(device)
|
||||
|
||||
if X.dtype != net.dtype:
|
||||
X = X.to(net.dtype)
|
||||
lbl = lbl.to(net.dtype)
|
||||
|
||||
y = net(X)[0]
|
||||
loss = _loss_fn_seg(lbl, y, device)
|
||||
if y.shape[1] > 3:
|
||||
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
|
||||
loss += loss3
|
||||
test_loss = loss.item()
|
||||
test_loss *= len(imgi)
|
||||
lavgt += test_loss
|
||||
lavgt /= len(rperm)
|
||||
test_losses[iepoch] = lavgt
|
||||
lavg /= nsum
|
||||
train_logger.info(
|
||||
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time() - t0:.2f}s"
|
||||
)
|
||||
lavg, nsum = 0, 0
|
||||
|
||||
if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
|
||||
if save_each and iepoch != n_epochs - 1: # separate files as model progresses
|
||||
filename0 = str(filename) + f"_epoch_{iepoch:04d}"
|
||||
else:
|
||||
filename0 = filename
|
||||
train_logger.info(f"saving network parameters to {filename0}")
|
||||
net.save_model(filename0)
|
||||
|
||||
net.save_model(filename)
|
||||
|
||||
return filename, train_losses, test_losses
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
<!doctype html>
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
|
|
@ -10,179 +10,9 @@
|
|||
</head>
|
||||
|
||||
<body>
|
||||
<div style="padding: 1rem 1rem;">
|
||||
<div class="mb-3 border" style="padding: 1rem 1rem;">
|
||||
<div class="input-group mb-3">
|
||||
<input id="fileInput" type="file" class="form-control" multiple />
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
<div>
|
||||
<div style="padding-right: 50rem;">
|
||||
<div class="input-group mb-3">
|
||||
<span class="input-group-text">flow threshold:</span>
|
||||
<input id="flow" type="text" class="form-control" placeholder="" />
|
||||
</div>
|
||||
|
||||
<div class="input-group mb-3">
|
||||
<span class="input-group-text">cellprob threshold:</span>
|
||||
<input id="cellprob" type="text" class="form-control" placeholder="" />
|
||||
</div>
|
||||
|
||||
<div class="input-group mb-3">
|
||||
<span class="input-group-text">diameter:</span>
|
||||
<input id="diameter" type="text" class="form-control" placeholder="" />
|
||||
</div>
|
||||
|
||||
<label>
|
||||
<select id="model" class="form-select"></select>
|
||||
</label>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<br>
|
||||
<div>
|
||||
<button id="uploadBtn" class="btn btn-success">Upload</button>
|
||||
<progress id="bar" max="100" value="0" style="width:300px;"></progress>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<br><br><br>
|
||||
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/axios/dist/axios.min.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.11.6/dist/umd/popper.min.js"
|
||||
integrity="sha384-oBqDVmMz9ATKxIep9tiCxS/Z9fNfEXiDAYTujMAeBAsjFuCZSmKbSSUnQlmh/jp3"
|
||||
crossorigin="anonymous"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.min.js"
|
||||
integrity="sha384-ep+dxp/oz2RKF89ALMPGc7Z89QFa32C8Uv1A3TcEK8sMzXVysblLA3+eJWTzPJzT"
|
||||
crossorigin="anonymous"></script>
|
||||
<script>
|
||||
const API = "http://10.147.18.141:5000/";
|
||||
const APT_UPLOAD = API + "run_upload";
|
||||
const API_MODEL = API + "models";
|
||||
|
||||
async function loadModels() {
|
||||
const select = document.getElementById('model');
|
||||
// 占位 & 禁用,避免用户在加载中点选
|
||||
select.disabled = true;
|
||||
select.innerHTML = '<option value="">加载中…</option>';
|
||||
|
||||
// 可选:超时控制(5 秒)
|
||||
const controller = new AbortController();
|
||||
const t = setTimeout(() => controller.abort(), 5000);
|
||||
|
||||
try {
|
||||
const resp = await fetch(API_MODEL, {
|
||||
headers: { 'Accept': 'application/json' },
|
||||
signal: controller.signal
|
||||
});
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
|
||||
|
||||
const data = await resp.json();
|
||||
const list = Array.isArray(data.models) ? data.models : [];
|
||||
|
||||
// 记住之前的选择(如果有)
|
||||
const prev = select.value;
|
||||
|
||||
// 重新渲染选项
|
||||
select.innerHTML = '<option value="" disabled selected>请选择模型</option>';
|
||||
const seen = new Set();
|
||||
for (const name of list) {
|
||||
if (!name || seen.has(name)) continue;
|
||||
seen.add(name);
|
||||
const opt = document.createElement('option');
|
||||
opt.value = String(name);
|
||||
opt.textContent = String(name);
|
||||
select.appendChild(opt);
|
||||
}
|
||||
|
||||
// 如果原选择仍然存在,则恢复它
|
||||
if (prev && seen.has(prev)) {
|
||||
select.value = prev;
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('加载模型列表失败:', err);
|
||||
select.innerHTML = '<option value="">加载失败(点击重试)</option>';
|
||||
// 点击下拉框时重试
|
||||
select.addEventListener('click', loadModels, { once: true });
|
||||
} finally {
|
||||
clearTimeout(t);
|
||||
select.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
document.addEventListener('DOMContentLoaded', loadModels);
|
||||
|
||||
function buildUrl() {
|
||||
// 获取参数
|
||||
const model = document.getElementById('model')?.value;
|
||||
const flow = (document.getElementById('flow')?.value || '').trim();
|
||||
const cellp = (document.getElementById('cellprob')?.value || '').trim();
|
||||
const diameter = (document.getElementById('diameter')?.value || '').trim();
|
||||
|
||||
// 用 URLSearchParams 组装查询串
|
||||
const qs = new URLSearchParams({
|
||||
model: model,
|
||||
flow_threshold: flow,
|
||||
cellprob_threshold: cellp,
|
||||
diameter: diameter
|
||||
});
|
||||
|
||||
return `${API_UPLOAD}?${qs.toString()}`;
|
||||
}
|
||||
|
||||
document.getElementById("uploadBtn").addEventListener("click", async () => {
|
||||
const input = document.getElementById("fileInput");
|
||||
if (!input.files.length) return alert("请选择文件");
|
||||
|
||||
const fd = new FormData();
|
||||
for (const f of input.files) fd.append("files", f);
|
||||
|
||||
const bar = document.getElementById("bar");
|
||||
try {
|
||||
const URL = buildUrl();
|
||||
const res = await axios.post(URL, fd, {
|
||||
onUploadProgress: (e) => {
|
||||
if (e.total) bar.value = Math.round((e.loaded * 100) / e.total);
|
||||
},
|
||||
});
|
||||
// alert("上传成功:" + JSON.stringify(res.data));
|
||||
|
||||
// 创建一个提示元素
|
||||
const notice = document.createElement("div");
|
||||
notice.style.position = "fixed";
|
||||
notice.style.top = "20px";
|
||||
notice.style.left = "50%";
|
||||
notice.style.transform = "translateX(-50%)";
|
||||
notice.style.padding = "10px 20px";
|
||||
notice.style.background = "#4caf50";
|
||||
notice.style.color = "white";
|
||||
notice.style.borderRadius = "8px";
|
||||
notice.style.fontSize = "16px";
|
||||
notice.style.zIndex = "9999";
|
||||
|
||||
let seconds = 3;
|
||||
notice.textContent = `上传成功!${seconds} 秒后跳转预览页面…`;
|
||||
document.body.appendChild(notice);
|
||||
|
||||
const timer = setInterval(() => {
|
||||
seconds--;
|
||||
if (seconds > 0) {
|
||||
notice.textContent = `上传成功!${seconds} 秒后跳转预览页面…`;
|
||||
} else {
|
||||
clearInterval(timer);
|
||||
document.body.removeChild(notice);
|
||||
window.location.href = `preview.html?id=${encodeURIComponent(res.data['id'])}`;
|
||||
}
|
||||
}, 1000);
|
||||
} catch (e) {
|
||||
alert("上传失败:" + (e.response?.data?.message || e.message));
|
||||
} finally {
|
||||
bar.value = 0;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
<a href="run.html">运行</a>
|
||||
<a href="train.html">训练</a>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
Loading…
Reference in a new issue