feature(train): 新增训练前后端(未测试)

This commit is contained in:
ClovertaTheTrilobita 2025-09-23 16:27:02 +00:00
parent a6353d1c37
commit d51439dee0
4 changed files with 628 additions and 186 deletions

View file

@ -1,33 +1,53 @@
import os.path import os.path
from cellpose import io, models, train
from pathlib import Path from pathlib import Path
from omegaconf import OmegaConf from omegaconf import OmegaConf
import redis
import datetime
import json
CONFIG_PATH = Path(__file__).parent / "config.yaml" CONFIG_PATH = Path(__file__).parent / "config.yaml"
cfg = OmegaConf.load(CONFIG_PATH) cfg = OmegaConf.load(CONFIG_PATH)
cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve()) cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
BASE_DIR = cfg.data.root_dir BASE_DIR = cfg.data.root_dir
TEST_TRAIN_DIR = cfg.data.train.test_train_dir TRAIN_DIR = cfg.data.train.train_dir
TEST_TEST_DIR = cfg.data.train.test_test_dir TEST_DIR = cfg.data.train.test_dir
MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve()) MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve())
os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
def set_status(task_id, status, **extra):
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
def get_status(task_id):
raw = r.get(f"task:{task_id}")
return json.loads(raw) if raw else None
from cellpose import io, models, train
class Cptrain: class Cptrain:
@classmethod @classmethod
def start_train(cls, async def start_train(cls,
time: str | None = None, time: str | None = None,
model_name: str | None = None,): model_name: str | None = None,
image_filter: str = "_img",
mask_filter: str = "_masks",
base_model: str = "cpsam"):
train_dir = Path(TEST_TRAIN_DIR) / time train_dir = Path(TRAIN_DIR) / time
test_dir = Path(TEST_TEST_DIR) / time test_dir = Path(TEST_DIR) / time
os.makedirs(train_dir, exist_ok=True) os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True) os.makedirs(test_dir, exist_ok=True)
io.logger_setup() io.logger_setup()
output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter="_img", output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter=image_filter,
mask_filter="_masks", look_one_level_down=False) mask_filter=mask_filter, look_one_level_down=False)
images, labels, image_names, test_images, test_labels, image_names_test = output images, labels, image_names, test_images, test_labels, image_names_test = output
model = models.CellposeModel(gpu=True) model = models.CellposeModel(gpu=True, pretrained_model=base_model)
set_status(time, "running")
model_path, train_losses, test_losses = train.train_seg(model.net, model_path, train_losses, test_losses = train.train_seg(model.net,
train_data=images, train_labels=labels, train_data=images, train_labels=labels,
@ -36,4 +56,4 @@ class Cptrain:
n_epochs=100, model_name=model_name, n_epochs=100, model_name=model_name,
save_path=MODELS_DIR) save_path=MODELS_DIR)
print("模型已保存到:", model_path) print("模型已保存到:", model_path)

View file

