feature(train): 新增训练前后端(未测试)

This commit is contained in:
ClovertaTheTrilobita 2025-09-23 16:27:02 +00:00
parent a6353d1c37
commit d51439dee0
4 changed files with 628 additions and 186 deletions

View file

@ -1,33 +1,53 @@
import os.path
from cellpose import io, models, train
from pathlib import Path
from omegaconf import OmegaConf
import redis
import datetime
import json
CONFIG_PATH = Path(__file__).parent / "config.yaml"
cfg = OmegaConf.load(CONFIG_PATH)
cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
BASE_DIR = cfg.data.root_dir
TEST_TRAIN_DIR = cfg.data.train.test_train_dir
TEST_TEST_DIR = cfg.data.train.test_test_dir
TRAIN_DIR = cfg.data.train.train_dir
TEST_DIR = cfg.data.train.test_dir
MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve())
os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = MODELS_DIR
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
def set_status(task_id, status, **extra):
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
def get_status(task_id):
raw = r.get(f"task:{task_id}")
return json.loads(raw) if raw else None
from cellpose import io, models, train
class Cptrain:
@classmethod
def start_train(cls,
async def start_train(cls,
time: str | None = None,
model_name: str | None = None,):
model_name: str | None = None,
image_filter: str = "_img",
mask_filter: str = "_masks",
base_model: str = "cpsam"):
train_dir = Path(TEST_TRAIN_DIR) / time
test_dir = Path(TEST_TEST_DIR) / time
train_dir = Path(TRAIN_DIR) / time
test_dir = Path(TEST_DIR) / time
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
io.logger_setup()
output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter="_img",
mask_filter="_masks", look_one_level_down=False)
output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter=image_filter,
mask_filter=mask_filter, look_one_level_down=False)
images, labels, image_names, test_images, test_labels, image_names_test = output
model = models.CellposeModel(gpu=True)
model = models.CellposeModel(gpu=True, pretrained_model=base_model)
set_status(time, "running")
model_path, train_losses, test_losses = train.train_seg(model.net,
train_data=images, train_labels=labels,
@ -36,4 +56,4 @@ class Cptrain:
n_epochs=100, model_name=model_name,
save_path=MODELS_DIR)
print("模型已保存到:", model_path)
print("模型已保存到:", model_path)

View file

@ -14,6 +14,7 @@ from flask import Flask, send_from_directory, request, jsonify
from flask_cors import CORS
from werkzeug.utils import secure_filename
from backend.cp_train import Cptrain
from cp_run import Cprun
app = Flask(__name__)
@ -26,6 +27,8 @@ BASE_DIR = cfg.data.root_dir
UPLOAD_DIR = cfg.data.upload_dir
OUTPUT_DIR = cfg.data.run.output_dir
MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve())
TRAIN_DIR = cfg.data.train.train_dir
TEST_DIR = cfg.data.train.test_dir
os.makedirs(UPLOAD_DIR, exist_ok=True)
executor = ThreadPoolExecutor(max_workers=4)
@ -128,16 +131,50 @@ def run_upload():
@app.post("/train_upload")
def train_upload():
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
model_name = request.args.get("model_name") or f"custom_model-{ts}"
image_filter = request.args.get("image_filter") or "_img"
mask_filter = request.args.get("mask_filter") or "_masks"
base_model = request.args.get("base_model") or "cpsam"
train_files = request.files.getlist("train_files")
test_files = request.files.getlist("test_files")
set_status(ts, "pending")
saved = []
for f in train_files:
if not f or f.filename == "":
continue
name = secure_filename(f.filename)
f.save(os.path.join(UPLOAD_DIR, ts, name))
saved.append(os.path.join(UPLOAD_DIR, ts, name))
f.save(os.path.join(TRAIN_DIR, ts, name))
saved.append(os.path.join(TRAIN_DIR, ts, name))
for f in test_files:
if not f or f.filename == "":
continue
name = secure_filename(f.filename)
f.save(os.path.join(TEST_DIR, ts, name))
saved.append(os.path.join(TEST_DIR, ts, name))
def job():
return asyncio.run(Cptrain.start_train(
time=ts,
model_name=model_name,
image_filter=image_filter,
mask_filter=mask_filter,
base_model=base_model
))
fut = executor.submit(job)
def done_cb(f):
try:
f.result()
set_status(ts, "success")
except Exception as e:
set_status(ts, "failed", error=str(e))
fut.add_done_callback(done_cb)
return jsonify({"ok": True, "count": len(saved), "id": ts})
@app.get("/status")
def status():

