From 547d96cfa9ff85242e4b4e20d4e5d2b278d596f9 Mon Sep 17 00:00:00 2001 From: yukun-hh Date: Sun, 17 May 2026 16:55:27 +0800 Subject: [PATCH] =?UTF-8?q?compare=5Fmodels.py:=20=E6=B7=BB=E5=8A=A0=20CSV?= =?UTF-8?q?=20=E7=BC=93=E5=AD=98=E6=9C=BA=E5=88=B6=EF=BC=8C=E5=B7=B2?= =?UTF-8?q?=E6=9C=89=E9=A2=84=E6=B5=8B=E6=95=B0=E6=8D=AE=E6=97=B6=E8=B7=B3?= =?UTF-8?q?=E8=BF=87=E9=87=8D=E5=A4=8D=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- baseline/compare_models.py | 42 ++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/baseline/compare_models.py b/baseline/compare_models.py index 7175071..b0e3c4c 100644 --- a/baseline/compare_models.py +++ b/baseline/compare_models.py @@ -5,7 +5,7 @@ baseline/compare_models.py author: yukun-hh date: 2026-5-14 """ -import sys, os +import sys, os, re sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import numpy as np @@ -29,7 +29,7 @@ matplotlib.rcParams['axes.unicode_minus'] = False # ============================================================ # ★★★ 可配置参数 ★★★ # ============================================================ -DATA_ROOT = '../trash_division_data/ultimate_4_class/' +DATA_ROOT = '../../trash_division_data/ultimate_4_class/' BATCH_SIZE = 32 IMAGE_SIZE = 256 NUM_WORKERS = 4 @@ -45,7 +45,7 @@ NUM_CLASSES = 4 def get_resnet34_preds(train_loader, val_loader, device): model = Net(num_classes=NUM_CLASSES) - state_dict = torch.load('best_model.pth', map_location='cpu') + state_dict = torch.load('../best_model.pth', map_location='cpu') if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] elif 'model' in state_dict: @@ -106,6 +106,29 @@ def compute_macro_roc(y_true, y_probs): return all_fpr, mean_tpr, macro_auc +def sanitize_filename(name): + return re.sub(r'[^\w\-_]', '_', name).strip('_') + + +def preds_csv_path(out_dir, model_name): + safe = sanitize_filename(model_name) + return os.path.join(out_dir, f'{safe}_preds.csv') + + +def save_preds_csv(path, y_true, y_preds, y_probs): + header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES)) + data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs]) + np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f') + + +def load_preds_csv(path): + data = np.loadtxt(path, delimiter=',', skiprows=1) + y_true = data[:, 0].astype(int) + y_preds = data[:, 1].astype(int) + y_probs = data[:, 2:2 + NUM_CLASSES] + return y_true, y_preds, y_probs + + if __name__ == '__main__': out_dir = os.path.dirname(os.path.abspath(__file__)) @@ -132,12 +155,19 @@ if __name__ == '__main__': val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, drop_last=False) - # ———— 评估所有模型 ———— + # ———— 评估所有模型(有缓存则跳过)———— results = {} for name, func in MODELS: print(f"\n{'='*50}") - print(f"评估: {name}") - y_true, y_preds, y_probs = func(train_loader, val_loader, device) + csv_path = preds_csv_path(out_dir, name) + if os.path.exists(csv_path): + print(f"加载缓存: {os.path.basename(csv_path)}") + y_true, y_preds, y_probs = load_preds_csv(csv_path) + else: + print(f"评估: {name}") + y_true, y_preds, y_probs = func(train_loader, val_loader, device) + save_preds_csv(csv_path, y_true, y_preds, y_probs) + print(f" 预测数据已保存: {os.path.basename(csv_path)}") acc = accuracy_score(y_true, y_preds) fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs) results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,