@ -14,6 +14,7 @@ from flask import Flask, send_from_directory, request, jsonify
from flask_cors import CORS from flask_cors import CORS
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
from backend.cp_train import Cptrain
from cp_run import Cprun from cp_run import Cprun
app = Flask(__name__) app = Flask(__name__)
@ -26,6 +27,8 @@ BASE_DIR = cfg.data.root_dir
UPLOAD_DIR = cfg.data.upload_dir UPLOAD_DIR = cfg.data.upload_dir
OUTPUT_DIR = cfg.data.run.output_dir OUTPUT_DIR = cfg.data.run.output_dir
MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve()) MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve())
TRAIN_DIR = cfg.data.train.train_dir
TEST_DIR = cfg.data.train.test_dir
os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(UPLOAD_DIR, exist_ok=True)
executor = ThreadPoolExecutor(max_workers=4) executor = ThreadPoolExecutor(max_workers=4)
@ -128,16 +131,50 @@ def run_upload():
@app.post("/train_upload") @app.post("/train_upload")
def train_upload(): def train_upload():
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}" ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
model_name = request.args.get("model_name") or f"custom_model-{ts}"
image_filter = request.args.get("image_filter") or "_img"
mask_filter = request.args.get("mask_filter") or "_masks"
base_model = request.args.get("base_model") or "cpsam"
train_files = request.files.getlist("train_files") train_files = request.files.getlist("train_files")
test_files = request.files.getlist("test_files") test_files = request.files.getlist("test_files")
set_status(ts, "pending")
saved = [] saved = []
for f in train_files: for f in train_files:
if not f or f.filename == "": if not f or f.filename == "":
continue continue
name = secure_filename(f.filename) name = secure_filename(f.filename)
f.save(os.path.join(UPLOAD_DIR, ts, name)) f.save(os.path.join(TRAIN_DIR, ts, name))
saved.append(os.path.join(UPLOAD_DIR, ts, name)) saved.append(os.path.join(TRAIN_DIR, ts, name))
for f in test_files:
if not f or f.filename == "":
continue
name = secure_filename(f.filename)
f.save(os.path.join(TEST_DIR, ts, name))
saved.append(os.path.join(TEST_DIR, ts, name))
def job():
return asyncio.run(Cptrain.start_train(
time=ts,
model_name=model_name,
image_filter=image_filter,
mask_filter=mask_filter,
base_model=base_model
))
fut = executor.submit(job)
def done_cb(f):
try:
f.result()
set_status(ts, "success")
except Exception as e:
set_status(ts, "failed", error=str(e))
fut.add_done_callback(done_cb)
return jsonify({"ok": True, "count": len(saved), "id": ts})
@app.get("/status") @app.get("/status")
def status(): def status():

555
backend/train.py Normal file
View file

