diff --git a/baseline/compare_models.py b/baseline/compare_models.py index 7175071..b0e3c4c 100644 --- a/baseline/compare_models.py +++ b/baseline/compare_models.py @@ -5,7 +5,7 @@ baseline/compare_models.py author: yukun-hh date: 2026-5-14 """ -import sys, os +import sys, os, re sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import numpy as np @@ -29,7 +29,7 @@ matplotlib.rcParams['axes.unicode_minus'] = False # ============================================================ # ★★★ 可配置参数 ★★★ # ============================================================ -DATA_ROOT = '../trash_division_data/ultimate_4_class/' +DATA_ROOT = '../../trash_division_data/ultimate_4_class/' BATCH_SIZE = 32 IMAGE_SIZE = 256 NUM_WORKERS = 4 @@ -45,7 +45,7 @@ NUM_CLASSES = 4 def get_resnet34_preds(train_loader, val_loader, device): model = Net(num_classes=NUM_CLASSES) - state_dict = torch.load('best_model.pth', map_location='cpu') + state_dict = torch.load('../best_model.pth', map_location='cpu') if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] elif 'model' in state_dict: @@ -106,6 +106,29 @@ def compute_macro_roc(y_true, y_probs): return all_fpr, mean_tpr, macro_auc +def sanitize_filename(name): + return re.sub(r'[^\w\-_]', '_', name).strip('_') + + +def preds_csv_path(out_dir, model_name): + safe = sanitize_filename(model_name) + return os.path.join(out_dir, f'{safe}_preds.csv') + + +def save_preds_csv(path, y_true, y_preds, y_probs): + header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES)) + data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs]) + np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f') + + +def load_preds_csv(path): + data = np.loadtxt(path, delimiter=',', skiprows=1) + y_true = data[:, 0].astype(int) + y_preds = data[:, 1].astype(int) + y_probs = data[:, 2:2 + NUM_CLASSES] + return y_true, y_preds, y_probs + + if __name__ == '__main__': out_dir = os.path.dirname(os.path.abspath(__file__)) @@ -132,12 +155,19 @@ if __name__ == '__main__': val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, drop_last=False) - # ———— 评估所有模型 ———— + # ———— 评估所有模型(有缓存则跳过)———— results = {} for name, func in MODELS: print(f"\n{'='*50}") - print(f"评估: {name}") - y_true, y_preds, y_probs = func(train_loader, val_loader, device) + csv_path = preds_csv_path(out_dir, name) + if os.path.exists(csv_path): + print(f"加载缓存: {os.path.basename(csv_path)}") + y_true, y_preds, y_probs = load_preds_csv(csv_path) + else: + print(f"评估: {name}") + y_true, y_preds, y_probs = func(train_loader, val_loader, device) + save_preds_csv(csv_path, y_true, y_preds, y_probs) + print(f" 预测数据已保存: {os.path.basename(csv_path)}") acc = accuracy_score(y_true, y_preds) fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs) results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,