trash-division/Baseline.py

129 lines
4.4 KiB
Python

"""
Baseline.py
VGG16 预训练模型特征提取 + KNN 四分类基线
author: yukun-hh
date: 2026-5-14
"""
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (
accuracy_score, f1_score,
confusion_matrix, ConfusionMatrixDisplay,
classification_report,
)
from Dataloader import RobustImageFolder
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 可配置参数 ★★★
# ============================================================
DATA_ROOT = '../trash_division_data/ultimate_4_class/'
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
K = 5
# ============================================================
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
def load_vgg16_extractor(device):
try:
model = models.vgg16(weights='IMAGENET1K_V1')
except TypeError:
model = models.vgg16(pretrained=True)
model.classifier = nn.Identity()
model = model.to(device).eval()
for param in model.parameters():
param.requires_grad = False
return model
def extract_features(model, loader, device):
model.eval()
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(loader, desc='Extracting features'):
images = images.to(device)
feats = model(images)
all_features.append(feats.cpu().numpy())
all_labels.append(labels.numpy())
return np.concatenate(all_features), np.concatenate(all_labels)
if __name__ == '__main__':
val_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'train'),
transform=val_transform,
)
val_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'val'),
transform=val_transform,
)
print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=False)
device = torch.device('cuda' if torch.cuda.is_available()
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
else 'cpu')
print(f"Device: {device}")
extractor = load_vgg16_extractor(device)
print("VGG16 特征提取器加载完成 (classifier 已移除)")
print("提取训练集特征 ...")
train_feats, train_labels = extract_features(extractor, train_loader, device)
print(f"训练特征: {train_feats.shape}")
print("提取验证集特征 ...")
val_feats, val_labels = extract_features(extractor, val_loader, device)
print(f"验证特征: {val_feats.shape}")
knn = KNeighborsClassifier(n_neighbors=K, n_jobs=-1)
knn.fit(train_feats, train_labels)
print(f"KNN (K={K}) 训练完成")
val_preds = knn.predict(val_feats)
acc = accuracy_score(val_labels, val_preds)
macro_f1 = f1_score(val_labels, val_preds, average='macro')
print(f"\n验证集 Accuracy: {acc:.4f}")
print(f"验证集 Macro-F1: {macro_f1:.4f}")
print(f"\n分类报告:\n{classification_report(val_labels, val_preds, target_names=CLASS_NAMES)}")
cm = confusion_matrix(val_labels, val_preds)
fig, ax = plt.subplots(figsize=(8, 7))
ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14)
plt.tight_layout()
plt.savefig('baseline_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()
print("混淆矩阵已保存: baseline_confusion_matrix.png")