""" baseline/compare_models.py 多模型对比:ROC 曲线 + 准确率柱状图 添加新模型只需在 MODELS 列表加一行,无需修改绘图代码 author: yukun-hh date: 2026-5-14 """ import sys, os, re sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import numpy as np import matplotlib.pyplot as plt import matplotlib import torch from torch.utils.data import DataLoader from torchvision import transforms from tqdm import tqdm from sklearn.metrics import ( roc_curve, auc, accuracy_score, precision_recall_curve, average_precision_score, ) from Model import Net from Dataloader import RobustImageFolder from baseline.VGG_KNN import VGGKNNBaseline from baseline.ResNet34_Pretrained_10pct import get_resnet34_10pct_preds from baseline.HOG_Baseline import get_hog_lr_preds matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] matplotlib.rcParams['axes.unicode_minus'] = False # ============================================================ # ★★★ 可配置参数 ★★★ # ============================================================ DATA_ROOT = '../../trash_division_data/ultimate_4_class/' BATCH_SIZE = 32 IMAGE_SIZE = 256 NUM_WORKERS = 4 K_KNN = 5 # ============================================================ CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] NUM_CLASSES = 4 # ============================================================ # 预测函数 — 每个函数签名: (train_loader, val_loader, device) -> (y_true, y_preds, y_probs) # ============================================================ def get_resnet34_preds(train_loader, val_loader, device): model = Net(num_classes=NUM_CLASSES) state_dict = torch.load('../best_model.pth', map_location='cpu') if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] elif 'model' in state_dict: state_dict = state_dict['model'] model.load_state_dict(state_dict) model = model.to(device).eval() y_true, y_preds, y_probs = [], [], [] with torch.no_grad(): for images, labels in tqdm(val_loader, desc='ResNet-34'): images, labels = images.to(device), labels logits = model(images) probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) y_true.append(labels.numpy()) y_preds.append(preds.cpu().numpy()) y_probs.append(probs.cpu().numpy()) return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs) def get_vgg_knn_preds(train_loader, val_loader, device): baseline = VGGKNNBaseline(k=K_KNN, device=device) baseline.fit(train_loader) return baseline.predict(val_loader) # ============================================================ # ★ 模型注册表 — 添加新模型只需在这里加一行 ★ # ============================================================ MODELS = [ ('ResNet-34', get_resnet34_preds), ('ResNet-34 (10% Fine-tune)', get_resnet34_10pct_preds), ('VGG16 + KNN (K=5)', get_vgg_knn_preds), ('HOG + LogisticRegression', get_hog_lr_preds), # 未来轻松扩展示例: # ('ResNet-18 (pretrained)', get_resnet18_preds), # ('ResNet-50 (pretrained)', get_resnet50_preds), # ('ResNet-34 (finetuned)', get_finetuned_preds), ] # ============================================================ # 调色板 (扩展时无需修改) # ============================================================ COLORS = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] def compute_macro_roc(y_true, y_probs): one_hot = np.eye(NUM_CLASSES)[y_true] fpr_dict, tpr_dict = {}, {} for c in range(NUM_CLASSES): fpr_dict[c], tpr_dict[c], _ = roc_curve(one_hot[:, c], y_probs[:, c]) all_fpr = np.unique(np.concatenate([fpr_dict[c] for c in range(NUM_CLASSES)])) mean_tpr = np.zeros_like(all_fpr) for c in range(NUM_CLASSES): mean_tpr += np.interp(all_fpr, fpr_dict[c], tpr_dict[c]) mean_tpr /= NUM_CLASSES macro_auc = auc(all_fpr, mean_tpr) return all_fpr, mean_tpr, macro_auc def compute_macro_pr(y_true, y_probs): one_hot = np.eye(NUM_CLASSES)[y_true] prec_dict, rec_dict = {}, {} for c in range(NUM_CLASSES): prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c]) all_rec = np.linspace(0, 1, 200) mean_prec = np.zeros_like(all_rec) for c in range(NUM_CLASSES): mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1]) mean_prec /= NUM_CLASSES macro_ap = average_precision_score(one_hot, y_probs, average='macro') return all_rec, mean_prec, macro_ap def sanitize_filename(name): return re.sub(r'[^\w\-_]', '_', name).strip('_') def preds_csv_path(out_dir, model_name): safe = sanitize_filename(model_name) return os.path.join(out_dir, f'{safe}_preds.csv') def save_preds_csv(path, y_true, y_preds, y_probs): header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES)) data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs]) np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f') def load_preds_csv(path): data = np.loadtxt(path, delimiter=',', skiprows=1) y_true = data[:, 0].astype(int) y_preds = data[:, 1].astype(int) y_probs = data[:, 2:2 + NUM_CLASSES] return y_true, y_preds, y_probs if __name__ == '__main__': out_dir = os.path.dirname(os.path.abspath(__file__)) device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available() else 'cpu') print(f"Device: {device}") val_transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'train'), transform=val_transform) val_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'val'), transform=val_transform) print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}") train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, drop_last=False) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, drop_last=False) # ———— 评估所有模型(有缓存则跳过)———— results = {} for name, func in MODELS: print(f"\n{'='*50}") csv_path = preds_csv_path(out_dir, name) if os.path.exists(csv_path): print(f"加载缓存: {os.path.basename(csv_path)}") y_true, y_preds, y_probs = load_preds_csv(csv_path) else: print(f"评估: {name}") y_true, y_preds, y_probs = func(train_loader, val_loader, device) save_preds_csv(csv_path, y_true, y_preds, y_probs) print(f" 预测数据已保存: {os.path.basename(csv_path)}") acc = accuracy_score(y_true, y_preds) fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs) rec, prec, macro_ap = compute_macro_pr(y_true, y_probs) results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs, 'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc, 'rec': rec, 'prec': prec, 'ap': macro_ap} print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}") # ———— ROC 对比图 ———— fig, ax = plt.subplots(figsize=(8, 7)) for i, (name, r) in enumerate(results.items()): color = COLORS[i % len(COLORS)] ax.plot(r['fpr'], r['tpr'], color=color, lw=2, label=f"{name} (AUC={r['auc']:.4f})") ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5) ax.set_xlim(0, 1); ax.set_ylim(0, 1.05) ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate') ax.set_title('ROC Curve Comparison (Macro-Average)', fontsize=14) ax.legend(loc='lower right'); ax.grid(True, alpha=0.3) plt.tight_layout() roc_path = os.path.join(out_dir, 'roc_comparison.png') plt.savefig(roc_path, dpi=150, bbox_inches='tight') plt.show() print(f"\nROC 对比图已保存: {roc_path}") # ———— PR 对比图 ———— fig, ax = plt.subplots(figsize=(8, 7)) for i, (name, r) in enumerate(results.items()): color = COLORS[i % len(COLORS)] ax.plot(r['rec'], r['prec'], color=color, lw=2, label=f"{name} (AP={r['ap']:.4f})") ax.set_xlim(0, 1); ax.set_ylim(0, 1.05) ax.set_xlabel('Recall'); ax.set_ylabel('Precision') ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14) ax.legend(loc='lower left'); ax.grid(True, alpha=0.3) plt.tight_layout() pr_path = os.path.join(out_dir, 'pr_comparison.png') plt.savefig(pr_path, dpi=150, bbox_inches='tight') plt.show() print(f"PR 对比图已保存: {pr_path}") # ———— 准确率柱状图 ———— names = list(results.keys()) accs = [results[n]['acc'] for n in names] fig, ax = plt.subplots(figsize=(8, 5)) bar_colors = [COLORS[i % len(COLORS)] for i in range(len(names))] bars = ax.bar(names, accs, color=bar_colors, edgecolor='white', linewidth=1.2) for bar, acc in zip(bars, accs): ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, f'{acc:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold') ax.set_ylim(0, max(accs) * 1.15) ax.set_ylabel('Accuracy'); ax.set_title('Accuracy Comparison', fontsize=14) ax.grid(True, alpha=0.3, axis='y') plt.tight_layout() bar_path = os.path.join(out_dir, 'accuracy_bar.png') plt.savefig(bar_path, dpi=150, bbox_inches='tight') plt.show() print(f"准确率柱状图已保存: {bar_path}")