diff --git a/backend/cp_train.py b/backend/cp_train.py index d6d9aa4..58594ac 100644 --- a/backend/cp_train.py +++ b/backend/cp_train.py @@ -5,6 +5,8 @@ import redis import datetime import json +from sympy import false + CONFIG_PATH = Path(__file__).parent / "config.yaml" cfg = OmegaConf.load(CONFIG_PATH) cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve()) @@ -34,11 +36,25 @@ class Cptrain: @classmethod async def start_train(cls, - time: str | None = None, - model_name: str | None = None, - image_filter: str = "_img", - mask_filter: str = "_masks", - base_model: str = "cpsam"): + time: str | None = None, + model_name: str | None = None, + image_filter: str = "_img", + mask_filter: str = "_masks", + base_model: str = "cpsam", + train_probs: list[float] = None, + test_probs: list[float] = None, + batch_size: int = 8, + learning_rate = 5e-5, + n_epochs: int = 100, + weight_decay=0.1, + normalize: bool =True, + compute_flows: bool = False, + min_train_masks: int = 5, + nimg_per_epoch: int =None, + rescale: bool= False, + scale_range=None, + channel_axis: int = None, + ): train_dir = Path(TRAIN_DIR) / time test_dir = Path(TEST_DIR) / time @@ -56,9 +72,13 @@ class Cptrain: model_path, train_losses, test_losses = train.train_seg(model.net, train_data=images, train_labels=labels, test_data=test_images, test_labels=test_labels, - weight_decay=0.1, learning_rate=1e-5, - n_epochs=100, model_name=model_name, - save_path=BASE_DIR) + train_probs=train_probs, test_probs=test_probs, + weight_decay=weight_decay, learning_rate=learning_rate, + n_epochs=n_epochs, model_name=model_name, + save_path=BASE_DIR, batch_size=batch_size, + normalize=normalize, compute_flows=compute_flows, min_train_masks=min_train_masks, + nimg_per_epoch=nimg_per_epoch, rescale=rescale, scale_range=scale_range, channel_axis=channel_axis + ) set_status(time, "done", train_losses, test_losses) print("模型已保存到:", model_path) diff --git a/backend/flaskApp.py b/backend/flaskApp.py index bcf20d0..e3c14b9 100644 --- a/backend/flaskApp.py +++ b/backend/flaskApp.py @@ -145,6 +145,17 @@ def train_upload(): image_filter = request.args.get("image_filter") or "_img" mask_filter = request.args.get("mask_filter") or "_masks" base_model = request.args.get("base_model") or "cpsam" + batch_size = request.args.get("batch_size") or 8 + learning_rate = request.args.get("learning_rate") or 5e-5 + n_epochs = request.args.get("n_epochs") or 100 + weight_decay = request.args.get("weight_decay") or 0.1 + normalize = request.args.get("normalize") or True + compute_flows = request.args.get("compute_flows") or False + min_train_masks = request.args.get(" min_train_masks") or 5 + nimg_per_epoch = request.args.get("nimg_per_epoch") or None + rescale = request.args.get("rescale") or False + scale_range = request.args.get("scale_range") or None + channel_axis = request.args.get("channel_axis") or None train_files = request.files.getlist("train_files") test_files = request.files.getlist("test_files") @@ -172,7 +183,18 @@ def train_upload(): model_name=model_name, image_filter=image_filter, mask_filter=mask_filter, - base_model=base_model + base_model=base_model, + batch_size=batch_size, + learning_rate=learning_rate, + n_epochs=n_epochs, + weight_decay=weight_decay, + normalize=normalize, + compute_flows=compute_flows, + min_train_masks=min_train_masks, + nimg_per_epoch=nimg_per_epoch, + rescale=rescale, + scale_range=scale_range, + channel_axis=channel_axis, )) fut = executor.submit(job) diff --git a/frontend/index.html b/frontend/index.html index 7a0a62d..c3684bd 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -1,18 +1,139 @@ - - +
- - -