diff --git a/README.md b/README.md
index 89ec028..6892f3b 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@
- ⛵️ 训练、分割结果随时下载
- 📚 权重可直接作为后续分割模型
- 🛠️ 一键安装部署脚本
-- 🚧 [TODO] 前端样式美化
+- 🎨 前端样式美化
diff --git a/backend/cp_train.py b/backend/cp_train.py
index d6d9aa4..58594ac 100644
--- a/backend/cp_train.py
+++ b/backend/cp_train.py
@@ -5,6 +5,8 @@ import redis
import datetime
import json
+from sympy import false
+
CONFIG_PATH = Path(__file__).parent / "config.yaml"
cfg = OmegaConf.load(CONFIG_PATH)
cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
@@ -34,11 +36,25 @@ class Cptrain:
@classmethod
async def start_train(cls,
- time: str | None = None,
- model_name: str | None = None,
- image_filter: str = "_img",
- mask_filter: str = "_masks",
- base_model: str = "cpsam"):
+ time: str | None = None,
+ model_name: str | None = None,
+ image_filter: str = "_img",
+ mask_filter: str = "_masks",
+ base_model: str = "cpsam",
+ train_probs: list[float] = None,
+ test_probs: list[float] = None,
+ batch_size: int = 8,
+ learning_rate = 5e-5,
+ n_epochs: int = 100,
+ weight_decay=0.1,
+ normalize: bool =True,
+ compute_flows: bool = False,
+ min_train_masks: int = 5,
+ nimg_per_epoch: int =None,
+ rescale: bool= False,
+ scale_range=None,
+ channel_axis: int = None,
+ ):
train_dir = Path(TRAIN_DIR) / time
test_dir = Path(TEST_DIR) / time
@@ -56,9 +72,13 @@ class Cptrain:
model_path, train_losses, test_losses = train.train_seg(model.net,
train_data=images, train_labels=labels,
test_data=test_images, test_labels=test_labels,
- weight_decay=0.1, learning_rate=1e-5,
- n_epochs=100, model_name=model_name,
- save_path=BASE_DIR)
+ train_probs=train_probs, test_probs=test_probs,
+ weight_decay=weight_decay, learning_rate=learning_rate,
+ n_epochs=n_epochs, model_name=model_name,
+ save_path=BASE_DIR, batch_size=batch_size,
+ normalize=normalize, compute_flows=compute_flows, min_train_masks=min_train_masks,
+ nimg_per_epoch=nimg_per_epoch, rescale=rescale, scale_range=scale_range, channel_axis=channel_axis
+ )
set_status(time, "done", train_losses, test_losses)
print("模型已保存到:", model_path)
diff --git a/backend/flaskApp.py b/backend/flaskApp.py
index bcf20d0..40109f5 100644
--- a/backend/flaskApp.py
+++ b/backend/flaskApp.py
@@ -140,11 +140,47 @@ def run_upload():
@app.post("/train_upload")
def train_upload():
+
+ def _to_float(x, default):
+ try:
+ return float(x)
+ except (TypeError, ValueError):
+ return default
+
+ def _to_int(x, default):
+ try:
+ return int(x)
+ except (TypeError, ValueError):
+ return default
+
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
model_name = request.args.get("model_name") or f"custom_model-{ts}"
image_filter = request.args.get("image_filter") or "_img"
mask_filter = request.args.get("mask_filter") or "_masks"
base_model = request.args.get("base_model") or "cpsam"
+ batch_size = _to_int(request.args.get("batch_size"), 8)
+ learning_rate = _to_float(request.args.get("learning_rate"), 5e-5)
+ n_epochs = _to_int(request.args.get("n_epochs"), 100)
+ weight_decay = _to_float(request.args.get("weight_decay"), 0.1)
+ normalize = request.args.get(
+ "normalize",
+ default=True,
+ type=lambda v: str(v).strip().lower() in ("1","true","t","yes","y","on")
+ )
+ compute_flows = request.args.get(
+ "compute_flows",
+ default=True,
+ type=lambda v: str(v).strip().lower() in ("1","true","t","yes","y","on")
+ )
+ min_train_masks = _to_int(request.args.get(" min_train_masks"), 5)
+ nimg_per_epoch = _to_int(request.args.get("nimg_per_epoch"), None)
+ rescale = request.args.get(
+ "rescale",
+ default=False,
+ type=lambda v: str(v).strip().lower() in ("1","true","t","yes","y","on")
+ )
+ scale_range = _to_float(request.args.get("scale_range"), None)
+ channel_axis = _to_int(request.args.get("channel_axis"), None)
train_files = request.files.getlist("train_files")
test_files = request.files.getlist("test_files")
@@ -172,7 +208,18 @@ def train_upload():
model_name=model_name,
image_filter=image_filter,
mask_filter=mask_filter,
- base_model=base_model
+ base_model=base_model,
+ batch_size=batch_size,
+ learning_rate=learning_rate,
+ n_epochs=n_epochs,
+ weight_decay=weight_decay,
+ normalize=normalize,
+ compute_flows=compute_flows,
+ min_train_masks=min_train_masks,
+ nimg_per_epoch=nimg_per_epoch,
+ rescale=rescale,
+ scale_range=scale_range,
+ channel_axis=channel_axis,
))
fut = executor.submit(job)
@@ -197,6 +244,7 @@ def status():
"""
task_id = request.args.get('id')
st = get_status(task_id)
+ print(st)
if not st:
return jsonify({"ok": True, "exists": False, "status": "not_found"}), 200
return jsonify({"ok": True, "exists": True, **st}), 200
diff --git a/frontend/index.html b/frontend/index.html
index 7a0a62d..fd98f6d 100644
--- a/frontend/index.html
+++ b/frontend/index.html
@@ -1,18 +1,139 @@
-
-
+