compare_models.py: 添加 CSV 缓存机制,已有预测数据时跳过重复计算
This commit is contained in:
parent
818d98d06c
commit
547d96cfa9
1 changed files with 36 additions and 6 deletions
|
|
@ -5,7 +5,7 @@ baseline/compare_models.py
|
|||
author: yukun-hh
|
||||
date: 2026-5-14
|
||||
"""
|
||||
import sys, os
|
||||
import sys, os, re
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -29,7 +29,7 @@ matplotlib.rcParams['axes.unicode_minus'] = False
|
|||
# ============================================================
|
||||
# ★★★ 可配置参数 ★★★
|
||||
# ============================================================
|
||||
DATA_ROOT = '../trash_division_data/ultimate_4_class/'
|
||||
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
|
||||
BATCH_SIZE = 32
|
||||
IMAGE_SIZE = 256
|
||||
NUM_WORKERS = 4
|
||||
|
|
@ -45,7 +45,7 @@ NUM_CLASSES = 4
|
|||
|
||||
def get_resnet34_preds(train_loader, val_loader, device):
|
||||
model = Net(num_classes=NUM_CLASSES)
|
||||
state_dict = torch.load('best_model.pth', map_location='cpu')
|
||||
state_dict = torch.load('../best_model.pth', map_location='cpu')
|
||||
if 'model_state_dict' in state_dict:
|
||||
state_dict = state_dict['model_state_dict']
|
||||
elif 'model' in state_dict:
|
||||
|
|
@ -106,6 +106,29 @@ def compute_macro_roc(y_true, y_probs):
|
|||
return all_fpr, mean_tpr, macro_auc
|
||||
|
||||
|
||||
def sanitize_filename(name):
|
||||
return re.sub(r'[^\w\-_]', '_', name).strip('_')
|
||||
|
||||
|
||||
def preds_csv_path(out_dir, model_name):
|
||||
safe = sanitize_filename(model_name)
|
||||
return os.path.join(out_dir, f'{safe}_preds.csv')
|
||||
|
||||
|
||||
def save_preds_csv(path, y_true, y_preds, y_probs):
|
||||
header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES))
|
||||
data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs])
|
||||
np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f')
|
||||
|
||||
|
||||
def load_preds_csv(path):
|
||||
data = np.loadtxt(path, delimiter=',', skiprows=1)
|
||||
y_true = data[:, 0].astype(int)
|
||||
y_preds = data[:, 1].astype(int)
|
||||
y_probs = data[:, 2:2 + NUM_CLASSES]
|
||||
return y_true, y_preds, y_probs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
out_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
|
@ -132,12 +155,19 @@ if __name__ == '__main__':
|
|||
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
|
||||
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
|
||||
|
||||
# ———— 评估所有模型 ————
|
||||
# ———— 评估所有模型(有缓存则跳过)————
|
||||
results = {}
|
||||
for name, func in MODELS:
|
||||
print(f"\n{'='*50}")
|
||||
print(f"评估: {name}")
|
||||
y_true, y_preds, y_probs = func(train_loader, val_loader, device)
|
||||
csv_path = preds_csv_path(out_dir, name)
|
||||
if os.path.exists(csv_path):
|
||||
print(f"加载缓存: {os.path.basename(csv_path)}")
|
||||
y_true, y_preds, y_probs = load_preds_csv(csv_path)
|
||||
else:
|
||||
print(f"评估: {name}")
|
||||
y_true, y_preds, y_probs = func(train_loader, val_loader, device)
|
||||
save_preds_csv(csv_path, y_true, y_preds, y_probs)
|
||||
print(f" 预测数据已保存: {os.path.basename(csv_path)}")
|
||||
acc = accuracy_score(y_true, y_preds)
|
||||
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
|
||||
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
|
||||
|
|
|
|||
Loading…
Reference in a new issue