compare_models.py: 新增 PR 曲线对比图
This commit is contained in:
parent
010dacb533
commit
3fee1c82ab
1 changed files with 38 additions and 3 deletions
|
|
@ -17,7 +17,10 @@ from torch.utils.data import DataLoader
|
|||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from sklearn.metrics import roc_curve, auc, accuracy_score
|
||||
from sklearn.metrics import (
|
||||
roc_curve, auc, accuracy_score,
|
||||
precision_recall_curve, average_precision_score,
|
||||
)
|
||||
|
||||
from Model import Net
|
||||
from Dataloader import RobustImageFolder
|
||||
|
|
@ -108,6 +111,20 @@ def compute_macro_roc(y_true, y_probs):
|
|||
return all_fpr, mean_tpr, macro_auc
|
||||
|
||||
|
||||
def compute_macro_pr(y_true, y_probs):
|
||||
one_hot = np.eye(NUM_CLASSES)[y_true]
|
||||
prec_dict, rec_dict = {}, {}
|
||||
for c in range(NUM_CLASSES):
|
||||
prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c])
|
||||
all_rec = np.linspace(0, 1, 200)
|
||||
mean_prec = np.zeros_like(all_rec)
|
||||
for c in range(NUM_CLASSES):
|
||||
mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1])
|
||||
mean_prec /= NUM_CLASSES
|
||||
macro_ap = average_precision_score(one_hot, y_probs, average='macro')
|
||||
return all_rec, mean_prec, macro_ap
|
||||
|
||||
|
||||
def sanitize_filename(name):
|
||||
return re.sub(r'[^\w\-_]', '_', name).strip('_')
|
||||
|
||||
|
|
@ -172,9 +189,11 @@ if __name__ == '__main__':
|
|||
print(f" 预测数据已保存: {os.path.basename(csv_path)}")
|
||||
acc = accuracy_score(y_true, y_preds)
|
||||
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
|
||||
rec, prec, macro_ap = compute_macro_pr(y_true, y_probs)
|
||||
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
|
||||
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc}
|
||||
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f}")
|
||||
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc,
|
||||
'rec': rec, 'prec': prec, 'ap': macro_ap}
|
||||
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}")
|
||||
|
||||
# ———— ROC 对比图 ————
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
|
|
@ -193,6 +212,22 @@ if __name__ == '__main__':
|
|||
plt.show()
|
||||
print(f"\nROC 对比图已保存: {roc_path}")
|
||||
|
||||
# ———— PR 对比图 ————
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
for i, (name, r) in enumerate(results.items()):
|
||||
color = COLORS[i % len(COLORS)]
|
||||
ax.plot(r['rec'], r['prec'], color=color, lw=2,
|
||||
label=f"{name} (AP={r['ap']:.4f})")
|
||||
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
|
||||
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
|
||||
ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14)
|
||||
ax.legend(loc='lower left'); ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
pr_path = os.path.join(out_dir, 'pr_comparison.png')
|
||||
plt.savefig(pr_path, dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print(f"PR 对比图已保存: {pr_path}")
|
||||
|
||||
# ———— 准确率柱状图 ————
|
||||
names = list(results.keys())
|
||||
accs = [results[n]['acc'] for n in names]
|
||||
|
|
|
|||
Loading…
Reference in a new issue