diff --git a/backend/cp_train.py b/backend/cp_train.py index a624ac7..e01521b 100644 --- a/backend/cp_train.py +++ b/backend/cp_train.py @@ -1,33 +1,53 @@ import os.path -from cellpose import io, models, train from pathlib import Path from omegaconf import OmegaConf +import redis +import datetime +import json CONFIG_PATH = Path(__file__).parent / "config.yaml" cfg = OmegaConf.load(CONFIG_PATH) cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve()) BASE_DIR = cfg.data.root_dir -TEST_TRAIN_DIR = cfg.data.train.test_train_dir -TEST_TEST_DIR = cfg.data.train.test_test_dir +TRAIN_DIR = cfg.data.train.train_dir +TEST_DIR = cfg.data.train.test_dir MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve()) +os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR + +r = redis.Redis(host="127.0.0.1", port=6379, db=0) + +def set_status(task_id, status, **extra): + payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra} + r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期 + +def get_status(task_id): + raw = r.get(f"task:{task_id}") + return json.loads(raw) if raw else None + +from cellpose import io, models, train class Cptrain: @classmethod - def start_train(cls, + async def start_train(cls, time: str | None = None, - model_name: str | None = None,): + model_name: str | None = None, + image_filter: str = "_img", + mask_filter: str = "_masks", + base_model: str = "cpsam"): - train_dir = Path(TEST_TRAIN_DIR) / time - test_dir = Path(TEST_TEST_DIR) / time + train_dir = Path(TRAIN_DIR) / time + test_dir = Path(TEST_DIR) / time os.makedirs(train_dir, exist_ok=True) os.makedirs(test_dir, exist_ok=True) io.logger_setup() - output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter="_img", - mask_filter="_masks", look_one_level_down=False) + output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter=image_filter, + mask_filter=mask_filter, look_one_level_down=False) images, labels, image_names, test_images, test_labels, image_names_test = output - model = models.CellposeModel(gpu=True) + model = models.CellposeModel(gpu=True, pretrained_model=base_model) + + set_status(time, "running") model_path, train_losses, test_losses = train.train_seg(model.net, train_data=images, train_labels=labels, @@ -36,4 +56,4 @@ class Cptrain: n_epochs=100, model_name=model_name, save_path=MODELS_DIR) - print("模型已保存到:", model_path) \ No newline at end of file + print("模型已保存到:", model_path) diff --git a/backend/flaskApp.py b/backend/flaskApp.py index 5671347..d86e2ce 100644 --- a/backend/flaskApp.py +++ b/backend/flaskApp.py @@ -14,6 +14,7 @@ from flask import Flask, send_from_directory, request, jsonify from flask_cors import CORS from werkzeug.utils import secure_filename +from backend.cp_train import Cptrain from cp_run import Cprun app = Flask(__name__) @@ -26,6 +27,8 @@ BASE_DIR = cfg.data.root_dir UPLOAD_DIR = cfg.data.upload_dir OUTPUT_DIR = cfg.data.run.output_dir MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve()) +TRAIN_DIR = cfg.data.train.train_dir +TEST_DIR = cfg.data.train.test_dir os.makedirs(UPLOAD_DIR, exist_ok=True) executor = ThreadPoolExecutor(max_workers=4) @@ -128,16 +131,50 @@ def run_upload(): @app.post("/train_upload") def train_upload(): ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}" + model_name = request.args.get("model_name") or f"custom_model-{ts}" + image_filter = request.args.get("image_filter") or "_img" + mask_filter = request.args.get("mask_filter") or "_masks" + base_model = request.args.get("base_model") or "cpsam" + train_files = request.files.getlist("train_files") test_files = request.files.getlist("test_files") + set_status(ts, "pending") saved = [] for f in train_files: if not f or f.filename == "": continue name = secure_filename(f.filename) - f.save(os.path.join(UPLOAD_DIR, ts, name)) - saved.append(os.path.join(UPLOAD_DIR, ts, name)) + f.save(os.path.join(TRAIN_DIR, ts, name)) + saved.append(os.path.join(TRAIN_DIR, ts, name)) + for f in test_files: + if not f or f.filename == "": + continue + name = secure_filename(f.filename) + f.save(os.path.join(TEST_DIR, ts, name)) + saved.append(os.path.join(TEST_DIR, ts, name)) + + def job(): + return asyncio.run(Cptrain.start_train( + time=ts, + model_name=model_name, + image_filter=image_filter, + mask_filter=mask_filter, + base_model=base_model + )) + + fut = executor.submit(job) + + def done_cb(f): + try: + f.result() + set_status(ts, "success") + except Exception as e: + set_status(ts, "failed", error=str(e)) + + fut.add_done_callback(done_cb) + + return jsonify({"ok": True, "count": len(saved), "id": ts}) @app.get("/status") def status(): diff --git a/backend/train.py b/backend/train.py new file mode 100644 index 0000000..ebd4720 --- /dev/null +++ b/backend/train.py @@ -0,0 +1,555 @@ +import time +import os +import numpy as np +from cellpose import io, utils, models, dynamics +from cellpose.transforms import normalize_img, random_rotate_and_resize, convert_image +from pathlib import Path +import torch +from torch import nn +from tqdm import trange +import redis +import json +import datetime + +import logging + +r = redis.Redis(host="127.0.0.1", port=6379, db=0) + +def set_status(task_id, status, **extra): + payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra} + r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期 + +def get_status(task_id): + raw = r.get(f"task:{task_id}") + return json.loads(raw) if raw else None + +train_logger = logging.getLogger(__name__) + + +def _loss_fn_class(lbl, y, class_weights=None): + """ + Calculates the loss function between true labels lbl and prediction y. + + Args: + lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX). + y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob). + + Returns: + torch.Tensor: Loss value. + + """ + + criterion3 = nn.CrossEntropyLoss(reduction="mean", weight=class_weights) + loss3 = criterion3(y[:, :-3], lbl[:, 0].long()) + + return loss3 + + +def _loss_fn_seg(lbl, y, device): + """ + Calculates the loss function between true labels lbl and prediction y. + + Args: + lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX). + y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob). + device (torch.device): Device on which the tensors are located. + + Returns: + torch.Tensor: Loss value. + + """ + criterion = nn.MSELoss(reduction="mean") + criterion2 = nn.BCEWithLogitsLoss(reduction="mean") + veci = 5. * lbl[:, -2:] + loss = criterion(y[:, -3:-1], veci) + loss /= 2. + loss2 = criterion2(y[:, -1], (lbl[:, -3] > 0.5).to(y.dtype)) + loss = loss + loss2 + return loss + + +def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False}): + """ + Reshapes and normalizes the input data. + + Args: + data (list): List of input data, with channels axis first or last. + normalize_params (dict, optional): Dictionary of normalization parameters. Defaults to {"normalize": False}. + + Returns: + list: List of reshaped and normalized data. + """ + if (np.array([td.ndim != 3 for td in data]).sum() > 0 or + np.array([td.shape[0] != 3 for td in data]).sum() > 0): + data_new = [] + for td in data: + if td.ndim == 3: + channel_axis0 = channel_axis if channel_axis is not None else np.array(td.shape).argmin() + # put channel axis first + td = np.moveaxis(td, channel_axis0, 0) + td = td[:3] # keep at most 3 channels + if td.ndim == 2 or (td.ndim == 3 and td.shape[0] == 1): + td = np.stack((td, 0 * td, 0 * td), axis=0) + elif td.ndim == 3 and td.shape[0] < 3: + td = np.concatenate((td, 0 * td[:1]), axis=0) + data_new.append(td) + data = data_new + if normalize_params["normalize"]: + data = [ + normalize_img(td, normalize=normalize_params, axis=0) + for td in data + ] + return data + + +def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, + normalize_params={"normalize": False}): + """ + Get a batch of images and labels. + + Args: + inds (list): List of indices indicating which images and labels to retrieve. + data (list or None): List of image data. If None, images will be loaded from files. + labels (list or None): List of label data. If None, labels will be loaded from files. + files (list or None): List of file paths for images. + labels_files (list or None): List of file paths for labels. + normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize). + + Returns: + tuple: A tuple containing two lists: the batch of images and the batch of labels. + """ + if data is None: + lbls = None + imgs = [io.imread(files[i]) for i in inds] + imgs = _reshape_norm(imgs, normalize_params=normalize_params) + if labels_files is not None: + lbls = [io.imread(labels_files[i])[1:] for i in inds] + else: + imgs = [data[i] for i in inds] + lbls = [labels[i][1:] for i in inds] + return imgs, lbls + + +def _reshape_norm_save(files, channels=None, channel_axis=None, + normalize_params={"normalize": False}): + """ not currently used -- normalization happening on each batch if not load_files """ + files_new = [] + for f in trange(files): + td = io.imread(f) + if channels is not None: + td = convert_image(td, channels=channels, + channel_axis=channel_axis) + td = td.transpose(2, 0, 1) + if normalize_params["normalize"]: + td = normalize_img(td, normalize=normalize_params, axis=0) + fnew = os.path.splitext(str(f))[0] + "_cpnorm.tif" + io.imsave(fnew, td) + files_new.append(fnew) + return files_new + # else: + # train_files = reshape_norm_save(train_files, channels=channels, + # channel_axis=channel_axis, normalize_params=normalize_params) + # elif test_files is not None: + # test_files = reshape_norm_save(test_files, channels=channels, + # channel_axis=channel_axis, normalize_params=normalize_params) + + +def _process_train_test(train_data=None, train_labels=None, train_files=None, + train_labels_files=None, train_probs=None, test_data=None, + test_labels=None, test_files=None, test_labels_files=None, + test_probs=None, load_files=True, min_train_masks=5, + compute_flows=False, normalize_params={"normalize": False}, + channel_axis=None, device=None): + """ + Process train and test data. + + Args: + train_data (list or None): List of training data arrays. + train_labels (list or None): List of training label arrays. + train_files (list or None): List of training file paths. + train_labels_files (list or None): List of training label file paths. + train_probs (ndarray or None): Array of training probabilities. + test_data (list or None): List of test data arrays. + test_labels (list or None): List of test label arrays. + test_files (list or None): List of test file paths. + test_labels_files (list or None): List of test label file paths. + test_probs (ndarray or None): Array of test probabilities. + load_files (bool): Whether to load data from files. + min_train_masks (int): Minimum number of masks required for training images. + compute_flows (bool): Whether to compute flows. + channels (list or None): List of channel indices to use. + channel_axis (int or None): Axis of channel dimension. + rgb (bool): Convert training/testing images to RGB. + normalize_params (dict): Dictionary of normalization parameters. + device (torch.device): Device to use for computation. + + Returns: + tuple: A tuple containing the processed train and test data and sampling probabilities and diameters. + """ + if device == None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device( + 'mps') if torch.backends.mps.is_available() else None + + if train_data is not None and train_labels is not None: + # if data is loaded + nimg = len(train_data) + nimg_test = len(test_data) if test_data is not None else None + else: + # otherwise use files + nimg = len(train_files) + if train_labels_files is None: + train_labels_files = [ + os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files + ] + train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)] + if (test_data is not None or + test_files is not None) and test_labels_files is None: + test_labels_files = [ + os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files + ] + test_labels_files = [tf for tf in test_labels_files if os.path.exists(tf)] + if not load_files: + train_logger.info(">>> using files instead of loading dataset") + else: + # load all images + train_logger.info(">>> loading images and labels") + train_data = [io.imread(train_files[i]) for i in trange(nimg)] + train_labels = [io.imread(train_labels_files[i]) for i in trange(nimg)] + nimg_test = len(test_files) if test_files is not None else None + if load_files and nimg_test: + test_data = [io.imread(test_files[i]) for i in trange(nimg_test)] + test_labels = [io.imread(test_labels_files[i]) for i in trange(nimg_test)] + + ### check that arrays are correct size + if ((train_labels is not None and nimg != len(train_labels)) or + (train_labels_files is not None and nimg != len(train_labels_files))): + error_message = "train data and labels not same length" + train_logger.critical(error_message) + raise ValueError(error_message) + if ((test_labels is not None and nimg_test != len(test_labels)) or + (test_labels_files is not None and nimg_test != len(test_labels_files))): + train_logger.warning("test data and labels not same length, not using") + test_data, test_files = None, None + if train_labels is not None: + if train_labels[0].ndim < 2 or train_data[0].ndim < 2: + error_message = "training data or labels are not at least two-dimensional" + train_logger.critical(error_message) + raise ValueError(error_message) + if train_data[0].ndim > 3: + error_message = "training data is more than three-dimensional (should be 2D or 3D array)" + train_logger.critical(error_message) + raise ValueError(error_message) + + ### check that flows are computed + if train_labels is not None: + train_labels = dynamics.labels_to_flows(train_labels, files=train_files, + device=device) + if test_labels is not None: + test_labels = dynamics.labels_to_flows(test_labels, files=test_files, + device=device) + elif compute_flows: + for k in trange(nimg): + tl = dynamics.labels_to_flows(io.imread(train_labels_files), + files=train_files, device=device) + if test_files is not None: + for k in trange(nimg_test): + tl = dynamics.labels_to_flows(io.imread(test_labels_files), + files=test_files, device=device) + + ### compute diameters + nmasks = np.zeros(nimg) + diam_train = np.zeros(nimg) + train_logger.info(">>> computing diameters") + for k in trange(nimg): + tl = (train_labels[k][0] + if train_labels is not None else io.imread(train_labels_files[k])[0]) + diam_train[k], dall = utils.diameters(tl) + nmasks[k] = len(dall) + diam_train[diam_train < 5] = 5. + if test_data is not None: + diam_test = np.array( + [utils.diameters(test_labels[k][0])[0] for k in trange(len(test_labels))]) + diam_test[diam_test < 5] = 5. + elif test_labels_files is not None: + diam_test = np.array([ + utils.diameters(io.imread(test_labels_files[k])[0])[0] + for k in trange(len(test_labels_files)) + ]) + diam_test[diam_test < 5] = 5. + else: + diam_test = None + + ### check to remove training images with too few masks + if min_train_masks > 0: + nremove = (nmasks < min_train_masks).sum() + if nremove > 0: + train_logger.warning( + f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set" + ) + ikeep = np.nonzero(nmasks >= min_train_masks)[0] + if train_data is not None: + train_data = [train_data[i] for i in ikeep] + train_labels = [train_labels[i] for i in ikeep] + if train_files is not None: + train_files = [train_files[i] for i in ikeep] + if train_labels_files is not None: + train_labels_files = [train_labels_files[i] for i in ikeep] + if train_probs is not None: + train_probs = train_probs[ikeep] + diam_train = diam_train[ikeep] + nimg = len(train_data) + + ### normalize probabilities + train_probs = 1. / nimg * np.ones(nimg, + "float64") if train_probs is None else train_probs + train_probs /= train_probs.sum() + if test_files is not None or test_data is not None: + test_probs = 1. / nimg_test * np.ones( + nimg_test, "float64") if test_probs is None else test_probs + test_probs /= test_probs.sum() + + ### reshape and normalize train / test data + normed = False + if normalize_params["normalize"]: + train_logger.info(f">>> normalizing {normalize_params}") + if train_data is not None: + train_data = _reshape_norm(train_data, channel_axis=channel_axis, + normalize_params=normalize_params) + normed = True + if test_data is not None: + test_data = _reshape_norm(test_data, channel_axis=channel_axis, + normalize_params=normalize_params) + + return (train_data, train_labels, train_files, train_labels_files, train_probs, + diam_train, test_data, test_labels, test_files, test_labels_files, + test_probs, diam_test, normed) + + +def train_seg(net, train_data=None, train_labels=None, train_files=None, + train_labels_files=None, train_probs=None, test_data=None, + test_labels=None, test_files=None, test_labels_files=None, + test_probs=None, channel_axis=None, + load_files=True, batch_size=1, learning_rate=5e-5, SGD=False, + n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False, + save_path=None, save_every=100, save_each=False, nimg_per_epoch=None, + nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256, + min_train_masks=5, model_name=None, class_weights=None, ts=None): + """ + Train the network with images for segmentation. + + Args: + net (object): The network model to train. + train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None. + train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. + train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None. + train_labels_files (list or None): List of training label file paths. Defaults to None. + train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None. + test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None. + test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. + test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None. + test_labels_files (list or None): List of test label file paths. Defaults to None. + test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None. + load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True. + batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 8. + learning_rate (float or List[float], optional): Float or list/np.ndarray - learning rate for training. Defaults to 0.005. + n_epochs (int, optional): Integer - number of times to go through the whole training set during training. Defaults to 2000. + weight_decay (float, optional): Float - weight decay for the optimizer. Defaults to 1e-5. + momentum (float, optional): Float - momentum for the optimizer. Defaults to 0.9. + SGD (bool, optional): Deprecated in v4.0.1+ - AdamW always used. + normalize (bool or dict, optional): Boolean or dictionary - whether to normalize the data. Defaults to True. + compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False. + save_path (str, optional): String - where to save the trained model. Defaults to None. + save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100. + save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False. + nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None. + nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None. + rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True. + min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5. + model_name (str, optional): String - name of the network. Defaults to None. + + Returns: + tuple: A tuple containing the path to the saved model weights, training losses, and test losses. + + """ + if SGD: + train_logger.warning("SGD is deprecated, using AdamW instead") + + device = net.device + + scale_range = 0.5 if scale_range is None else scale_range + + if isinstance(normalize, dict): + normalize_params = {**models.normalize_default, **normalize} + elif not isinstance(normalize, bool): + raise ValueError("normalize parameter must be a bool or a dict") + else: + normalize_params = models.normalize_default + normalize_params["normalize"] = normalize + + out = _process_train_test(train_data=train_data, train_labels=train_labels, + train_files=train_files, train_labels_files=train_labels_files, + train_probs=train_probs, + test_data=test_data, test_labels=test_labels, + test_files=test_files, test_labels_files=test_labels_files, + test_probs=test_probs, + load_files=load_files, min_train_masks=min_train_masks, + compute_flows=compute_flows, channel_axis=channel_axis, + normalize_params=normalize_params, device=net.device) + (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train, + test_data, test_labels, test_files, test_labels_files, test_probs, diam_test, + normed) = out + # already normalized, do not normalize during training + if normed: + kwargs = {} + else: + kwargs = {"normalize_params": normalize_params, "channel_axis": channel_axis} + + net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device) + + if class_weights is not None and isinstance(class_weights, (list, np.ndarray, tuple)): + class_weights = torch.from_numpy(class_weights).to(device).float() + print(class_weights) + + nimg = len(train_data) if train_data is not None else len(train_files) + nimg_test = len(test_data) if test_data is not None else None + nimg_test = len(test_files) if test_files is not None else nimg_test + nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch + nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch + + # learning rate schedule + LR = np.linspace(0, learning_rate, 10) + LR = np.append(LR, learning_rate * np.ones(max(0, n_epochs - 10))) + if n_epochs > 300: + LR = LR[:-100] + for i in range(10): + LR = np.append(LR, LR[-1] / 2 * np.ones(10)) + elif n_epochs > 99: + LR = LR[:-50] + for i in range(10): + LR = np.append(LR, LR[-1] / 2 * np.ones(5)) + + train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}") + train_logger.info( + f">>> AdamW, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}" + ) + optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate, + weight_decay=weight_decay) + + t0 = time.time() + model_name = f"cellpose_{t0}" if model_name is None else model_name + save_path = Path.cwd() if save_path is None else Path(save_path) + filename = save_path / "models" / model_name + (save_path / "models").mkdir(exist_ok=True) + + train_logger.info(f">>> saving model to {filename}") + + lavg, nsum = 0, 0 + train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs) + for iepoch in range(n_epochs): + np.random.seed(iepoch) + if nimg != nimg_per_epoch: + # choose random images for epoch with probability train_probs + rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,), + p=train_probs) + else: + # otherwise use all images + rperm = np.random.permutation(np.arange(0, nimg)) + for param_group in optimizer.param_groups: + param_group["lr"] = LR[iepoch] # set learning rate + net.train() + for k in range(0, nimg_per_epoch, batch_size): + kend = min(k + batch_size, nimg_per_epoch) + inds = rperm[k:kend] + imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels, + files=train_files, labels_files=train_labels_files, + **kwargs) + diams = np.array([diam_train[i] for i in inds]) + rsc = diams / net.diam_mean.item() if rescale else np.ones( + len(diams), "float32") + # augmentations + imgi, lbl = random_rotate_and_resize(imgs, Y=lbls, rescale=rsc, + scale_range=scale_range, + xy=(bsize, bsize))[:2] + # network and loss optimization + X = torch.from_numpy(imgi).to(device) + lbl = torch.from_numpy(lbl).to(device) + + if X.dtype != net.dtype: + X = X.to(net.dtype) + lbl = lbl.to(net.dtype) + + y = net(X)[0] + loss = _loss_fn_seg(lbl, y, device) + if y.shape[1] > 3: + loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) + loss += loss3 + optimizer.zero_grad() + loss.backward() + optimizer.step() + train_loss = loss.item() + train_loss *= len(imgi) + + # keep track of average training loss across epochs + lavg += train_loss + nsum += len(imgi) + # per epoch training loss + train_losses[iepoch] += train_loss + train_losses[iepoch] /= nimg_per_epoch + + if iepoch == 5 or iepoch % 10 == 0: + lavgt = 0. + if test_data is not None or test_files is not None: + np.random.seed(42) + if nimg_test != nimg_test_per_epoch: + rperm = np.random.choice(np.arange(0, nimg_test), + size=(nimg_test_per_epoch,), p=test_probs) + else: + rperm = np.random.permutation(np.arange(0, nimg_test)) + for ibatch in range(0, len(rperm), batch_size): + with torch.no_grad(): + net.eval() + inds = rperm[ibatch:ibatch + batch_size] + imgs, lbls = _get_batch(inds, data=test_data, + labels=test_labels, files=test_files, + labels_files=test_labels_files, + **kwargs) + diams = np.array([diam_test[i] for i in inds]) + rsc = diams / net.diam_mean.item() if rescale else np.ones( + len(diams), "float32") + imgi, lbl = random_rotate_and_resize( + imgs, Y=lbls, rescale=rsc, scale_range=scale_range, + xy=(bsize, bsize))[:2] + X = torch.from_numpy(imgi).to(device) + lbl = torch.from_numpy(lbl).to(device) + + if X.dtype != net.dtype: + X = X.to(net.dtype) + lbl = lbl.to(net.dtype) + + y = net(X)[0] + loss = _loss_fn_seg(lbl, y, device) + if y.shape[1] > 3: + loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) + loss += loss3 + test_loss = loss.item() + test_loss *= len(imgi) + lavgt += test_loss + lavgt /= len(rperm) + test_losses[iepoch] = lavgt + lavg /= nsum + train_logger.info( + f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time() - t0:.2f}s" + ) + lavg, nsum = 0, 0 + + if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): + if save_each and iepoch != n_epochs - 1: # separate files as model progresses + filename0 = str(filename) + f"_epoch_{iepoch:04d}" + else: + filename0 = filename + train_logger.info(f"saving network parameters to {filename0}") + net.save_model(filename0) + + net.save_model(filename) + + return filename, train_losses, test_losses diff --git a/frontend/index.html b/frontend/index.html index 2ae0aeb..7a0a62d 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -1,4 +1,4 @@ - +
@@ -10,179 +10,9 @@ -