From 3624f058c261ada9cd6e89541aa8c42cb4898516 Mon Sep 17 00:00:00 2001 From: yukun-hh Date: Thu, 14 May 2026 19:50:05 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20VGG16=20+=20KNN=20?= =?UTF-8?q?=E5=9F=BA=E7=BA=BF=E6=A8=A1=E5=9E=8B=20Baseline.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + Baseline.py | 129 ++++++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 2 + 3 files changed, 132 insertions(+) create mode 100644 Baseline.py diff --git a/.gitignore b/.gitignore index c68b013..fb1c79f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ !/README.md !/THIRD_PARTY_LICENSES.md !/Train.py +!/Baseline.py !/AGENTS.md !/Finetune.py !/Curve.py diff --git a/Baseline.py b/Baseline.py new file mode 100644 index 0000000..9be8a36 --- /dev/null +++ b/Baseline.py @@ -0,0 +1,129 @@ +""" +Baseline.py +VGG16 预训练模型特征提取 + KNN 四分类基线 +author: yukun-hh +date: 2026-5-14 +""" +import os +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torchvision import models, transforms +from tqdm import tqdm + +from sklearn.neighbors import KNeighborsClassifier +from sklearn.metrics import ( + accuracy_score, f1_score, + confusion_matrix, ConfusionMatrixDisplay, + classification_report, +) + +from Dataloader import RobustImageFolder + +matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] +matplotlib.rcParams['axes.unicode_minus'] = False + +# ============================================================ +# ★★★ 可配置参数 ★★★ +# ============================================================ +DATA_ROOT = '../trash_division_data/ultimate_4_class/' +BATCH_SIZE = 32 +IMAGE_SIZE = 256 +NUM_WORKERS = 4 +K = 5 +# ============================================================ + +CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] + + +def load_vgg16_extractor(device): + try: + model = models.vgg16(weights='IMAGENET1K_V1') + except TypeError: + model = models.vgg16(pretrained=True) + model.classifier = nn.Identity() + model = model.to(device).eval() + for param in model.parameters(): + param.requires_grad = False + return model + + +def extract_features(model, loader, device): + model.eval() + all_features = [] + all_labels = [] + with torch.no_grad(): + for images, labels in tqdm(loader, desc='Extracting features'): + images = images.to(device) + feats = model(images) + all_features.append(feats.cpu().numpy()) + all_labels.append(labels.numpy()) + return np.concatenate(all_features), np.concatenate(all_labels) + + +if __name__ == '__main__': + val_transform = transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + + train_dataset = RobustImageFolder( + root=os.path.join(DATA_ROOT, 'train'), + transform=val_transform, + ) + val_dataset = RobustImageFolder( + root=os.path.join(DATA_ROOT, 'val'), + transform=val_transform, + ) + print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}") + + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, + shuffle=False, num_workers=NUM_WORKERS, + pin_memory=True, drop_last=False) + val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, + shuffle=False, num_workers=NUM_WORKERS, + pin_memory=True, drop_last=False) + + device = torch.device('cuda' if torch.cuda.is_available() + else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available() + else 'cpu') + print(f"Device: {device}") + + extractor = load_vgg16_extractor(device) + print("VGG16 特征提取器加载完成 (classifier 已移除)") + + print("提取训练集特征 ...") + train_feats, train_labels = extract_features(extractor, train_loader, device) + print(f"训练特征: {train_feats.shape}") + + print("提取验证集特征 ...") + val_feats, val_labels = extract_features(extractor, val_loader, device) + print(f"验证特征: {val_feats.shape}") + + knn = KNeighborsClassifier(n_neighbors=K, n_jobs=-1) + knn.fit(train_feats, train_labels) + print(f"KNN (K={K}) 训练完成") + + val_preds = knn.predict(val_feats) + + acc = accuracy_score(val_labels, val_preds) + macro_f1 = f1_score(val_labels, val_preds, average='macro') + print(f"\n验证集 Accuracy: {acc:.4f}") + print(f"验证集 Macro-F1: {macro_f1:.4f}") + print(f"\n分类报告:\n{classification_report(val_labels, val_preds, target_names=CLASS_NAMES)}") + + cm = confusion_matrix(val_labels, val_preds) + fig, ax = plt.subplots(figsize=(8, 7)) + ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot( + ax=ax, cmap='Blues', values_format='d', xticks_rotation=30) + ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14) + plt.tight_layout() + plt.savefig('baseline_confusion_matrix.png', dpi=150, bbox_inches='tight') + plt.show() + print("混淆矩阵已保存: baseline_confusion_matrix.png") diff --git a/README.md b/README.md index 0724ca2..1a85aec 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ | 文件 | 功能 | |---|---| +| `Baseline.py` | 基线模型,VGG16 预训练特征提取 + KNN 四分类 | | `Train.py` | 训练主脚本,包含训练循环、验证、评估 | | `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 | | `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 | @@ -140,6 +141,7 @@ ``` trash-division/ ├── AGENTS.md # AI 助手指南 +├── Baseline.py # 基线模型脚本 ├── best_model.pth # 最佳模型权重(不纳入版本控制) ├── Curve.py # 训练曲线绘制脚本 ├── Dataloader.py # 数据加载模块