compare_models.py: 新增 PR 曲线对比图

This commit is contained in:
yukun-hh 2026-05-17 18:31:32 +08:00
parent 010dacb533
commit 3fee1c82ab

View file

@ -17,7 +17,10 @@ from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc, accuracy_score
from sklearn.metrics import (
roc_curve, auc, accuracy_score,
precision_recall_curve, average_precision_score,
)
from Model import Net
from Dataloader import RobustImageFolder
@ -108,6 +111,20 @@ def compute_macro_roc(y_true, y_probs):
return all_fpr, mean_tpr, macro_auc
def compute_macro_pr(y_true, y_probs):
one_hot = np.eye(NUM_CLASSES)[y_true]
prec_dict, rec_dict = {}, {}
for c in range(NUM_CLASSES):
prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c])
all_rec = np.linspace(0, 1, 200)
mean_prec = np.zeros_like(all_rec)
for c in range(NUM_CLASSES):
mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1])
mean_prec /= NUM_CLASSES
macro_ap = average_precision_score(one_hot, y_probs, average='macro')
return all_rec, mean_prec, macro_ap
def sanitize_filename(name):
return re.sub(r'[^\w\-_]', '_', name).strip('_')
@ -172,9 +189,11 @@ if __name__ == '__main__':
print(f" 预测数据已保存: {os.path.basename(csv_path)}")
acc = accuracy_score(y_true, y_preds)
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
rec, prec, macro_ap = compute_macro_pr(y_true, y_probs)
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc}
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f}")
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc,
'rec': rec, 'prec': prec, 'ap': macro_ap}
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}")
# ———— ROC 对比图 ————
fig, ax = plt.subplots(figsize=(8, 7))
@ -193,6 +212,22 @@ if __name__ == '__main__':
plt.show()
print(f"\nROC 对比图已保存: {roc_path}")
# ———— PR 对比图 ————
fig, ax = plt.subplots(figsize=(8, 7))
for i, (name, r) in enumerate(results.items()):
color = COLORS[i % len(COLORS)]
ax.plot(r['rec'], r['prec'], color=color, lw=2,
label=f"{name} (AP={r['ap']:.4f})")
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14)
ax.legend(loc='lower left'); ax.grid(True, alpha=0.3)
plt.tight_layout()
pr_path = os.path.join(out_dir, 'pr_comparison.png')
plt.savefig(pr_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"PR 对比图已保存: {pr_path}")
# ———— 准确率柱状图 ————
names = list(results.keys())
accs = [results[n]['acc'] for n in names]