@ -0,0 +1,555 @@
import time
import os
import numpy as np
from cellpose import io, utils, models, dynamics
from cellpose.transforms import normalize_img, random_rotate_and_resize, convert_image
from pathlib import Path
import torch
from torch import nn
from tqdm import trange
import redis
import json
import datetime
import logging
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
def set_status(task_id, status, **extra):
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
def get_status(task_id):
raw = r.get(f"task:{task_id}")
return json.loads(raw) if raw else None
train_logger = logging.getLogger(__name__)
def _loss_fn_class(lbl, y, class_weights=None):
"""
Calculates the loss function between true labels lbl and prediction y.
Args:
lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
Returns:
torch.Tensor: Loss value.
"""
criterion3 = nn.CrossEntropyLoss(reduction="mean", weight=class_weights)
loss3 = criterion3(y[:, :-3], lbl[:, 0].long())
return loss3
def _loss_fn_seg(lbl, y, device):
"""
Calculates the loss function between true labels lbl and prediction y.
Args:
lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
device (torch.device): Device on which the tensors are located.
Returns:
torch.Tensor: Loss value.
"""
criterion = nn.MSELoss(reduction="mean")
criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
veci = 5. * lbl[:, -2:]
loss = criterion(y[:, -3:-1], veci)
loss /= 2.
loss2 = criterion2(y[:, -1], (lbl[:, -3] > 0.5).to(y.dtype))
loss = loss + loss2
return loss
def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False}):
"""
Reshapes and normalizes the input data.
Args:
data (list): List of input data, with channels axis first or last.
normalize_params (dict, optional): Dictionary of normalization parameters. Defaults to {"normalize": False}.
Returns:
list: List of reshaped and normalized data.
"""
if (np.array([td.ndim != 3 for td in data]).sum() > 0 or
np.array([td.shape[0] != 3 for td in data]).sum() > 0):
data_new = []
for td in data:
if td.ndim == 3:
channel_axis0 = channel_axis if channel_axis is not None else np.array(td.shape).argmin()
# put channel axis first
td = np.moveaxis(td, channel_axis0, 0)
td = td[:3] # keep at most 3 channels
if td.ndim == 2 or (td.ndim == 3 and td.shape[0] == 1):
td = np.stack((td, 0 * td, 0 * td), axis=0)
elif td.ndim == 3 and td.shape[0] < 3:
td = np.concatenate((td, 0 * td[:1]), axis=0)
data_new.append(td)
data = data_new
if normalize_params["normalize"]:
data = [
normalize_img(td, normalize=normalize_params, axis=0)
for td in data
]
return data
def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
normalize_params={"normalize": False}):
"""
Get a batch of images and labels.
Args:
inds (list): List of indices indicating which images and labels to retrieve.
data (list or None): List of image data. If None, images will be loaded from files.
labels (list or None): List of label data. If None, labels will be loaded from files.
files (list or None): List of file paths for images.
labels_files (list or None): List of file paths for labels.
normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize).
Returns:
tuple: A tuple containing two lists: the batch of images and the batch of labels.
"""
if data is None:
lbls = None
imgs = [io.imread(files[i]) for i in inds]
imgs = _reshape_norm(imgs, normalize_params=normalize_params)
if labels_files is not None:
lbls = [io.imread(labels_files[i])[1:] for i in inds]
else:
imgs = [data[i] for i in inds]
lbls = [labels[i][1:] for i in inds]
return imgs, lbls
def _reshape_norm_save(files, channels=None, channel_axis=None,
normalize_params={"normalize": False}):
""" not currently used -- normalization happening on each batch if not load_files """
files_new = []
for f in trange(files):
td = io.imread(f)
if channels is not None:
td = convert_image(td, channels=channels,
channel_axis=channel_axis)
td = td.transpose(2, 0, 1)
if normalize_params["normalize"]:
td = normalize_img(td, normalize=normalize_params, axis=0)
fnew = os.path.splitext(str(f))[0] + "_cpnorm.tif"
io.imsave(fnew, td)
files_new.append(fnew)
return files_new
# else:
# train_files = reshape_norm_save(train_files, channels=channels,
# channel_axis=channel_axis, normalize_params=normalize_params)
# elif test_files is not None:
# test_files = reshape_norm_save(test_files, channels=channels,
# channel_axis=channel_axis, normalize_params=normalize_params)
def _process_train_test(train_data=None, train_labels=None, train_files=None,
train_labels_files=None, train_probs=None, test_data=None,
test_labels=None, test_files=None, test_labels_files=None,
test_probs=None, load_files=True, min_train_masks=5,
compute_flows=False, normalize_params={"normalize": False},
channel_axis=None, device=None):
"""
Process train and test data.
Args:
train_data (list or None): List of training data arrays.
train_labels (list or None): List of training label arrays.
train_files (list or None): List of training file paths.
train_labels_files (list or None): List of training label file paths.
train_probs (ndarray or None): Array of training probabilities.
test_data (list or None): List of test data arrays.
test_labels (list or None): List of test label arrays.
test_files (list or None): List of test file paths.
test_labels_files (list or None): List of test label file paths.
test_probs (ndarray or None): Array of test probabilities.
load_files (bool): Whether to load data from files.
min_train_masks (int): Minimum number of masks required for training images.
compute_flows (bool): Whether to compute flows.
channels (list or None): List of channel indices to use.
channel_axis (int or None): Axis of channel dimension.
rgb (bool): Convert training/testing images to RGB.
normalize_params (dict): Dictionary of normalization parameters.
device (torch.device): Device to use for computation.
Returns:
tuple: A tuple containing the processed train and test data and sampling probabilities and diameters.
"""
if device == None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device(
'mps') if torch.backends.mps.is_available() else None
if train_data is not None and train_labels is not None:
# if data is loaded
nimg = len(train_data)
nimg_test = len(test_data) if test_data is not None else None
else:
# otherwise use files
nimg = len(train_files)
if train_labels_files is None:
train_labels_files = [
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files
]
train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)]
if (test_data is not None or
test_files is not None) and test_labels_files is None:
test_labels_files = [
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files
]
test_labels_files = [tf for tf in test_labels_files if os.path.exists(tf)]
if not load_files:
train_logger.info(">>> using files instead of loading dataset")
else:
# load all images
train_logger.info(">>> loading images and labels")
train_data = [io.imread(train_files[i]) for i in trange(nimg)]
train_labels = [io.imread(train_labels_files[i]) for i in trange(nimg)]
nimg_test = len(test_files) if test_files is not None else None
if load_files and nimg_test:
test_data = [io.imread(test_files[i]) for i in trange(nimg_test)]
test_labels = [io.imread(test_labels_files[i]) for i in trange(nimg_test)]
### check that arrays are correct size
if ((train_labels is not None and nimg != len(train_labels)) or
(train_labels_files is not None and nimg != len(train_labels_files))):
error_message = "train data and labels not same length"
train_logger.critical(error_message)
raise ValueError(error_message)
if ((test_labels is not None and nimg_test != len(test_labels)) or
(test_labels_files is not None and nimg_test != len(test_labels_files))):
train_logger.warning("test data and labels not same length, not using")
test_data, test_files = None, None
if train_labels is not None:
if train_labels[0].ndim < 2 or train_data[0].ndim < 2:
error_message = "training data or labels are not at least two-dimensional"
train_logger.critical(error_message)
raise ValueError(error_message)
if train_data[0].ndim > 3:
error_message = "training data is more than three-dimensional (should be 2D or 3D array)"
train_logger.critical(error_message)
raise ValueError(error_message)
### check that flows are computed
if train_labels is not None:
train_labels = dynamics.labels_to_flows(train_labels, files=train_files,
device=device)
if test_labels is not None:
test_labels = dynamics.labels_to_flows(test_labels, files=test_files,
device=device)
elif compute_flows:
for k in trange(nimg):
tl = dynamics.labels_to_flows(io.imread(train_labels_files),
files=train_files, device=device)
if test_files is not None:
for k in trange(nimg_test):
tl = dynamics.labels_to_flows(io.imread(test_labels_files),
files=test_files, device=device)
### compute diameters
nmasks = np.zeros(nimg)
diam_train = np.zeros(nimg)
train_logger.info(">>> computing diameters")
for k in trange(nimg):
tl = (train_labels[k][0]
if train_labels is not None else io.imread(train_labels_files[k])[0])
diam_train[k], dall = utils.diameters(tl)
nmasks[k] = len(dall)
diam_train[diam_train < 5] = 5.
if test_data is not None:
diam_test = np.array(
[utils.diameters(test_labels[k][0])[0] for k in trange(len(test_labels))])
diam_test[diam_test < 5] = 5.
elif test_labels_files is not None:
diam_test = np.array([
utils.diameters(io.imread(test_labels_files[k])[0])[0]
for k in trange(len(test_labels_files))
])
diam_test[diam_test < 5] = 5.
else:
diam_test = None
### check to remove training images with too few masks
if min_train_masks > 0:
nremove = (nmasks < min_train_masks).sum()
if nremove > 0:
train_logger.warning(
f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set"
)
ikeep = np.nonzero(nmasks >= min_train_masks)[0]
if train_data is not None:
train_data = [train_data[i] for i in ikeep]
train_labels = [train_labels[i] for i in ikeep]
if train_files is not None:
train_files = [train_files[i] for i in ikeep]
if train_labels_files is not None:
train_labels_files = [train_labels_files[i] for i in ikeep]
if train_probs is not None:
train_probs = train_probs[ikeep]
diam_train = diam_train[ikeep]
nimg = len(train_data)
### normalize probabilities
train_probs = 1. / nimg * np.ones(nimg,
"float64") if train_probs is None else train_probs
train_probs /= train_probs.sum()
if test_files is not None or test_data is not None:
test_probs = 1. / nimg_test * np.ones(
nimg_test, "float64") if test_probs is None else test_probs
test_probs /= test_probs.sum()
### reshape and normalize train / test data
normed = False
if normalize_params["normalize"]:
train_logger.info(f">>> normalizing {normalize_params}")
if train_data is not None:
train_data = _reshape_norm(train_data, channel_axis=channel_axis,
normalize_params=normalize_params)
normed = True
if test_data is not None:
test_data = _reshape_norm(test_data, channel_axis=channel_axis,
normalize_params=normalize_params)
return (train_data, train_labels, train_files, train_labels_files, train_probs,
diam_train, test_data, test_labels, test_files, test_labels_files,
test_probs, diam_test, normed)
def train_seg(net, train_data=None, train_labels=None, train_files=None,
train_labels_files=None, train_probs=None, test_data=None,
test_labels=None, test_files=None, test_labels_files=None,
test_probs=None, channel_axis=None,
load_files=True, batch_size=1, learning_rate=5e-5, SGD=False,
n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False,
save_path=None, save_every=100, save_each=False, nimg_per_epoch=None,
nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256,
min_train_masks=5, model_name=None, class_weights=None, ts=None):
"""
Train the network with images for segmentation.
Args:
net (object): The network model to train.
train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None.
train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None.
train_labels_files (list or None): List of training label file paths. Defaults to None.
train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None.
test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None.
test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None.
test_labels_files (list or None): List of test label file paths. Defaults to None.
test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None.
load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True.
batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 8.
learning_rate (float or List[float], optional): Float or list/np.ndarray - learning rate for training. Defaults to 0.005.
n_epochs (int, optional): Integer - number of times to go through the whole training set during training. Defaults to 2000.
weight_decay (float, optional): Float - weight decay for the optimizer. Defaults to 1e-5.
momentum (float, optional): Float - momentum for the optimizer. Defaults to 0.9.
SGD (bool, optional): Deprecated in v4.0.1+ - AdamW always used.
normalize (bool or dict, optional): Boolean or dictionary - whether to normalize the data. Defaults to True.
compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False.
save_path (str, optional): String - where to save the trained model. Defaults to None.
save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100.
save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False.
nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None.
nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None.
rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True.
min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5.
model_name (str, optional): String - name of the network. Defaults to None.
Returns:
tuple: A tuple containing the path to the saved model weights, training losses, and test losses.
"""
if SGD:
train_logger.warning("SGD is deprecated, using AdamW instead")
device = net.device
scale_range = 0.5 if scale_range is None else scale_range
if isinstance(normalize, dict):
normalize_params = {**models.normalize_default, **normalize}
elif not isinstance(normalize, bool):
raise ValueError("normalize parameter must be a bool or a dict")
else:
normalize_params = models.normalize_default
normalize_params["normalize"] = normalize
out = _process_train_test(train_data=train_data, train_labels=train_labels,
train_files=train_files, train_labels_files=train_labels_files,
train_probs=train_probs,
test_data=test_data, test_labels=test_labels,
test_files=test_files, test_labels_files=test_labels_files,
test_probs=test_probs,
load_files=load_files, min_train_masks=min_train_masks,
compute_flows=compute_flows, channel_axis=channel_axis,
normalize_params=normalize_params, device=net.device)
(train_data, train_labels, train_files, train_labels_files, train_probs, diam_train,
test_data, test_labels, test_files, test_labels_files, test_probs, diam_test,
normed) = out
# already normalized, do not normalize during training
if normed:
kwargs = {}
else:
kwargs = {"normalize_params": normalize_params, "channel_axis": channel_axis}
net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)
if class_weights is not None and isinstance(class_weights, (list, np.ndarray, tuple)):
class_weights = torch.from_numpy(class_weights).to(device).float()
print(class_weights)
nimg = len(train_data) if train_data is not None else len(train_files)
nimg_test = len(test_data) if test_data is not None else None
nimg_test = len(test_files) if test_files is not None else nimg_test
nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
# learning rate schedule
LR = np.linspace(0, learning_rate, 10)
LR = np.append(LR, learning_rate * np.ones(max(0, n_epochs - 10)))
if n_epochs > 300:
LR = LR[:-100]
for i in range(10):
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
elif n_epochs > 99:
LR = LR[:-50]
for i in range(10):
LR = np.append(LR, LR[-1] / 2 * np.ones(5))
train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}")
train_logger.info(
f">>> AdamW, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}"
)
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate,
weight_decay=weight_decay)
t0 = time.time()
model_name = f"cellpose_{t0}" if model_name is None else model_name
save_path = Path.cwd() if save_path is None else Path(save_path)
filename = save_path / "models" / model_name
(save_path / "models").mkdir(exist_ok=True)
train_logger.info(f">>> saving model to {filename}")
lavg, nsum = 0, 0
train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs)
for iepoch in range(n_epochs):
np.random.seed(iepoch)
if nimg != nimg_per_epoch:
# choose random images for epoch with probability train_probs
rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
p=train_probs)
else:
# otherwise use all images
rperm = np.random.permutation(np.arange(0, nimg))
for param_group in optimizer.param_groups:
param_group["lr"] = LR[iepoch] # set learning rate
net.train()
for k in range(0, nimg_per_epoch, batch_size):
kend = min(k + batch_size, nimg_per_epoch)
inds = rperm[k:kend]
imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels,
files=train_files, labels_files=train_labels_files,
**kwargs)
diams = np.array([diam_train[i] for i in inds])
rsc = diams / net.diam_mean.item() if rescale else np.ones(
len(diams), "float32")
# augmentations
imgi, lbl = random_rotate_and_resize(imgs, Y=lbls, rescale=rsc,
scale_range=scale_range,
xy=(bsize, bsize))[:2]
# network and loss optimization
X = torch.from_numpy(imgi).to(device)
lbl = torch.from_numpy(lbl).to(device)
if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)
y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
loss += loss3
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss = loss.item()
train_loss *= len(imgi)
# keep track of average training loss across epochs
lavg += train_loss
nsum += len(imgi)
# per epoch training loss
train_losses[iepoch] += train_loss
train_losses[iepoch] /= nimg_per_epoch
if iepoch == 5 or iepoch % 10 == 0:
lavgt = 0.
if test_data is not None or test_files is not None:
np.random.seed(42)
if nimg_test != nimg_test_per_epoch:
rperm = np.random.choice(np.arange(0, nimg_test),
size=(nimg_test_per_epoch,), p=test_probs)
else:
rperm = np.random.permutation(np.arange(0, nimg_test))
for ibatch in range(0, len(rperm), batch_size):
with torch.no_grad():
net.eval()
inds = rperm[ibatch:ibatch + batch_size]
imgs, lbls = _get_batch(inds, data=test_data,
labels=test_labels, files=test_files,
labels_files=test_labels_files,
**kwargs)
diams = np.array([diam_test[i] for i in inds])
rsc = diams / net.diam_mean.item() if rescale else np.ones(
len(diams), "float32")
imgi, lbl = random_rotate_and_resize(
imgs, Y=lbls, rescale=rsc, scale_range=scale_range,
xy=(bsize, bsize))[:2]
X = torch.from_numpy(imgi).to(device)
lbl = torch.from_numpy(lbl).to(device)
if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)
y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
loss += loss3
test_loss = loss.item()
test_loss *= len(imgi)
lavgt += test_loss
lavgt /= len(rperm)
test_losses[iepoch] = lavgt
lavg /= nsum
train_logger.info(
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time() - t0:.2f}s"
)
lavg, nsum = 0, 0
if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
if save_each and iepoch != n_epochs - 1: # separate files as model progresses
filename0 = str(filename) + f"_epoch_{iepoch:04d}"
else:
filename0 = filename
train_logger.info(f"saving network parameters to {filename0}")
net.save_model(filename0)
net.save_model(filename)
return filename, train_losses, test_losses

