添加 VGG16 + KNN 基线模型 Baseline.py

This commit is contained in:
yukun-hh 2026-05-14 19:50:05 +08:00
parent 76b56dd64b
commit 3624f058c2
3 changed files with 132 additions and 0 deletions

1
.gitignore vendored
View file

@ -7,6 +7,7 @@
!/README.md
!/THIRD_PARTY_LICENSES.md
!/Train.py
!/Baseline.py
!/AGENTS.md
!/Finetune.py
!/Curve.py

129
Baseline.py Normal file
View file

@ -0,0 +1,129 @@
"""
Baseline.py
VGG16 预训练模型特征提取 + KNN 四分类基线
author: yukun-hh
date: 2026-5-14
"""
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (
accuracy_score, f1_score,
confusion_matrix, ConfusionMatrixDisplay,
classification_report,
)
from Dataloader import RobustImageFolder
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 可配置参数 ★★★
# ============================================================
DATA_ROOT = '../trash_division_data/ultimate_4_class/'
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
K = 5
# ============================================================
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
def load_vgg16_extractor(device):
try:
model = models.vgg16(weights='IMAGENET1K_V1')
except TypeError:
model = models.vgg16(pretrained=True)
model.classifier = nn.Identity()
model = model.to(device).eval()
for param in model.parameters():
param.requires_grad = False
return model
def extract_features(model, loader, device):
model.eval()
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(loader, desc='Extracting features'):
images = images.to(device)
feats = model(images)
all_features.append(feats.cpu().numpy())
all_labels.append(labels.numpy())
return np.concatenate(all_features), np.concatenate(all_labels)
if __name__ == '__main__':
val_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'train'),
transform=val_transform,
)
val_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'val'),
transform=val_transform,
)
print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=False)
device = torch.device('cuda' if torch.cuda.is_available()
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
else 'cpu')
print(f"Device: {device}")
extractor = load_vgg16_extractor(device)
print("VGG16 特征提取器加载完成 (classifier 已移除)")
print("提取训练集特征 ...")
train_feats, train_labels = extract_features(extractor, train_loader, device)
print(f"训练特征: {train_feats.shape}")
print("提取验证集特征 ...")
val_feats, val_labels = extract_features(extractor, val_loader, device)
print(f"验证特征: {val_feats.shape}")
knn = KNeighborsClassifier(n_neighbors=K, n_jobs=-1)
knn.fit(train_feats, train_labels)
print(f"KNN (K={K}) 训练完成")
val_preds = knn.predict(val_feats)
acc = accuracy_score(val_labels, val_preds)
macro_f1 = f1_score(val_labels, val_preds, average='macro')
print(f"\n验证集 Accuracy: {acc:.4f}")
print(f"验证集 Macro-F1: {macro_f1:.4f}")
print(f"\n分类报告:\n{classification_report(val_labels, val_preds, target_names=CLASS_NAMES)}")
cm = confusion_matrix(val_labels, val_preds)
fig, ax = plt.subplots(figsize=(8, 7))
ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14)
plt.tight_layout()
plt.savefig('baseline_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()
print("混淆矩阵已保存: baseline_confusion_matrix.png")

View file

@ -123,6 +123,7 @@
| 文件 | 功能 |
|---|---|
| `Baseline.py` | 基线模型VGG16 预训练特征提取 + KNN 四分类 |
| `Train.py` | 训练主脚本,包含训练循环、验证、评估 |
| `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 |
| `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 |
@ -140,6 +141,7 @@
```
trash-division/
├── AGENTS.md # AI 助手指南
├── Baseline.py # 基线模型脚本
├── best_model.pth # 最佳模型权重(不纳入版本控制)
├── Curve.py # 训练曲线绘制脚本
├── Dataloader.py # 数据加载模块