mirror of
https://github.com/ClovertaTheTrilobita/cellpose-web.git
synced 2026-04-01 23:14:50 +00:00
feature(train): 新增训练模型功能(施工中)
This commit is contained in:
parent
f55b2cc4d7
commit
6990a1e3c0
1 changed files with 11 additions and 5 deletions
|
|
@ -9,18 +9,21 @@ cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
|
||||||
BASE_DIR = cfg.data.root_dir
|
BASE_DIR = cfg.data.root_dir
|
||||||
TEST_TRAIN_DIR = cfg.data.train.test_train_dir
|
TEST_TRAIN_DIR = cfg.data.train.test_train_dir
|
||||||
TEST_TEST_DIR = cfg.data.train.test_test_dir
|
TEST_TEST_DIR = cfg.data.train.test_test_dir
|
||||||
|
MODELS_DIR = str((CONFIG_PATH.parent / cfg.model.save_dir).resolve())
|
||||||
|
|
||||||
class Cptrain:
|
class Cptrain:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def train_test(cls):
|
def start_train(cls,
|
||||||
|
time: str | None = None,
|
||||||
|
model_name: str | None = None,):
|
||||||
|
|
||||||
train_dir = TEST_TRAIN_DIR
|
train_dir = Path(TEST_TRAIN_DIR) / time
|
||||||
test_dir = TEST_TEST_DIR
|
test_dir = Path(TEST_TEST_DIR) / time
|
||||||
os.makedirs(train_dir, exist_ok=True)
|
os.makedirs(train_dir, exist_ok=True)
|
||||||
os.makedirs(test_dir, exist_ok=True)
|
os.makedirs(test_dir, exist_ok=True)
|
||||||
io.logger_setup()
|
io.logger_setup()
|
||||||
output = io.load_train_test_data(train_dir, test_dir, image_filter="_img",
|
output = io.load_train_test_data(str(train_dir), str(test_dir), image_filter="_img",
|
||||||
mask_filter="_masks", look_one_level_down=False)
|
mask_filter="_masks", look_one_level_down=False)
|
||||||
images, labels, image_names, test_images, test_labels, image_names_test = output
|
images, labels, image_names, test_images, test_labels, image_names_test = output
|
||||||
|
|
||||||
|
|
@ -30,4 +33,7 @@ class Cptrain:
|
||||||
train_data=images, train_labels=labels,
|
train_data=images, train_labels=labels,
|
||||||
test_data=test_images, test_labels=test_labels,
|
test_data=test_images, test_labels=test_labels,
|
||||||
weight_decay=0.1, learning_rate=1e-5,
|
weight_decay=0.1, learning_rate=1e-5,
|
||||||
n_epochs=100, model_name="my_new_model")
|
n_epochs=100, model_name=model_name,
|
||||||
|
save_path=MODELS_DIR)
|
||||||
|
|
||||||
|
print("模型已保存到:", model_path)
|
||||||
Loading…
Reference in a new issue