555
backend/train.py Normal file
View file

@ -0,0 +1,555 @@
import time
import os
import numpy as np
from cellpose import io, utils, models, dynamics
from cellpose.transforms import normalize_img, random_rotate_and_resize, convert_image
from pathlib import Path
import torch
from torch import nn
from tqdm import trange
import redis
import json
import datetime
import logging
r = redis.Redis(host="127.0.0.1", port=6379, db=0)
def set_status(task_id, status, **extra):
payload = {"status": status, "updated_at": datetime.datetime.utcnow().isoformat(), **extra}
r.set(f"task:{task_id}", json.dumps(payload), ex=86400) # 1 天过期
def get_status(task_id):
raw = r.get(f"task:{task_id}")
return json.loads(raw) if raw else None
train_logger = logging.getLogger(__name__)
def _loss_fn_class(lbl, y, class_weights=None):
"""
Calculates the loss function between true labels lbl and prediction y.
Args:
lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
Returns:
torch.Tensor: Loss value.
"""
criterion3 = nn.CrossEntropyLoss(reduction="mean", weight=class_weights)
loss3 = criterion3(y[:, :-3], lbl[:, 0].long())
return loss3
def _loss_fn_seg(lbl, y, device):
"""
Calculates the loss function between true labels lbl and prediction y.
Args:
lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
device (torch.device): Device on which the tensors are located.
Returns:
torch.Tensor: Loss value.
"""
criterion = nn.MSELoss(reduction="mean")
criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
veci = 5. * lbl[:, -2:]
loss = criterion(y[:, -3:-1], veci)
loss /= 2.
loss2 = criterion2(y[:, -1], (lbl[:, -3] > 0.5).to(y.dtype))
loss = loss + loss2
return loss
def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False}):
"""
Reshapes and normalizes the input data.
Args:
data (list): List of input data, with channels axis first or last.
normalize_params (dict, optional): Dictionary of normalization parameters. Defaults to {"normalize": False}.
Returns:
list: List of reshaped and normalized data.
"""
if (np.array([td.ndim != 3 for td in data]).sum() > 0 or
np.array([td.shape[0] != 3 for td in data]).sum() > 0):
data_new = []
for td in data:
if td.ndim == 3:
channel_axis0 = channel_axis if channel_axis is not None else np.array(td.shape).argmin()
# put channel axis first
td = np.moveaxis(td, channel_axis0, 0)
td = td[:3] # keep at most 3 channels
if td.ndim == 2 or (td.ndim == 3 and td.shape[0] == 1):
td = np.stack((td, 0 * td, 0 * td), axis=0)
elif td.ndim == 3 and td.shape[0] < 3:
td = np.concatenate((td, 0 * td[:1]), axis=0)
data_new.append(td)
data = data_new
if normalize_params["normalize"]:
data = [
normalize_img(td, normalize=normalize_params, axis=0)
for td in data
]
return data
def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
normalize_params={"normalize": False}):
"""
Get a batch of images and labels.
Args:
inds (list): List of indices indicating which images and labels to retrieve.
data (list or None): List of image data. If None, images will be loaded from files.
labels (list or None): List of label data. If None, labels will be loaded from files.
files (list or None): List of file paths for images.
labels_files (list or None): List of file paths for labels.
normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize).
Returns:
tuple: A tuple containing two lists: the batch of images and the batch of labels.
"""
if data is None:
lbls = None
imgs = [io.imread(files[i]) for i in inds]
imgs = _reshape_norm(imgs, normalize_params=normalize_params)
if labels_files is not None:
lbls = [io.imread(labels_files[i])[1:] for i in inds]
else:
imgs = [data[i] for i in inds]
lbls = [labels[i][1:] for i in inds]
return imgs, lbls
def _reshape_norm_save(files, channels=None, channel_axis=None,
normalize_params={"normalize": False}):
""" not currently used -- normalization happening on each batch if not load_files """
files_new = []
for f in trange(files):
td = io.imread(f)
if channels is not None:
td = convert_image(td, channels=channels,
channel_axis=channel_axis)
td = td.transpose(2, 0, 1)
if normalize_params["normalize"]:
td = normalize_img(td, normalize=normalize_params, axis=0)
fnew = os.path.splitext(str(f))[0] + "_cpnorm.tif"
io.imsave(fnew, td)
files_new.append(fnew)
return files_new
# else:
# train_files = reshape_norm_save(train_files, channels=channels,
# channel_axis=channel_axis, normalize_params=normalize_params)
# elif test_files is not None:
# test_files = reshape_norm_save(test_files, channels=channels,
# channel_axis=channel_axis, normalize_params=normalize_params)
def _process_train_test(train_data=None, train_labels=None, train_files=None,
train_labels_files=None, train_probs=None, test_data=None,
test_labels=None, test_files=None, test_labels_files=None,
test_probs=None, load_files=True, min_train_masks=5,
compute_flows=False, normalize_params={"normalize": False},
channel_axis=None, device=None):
"""
Process train and test data.
Args:
train_data (list or None): List of training data arrays.
train_labels (list or None): List of training label arrays.
train_files (list or None): List of training file paths.
train_labels_files (list or None): List of training label file paths.
train_probs (ndarray or None): Array of training probabilities.
test_data (list or None): List of test data arrays.
test_labels (list or None): List of test label arrays.
test_files (list or None): List of test file paths.
test_labels_files (list or None): List of test label file paths.
test_probs (ndarray or None): Array of test probabilities.
load_files (bool): Whether to load data from files.
min_train_masks (int): Minimum number of masks required for training images.
compute_flows (bool): Whether to compute flows.
channels (list or None): List of channel indices to use.
channel_axis (int or None): Axis of channel dimension.
rgb (bool): Convert training/testing images to RGB.
normalize_params (dict): Dictionary of normalization parameters.
device (torch.device): Device to use for computation.
Returns:
tuple: A tuple containing the processed train and test data and sampling probabilities and diameters.
"""
if device == None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device(
'mps') if torch.backends.mps.is_available() else None
if train_data is not None and train_labels is not None:
# if data is loaded
nimg = len(train_data)
nimg_test = len(test_data) if test_data is not None else None
else:
# otherwise use files
nimg = len(train_files)
if train_labels_files is None:
train_labels_files = [
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files
]
train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)]
if (test_data is not None or
test_files is not None) and test_labels_files is None:
test_labels_files = [
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files
]
test_labels_files = [tf for tf in test_labels_files if os.path.exists(tf)]
if not load_files:
train_logger.info(">>> using files instead of loading dataset")
else:
# load all images
train_logger.info(">>> loading images and labels")
train_data = [io.imread(train_files[i]) for i in trange(nimg)]
train_labels = [io.imread(train_labels_files[i]) for i in trange(nimg)]
nimg_test = len(test_files) if test_files is not None else None
if load_files and nimg_test:
test_data = [io.imread(test_files[i]) for i in trange(nimg_test)]
test_labels = [io.imread(test_labels_files[i]) for i in trange(nimg_test)]
### check that arrays are correct size
if ((train_labels is not None and nimg != len(train_labels)) or
(train_labels_files is not None and nimg != len(train_labels_files))):
error_message = "train data and labels not same length"
train_logger.critical(error_message)
raise ValueError(error_message)
if ((test_labels is not None and nimg_test != len(test_labels)) or
(test_labels_files is not None and nimg_test != len(test_labels_files))):
train_logger.warning("test data and labels not same length, not using")
test_data, test_files = None, None
if train_labels is not None:
if train_labels[0].ndim < 2 or train_data[0].ndim < 2:
error_message = "training data or labels are not at least two-dimensional"
train_logger.critical(error_message)
raise ValueError(error_message)
if train_data[0].ndim > 3:
error_message = "training data is more than three-dimensional (should be 2D or 3D array)"
train_logger.critical(error_message)
raise ValueError(error_message)
### check that flows are computed
if train_labels is not None:
train_labels = dynamics.labels_to_flows(train_labels, files=train_files,
device=device)
if test_labels is not None:
test_labels = dynamics.labels_to_flows(test_labels, files=test_files,
device=device)
elif compute_flows:
for k in trange(nimg):
tl = dynamics.labels_to_flows(io.imread(train_labels_files),
files=train_files, device=device)
if test_files is not None:
for k in trange(nimg_test):
tl = dynamics.labels_to_flows(io.imread(test_labels_files),
files=test_files, device=device)
### compute diameters
nmasks = np.zeros(nimg)
diam_train = np.zeros(nimg)
train_logger.info(">>> computing diameters")
for k in trange(nimg):
tl = (train_labels[k][0]
if train_labels is not None else io.imread(train_labels_files[k])[0])
diam_train[k], dall = utils.diameters(tl)
nmasks[k] = len(dall)
diam_train[diam_train < 5] = 5.
if test_data is not None:
diam_test = np.array(
[utils.diameters(test_labels[k][0])[0] for k in trange(len(test_labels))])
diam_test[diam_test < 5] = 5.
elif test_labels_files is not None:
diam_test = np.array([
utils.diameters(io.imread(test_labels_files[k])[0])[0]
for k in trange(len(test_labels_files))
])
diam_test[diam_test < 5] = 5.
else:
diam_test = None
### check to remove training images with too few masks
if min_train_masks > 0:
nremove = (nmasks < min_train_masks).sum()
if nremove > 0:
train_logger.warning(
f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set"
)
ikeep = np.nonzero(nmasks >= min_train_masks)[0]
if train_data is not None:
train_data = [train_data[i] for i in ikeep]
train_labels = [train_labels[i] for i in ikeep]
if train_files is not None:
train_files = [train_files[i] for i in ikeep]
if train_labels_files is not None:
train_labels_files = [train_labels_files[i] for i in ikeep]
if train_probs is not None:
train_probs = train_probs[ikeep]
diam_train = diam_train[ikeep]
nimg = len(train_data)
### normalize probabilities
train_probs = 1. / nimg * np.ones(nimg,
"float64") if train_probs is None else train_probs
train_probs /= train_probs.sum()
if test_files is not None or test_data is not None:
test_probs = 1. / nimg_test * np.ones(
nimg_test, "float64") if test_probs is None else test_probs
test_probs /= test_probs.sum()
### reshape and normalize train / test data
normed = False
if normalize_params["normalize"]:
train_logger.info(f">>> normalizing {normalize_params}")
if train_data is not None:
train_data = _reshape_norm(train_data, channel_axis=channel_axis,
normalize_params=normalize_params)
normed = True
if test_data is not None:
test_data = _reshape_norm(test_data, channel_axis=channel_axis,
normalize_params=normalize_params)
return (train_data, train_labels, train_files, train_labels_files, train_probs,
diam_train, test_data, test_labels, test_files, test_labels_files,
test_probs, diam_test, normed)
def train_seg(net, train_data=None, train_labels=None, train_files=None,
train_labels_files=None, train_probs=None, test_data=None,
test_labels=None, test_files=None, test_labels_files=None,
test_probs=None, channel_axis=None,
load_files=True, batch_size=1, learning_rate=5e-5, SGD=False,
n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False,
save_path=None, save_every=100, save_each=False, nimg_per_epoch=None,
nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256,
min_train_masks=5, model_name=None, class_weights=None, ts=None):
"""
Train the network with images for segmentation.
Args:
net (object): The network model to train.
train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None.
train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None.
train_labels_files (list or None): List of training label file paths. Defaults to None.
train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None.
test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None.
test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None.
test_labels_files (list or None): List of test label file paths. Defaults to None.
test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None.
load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True.
batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 8.
learning_rate (float or List[float], optional): Float or list/np.ndarray - learning rate for training. Defaults to 0.005.
n_epochs (int, optional): Integer - number of times to go through the whole training set during training. Defaults to 2000.
weight_decay (float, optional): Float - weight decay for the optimizer. Defaults to 1e-5.
momentum (float, optional): Float - momentum for the optimizer. Defaults to 0.9.
SGD (bool, optional): Deprecated in v4.0.1+ - AdamW always used.
normalize (bool or dict, optional): Boolean or dictionary - whether to normalize the data. Defaults to True.
compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False.
save_path (str, optional): String - where to save the trained model. Defaults to None.
save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100.
save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False.
nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None.
nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None.
rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True.
min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5.
model_name (str, optional): String - name of the network. Defaults to None.
Returns:
tuple: A tuple containing the path to the saved model weights, training losses, and test losses.
"""
if SGD:
train_logger.warning("SGD is deprecated, using AdamW instead")
device = net.device
scale_range = 0.5 if scale_range is None else scale_range
if isinstance(normalize, dict):
normalize_params = {**models.normalize_default, **normalize}
elif not isinstance(normalize, bool):
raise ValueError("normalize parameter must be a bool or a dict")
else:
normalize_params = models.normalize_default
normalize_params["normalize"] = normalize
out = _process_train_test(train_data=train_data, train_labels=train_labels,
train_files=train_files, train_labels_files=train_labels_files,
train_probs=train_probs,
test_data=test_data, test_labels=test_labels,
test_files=test_files, test_labels_files=test_labels_files,
test_probs=test_probs,
load_files=load_files, min_train_masks=min_train_masks,
compute_flows=compute_flows, channel_axis=channel_axis,
normalize_params=normalize_params, device=net.device)
(train_data, train_labels, train_files, train_labels_files, train_probs, diam_train,
test_data, test_labels, test_files, test_labels_files, test_probs, diam_test,
normed) = out
# already normalized, do not normalize during training
if normed:
kwargs = {}
else:
kwargs = {"normalize_params": normalize_params, "channel_axis": channel_axis}
net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)
if class_weights is not None and isinstance(class_weights, (list, np.ndarray, tuple)):
class_weights = torch.from_numpy(class_weights).to(device).float()
print(class_weights)
nimg = len(train_data) if train_data is not None else len(train_files)
nimg_test = len(test_data) if test_data is not None else None
nimg_test = len(test_files) if test_files is not None else nimg_test
nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
# learning rate schedule
LR = np.linspace(0, learning_rate, 10)
LR = np.append(LR, learning_rate * np.ones(max(0, n_epochs - 10)))
if n_epochs > 300:
LR = LR[:-100]
for i in range(10):
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
elif n_epochs > 99:
LR = LR[:-50]
for i in range(10):
LR = np.append(LR, LR[-1] / 2 * np.ones(5))
train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}")
train_logger.info(
f">>> AdamW, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}"
)
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate,
weight_decay=weight_decay)
t0 = time.time()
model_name = f"cellpose_{t0}" if model_name is None else model_name
save_path = Path.cwd() if save_path is None else Path(save_path)
filename = save_path / "models" / model_name
(save_path / "models").mkdir(exist_ok=True)
train_logger.info(f">>> saving model to {filename}")
lavg, nsum = 0, 0
train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs)
for iepoch in range(n_epochs):
np.random.seed(iepoch)
if nimg != nimg_per_epoch:
# choose random images for epoch with probability train_probs
rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
p=train_probs)
else:
# otherwise use all images
rperm = np.random.permutation(np.arange(0, nimg))
for param_group in optimizer.param_groups:
param_group["lr"] = LR[iepoch] # set learning rate
net.train()
for k in range(0, nimg_per_epoch, batch_size):
kend = min(k + batch_size, nimg_per_epoch)
inds = rperm[k:kend]
imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels,
files=train_files, labels_files=train_labels_files,
**kwargs)
diams = np.array([diam_train[i] for i in inds])
rsc = diams / net.diam_mean.item() if rescale else np.ones(
len(diams), "float32")
# augmentations
imgi, lbl = random_rotate_and_resize(imgs, Y=lbls, rescale=rsc,
scale_range=scale_range,
xy=(bsize, bsize))[:2]
# network and loss optimization
X = torch.from_numpy(imgi).to(device)
lbl = torch.from_numpy(lbl).to(device)
if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)
y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
loss += loss3
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss = loss.item()
train_loss *= len(imgi)
# keep track of average training loss across epochs
lavg += train_loss
nsum += len(imgi)
# per epoch training loss
train_losses[iepoch] += train_loss
train_losses[iepoch] /= nimg_per_epoch
if iepoch == 5 or iepoch % 10 == 0:
lavgt = 0.
if test_data is not None or test_files is not None:
np.random.seed(42)
if nimg_test != nimg_test_per_epoch:
rperm = np.random.choice(np.arange(0, nimg_test),
size=(nimg_test_per_epoch,), p=test_probs)
else:
rperm = np.random.permutation(np.arange(0, nimg_test))
for ibatch in range(0, len(rperm), batch_size):
with torch.no_grad():
net.eval()
inds = rperm[ibatch:ibatch + batch_size]
imgs, lbls = _get_batch(inds, data=test_data,
labels=test_labels, files=test_files,
labels_files=test_labels_files,
**kwargs)
diams = np.array([diam_test[i] for i in inds])
rsc = diams / net.diam_mean.item() if rescale else np.ones(
len(diams), "float32")
imgi, lbl = random_rotate_and_resize(
imgs, Y=lbls, rescale=rsc, scale_range=scale_range,
xy=(bsize, bsize))[:2]
X = torch.from_numpy(imgi).to(device)
lbl = torch.from_numpy(lbl).to(device)
if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)
y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
loss += loss3
test_loss = loss.item()
test_loss *= len(imgi)
lavgt += test_loss
lavgt /= len(rperm)
test_losses[iepoch] = lavgt
lavg /= nsum
train_logger.info(
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time() - t0:.2f}s"
)
lavg, nsum = 0, 0
if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
if save_each and iepoch != n_epochs - 1: # separate files as model progresses
filename0 = str(filename) + f"_epoch_{iepoch:04d}"
else:
filename0 = filename
train_logger.info(f"saving network parameters to {filename0}")
net.save_model(filename0)
net.save_model(filename)
return filename, train_losses, test_losses