View file

@ -1,4 +1,4 @@
<!doctype html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
@ -10,179 +10,9 @@
</head> </head>
<body> <body>
<div style="padding: 1rem 1rem;">
<div class="mb-3 border" style="padding: 1rem 1rem;">
<div class="input-group mb-3">
<input id="fileInput" type="file" class="form-control" multiple />
</div>
<hr> <a href="run.html">运行</a>
<div> <a href="train.html">训练</a>
<div style="padding-right: 50rem;">
<div class="input-group mb-3">
<span class="input-group-text">flow threshold:</span>
<input id="flow" type="text" class="form-control" placeholder="" />
</div>
<div class="input-group mb-3">
<span class="input-group-text">cellprob threshold:</span>
<input id="cellprob" type="text" class="form-control" placeholder="" />
</div>
<div class="input-group mb-3">
<span class="input-group-text">diameter:</span>
<input id="diameter" type="text" class="form-control" placeholder="" />
</div>
<label>
<select id="model" class="form-select"></select>
</label>
</label>
</div>
</div>
<br>
<div>
<button id="uploadBtn" class="btn btn-success">Upload</button>
<progress id="bar" max="100" value="0" style="width:300px;"></progress>
</div>
</div>
</div>
<br><br><br>
<script src="https://cdn.jsdelivr.net/npm/axios/dist/axios.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.11.6/dist/umd/popper.min.js"
integrity="sha384-oBqDVmMz9ATKxIep9tiCxS/Z9fNfEXiDAYTujMAeBAsjFuCZSmKbSSUnQlmh/jp3"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.min.js"
integrity="sha384-ep+dxp/oz2RKF89ALMPGc7Z89QFa32C8Uv1A3TcEK8sMzXVysblLA3+eJWTzPJzT"
crossorigin="anonymous"></script>
<script>
const API = "http://10.147.18.141:5000/";
const APT_UPLOAD = API + "run_upload";
const API_MODEL = API + "models";
async function loadModels() {
const select = document.getElementById('model');
// 占位 & 禁用,避免用户在加载中点选
select.disabled = true;
select.innerHTML = '<option value="">加载中…</option>';
// 可选超时控制5 秒)
const controller = new AbortController();
const t = setTimeout(() => controller.abort(), 5000);
try {
const resp = await fetch(API_MODEL, {
headers: { 'Accept': 'application/json' },
signal: controller.signal
});
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
const data = await resp.json();
const list = Array.isArray(data.models) ? data.models : [];
// 记住之前的选择(如果有)
const prev = select.value;
// 重新渲染选项
select.innerHTML = '<option value="" disabled selected>请选择模型</option>';
const seen = new Set();
for (const name of list) {
if (!name || seen.has(name)) continue;
seen.add(name);
const opt = document.createElement('option');
opt.value = String(name);
opt.textContent = String(name);
select.appendChild(opt);
}
// 如果原选择仍然存在,则恢复它
if (prev && seen.has(prev)) {
select.value = prev;
}
} catch (err) {
console.error('加载模型列表失败:', err);
select.innerHTML = '<option value="">加载失败(点击重试)</option>';
// 点击下拉框时重试
select.addEventListener('click', loadModels, { once: true });
} finally {
clearTimeout(t);
select.disabled = false;
}
}
document.addEventListener('DOMContentLoaded', loadModels);
function buildUrl() {
// 获取参数
const model = document.getElementById('model')?.value;
const flow = (document.getElementById('flow')?.value || '').trim();
const cellp = (document.getElementById('cellprob')?.value || '').trim();
const diameter = (document.getElementById('diameter')?.value || '').trim();
// 用 URLSearchParams 组装查询串
const qs = new URLSearchParams({
model: model,
flow_threshold: flow,
cellprob_threshold: cellp,
diameter: diameter
});
return `${API_UPLOAD}?${qs.toString()}`;
}
document.getElementById("uploadBtn").addEventListener("click", async () => {
const input = document.getElementById("fileInput");
if (!input.files.length) return alert("请选择文件");
const fd = new FormData();
for (const f of input.files) fd.append("files", f);
const bar = document.getElementById("bar");
try {
const URL = buildUrl();
const res = await axios.post(URL, fd, {
onUploadProgress: (e) => {
if (e.total) bar.value = Math.round((e.loaded * 100) / e.total);
},
});
// alert("上传成功:" + JSON.stringify(res.data));
// 创建一个提示元素
const notice = document.createElement("div");
notice.style.position = "fixed";
notice.style.top = "20px";
notice.style.left = "50%";
notice.style.transform = "translateX(-50%)";
notice.style.padding = "10px 20px";
notice.style.background = "#4caf50";
notice.style.color = "white";
notice.style.borderRadius = "8px";
notice.style.fontSize = "16px";
notice.style.zIndex = "9999";
let seconds = 3;
notice.textContent = `上传成功!${seconds} 秒后跳转预览页面…`;
document.body.appendChild(notice);
const timer = setInterval(() => {
seconds--;
if (seconds > 0) {
notice.textContent = `上传成功!${seconds} 秒后跳转预览页面…`;
} else {
clearInterval(timer);
document.body.removeChild(notice);
window.location.href = `preview.html?id=${encodeURIComponent(res.data['id'])}`;
}
}, 1000);
} catch (e) {
alert("上传失败:" + (e.response?.data?.message || e.message));
} finally {
bar.value = 0;
}
});
</script>
</body> </body>
</html> </html>