compare_models.py: 新增 PR 曲线对比图
This commit is contained in:
parent
010dacb533
commit
3fee1c82ab
1 changed files with 38 additions and 3 deletions
|
|
@ -17,7 +17,10 @@ from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from sklearn.metrics import roc_curve, auc, accuracy_score
|
from sklearn.metrics import (
|
||||||
|
roc_curve, auc, accuracy_score,
|
||||||
|
precision_recall_curve, average_precision_score,
|
||||||
|
)
|
||||||
|
|
||||||
from Model import Net
|
from Model import Net
|
||||||
from Dataloader import RobustImageFolder
|
from Dataloader import RobustImageFolder
|
||||||
|
|
@ -108,6 +111,20 @@ def compute_macro_roc(y_true, y_probs):
|
||||||
return all_fpr, mean_tpr, macro_auc
|
return all_fpr, mean_tpr, macro_auc
|
||||||
|
|
||||||
|
|
||||||
|
def compute_macro_pr(y_true, y_probs):
|
||||||
|
one_hot = np.eye(NUM_CLASSES)[y_true]
|
||||||
|
prec_dict, rec_dict = {}, {}
|
||||||
|
for c in range(NUM_CLASSES):
|
||||||
|
prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c])
|
||||||
|
all_rec = np.linspace(0, 1, 200)
|
||||||
|
mean_prec = np.zeros_like(all_rec)
|
||||||
|
for c in range(NUM_CLASSES):
|
||||||
|
mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1])
|
||||||
|
mean_prec /= NUM_CLASSES
|
||||||
|
macro_ap = average_precision_score(one_hot, y_probs, average='macro')
|
||||||
|
return all_rec, mean_prec, macro_ap
|
||||||
|
|
||||||
|
|
||||||
def sanitize_filename(name):
|
def sanitize_filename(name):
|
||||||
return re.sub(r'[^\w\-_]', '_', name).strip('_')
|
return re.sub(r'[^\w\-_]', '_', name).strip('_')
|
||||||
|
|
||||||
|
|
@ -172,9 +189,11 @@ if __name__ == '__main__':
|
||||||
print(f" 预测数据已保存: {os.path.basename(csv_path)}")
|
print(f" 预测数据已保存: {os.path.basename(csv_path)}")
|
||||||
acc = accuracy_score(y_true, y_preds)
|
acc = accuracy_score(y_true, y_preds)
|
||||||
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
|
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
|
||||||
|
rec, prec, macro_ap = compute_macro_pr(y_true, y_probs)
|
||||||
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
|
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
|
||||||
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc}
|
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc,
|
||||||
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f}")
|
'rec': rec, 'prec': prec, 'ap': macro_ap}
|
||||||
|
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}")
|
||||||
|
|
||||||
# ———— ROC 对比图 ————
|
# ———— ROC 对比图 ————
|
||||||
fig, ax = plt.subplots(figsize=(8, 7))
|
fig, ax = plt.subplots(figsize=(8, 7))
|
||||||
|
|
@ -193,6 +212,22 @@ if __name__ == '__main__':
|
||||||
plt.show()
|
plt.show()
|
||||||
print(f"\nROC 对比图已保存: {roc_path}")
|
print(f"\nROC 对比图已保存: {roc_path}")
|
||||||
|
|
||||||
|
# ———— PR 对比图 ————
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 7))
|
||||||
|
for i, (name, r) in enumerate(results.items()):
|
||||||
|
color = COLORS[i % len(COLORS)]
|
||||||
|
ax.plot(r['rec'], r['prec'], color=color, lw=2,
|
||||||
|
label=f"{name} (AP={r['ap']:.4f})")
|
||||||
|
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
|
||||||
|
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
|
||||||
|
ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14)
|
||||||
|
ax.legend(loc='lower left'); ax.grid(True, alpha=0.3)
|
||||||
|
plt.tight_layout()
|
||||||
|
pr_path = os.path.join(out_dir, 'pr_comparison.png')
|
||||||
|
plt.savefig(pr_path, dpi=150, bbox_inches='tight')
|
||||||
|
plt.show()
|
||||||
|
print(f"PR 对比图已保存: {pr_path}")
|
||||||
|
|
||||||
# ———— 准确率柱状图 ————
|
# ———— 准确率柱状图 ————
|
||||||
names = list(results.keys())
|
names = list(results.keys())
|
||||||
accs = [results[n]['acc'] for n in names]
|
accs = [results[n]['acc'] for n in names]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue