refactor(path): 新增config.yaml作为配置文件,重构项目路径

This commit is contained in:
ClovertaTheTrilobita 2025-09-22 17:17:33 +00:00
parent 86a460afc5
commit 04e2bf0edc
4 changed files with 88 additions and 15 deletions

16
backend/config.yaml Normal file
View file

@ -0,0 +1,16 @@
model:
data:
root_dir: .
run:
test_output_dir: ${data.root_dir}/run/test_output
output_dir: ${data.root_dir}/run/output
train:
test_test_dir: ${data.root_dir}/train/test_test
test_train_dir: ${data.root_dir}/train/test_train
test_dir: ${data.root_dir}/train/test
train_dir: ${data.root_dir}/train/train
upload_dir: ${data.root_dir}/uploads

View file

@ -4,7 +4,16 @@ from PIL import Image
import numpy as np import numpy as np
import os, datetime import os, datetime
import time import time
from omegaconf import OmegaConf
from pathlib import Path
CONFIG_PATH = Path(__file__).parent / "config.yaml"
cfg = OmegaConf.load(CONFIG_PATH)
cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
BASE_DIR = cfg.data.root_dir
UPLOAD_DIR = cfg.data.upload_dir
OUTPUT_DIR = cfg.data.run.output_dir
OUTPUT_TEST_DIR = cfg.data.run.test_output_dir
class Cprun: class Cprun:
@ -22,7 +31,7 @@ class Cprun:
) )
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}" ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
outdir = os.path.join(os.path.dirname(__file__), "test_output", ts) outdir = os.path.join(OUTPUT_TEST_DIR, ts)
os.makedirs(outdir, exist_ok=True) # 自动创建目录 os.makedirs(outdir, exist_ok=True) # 自动创建目录
for img, mask, flow, name in zip(imgs, masks, flows, files): for img, mask, flow, name in zip(imgs, masks, flows, files):
base = os.path.join(outdir, os.path.splitext(os.path.basename(name))[0]) base = os.path.join(outdir, os.path.splitext(os.path.basename(name))[0])
@ -63,7 +72,7 @@ class Cprun:
) )
ts = time ts = time
outdir = os.path.join(os.path.dirname(__file__), "output", ts) outdir = os.path.join(OUTPUT_DIR, ts)
os.makedirs(outdir, exist_ok=True) # 自动创建目录 os.makedirs(outdir, exist_ok=True) # 自动创建目录
for img, mask, flow, name in zip(imgs, masks, flows, files): for img, mask, flow, name in zip(imgs, masks, flows, files):
base = os.path.join(outdir, os.path.splitext(os.path.basename(name))[0]) base = os.path.join(outdir, os.path.splitext(os.path.basename(name))[0])

View file

@ -0,0 +1,33 @@
import os.path
from cellpose import io, models, train
from pathlib import Path
from omegaconf import OmegaConf
CONFIG_PATH = Path(__file__).parent / "config.yaml"
cfg = OmegaConf.load(CONFIG_PATH)
cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
BASE_DIR = cfg.data.root_dir
TEST_TRAIN_DIR = cfg.data.train.test_train_dir
TEST_TEST_DIR = cfg.data.train.test_test_dir
class Cptrain:
@classmethod
def train_test(cls):
train_dir = TEST_TRAIN_DIR
test_dir = TEST_TEST_DIR
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
io.logger_setup()
output = io.load_train_test_data(train_dir, test_dir, image_filter="_img",
mask_filter="_masks", look_one_level_down=False)
images, labels, image_names, test_images, test_labels, image_names_test = output
model = models.CellposeModel(gpu=True)
model_path, train_losses, test_losses = train.train_seg(model.net,
train_data=images, train_labels=labels,
test_data=test_images, test_labels=test_labels,
weight_decay=0.1, learning_rate=1e-5,
n_epochs=100, model_name="my_new_model")

View file

