180 lines
7.2 KiB
Python
180 lines
7.2 KiB
Python
"""
|
||
baseline/compare_models.py
|
||
多模型对比:ROC 曲线 + 准确率柱状图
|
||
添加新模型只需在 MODELS 列表加一行,无需修改绘图代码
|
||
author: yukun-hh
|
||
date: 2026-5-14
|
||
"""
|
||
import sys, os
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from torchvision import transforms
|
||
from tqdm import tqdm
|
||
|
||
from sklearn.metrics import roc_curve, auc, accuracy_score
|
||
|
||
from Model import Net
|
||
from Dataloader import RobustImageFolder
|
||
from baseline.VGG_KNN import VGGKNNBaseline
|
||
|
||
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||
matplotlib.rcParams['axes.unicode_minus'] = False
|
||
|
||
# ============================================================
|
||
# ★★★ 可配置参数 ★★★
|
||
# ============================================================
|
||
DATA_ROOT = '../trash_division_data/ultimate_4_class/'
|
||
BATCH_SIZE = 32
|
||
IMAGE_SIZE = 256
|
||
NUM_WORKERS = 4
|
||
K_KNN = 5
|
||
# ============================================================
|
||
|
||
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
|
||
NUM_CLASSES = 4
|
||
|
||
# ============================================================
|
||
# 预测函数 — 每个函数签名: (train_loader, val_loader, device) -> (y_true, y_preds, y_probs)
|
||
# ============================================================
|
||
|
||
def get_resnet34_preds(train_loader, val_loader, device):
|
||
model = Net(num_classes=NUM_CLASSES)
|
||
state_dict = torch.load('best_model.pth', map_location='cpu')
|
||
if 'model_state_dict' in state_dict:
|
||
state_dict = state_dict['model_state_dict']
|
||
elif 'model' in state_dict:
|
||
state_dict = state_dict['model']
|
||
model.load_state_dict(state_dict)
|
||
model = model.to(device).eval()
|
||
|
||
y_true, y_preds, y_probs = [], [], []
|
||
with torch.no_grad():
|
||
for images, labels in tqdm(val_loader, desc='ResNet-34'):
|
||
images, labels = images.to(device), labels
|
||
logits = model(images)
|
||
probs = torch.softmax(logits, dim=1)
|
||
preds = probs.argmax(dim=1)
|
||
y_true.append(labels.numpy())
|
||
y_preds.append(preds.cpu().numpy())
|
||
y_probs.append(probs.cpu().numpy())
|
||
return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs)
|
||
|
||
|
||
def get_vgg_knn_preds(train_loader, val_loader, device):
|
||
baseline = VGGKNNBaseline(k=K_KNN, device=device)
|
||
baseline.fit(train_loader)
|
||
return baseline.predict(val_loader)
|
||
|
||
|
||
# ============================================================
|
||
# ★ 模型注册表 — 添加新模型只需在这里加一行 ★
|
||
# ============================================================
|
||
|
||
MODELS = [
|
||
('ResNet-34', get_resnet34_preds),
|
||
('VGG16 + KNN (K=5)', get_vgg_knn_preds),
|
||
# 未来轻松扩展示例:
|
||
# ('ResNet-18 (pretrained)', get_resnet18_preds),
|
||
# ('ResNet-50 (pretrained)', get_resnet50_preds),
|
||
# ('ResNet-34 (finetuned)', get_finetuned_preds),
|
||
]
|
||
|
||
# ============================================================
|
||
# 调色板 (扩展时无需修改)
|
||
# ============================================================
|
||
COLORS = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b',
|
||
'#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
|
||
|
||
|
||
def compute_macro_roc(y_true, y_probs):
|
||
one_hot = np.eye(NUM_CLASSES)[y_true]
|
||
fpr_dict, tpr_dict = {}, {}
|
||
for c in range(NUM_CLASSES):
|
||
fpr_dict[c], tpr_dict[c], _ = roc_curve(one_hot[:, c], y_probs[:, c])
|
||
all_fpr = np.unique(np.concatenate([fpr_dict[c] for c in range(NUM_CLASSES)]))
|
||
mean_tpr = np.zeros_like(all_fpr)
|
||
for c in range(NUM_CLASSES):
|
||
mean_tpr += np.interp(all_fpr, fpr_dict[c], tpr_dict[c])
|
||
mean_tpr /= NUM_CLASSES
|
||
macro_auc = auc(all_fpr, mean_tpr)
|
||
return all_fpr, mean_tpr, macro_auc
|
||
|
||
|
||
if __name__ == '__main__':
|
||
out_dir = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
device = torch.device('cuda' if torch.cuda.is_available()
|
||
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
|
||
else 'cpu')
|
||
print(f"Device: {device}")
|
||
|
||
val_transform = transforms.Compose([
|
||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||
std=[0.229, 0.224, 0.225]),
|
||
])
|
||
|
||
train_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'train'),
|
||
transform=val_transform)
|
||
val_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'val'),
|
||
transform=val_transform)
|
||
print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}")
|
||
|
||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False,
|
||
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
|
||
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
|
||
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
|
||
|
||
# ———— 评估所有模型 ————
|
||
results = {}
|
||
for name, func in MODELS:
|
||
print(f"\n{'='*50}")
|
||
print(f"评估: {name}")
|
||
y_true, y_preds, y_probs = func(train_loader, val_loader, device)
|
||
acc = accuracy_score(y_true, y_preds)
|
||
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
|
||
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
|
||
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc}
|
||
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f}")
|
||
|
||
# ———— ROC 对比图 ————
|
||
fig, ax = plt.subplots(figsize=(8, 7))
|
||
for i, (name, r) in enumerate(results.items()):
|
||
color = COLORS[i % len(COLORS)]
|
||
ax.plot(r['fpr'], r['tpr'], color=color, lw=2,
|
||
label=f"{name} (AUC={r['auc']:.4f})")
|
||
ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
|
||
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
|
||
ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
|
||
ax.set_title('ROC Curve Comparison (Macro-Average)', fontsize=14)
|
||
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)
|
||
plt.tight_layout()
|
||
roc_path = os.path.join(out_dir, 'roc_comparison.png')
|
||
plt.savefig(roc_path, dpi=150, bbox_inches='tight')
|
||
plt.show()
|
||
print(f"\nROC 对比图已保存: {roc_path}")
|
||
|
||
# ———— 准确率柱状图 ————
|
||
names = list(results.keys())
|
||
accs = [results[n]['acc'] for n in names]
|
||
fig, ax = plt.subplots(figsize=(8, 5))
|
||
bar_colors = [COLORS[i % len(COLORS)] for i in range(len(names))]
|
||
bars = ax.bar(names, accs, color=bar_colors, edgecolor='white', linewidth=1.2)
|
||
for bar, acc in zip(bars, accs):
|
||
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
|
||
f'{acc:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
|
||
ax.set_ylim(0, max(accs) * 1.15)
|
||
ax.set_ylabel('Accuracy'); ax.set_title('Accuracy Comparison', fontsize=14)
|
||
ax.grid(True, alpha=0.3, axis='y')
|
||
plt.tight_layout()
|
||
bar_path = os.path.join(out_dir, 'accuracy_bar.png')
|
||
plt.savefig(bar_path, dpi=150, bbox_inches='tight')
|
||
plt.show()
|
||
print(f"准确率柱状图已保存: {bar_path}")
|