From 3fee1c82abbd17a4382023b365d5bf18823493b0 Mon Sep 17 00:00:00 2001 From: yukun-hh Date: Sun, 17 May 2026 18:31:32 +0800 Subject: [PATCH] =?UTF-8?q?compare=5Fmodels.py:=20=E6=96=B0=E5=A2=9E=20PR?= =?UTF-8?q?=20=E6=9B=B2=E7=BA=BF=E5=AF=B9=E6=AF=94=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- baseline/compare_models.py | 41 +++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/baseline/compare_models.py b/baseline/compare_models.py index f8fdc59..eb4332e 100644 --- a/baseline/compare_models.py +++ b/baseline/compare_models.py @@ -17,7 +17,10 @@ from torch.utils.data import DataLoader from torchvision import transforms from tqdm import tqdm -from sklearn.metrics import roc_curve, auc, accuracy_score +from sklearn.metrics import ( + roc_curve, auc, accuracy_score, + precision_recall_curve, average_precision_score, +) from Model import Net from Dataloader import RobustImageFolder @@ -108,6 +111,20 @@ def compute_macro_roc(y_true, y_probs): return all_fpr, mean_tpr, macro_auc +def compute_macro_pr(y_true, y_probs): + one_hot = np.eye(NUM_CLASSES)[y_true] + prec_dict, rec_dict = {}, {} + for c in range(NUM_CLASSES): + prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c]) + all_rec = np.linspace(0, 1, 200) + mean_prec = np.zeros_like(all_rec) + for c in range(NUM_CLASSES): + mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1]) + mean_prec /= NUM_CLASSES + macro_ap = average_precision_score(one_hot, y_probs, average='macro') + return all_rec, mean_prec, macro_ap + + def sanitize_filename(name): return re.sub(r'[^\w\-_]', '_', name).strip('_') @@ -172,9 +189,11 @@ if __name__ == '__main__': print(f" 预测数据已保存: {os.path.basename(csv_path)}") acc = accuracy_score(y_true, y_preds) fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs) + rec, prec, macro_ap = compute_macro_pr(y_true, y_probs) results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs, - 'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc} - print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f}") + 'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc, + 'rec': rec, 'prec': prec, 'ap': macro_ap} + print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}") # ———— ROC 对比图 ———— fig, ax = plt.subplots(figsize=(8, 7)) @@ -193,6 +212,22 @@ if __name__ == '__main__': plt.show() print(f"\nROC 对比图已保存: {roc_path}") + # ———— PR 对比图 ———— + fig, ax = plt.subplots(figsize=(8, 7)) + for i, (name, r) in enumerate(results.items()): + color = COLORS[i % len(COLORS)] + ax.plot(r['rec'], r['prec'], color=color, lw=2, + label=f"{name} (AP={r['ap']:.4f})") + ax.set_xlim(0, 1); ax.set_ylim(0, 1.05) + ax.set_xlabel('Recall'); ax.set_ylabel('Precision') + ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14) + ax.legend(loc='lower left'); ax.grid(True, alpha=0.3) + plt.tight_layout() + pr_path = os.path.join(out_dir, 'pr_comparison.png') + plt.savefig(pr_path, dpi=150, bbox_inches='tight') + plt.show() + print(f"PR 对比图已保存: {pr_path}") + # ———— 准确率柱状图 ———— names = list(results.keys()) accs = [results[n]['acc'] for n in names]