@ -1,17 +1,31 @@
import asyncio import asyncio
import base64 import base64
import datetime
import json
import os
import redis
import shutil
import time
from omegaconf import OmegaConf
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from flask import Flask, send_from_directory, request, jsonify
import os, shutil, time, threading, datetime, json, redis
from werkzeug.utils import secure_filename
from flask_cors import CORS
from pathlib import Path from pathlib import Path
from flask import Flask, send_from_directory, request, jsonify
from flask_cors import CORS
from werkzeug.utils import secure_filename
from cp_run import Cprun from cp_run import Cprun
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
BASE_DIR = Path(__file__).resolve().parent
UPLOAD_DIR = BASE_DIR / "uploads" CONFIG_PATH = Path(__file__).parent / "config.yaml"
cfg = OmegaConf.load(CONFIG_PATH)
cfg.data.root_dir = str((CONFIG_PATH.parent / cfg.data.root_dir).resolve())
BASE_DIR = cfg.data.root_dir
UPLOAD_DIR = cfg.data.upload_dir
OUTPUT_DIR = cfg.data.run.output_dir
os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(UPLOAD_DIR, exist_ok=True)
executor = ThreadPoolExecutor(max_workers=4) executor = ThreadPoolExecutor(max_workers=4)
TASKS = {} TASKS = {}
@ -40,12 +54,13 @@ def test_download():
@app.get("/dl") @app.get("/dl")
def download(): def download():
timestamp = request.args.get("id") timestamp = request.args.get("id")
input_dir = os.path.join(BASE_DIR, "output", timestamp) input_dir = os.path.join(OUTPUT_DIR, timestamp)
output_dir = os.path.join(BASE_DIR, "output/tmp", timestamp) # 不要加 .zipmake_archive 会自动加 output_dir = os.path.join(OUTPUT_DIR, "tmp", timestamp) # 不要加 .zipmake_archive 会自动加
os.makedirs(BASE_DIR / "output/tmp", exist_ok=True) # 确保 tmp 存在 os.makedirs(Path(OUTPUT_DIR) / "tmp", exist_ok=True) # 确保 tmp 存在
shutil.make_archive(output_dir, 'zip', input_dir) shutil.make_archive(output_dir, 'zip', input_dir)
print(f"压缩完成: {output_dir}.zip") print(f"压缩完成: {output_dir}.zip")
return send_from_directory("output/tmp", f"{timestamp}.zip", as_attachment=True) print(OUTPUT_DIR)
return send_from_directory(f"{OUTPUT_DIR}/tmp/", f"{timestamp}.zip", as_attachment=True)
@app.post("/upload") @app.post("/upload")
def upload(): def upload():
@ -75,14 +90,14 @@ def upload():
# 将文件保存在本地目录中 # 将文件保存在本地目录中
ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}" ts = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + f"-{int(time.time()*1000)%1000:03d}"
os.makedirs(UPLOAD_DIR / ts, exist_ok=True) os.makedirs(Path(UPLOAD_DIR) / ts, exist_ok=True)
files = request.files.getlist("files") files = request.files.getlist("files")
saved = [] saved = []
for f in files: for f in files:
if not f or f.filename == "": if not f or f.filename == "":
continue continue
name = secure_filename(f.filename) name = secure_filename(f.filename)
f.save(os.path.join(UPLOAD_DIR / ts, name)) f.save(os.path.join(UPLOAD_DIR, ts, name))
saved.append(os.path.join(UPLOAD_DIR, ts, name)) saved.append(os.path.join(UPLOAD_DIR, ts, name))
# 新建一个线程,防止返回被阻塞 # 新建一个线程,防止返回被阻塞
@ -125,7 +140,7 @@ def status():
@app.get("/preview") @app.get("/preview")
def preview(): def preview():
task_id = request.args.get('id') task_id = request.args.get('id')
task_dir = BASE_DIR / "output" / task_id task_dir = Path(OUTPUT_DIR) / task_id
if not task_dir.exists(): if not task_dir.exists():
return jsonify({"ok": False, "error": "task not found"}), 200 return jsonify({"ok": False, "error": "task not found"}), 200