View file

@ -1,4 +1,4 @@
<!doctype html>
<!DOCTYPE html>
<html lang="en">
<head>
@ -10,179 +10,9 @@
</head>
<body>
<div style="padding: 1rem 1rem;">
<div class="mb-3 border" style="padding: 1rem 1rem;">
<div class="input-group mb-3">
<input id="fileInput" type="file" class="form-control" multiple />
</div>
<hr>
<div>
<div style="padding-right: 50rem;">
<div class="input-group mb-3">
<span class="input-group-text">flow threshold:</span>
<input id="flow" type="text" class="form-control" placeholder="" />
</div>
<div class="input-group mb-3">
<span class="input-group-text">cellprob threshold:</span>
<input id="cellprob" type="text" class="form-control" placeholder="" />
</div>
<div class="input-group mb-3">
<span class="input-group-text">diameter:</span>
<input id="diameter" type="text" class="form-control" placeholder="" />
</div>
<label>
<select id="model" class="form-select"></select>
</label>
</label>
</div>
</div>
<br>
<div>
<button id="uploadBtn" class="btn btn-success">Upload</button>
<progress id="bar" max="100" value="0" style="width:300px;"></progress>
</div>
</div>
</div>
<br><br><br>
<script src="https://cdn.jsdelivr.net/npm/axios/dist/axios.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.11.6/dist/umd/popper.min.js"
integrity="sha384-oBqDVmMz9ATKxIep9tiCxS/Z9fNfEXiDAYTujMAeBAsjFuCZSmKbSSUnQlmh/jp3"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.min.js"
integrity="sha384-ep+dxp/oz2RKF89ALMPGc7Z89QFa32C8Uv1A3TcEK8sMzXVysblLA3+eJWTzPJzT"
crossorigin="anonymous"></script>
<script>
const API = "http://10.147.18.141:5000/";
const APT_UPLOAD = API + "run_upload";
const API_MODEL = API + "models";
async function loadModels() {
const select = document.getElementById('model');
// 占位 & 禁用,避免用户在加载中点选
select.disabled = true;
select.innerHTML = '<option value="">加载中…</option>';
// 可选超时控制5 秒)
const controller = new AbortController();
const t = setTimeout(() => controller.abort(), 5000);
try {
const resp = await fetch(API_MODEL, {
headers: { 'Accept': 'application/json' },
signal: controller.signal
});
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
const data = await resp.json();
const list = Array.isArray(data.models) ? data.models : [];
// 记住之前的选择(如果有)
const prev = select.value;
// 重新渲染选项
select.innerHTML = '<option value="" disabled selected>请选择模型</option>';
const seen = new Set();
for (const name of list) {
if (!name || seen.has(name)) continue;
seen.add(name);
const opt = document.createElement('option');
opt.value = String(name);
opt.textContent = String(name);
select.appendChild(opt);
}
// 如果原选择仍然存在,则恢复它
if (prev && seen.has(prev)) {
select.value = prev;
}
} catch (err) {
console.error('加载模型列表失败:', err);
select.innerHTML = '<option value="">加载失败(点击重试)</option>';
// 点击下拉框时重试
select.addEventListener('click', loadModels, { once: true });
} finally {
clearTimeout(t);
select.disabled = false;
}
}
document.addEventListener('DOMContentLoaded', loadModels);
function buildUrl() {
// 获取参数
const model = document.getElementById('model')?.value;
const flow = (document.getElementById('flow')?.value || '').trim();
const cellp = (document.getElementById('cellprob')?.value || '').trim();
const diameter = (document.getElementById('diameter')?.value || '').trim();
// 用 URLSearchParams 组装查询串
const qs = new URLSearchParams({
model: model,
flow_threshold: flow,
cellprob_threshold: cellp,
diameter: diameter
});
return `${API_UPLOAD}?${qs.toString()}`;
}
document.getElementById("uploadBtn").addEventListener("click", async () => {
const input = document.getElementById("fileInput");
if (!input.files.length) return alert("请选择文件");
const fd = new FormData();
for (const f of input.files) fd.append("files", f);
const bar = document.getElementById("bar");
try {
const URL = buildUrl();
const res = await axios.post(URL, fd, {
onUploadProgress: (e) => {
if (e.total) bar.value = Math.round((e.loaded * 100) / e.total);
},
});
// alert("上传成功:" + JSON.stringify(res.data));
// 创建一个提示元素
const notice = document.createElement("div");
notice.style.position = "fixed";
notice.style.top = "20px";
notice.style.left = "50%";
notice.style.transform = "translateX(-50%)";
notice.style.padding = "10px 20px";
notice.style.background = "#4caf50";
notice.style.color = "white";
notice.style.borderRadius = "8px";
notice.style.fontSize = "16px";
notice.style.zIndex = "9999";
let seconds = 3;
notice.textContent = `上传成功!${seconds} 秒后跳转预览页面…`;
document.body.appendChild(notice);
const timer = setInterval(() => {
seconds--;
if (seconds > 0) {
notice.textContent = `上传成功!${seconds} 秒后跳转预览页面…`;
} else {
clearInterval(timer);
document.body.removeChild(notice);
window.location.href = `preview.html?id=${encodeURIComponent(res.data['id'])}`;
}
}, 1000);
} catch (e) {
alert("上传失败:" + (e.response?.data?.message || e.message));
} finally {
bar.value = 0;
}
});
</script>
<a href="run.html">运行</a>
<a href="train.html">训练</a>
</body>
</html>