compare_models.py: 添加 CSV 缓存机制,已有预测数据时跳过重复计算

This commit is contained in:
yukun-hh 2026-05-17 16:55:27 +08:00
parent 818d98d06c
commit 547d96cfa9

View file

@ -5,7 +5,7 @@ baseline/compare_models.py
author: yukun-hh author: yukun-hh
date: 2026-5-14 date: 2026-5-14
""" """
import sys, os import sys, os, re
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np import numpy as np
@ -29,7 +29,7 @@ matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================ # ============================================================
# ★★★ 可配置参数 ★★★ # ★★★ 可配置参数 ★★★
# ============================================================ # ============================================================
DATA_ROOT = '../trash_division_data/ultimate_4_class/' DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
BATCH_SIZE = 32 BATCH_SIZE = 32
IMAGE_SIZE = 256 IMAGE_SIZE = 256
NUM_WORKERS = 4 NUM_WORKERS = 4
@ -45,7 +45,7 @@ NUM_CLASSES = 4
def get_resnet34_preds(train_loader, val_loader, device): def get_resnet34_preds(train_loader, val_loader, device):
model = Net(num_classes=NUM_CLASSES) model = Net(num_classes=NUM_CLASSES)
state_dict = torch.load('best_model.pth', map_location='cpu') state_dict = torch.load('../best_model.pth', map_location='cpu')
if 'model_state_dict' in state_dict: if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict'] state_dict = state_dict['model_state_dict']
elif 'model' in state_dict: elif 'model' in state_dict:
@ -106,6 +106,29 @@ def compute_macro_roc(y_true, y_probs):
return all_fpr, mean_tpr, macro_auc return all_fpr, mean_tpr, macro_auc
def sanitize_filename(name):
return re.sub(r'[^\w\-_]', '_', name).strip('_')
def preds_csv_path(out_dir, model_name):
safe = sanitize_filename(model_name)
return os.path.join(out_dir, f'{safe}_preds.csv')
def save_preds_csv(path, y_true, y_preds, y_probs):
header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES))
data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs])
np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f')
def load_preds_csv(path):
data = np.loadtxt(path, delimiter=',', skiprows=1)
y_true = data[:, 0].astype(int)
y_preds = data[:, 1].astype(int)
y_probs = data[:, 2:2 + NUM_CLASSES]
return y_true, y_preds, y_probs
if __name__ == '__main__': if __name__ == '__main__':
out_dir = os.path.dirname(os.path.abspath(__file__)) out_dir = os.path.dirname(os.path.abspath(__file__))
@ -132,12 +155,19 @@ if __name__ == '__main__':
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False) num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
# ———— 评估所有模型 ———— # ———— 评估所有模型(有缓存则跳过)————
results = {} results = {}
for name, func in MODELS: for name, func in MODELS:
print(f"\n{'='*50}") print(f"\n{'='*50}")
csv_path = preds_csv_path(out_dir, name)
if os.path.exists(csv_path):
print(f"加载缓存: {os.path.basename(csv_path)}")
y_true, y_preds, y_probs = load_preds_csv(csv_path)
else:
print(f"评估: {name}") print(f"评估: {name}")
y_true, y_preds, y_probs = func(train_loader, val_loader, device) y_true, y_preds, y_probs = func(train_loader, val_loader, device)
save_preds_csv(csv_path, y_true, y_preds, y_probs)
print(f" 预测数据已保存: {os.path.basename(csv_path)}")
acc = accuracy_score(y_true, y_preds) acc = accuracy_score(y_true, y_preds)
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs) fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs, results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,