From 6990a1e3c0bb56cb7e764b65a89a8fc09baaa36a Mon Sep 17 00:00:00 2001 From: ClovertaTheTrilobita Date: Mon, 22 Sep 2025 18:28:19 +0000 Subject: [PATCH] =?UTF-8?q?feature(train):=20=E6=96=B0=E5=A2=9E=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E6=A8=A1=E5=9E=8B=E5=8A=9F=E8=83=BD=EF=BC=88=E6=96=BD?= =?UTF-8?q?=E5=B7=A5=E4=B8=AD=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cp_train.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/backend/cp_train.py b/backend/cp_train.py index 9a94e38..a624ac7 100644 --- a/backend/cp_train.py +++ b/backend/cp_train.py @@ -9,18 +9,21 @@ cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve()) BASE_DIR = cfg.data.root_dir TEST_TRAIN_DIR = cfg.data.train.test_train_dir TEST_TEST_DIR = cfg.data.train.test_test_dir +MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve()) class Cptrain: @classmethod - def train_test(cls): + def start_train(cls, + time: str | None = None, + model_name: str | None = None,): - train_dir = TEST_TRAIN_DIR - test_dir = TEST_TEST_DIR + train_dir = Path(TEST_TRAIN_DIR) / time + test_dir = Path(TEST_TEST_DIR) / time os.makedirs(train_dir, exist_ok=True) os.makedirs(test_dir, exist_ok=True) io.logger_setup() - output = io.load_train_test_data(train_dir, test_dir, image_filter="_img", + output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter="_img", mask_filter="_masks", look_one_level_down=False) images, labels, image_names, test_images, test_labels, image_names_test = output @@ -30,4 +33,7 @@ class Cptrain: train_data=images, train_labels=labels, test_data=test_images, test_labels=test_labels, weight_decay=0.1, learning_rate=1e-5, - n_epochs=100, model_name="my_new_model") \ No newline at end of file + n_epochs=100, model_name=model_name, + save_path=MODELS_DIR) + + print("模型已保存到:", model_path) \ No newline at end of file