Compare commits
20 commits
cb17be247e
...
1557a2c6a0
| Author | SHA1 | Date | |
|---|---|---|---|
| 1557a2c6a0 | |||
| 8008f2da8f | |||
| 52b03e981c | |||
| ed2778972b | |||
| 65c9742a1c | |||
| 25e11c1914 | |||
| b7341c746f | |||
| 307b66e9db | |||
| 562fc4142a | |||
| 3fee1c82ab | |||
| 010dacb533 | |||
| 547d96cfa9 | |||
| 818d98d06c | |||
| 3624f058c2 | |||
| 76b56dd64b | |||
| 44c70250e3 | |||
| 2f4e9df26e | |||
| b4c99489b3 | |||
|
|
ce0c6da36a | ||
| 4575f3390f |
29
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
*
|
||||
!/.gitattributes
|
||||
!/Dataloader.py
|
||||
!/LICENSE
|
||||
!/Merge_classes.py
|
||||
!/Model.py
|
||||
!/README.md
|
||||
!/requirements.txt
|
||||
!/THIRD_PARTY_LICENSES.md
|
||||
!/Train.py
|
||||
!/Baseline.py
|
||||
!/Finetune.py
|
||||
!/Curve.py
|
||||
!/Evaluate.py
|
||||
!/baseline/
|
||||
!/baseline/__init__.py
|
||||
!/baseline/VGG_KNN.py
|
||||
!/baseline/compare_models.py
|
||||
!/baseline/ResNet34_Pretrained_10pct.py
|
||||
!/baseline/HOG_Baseline.py
|
||||
!/baseline/roc_comparison.png
|
||||
!/baseline/pr_comparison.png
|
||||
!/baseline/accuracy_bar.png
|
||||
!/training_log.csv
|
||||
!/confusion_matrix.png
|
||||
!/roc_curve.png
|
||||
!/pr_curve.png
|
||||
!/training_curves.png
|
||||
!.gitignore
|
||||
50
Curve.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""
|
||||
plot_training_curves.py
|
||||
从 training_log.csv 读取日志,绘制 Loss / F1 / Accuracy / LR 曲线
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
|
||||
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||||
matplotlib.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# ============ 读取数据 ============
|
||||
df = pd.read_csv('training_log.csv')
|
||||
best_rows = df[df['best'] == 'best']
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||||
|
||||
# ---- 1. Loss ----
|
||||
ax = axes[0, 0]
|
||||
ax.plot(df['epoch'], df['train_loss'], label='Train Loss', color='#1f77b4', lw=1.5)
|
||||
ax.plot(df['epoch'], df['val_loss'], label='Val Loss', color='#ff7f0e', lw=1.5)
|
||||
ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.set_title('Loss vs Epoch')
|
||||
ax.legend(); ax.grid(True, alpha=0.3)
|
||||
|
||||
# ---- 2. F1 Score ----
|
||||
ax = axes[0, 1]
|
||||
ax.plot(df['epoch'], df['train_f1'], label='Train F1', color='#1f77b4', lw=1.5)
|
||||
ax.plot(df['epoch'], df['val_f1'], label='Val F1', color='#ff7f0e', lw=1.5)
|
||||
ax.set_xlabel('Epoch'); ax.set_ylabel('F1 Score'); ax.set_title('F1 Score vs Epoch')
|
||||
ax.legend(); ax.grid(True, alpha=0.3)
|
||||
|
||||
# ---- 3. Accuracy ----
|
||||
ax = axes[1, 0]
|
||||
ax.plot(df['epoch'], df['train_acc'], label='Train Acc', color='#1f77b4', lw=1.5)
|
||||
ax.plot(df['epoch'], df['val_acc'], label='Val Acc', color='#ff7f0e', lw=1.5)
|
||||
ax.set_xlabel('Epoch'); ax.set_ylabel('Accuracy (%)'); ax.set_title('Accuracy vs Epoch')
|
||||
ax.legend(); ax.grid(True, alpha=0.3)
|
||||
|
||||
# ---- 4. Learning Rate ----
|
||||
ax = axes[1, 1]
|
||||
ax.plot(df['epoch'], df['lr'], color='#2ca02c', lw=1.5)
|
||||
ax.set_xlabel('Epoch'); ax.set_ylabel('Learning Rate'); ax.set_title('Learning Rate vs Epoch')
|
||||
ax.ticklabel_format(style='scientific', axis='y', scilimits=(0, 0))
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print("训练曲线已保存: training_curves.png")
|
||||
147
Evaluate.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
"""
|
||||
evaluate_and_plot.py
|
||||
加载模型,在验证集上推理,绘制混淆矩阵 / ROC / PR 曲线
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from sklearn.metrics import (
|
||||
confusion_matrix, ConfusionMatrixDisplay,
|
||||
roc_curve, auc,
|
||||
precision_recall_curve, average_precision_score,
|
||||
)
|
||||
|
||||
from Model import Net
|
||||
from Dataloader import RobustImageFolder
|
||||
|
||||
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||||
matplotlib.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# ============================================================
|
||||
# ★★★ 需要你修改的参数 ★★★
|
||||
# ============================================================
|
||||
MODEL_PATH = 'best_model.pth' # 模型权重路径
|
||||
DATA_ROOT = '../trash_division_data/ultimate_4_class/' # 数据集根目录
|
||||
BATCH_SIZE = 32
|
||||
IMAGE_SIZE = 256
|
||||
NUM_WORKERS = 4
|
||||
# ============================================================
|
||||
|
||||
# ---------- 1. 加载验证集 ----------
|
||||
val_transform = transforms.Compose([
|
||||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
val_dataset = RobustImageFolder(
|
||||
root=os.path.join(DATA_ROOT, 'val'),
|
||||
transform=val_transform,
|
||||
)
|
||||
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
|
||||
shuffle=False, num_workers=NUM_WORKERS,
|
||||
pin_memory=True, drop_last=False)
|
||||
|
||||
class_names = val_dataset.classes
|
||||
num_classes = len(class_names)
|
||||
print(f"类别: {class_names}")
|
||||
|
||||
# ---------- 2. 加载模型 ----------
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = Net(num_classes=num_classes)
|
||||
state_dict = torch.load(MODEL_PATH, map_location=device)
|
||||
if 'model_state_dict' in state_dict:
|
||||
state_dict = state_dict['model_state_dict']
|
||||
elif 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
model.load_state_dict(state_dict)
|
||||
model = model.to(device).eval()
|
||||
print("模型加载完成")
|
||||
|
||||
# ---------- 3. 推理 ----------
|
||||
all_labels = []
|
||||
all_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, labels in val_loader:
|
||||
images = images.to(device)
|
||||
probs = torch.softmax(model(images), dim=1)
|
||||
all_labels.append(labels.numpy())
|
||||
all_probs.append(probs.cpu().numpy())
|
||||
|
||||
all_labels = np.concatenate(all_labels)
|
||||
all_probs = np.concatenate(all_probs)
|
||||
all_preds = np.argmax(all_probs, axis=1)
|
||||
print(f"推理完成, 共 {len(all_labels)} 样本")
|
||||
|
||||
# ============================================================
|
||||
# ① 混淆矩阵
|
||||
# ============================================================
|
||||
cm = confusion_matrix(all_labels, all_preds)
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
ConfusionMatrixDisplay(cm, display_labels=class_names).plot(
|
||||
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
|
||||
ax.set_title('Confusion Matrix', fontsize=14)
|
||||
plt.tight_layout()
|
||||
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print("混淆矩阵已保存: confusion_matrix.png")
|
||||
|
||||
# ============================================================
|
||||
# ② ROC 曲线 (One-vs-Rest + Macro-average)
|
||||
# ============================================================
|
||||
one_hot = np.eye(num_classes)[all_labels]
|
||||
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
fpr_d, tpr_d, auc_d = {}, {}, {}
|
||||
|
||||
for i in range(num_classes):
|
||||
fpr_d[i], tpr_d[i], _ = roc_curve(one_hot[:, i], all_probs[:, i])
|
||||
auc_d[i] = auc(fpr_d[i], tpr_d[i])
|
||||
ax.plot(fpr_d[i], tpr_d[i], color=colors[i], lw=2,
|
||||
label=f'{class_names[i]} (AUC={auc_d[i]:.4f})')
|
||||
|
||||
# Macro-average
|
||||
all_fpr = np.unique(np.concatenate([fpr_d[i] for i in range(num_classes)]))
|
||||
mean_tpr = sum(np.interp(all_fpr, fpr_d[i], tpr_d[i]) for i in range(num_classes)) / num_classes
|
||||
macro_auc = auc(all_fpr, mean_tpr)
|
||||
ax.plot(all_fpr, mean_tpr, 'navy', lw=2, ls='--',
|
||||
label=f'Macro-avg (AUC={macro_auc:.4f})')
|
||||
ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
|
||||
|
||||
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
|
||||
ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
|
||||
ax.set_title('ROC Curve', fontsize=14)
|
||||
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
plt.savefig('roc_curve.png', dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print("ROC 曲线已保存: roc_curve.png")
|
||||
|
||||
# ============================================================
|
||||
# ③ Precision-Recall 曲线
|
||||
# ============================================================
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
|
||||
for i in range(num_classes):
|
||||
prec, rec, _ = precision_recall_curve(one_hot[:, i], all_probs[:, i])
|
||||
ap = average_precision_score(one_hot[:, i], all_probs[:, i])
|
||||
ax.plot(rec, prec, color=colors[i], lw=2,
|
||||
label=f'{class_names[i]} (AP={ap:.4f})')
|
||||
|
||||
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
|
||||
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
|
||||
ax.set_title('Precision-Recall Curve', fontsize=14)
|
||||
ax.legend(loc='best'); ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
plt.savefig('pr_curve.png', dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print("PR 曲线已保存: pr_curve.png")
|
||||
95
Model.py
|
|
@ -1,6 +1,5 @@
|
|||
"""
|
||||
模型定义文件 - 使用瓶颈结构 (Bottleneck) 的深度残差网络
|
||||
目标:约50层,参数量约80M
|
||||
模型定义文件 - ResNet-34
|
||||
author : yukun-hh
|
||||
date : 2026-4-10
|
||||
"""
|
||||
|
|
@ -10,27 +9,19 @@ from torch.nn import functional as F
|
|||
from torchsummary import summary
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
class BasicBlock(nn.Module):
|
||||
"""
|
||||
瓶颈残差块:1x1(降维) -> 3x3 -> 1x1(升维)
|
||||
ResNet-34 基础残差块:3x3 -> 3x3
|
||||
若需要下采样或通道变化,则在跳跃连接中使用 1x1 卷积
|
||||
"""
|
||||
expansion = 4 # 输出通道是中间通道的4倍
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, mid_channels, stride=1, downsample=None):
|
||||
"""
|
||||
:param in_channels: 输入通道数
|
||||
:param mid_channels: 中间层通道数(1x1降维后的通道数)
|
||||
:param stride: 步长,用于下采样
|
||||
:param downsample: 下采样模块(当stride≠1或通道变化时使用)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(mid_channels)
|
||||
self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(mid_channels)
|
||||
self.conv3 = nn.Conv2d(mid_channels, mid_channels * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(mid_channels * self.expansion)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
|
|
@ -43,10 +34,6 @@ class Bottleneck(nn.Module):
|
|||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
|
@ -57,68 +44,49 @@ class Bottleneck(nn.Module):
|
|||
|
||||
|
||||
class Net(nn.Module):
|
||||
"""
|
||||
基于 Bottleneck 的 ResNet 风格模型
|
||||
各阶段配置仿照 ResNet-50,适当调整宽度以达到约80M参数
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=4):
|
||||
def __init__(self, num_classes=4, dropout=0.5):
|
||||
super().__init__()
|
||||
|
||||
# 第一阶段:7x7卷积 + 最大池化
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# 残差阶段定义
|
||||
# 每个阶段的参数:[块数, 中间通道数, 步长]
|
||||
# 为了达到80M参数,我们略微加宽网络(相比标准ResNet-50)
|
||||
layers_config = [
|
||||
(3, 64, 1), # stage2: 3个瓶颈块,输出通道 64*4=256
|
||||
(4, 128, 2), # stage3: 4个瓶颈块,输出通道 128*4=512
|
||||
(14, 256, 2), # stage4: 14个瓶颈块,输出通道 256*4=1024(加深至此阶段)
|
||||
(3, 512, 2) # stage5: 3个瓶颈块,输出通道 512*4=2048
|
||||
(3, 64, 1), # layer1
|
||||
(4, 128, 2), # layer2
|
||||
(6, 256, 2), # layer3
|
||||
(3, 512, 2), # layer4
|
||||
]
|
||||
|
||||
self.in_channels = 64
|
||||
self.stage2 = self._make_layer(layers_config[0])
|
||||
self.stage3 = self._make_layer(layers_config[1])
|
||||
self.stage4 = self._make_layer(layers_config[2])
|
||||
self.stage5 = self._make_layer(layers_config[3])
|
||||
self.layer1 = self._make_layer(layers_config[0])
|
||||
self.layer2 = self._make_layer(layers_config[1])
|
||||
self.layer3 = self._make_layer(layers_config[2])
|
||||
self.layer4 = self._make_layer(layers_config[3])
|
||||
|
||||
# 全局池化与分类层
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc = nn.Linear(512, num_classes)
|
||||
|
||||
def _make_layer(self, config):
|
||||
"""
|
||||
构建一个残差阶段
|
||||
:param config: (块数, 中间通道数, 第一阶段步长)
|
||||
:return: nn.Sequential
|
||||
"""
|
||||
num_blocks, mid_channels, stride = config
|
||||
num_blocks, out_channels, stride = config
|
||||
downsample = None
|
||||
layers = []
|
||||
|
||||
# 第一个块可能需要下采样和通道匹配
|
||||
if stride != 1 or self.in_channels != mid_channels * Bottleneck.expansion:
|
||||
if stride != 1 or self.in_channels != out_channels:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.in_channels, mid_channels * Bottleneck.expansion,
|
||||
nn.Conv2d(self.in_channels, out_channels,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(mid_channels * Bottleneck.expansion),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
layers.append(
|
||||
Bottleneck(self.in_channels, mid_channels, stride, downsample)
|
||||
)
|
||||
self.in_channels = mid_channels * Bottleneck.expansion
|
||||
layers.append(BasicBlock(self.in_channels, out_channels, stride, downsample))
|
||||
self.in_channels = out_channels
|
||||
|
||||
# 后续块
|
||||
for _ in range(1, num_blocks):
|
||||
layers.append(
|
||||
Bottleneck(self.in_channels, mid_channels)
|
||||
)
|
||||
layers.append(BasicBlock(self.in_channels, out_channels))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
|
@ -128,13 +96,14 @@ class Net(nn.Module):
|
|||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
x = self.stage5(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
|
|
|||
148
README.md
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
> 同济大学 Python 人工智能程序设计课程小组作业
|
||||
|
||||
基于自定义 ResNet 风格 Bottleneck 架构的 CNN 模型(约 80M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。
|
||||
基于 ResNet-34 架构的 CNN 模型(约 21M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -18,6 +18,7 @@
|
|||
- [文件说明](#文件说明)
|
||||
- [目录结构](#目录结构)
|
||||
- [训练细节](#训练细节)
|
||||
- [评估与可视化](#评估与可视化)
|
||||
- [许可证](#许可证)
|
||||
|
||||
---
|
||||
|
|
@ -25,7 +26,7 @@
|
|||
## 项目特点
|
||||
|
||||
- **四类垃圾分类**:厨余垃圾(1)、可回收物(2)、其他垃圾(3)、有害垃圾(4)
|
||||
- **自定义 ResNet Bottleneck 架构**:约 80M 参数,50 层深度残差网络
|
||||
- **ResNet-34 架构**:约 21M 参数,34 层深度残差网络,含 Dropout 正则化
|
||||
- **数据增强**:训练时使用随机裁剪、水平翻转、旋转、色彩抖动
|
||||
- **Macro-F1 评估**:采用宏平均 F1 分数作为主要评估指标,兼顾各类别表现
|
||||
- **类别加权损失**:自动计算类别权重,缓解类别不平衡问题
|
||||
|
|
@ -35,28 +36,27 @@
|
|||
|
||||
## 模型架构
|
||||
|
||||
模型基于残差网络(ResNet)的 Bottleneck 构建块设计。
|
||||
模型基于标准 ResNet-34 架构,使用 BasicBlock 构建。
|
||||
|
||||
### Bottleneck 块
|
||||
### BasicBlock 块
|
||||
|
||||
每个 Bottleneck 块包含三个卷积层:
|
||||
每个 BasicBlock 包含两个 3x3 卷积层 + 跳跃连接:
|
||||
|
||||
| 层 | 卷积 | 作用 |
|
||||
|---|---|---|
|
||||
| 1x1 Conv | 降维 | 减少通道数,降低计算量 |
|
||||
| 3x3 Conv | 特征提取 | 核心卷积操作 |
|
||||
| 1x1 Conv | 升维 (x4) | 恢复通道数至输入的 4 倍 |
|
||||
| 3x3 Conv | 特征提取 | 第一层卷积 |
|
||||
| 3x3 Conv | 特征提取 | 第二层卷积 |
|
||||
|
||||
### 网络结构
|
||||
|
||||
| 阶段 | 块数 | 输出通道数 | 说明 |
|
||||
|---|---|---|---|
|
||||
| 初始层 | - | 64 | 7x7 Conv, stride=2 + MaxPool |
|
||||
| Stage 1 | 3 | 256 | 第一个残差阶段 |
|
||||
| Stage 2 | 4 | 512 | - |
|
||||
| Stage 3 | 14 | 1024 | 最深阶段(比 ResNet-50 加深) |
|
||||
| Stage 4 | 3 | 2048 | 最终残差阶段 |
|
||||
| 分类头 | - | 4 | 全局平均池化 + 全连接层 |
|
||||
| Layer1 | 3 | 64 | 第一个残差阶段 |
|
||||
| Layer2 | 4 | 128 | - |
|
||||
| Layer3 | 6 | 256 | - |
|
||||
| Layer4 | 3 | 512 | 最终残差阶段 |
|
||||
| 分类头 | - | 4 | 全局平均池化 + Dropout + 全连接层 |
|
||||
|
||||
## 数据集
|
||||
|
||||
|
|
@ -75,16 +75,14 @@
|
|||
|
||||
## 环境要求
|
||||
|
||||
本项目无 `requirements.txt`,需手动安装以下依赖:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
- Python 3.8+
|
||||
- PyTorch(推荐 1.10+)
|
||||
- torchvision
|
||||
- tqdm
|
||||
- matplotlib
|
||||
- pandas
|
||||
- Pillow
|
||||
- torchsummary
|
||||
> **注意**:`requirements.txt` 不锁定 PyTorch 的 CUDA / XPU 版本,请根据硬件自行安装对应版本,例如:
|
||||
> - NVIDIA GPU:`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121`
|
||||
> - Intel GPU (XPU):`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121` 安装
|
||||
> - CPU:`pip install torch torchvision` 即可
|
||||
|
||||
## 快速开始
|
||||
|
||||
|
|
@ -100,6 +98,20 @@
|
|||
python Train.py
|
||||
```
|
||||
|
||||
3. **微调模型**(可选,冻结浅层、微调深层):
|
||||
|
||||
```bash
|
||||
python Finetune.py
|
||||
```
|
||||
|
||||
4. **评估与可视化**:
|
||||
|
||||
```bash
|
||||
python Evaluate.py # 混淆矩阵、ROC 曲线、PR 曲线
|
||||
python Curve.py # 训练过程的 loss/f1/acc/lr 曲线
|
||||
python baseline/compare_models.py # 多模型基线对比(ROC/PR/准确率)
|
||||
```
|
||||
|
||||
> **注意**:
|
||||
> - 数据目录默认为 `../trash_division_data/ultimate_4_class/`,需先运行合并脚本
|
||||
> - Windows 系统需将 `num_workers` 设为 `0`(参见 `Dataloader.py` 和 `Train.py`)
|
||||
|
|
@ -110,27 +122,49 @@
|
|||
| 文件 | 功能 |
|
||||
|---|---|
|
||||
| `Train.py` | 训练主脚本,包含训练循环、验证、评估 |
|
||||
| `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 |
|
||||
| `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 |
|
||||
| `Model.py` | 模型定义,Bottleneck 残差块 + Net 主模型 |
|
||||
| `Model.py` | 模型定义,ResNet-34(BasicBlock)+ Dropout |
|
||||
| `Merge_classes.py` | 数据集预处理,265 类合并为 4 类 |
|
||||
| `best_model.pth` | 训练好的最佳模型权重(约 125 MB) |
|
||||
| `AGENTS.md` | AI 助手指南(开发辅助) |
|
||||
| `Evaluate.py` | 模型评估,绘制混淆矩阵、ROC 曲线、PR 曲线 |
|
||||
| `Curve.py` | 训练曲线绘制,从 CSV 读取并绘制 loss/f1/acc/lr 曲线 |
|
||||
| `baseline/VGG_KNN.py` | VGG16 预训练特征提取 + KNN 四分类基线 |
|
||||
| `baseline/ResNet34_Pretrained_10pct.py` | ResNet-34 ImageNet 预训练 + 10% 数据微调 |
|
||||
| `baseline/HOG_Baseline.py` | HOG + 颜色直方图 + LogisticRegression(纯传统 CV) |
|
||||
| `baseline/compare_models.py` | 多模型对比(ROC / PR 曲线 + 准确率柱状图) |
|
||||
| `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr |
|
||||
| `best_model.pth` | 训练好的最佳模型权重(约 125 MB,不纳入版本控制) |
|
||||
| `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 |
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
trash-division/
|
||||
├── AGENTS.md # AI 助手指南
|
||||
├── best_model.pth # 最佳模型权重
|
||||
├── baseline/ # 基线模型目录
|
||||
│ ├── VGG_KNN.py # VGG16 + KNN 分类脚本
|
||||
│ ├── ResNet34_Pretrained_10pct.py # ResNet-34 ImageNet 预训练 + 10% 微调
|
||||
│ ├── HOG_Baseline.py # HOG + LogisticRegression 纯传统基线
|
||||
│ ├── compare_models.py # 多模型对比脚本
|
||||
│ ├── roc_comparison.png # 多模型 ROC 对比(compare_models.py 输出)
|
||||
│ ├── pr_comparison.png # 多模型 PR 对比(compare_models.py 输出)
|
||||
│ └── accuracy_bar.png # 多模型准确率对比(compare_models.py 输出)
|
||||
├── best_model.pth # 最佳模型权重(不纳入版本控制)
|
||||
├── Curve.py # 训练曲线绘制脚本
|
||||
├── Dataloader.py # 数据加载模块
|
||||
├── Evaluate.py # 模型评估可视化脚本
|
||||
├── Finetune.py # 微调脚本
|
||||
├── .gitattributes # Git 属性配置
|
||||
├── LICENSE # MIT 许可证
|
||||
├── Merge_classes.py # 数据集预处理脚本
|
||||
├── Model.py # 模型定义
|
||||
├── README.md # 项目说明(本文件)
|
||||
├── THIRD_PARTY_LICENSES.md # 第三方许可证声明
|
||||
└── Train.py # 训练主脚本
|
||||
├── Train.py # 训练主脚本
|
||||
├── training_log.csv # 训练日志
|
||||
├── confusion_matrix.png # 混淆矩阵(Evaluate.py 输出)
|
||||
├── roc_curve.png # ROC 曲线(Evaluate.py 输出)
|
||||
├── pr_curve.png # PR 曲线(Evaluate.py 输出)
|
||||
└── training_curves.png # 训练曲线(Curve.py 输出)
|
||||
```
|
||||
|
||||
## 训练细节
|
||||
|
|
@ -150,6 +184,64 @@ trash-division/
|
|||
|
||||
训练时数据增强管线:RandomResizedCrop(256, scale=(0.8, 1.0)) + RandomHorizontalFlip(p=0.5) + RandomRotation(+-15 deg) + ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)
|
||||
|
||||
## 评估与可视化
|
||||
|
||||
训练完成后,`training_log.csv` 会记录每个 epoch 的训练/验证指标。以下两个脚本用于可视化分析:
|
||||
|
||||
### Evaluate.py — 模型评估
|
||||
|
||||
在验证集上推理,生成三张评估图表:
|
||||
|
||||
```bash
|
||||
python Evaluate.py
|
||||
```
|
||||
|
||||
脚本顶部的 `MODEL_PATH`、`DATA_ROOT`、`BATCH_SIZE`、`NUM_WORKERS` 可按需修改。
|
||||
|
||||
**混淆矩阵**
|
||||
|
||||

|
||||
|
||||
**ROC 曲线**
|
||||
|
||||

|
||||
|
||||
**PR 曲线**
|
||||
|
||||

|
||||
|
||||
### Curve.py — 训练曲线
|
||||
|
||||
从 `training_log.csv` 读取训练日志,绘制四张子图:
|
||||
|
||||
```bash
|
||||
python Curve.py
|
||||
```
|
||||
|
||||

|
||||
|
||||
### 基线模型对比
|
||||
|
||||
`compare_models.py` 对所有模型在验证集上统一评估,生成三张对比图表:
|
||||
|
||||
```bash
|
||||
python baseline/compare_models.py
|
||||
```
|
||||
|
||||
对比阵容:ResNet-34、ResNet-34 (10% Fine-tune)、VGG16 + KNN、HOG + LogisticRegression。
|
||||
|
||||
**ROC 曲线对比**
|
||||
|
||||

|
||||
|
||||
**PR 曲线对比**
|
||||
|
||||

|
||||
|
||||
**准确率柱状图**
|
||||
|
||||

|
||||
|
||||
## 许可证
|
||||
|
||||
本项目主代码采用 [MIT 许可证](LICENSE)。
|
||||
|
|
|
|||
145
baseline/HOG_Baseline.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""
|
||||
baseline/HOG_Baseline.py
|
||||
HOG + 颜色直方图特征提取 + LogisticRegression 四分类
|
||||
纯传统 CV/ML 基线,零神经网络依赖
|
||||
可独立运行,也可被 compare_models.py 导入
|
||||
author: yukun-hh
|
||||
date: 2026-5-14
|
||||
"""
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
|
||||
from skimage.feature import hog
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import (
|
||||
accuracy_score, f1_score,
|
||||
confusion_matrix, ConfusionMatrixDisplay,
|
||||
classification_report,
|
||||
)
|
||||
|
||||
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||||
matplotlib.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# ============================================================
|
||||
# ★★★ 可配置参数 ★★★
|
||||
# ============================================================
|
||||
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
|
||||
IMAGE_SIZE = 128
|
||||
HOG_ORIENTATIONS = 9
|
||||
HOG_PIXELS_PER_CELL = (8, 8)
|
||||
HOG_CELLS_PER_BLOCK = (2, 2)
|
||||
COLOR_BINS = 32
|
||||
# ============================================================
|
||||
|
||||
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
|
||||
NUM_CLASSES = 4
|
||||
|
||||
|
||||
def extract_hog_color(image):
|
||||
img = image.convert('RGB').resize((IMAGE_SIZE, IMAGE_SIZE))
|
||||
arr = np.array(img, dtype=np.float64) / 255.0
|
||||
|
||||
hog_feat = hog(arr, orientations=HOG_ORIENTATIONS,
|
||||
pixels_per_cell=HOG_PIXELS_PER_CELL,
|
||||
cells_per_block=HOG_CELLS_PER_BLOCK,
|
||||
channel_axis=2, feature_vector=True)
|
||||
|
||||
color_feat = []
|
||||
for c in range(3):
|
||||
hist, _ = np.histogram(arr[:, :, c], bins=COLOR_BINS, range=(0, 1))
|
||||
color_feat.append(hist)
|
||||
color_feat = np.concatenate(color_feat)
|
||||
|
||||
return np.concatenate([hog_feat, color_feat])
|
||||
|
||||
|
||||
class HOGLRBaseline:
|
||||
def __init__(self, data_root=DATA_ROOT, image_size=IMAGE_SIZE):
|
||||
self.data_root = data_root
|
||||
self.image_size = image_size
|
||||
self.clf = LogisticRegression(
|
||||
C=1.0, max_iter=1000, solver='lbfgs', n_jobs=-1,
|
||||
)
|
||||
|
||||
def _load_data(self, split):
|
||||
dir_path = os.path.join(self.data_root, split)
|
||||
features, labels = [], []
|
||||
for class_id in range(1, NUM_CLASSES + 1):
|
||||
class_dir = os.path.join(dir_path, str(class_id))
|
||||
if not os.path.isdir(class_dir):
|
||||
continue
|
||||
files = sorted(os.listdir(class_dir))
|
||||
for fname in tqdm(files, desc=f'{split}/class_{class_id}'):
|
||||
fpath = os.path.join(class_dir, fname)
|
||||
try:
|
||||
with Image.open(fpath) as img:
|
||||
feat = extract_hog_color(img)
|
||||
features.append(feat)
|
||||
labels.append(class_id - 1)
|
||||
except Exception:
|
||||
pass
|
||||
print(f" {split}: {len(features)} 张")
|
||||
return np.array(features, dtype=np.float32), np.array(labels)
|
||||
|
||||
def fit(self, train_dir=None):
|
||||
if train_dir is None:
|
||||
train_dir = 'train'
|
||||
print(" 提取训练集 HOG 特征 ...")
|
||||
X, y = self._load_data(train_dir)
|
||||
self.clf.fit(X, y)
|
||||
|
||||
def predict(self, val_dir=None):
|
||||
if val_dir is None:
|
||||
val_dir = 'val'
|
||||
print(" 提取验证集 HOG 特征 ...")
|
||||
X, y = self._load_data(val_dir)
|
||||
preds = self.clf.predict(X)
|
||||
probs = self.clf.predict_proba(X)
|
||||
return y, preds, probs
|
||||
|
||||
|
||||
# ============================================================
|
||||
# compare_models.py 导入接口
|
||||
# ============================================================
|
||||
|
||||
def get_hog_lr_preds(train_loader, val_loader, device):
|
||||
baseline = HOGLRBaseline()
|
||||
baseline.fit('train')
|
||||
return baseline.predict('val')
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 独立运行入口
|
||||
# ============================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
out_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
print("HOG + LogisticRegression 基线")
|
||||
baseline = HOGLRBaseline()
|
||||
|
||||
baseline.fit('train')
|
||||
y_true, y_preds, y_probs = baseline.predict('val')
|
||||
|
||||
acc = accuracy_score(y_true, y_preds)
|
||||
macro_f1 = f1_score(y_true, y_preds, average='macro')
|
||||
print(f"\n验证集 Accuracy: {acc:.4f}")
|
||||
print(f"验证集 Macro-F1: {macro_f1:.4f}")
|
||||
print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}")
|
||||
|
||||
cm = confusion_matrix(y_true, y_preds)
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
|
||||
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
|
||||
ax.set_title('HOG + LogisticRegression 混淆矩阵', fontsize=14)
|
||||
plt.tight_layout()
|
||||
cm_path = os.path.join(out_dir, 'hog_lr_confusion_matrix.png')
|
||||
plt.savefig(cm_path, dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print(f"混淆矩阵已保存: {cm_path}")
|
||||
278
baseline/ResNet34_Pretrained_10pct.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""
|
||||
baseline/ResNet34_Pretrained_10pct.py
|
||||
ResNet-34 ImageNet 预训练权重 + 10% 训练集微调
|
||||
可独立运行训练,也可被 compare_models.py 导入
|
||||
author: yukun-hh
|
||||
date: 2026-5-14
|
||||
"""
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torchvision import models, transforms
|
||||
from tqdm import tqdm
|
||||
import csv
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
|
||||
from Dataloader import RobustImageFolder
|
||||
|
||||
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||||
matplotlib.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# ============================================================
|
||||
# ★★★ 可配置参数 ★★★
|
||||
# ============================================================
|
||||
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
|
||||
BATCH_SIZE = 32
|
||||
IMAGE_SIZE = 256
|
||||
NUM_WORKERS = 4
|
||||
EPOCHS = 30
|
||||
LR = 0.001
|
||||
TRAIN_PCT = 0.1
|
||||
SEED = 42
|
||||
DROPOUT = 0.3
|
||||
MODEL_SAVE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct.pth')
|
||||
LOG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct_log.csv')
|
||||
# ============================================================
|
||||
|
||||
NUM_CLASSES = 4
|
||||
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
|
||||
|
||||
|
||||
class PretrainedResNet34(nn.Module):
|
||||
def __init__(self, num_classes=NUM_CLASSES, dropout=DROPOUT):
|
||||
super().__init__()
|
||||
self.backbone = models.resnet34(weights='IMAGENET1K_V1')
|
||||
in_features = self.backbone.fc.in_features
|
||||
self.backbone.fc = nn.Identity()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc = nn.Linear(in_features, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def freeze_early_layers(self):
|
||||
for param in self.backbone.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.backbone.bn1.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.backbone.layer1.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.backbone.layer2.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def print_trainable_info(self):
|
||||
frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
|
||||
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
total = frozen + trainable
|
||||
print(f" 冻结参数: {frozen:,} 可训练参数: {trainable:,} ({100.*trainable/total:.1f}%)")
|
||||
|
||||
|
||||
def compute_macro_f1(predicted, targets, num_classes=NUM_CLASSES):
|
||||
tp = torch.zeros(num_classes, device=predicted.device)
|
||||
fp = torch.zeros(num_classes, device=predicted.device)
|
||||
fn = torch.zeros(num_classes, device=predicted.device)
|
||||
for c in range(num_classes):
|
||||
tp[c] = ((predicted == c) & (targets == c)).sum()
|
||||
fp[c] = ((predicted == c) & (targets != c)).sum()
|
||||
fn[c] = ((predicted != c) & (targets == c)).sum()
|
||||
precision = tp / (tp + fp + 1e-8)
|
||||
recall = tp / (tp + fn + 1e-8)
|
||||
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
||||
return f1.mean().item()
|
||||
|
||||
|
||||
def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
|
||||
model.train()
|
||||
running_loss, correct, total = 0.0, 0, 0
|
||||
all_preds, all_labels = [], []
|
||||
pbar = tqdm(loader, desc=f'Epoch {epoch+1} [Train]')
|
||||
for images, labels in pbar:
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss += loss.item() * images.size(0)
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
all_preds.append(predicted)
|
||||
all_labels.append(labels)
|
||||
batch_f1 = compute_macro_f1(predicted, labels)
|
||||
pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}',
|
||||
'Acc': f'{100.*correct/total:.2f}%'})
|
||||
epoch_loss = running_loss / total
|
||||
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
|
||||
epoch_acc = 100. * correct / total
|
||||
return epoch_loss, epoch_f1, epoch_acc
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(model, loader, criterion, device):
|
||||
model.eval()
|
||||
running_loss, correct, total = 0.0, 0, 0
|
||||
all_preds, all_labels = [], []
|
||||
for images, labels in tqdm(loader, desc='[Validate]'):
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
running_loss += loss.item() * images.size(0)
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
all_preds.append(predicted)
|
||||
all_labels.append(labels)
|
||||
epoch_loss = running_loss / total
|
||||
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
|
||||
epoch_acc = 100. * correct / total
|
||||
return epoch_loss, epoch_f1, epoch_acc
|
||||
|
||||
|
||||
def train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR):
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=lr, momentum=0.9, weight_decay=1e-4)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
|
||||
history = {'train_loss': [], 'train_f1': [], 'train_acc': [],
|
||||
'val_loss': [], 'val_f1': [], 'val_acc': []}
|
||||
best_val_f1 = 0.0
|
||||
|
||||
log_file = open(LOG_PATH, 'w', newline='')
|
||||
log_writer = csv.writer(log_file)
|
||||
log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc',
|
||||
'val_loss', 'val_f1', 'val_acc', 'lr', 'best'])
|
||||
|
||||
for epoch in range(epochs):
|
||||
print(f'\n{"="*50}')
|
||||
print(f'Epoch {epoch+1}/{epochs}')
|
||||
|
||||
train_loss, train_f1, train_acc = train_one_epoch(
|
||||
model, train_loader, criterion, optimizer, device, epoch)
|
||||
val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device)
|
||||
scheduler.step()
|
||||
|
||||
history['train_loss'].append(train_loss)
|
||||
history['train_f1'].append(train_f1)
|
||||
history['train_acc'].append(train_acc)
|
||||
history['val_loss'].append(val_loss)
|
||||
history['val_f1'].append(val_f1)
|
||||
history['val_acc'].append(val_acc)
|
||||
|
||||
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}')
|
||||
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}')
|
||||
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
|
||||
|
||||
best_mark = ''
|
||||
if val_f1 > best_val_f1:
|
||||
best_val_f1 = val_f1
|
||||
torch.save(model.state_dict(), MODEL_SAVE_PATH)
|
||||
best_mark = 'best'
|
||||
print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})')
|
||||
|
||||
lr_val = optimizer.param_groups[0]['lr']
|
||||
log_writer.writerow([epoch+1, train_loss, train_f1, train_acc,
|
||||
val_loss, val_f1, val_acc, lr_val, best_mark])
|
||||
log_file.flush()
|
||||
|
||||
log_file.close()
|
||||
print(f'\n训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}')
|
||||
return history
|
||||
|
||||
|
||||
# ============================================================
|
||||
# compare_models.py 导入接口
|
||||
# ============================================================
|
||||
|
||||
def get_resnet34_10pct_preds(train_loader, val_loader, device):
|
||||
model = PretrainedResNet34(num_classes=NUM_CLASSES)
|
||||
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu'))
|
||||
model = model.to(device).eval()
|
||||
|
||||
y_true, y_preds, y_probs = [], [], []
|
||||
with torch.no_grad():
|
||||
for images, labels in tqdm(val_loader, desc='ResNet-34 (10%)'):
|
||||
images, labels = images.to(device), labels
|
||||
logits = model(images)
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
preds = probs.argmax(dim=1)
|
||||
y_true.append(labels.numpy())
|
||||
y_preds.append(preds.cpu().numpy())
|
||||
y_probs.append(probs.cpu().numpy())
|
||||
return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 独立训练入口
|
||||
# ============================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
random.seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available()
|
||||
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
|
||||
else 'cpu')
|
||||
print(f"Device: {device}")
|
||||
|
||||
val_transform = transforms.Compose([
|
||||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
full_train_dataset = RobustImageFolder(
|
||||
root=os.path.join(DATA_ROOT, 'train'),
|
||||
transform=train_transform,
|
||||
)
|
||||
val_dataset = RobustImageFolder(
|
||||
root=os.path.join(DATA_ROOT, 'val'),
|
||||
transform=val_transform,
|
||||
)
|
||||
|
||||
n_train = len(full_train_dataset)
|
||||
n_subset = max(1, int(n_train * TRAIN_PCT))
|
||||
indices = random.sample(range(n_train), n_subset)
|
||||
train_dataset = Subset(full_train_dataset, indices)
|
||||
print(f"训练集: {len(train_dataset)} / {n_train} ({TRAIN_PCT*100:.0f}%)")
|
||||
print(f"验证集: {len(val_dataset)}")
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
|
||||
shuffle=True, num_workers=NUM_WORKERS,
|
||||
pin_memory=True, drop_last=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
|
||||
shuffle=False, num_workers=NUM_WORKERS,
|
||||
pin_memory=True, drop_last=False)
|
||||
|
||||
model = PretrainedResNet34(num_classes=NUM_CLASSES, dropout=DROPOUT)
|
||||
model.freeze_early_layers()
|
||||
model.print_trainable_info()
|
||||
model = model.to(device)
|
||||
|
||||
history = train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR)
|
||||
|
||||
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu'))
|
||||
print(f"模型已保存: {MODEL_SAVE_PATH}")
|
||||
print(f"训练日志已保存: {LOG_PATH}")
|
||||
145
baseline/VGG_KNN.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""
|
||||
baseline/VGG_KNN.py
|
||||
VGG16 预训练模型特征提取 + KNN 四分类基线
|
||||
可独立运行,也可被 compare_models.py 导入复用
|
||||
author: yukun-hh
|
||||
date: 2026-5-14
|
||||
"""
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import models, transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from sklearn.metrics import (
|
||||
accuracy_score, f1_score,
|
||||
confusion_matrix, ConfusionMatrixDisplay,
|
||||
classification_report,
|
||||
)
|
||||
|
||||
from Dataloader import RobustImageFolder
|
||||
|
||||
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||||
matplotlib.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
|
||||
|
||||
|
||||
def load_vgg16_extractor(device):
|
||||
try:
|
||||
model = models.vgg16(weights='IMAGENET1K_V1')
|
||||
except TypeError:
|
||||
model = models.vgg16(pretrained=True)
|
||||
model.classifier = nn.Identity()
|
||||
model = model.to(device).eval()
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def extract_features(model, loader, device):
|
||||
model.eval()
|
||||
all_features = []
|
||||
all_labels = []
|
||||
with torch.no_grad():
|
||||
for images, labels in tqdm(loader, desc='Extracting features'):
|
||||
images = images.to(device)
|
||||
feats = model(images)
|
||||
all_features.append(feats.cpu().numpy())
|
||||
all_labels.append(labels.numpy())
|
||||
return np.concatenate(all_features), np.concatenate(all_labels)
|
||||
|
||||
|
||||
class VGGKNNBaseline:
|
||||
def __init__(self, k=5, device='cpu',
|
||||
data_root='../trash_division_data/ultimate_4_class/',
|
||||
image_size=256, batch_size=32, num_workers=4):
|
||||
self.k = k
|
||||
self.device = device
|
||||
self.data_root = data_root
|
||||
self.image_size = image_size
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.extractor = load_vgg16_extractor(device)
|
||||
self.knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
|
||||
|
||||
def _get_loader(self, split):
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((self.image_size, self.image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
dataset = RobustImageFolder(
|
||||
root=os.path.join(self.data_root, split),
|
||||
transform=transform,
|
||||
)
|
||||
print(f" {split}: {len(dataset)} 张")
|
||||
return DataLoader(dataset, batch_size=self.batch_size,
|
||||
shuffle=False, num_workers=self.num_workers,
|
||||
pin_memory=True, drop_last=False)
|
||||
|
||||
def fit(self, train_loader=None):
|
||||
if train_loader is None:
|
||||
train_loader = self._get_loader('train')
|
||||
print(" 提取训练集特征 ...")
|
||||
train_feats, train_labels = extract_features(self.extractor, train_loader, self.device)
|
||||
self.knn.fit(train_feats, train_labels)
|
||||
|
||||
def predict(self, val_loader=None):
|
||||
if val_loader is None:
|
||||
val_loader = self._get_loader('val')
|
||||
print(" 提取验证集特征 ...")
|
||||
val_feats, val_labels = extract_features(self.extractor, val_loader, self.device)
|
||||
preds = self.knn.predict(val_feats)
|
||||
probs = self.knn.predict_proba(val_feats)
|
||||
return val_labels, preds, probs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
DATA_ROOT = '../trash_division_data/ultimate_4_class/'
|
||||
BATCH_SIZE = 32
|
||||
IMAGE_SIZE = 256
|
||||
NUM_WORKERS = 4
|
||||
K = 5
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available()
|
||||
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
|
||||
else 'cpu')
|
||||
print(f"Device: {device}")
|
||||
|
||||
baseline = VGGKNNBaseline(k=K, device=device,
|
||||
data_root=DATA_ROOT, image_size=IMAGE_SIZE,
|
||||
batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
|
||||
|
||||
train_loader = baseline._get_loader('train')
|
||||
val_loader = baseline._get_loader('val')
|
||||
|
||||
baseline.fit(train_loader)
|
||||
y_true, y_preds, y_probs = baseline.predict(val_loader)
|
||||
|
||||
acc = accuracy_score(y_true, y_preds)
|
||||
macro_f1 = f1_score(y_true, y_preds, average='macro')
|
||||
print(f"\n验证集 Accuracy: {acc:.4f}")
|
||||
print(f"验证集 Macro-F1: {macro_f1:.4f}")
|
||||
print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}")
|
||||
|
||||
cm = confusion_matrix(y_true, y_preds)
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
|
||||
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
|
||||
ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14)
|
||||
plt.tight_layout()
|
||||
out_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vgg_knn_confusion_matrix.png')
|
||||
plt.savefig(out_path, dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print(f"混淆矩阵已保存: {out_path}")
|
||||
1
baseline/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# baseline package
|
||||
BIN
baseline/accuracy_bar.png
Normal file
|
After Width: | Height: | Size: 45 KiB |
249
baseline/compare_models.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
"""
|
||||
baseline/compare_models.py
|
||||
多模型对比:ROC 曲线 + 准确率柱状图
|
||||
添加新模型只需在 MODELS 列表加一行,无需修改绘图代码
|
||||
author: yukun-hh
|
||||
date: 2026-5-14
|
||||
"""
|
||||
import sys, os, re
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from sklearn.metrics import (
|
||||
roc_curve, auc, accuracy_score,
|
||||
precision_recall_curve, average_precision_score,
|
||||
)
|
||||
|
||||
from Model import Net
|
||||
from Dataloader import RobustImageFolder
|
||||
from baseline.VGG_KNN import VGGKNNBaseline
|
||||
from baseline.ResNet34_Pretrained_10pct import get_resnet34_10pct_preds
|
||||
from baseline.HOG_Baseline import get_hog_lr_preds
|
||||
|
||||
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||||
matplotlib.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# ============================================================
|
||||
# ★★★ 可配置参数 ★★★
|
||||
# ============================================================
|
||||
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
|
||||
BATCH_SIZE = 32
|
||||
IMAGE_SIZE = 256
|
||||
NUM_WORKERS = 4
|
||||
K_KNN = 5
|
||||
# ============================================================
|
||||
|
||||
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
|
||||
NUM_CLASSES = 4
|
||||
|
||||
# ============================================================
|
||||
# 预测函数 — 每个函数签名: (train_loader, val_loader, device) -> (y_true, y_preds, y_probs)
|
||||
# ============================================================
|
||||
|
||||
def get_resnet34_preds(train_loader, val_loader, device):
|
||||
model = Net(num_classes=NUM_CLASSES)
|
||||
state_dict = torch.load('../best_model.pth', map_location='cpu')
|
||||
if 'model_state_dict' in state_dict:
|
||||
state_dict = state_dict['model_state_dict']
|
||||
elif 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
model.load_state_dict(state_dict)
|
||||
model = model.to(device).eval()
|
||||
|
||||
y_true, y_preds, y_probs = [], [], []
|
||||
with torch.no_grad():
|
||||
for images, labels in tqdm(val_loader, desc='ResNet-34'):
|
||||
images, labels = images.to(device), labels
|
||||
logits = model(images)
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
preds = probs.argmax(dim=1)
|
||||
y_true.append(labels.numpy())
|
||||
y_preds.append(preds.cpu().numpy())
|
||||
y_probs.append(probs.cpu().numpy())
|
||||
return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs)
|
||||
|
||||
|
||||
def get_vgg_knn_preds(train_loader, val_loader, device):
|
||||
baseline = VGGKNNBaseline(k=K_KNN, device=device)
|
||||
baseline.fit(train_loader)
|
||||
return baseline.predict(val_loader)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# ★ 模型注册表 — 添加新模型只需在这里加一行 ★
|
||||
# ============================================================
|
||||
|
||||
MODELS = [
|
||||
('ResNet-34', get_resnet34_preds),
|
||||
('ResNet-34 (10% Fine-tune)', get_resnet34_10pct_preds),
|
||||
('VGG16 + KNN (K=5)', get_vgg_knn_preds),
|
||||
('HOG + LogisticRegression', get_hog_lr_preds),
|
||||
# 未来轻松扩展示例:
|
||||
# ('ResNet-18 (pretrained)', get_resnet18_preds),
|
||||
# ('ResNet-50 (pretrained)', get_resnet50_preds),
|
||||
# ('ResNet-34 (finetuned)', get_finetuned_preds),
|
||||
]
|
||||
|
||||
# ============================================================
|
||||
# 调色板 (扩展时无需修改)
|
||||
# ============================================================
|
||||
COLORS = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b',
|
||||
'#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
|
||||
|
||||
|
||||
def compute_macro_roc(y_true, y_probs):
|
||||
one_hot = np.eye(NUM_CLASSES)[y_true]
|
||||
fpr_dict, tpr_dict = {}, {}
|
||||
for c in range(NUM_CLASSES):
|
||||
fpr_dict[c], tpr_dict[c], _ = roc_curve(one_hot[:, c], y_probs[:, c])
|
||||
all_fpr = np.unique(np.concatenate([fpr_dict[c] for c in range(NUM_CLASSES)]))
|
||||
mean_tpr = np.zeros_like(all_fpr)
|
||||
for c in range(NUM_CLASSES):
|
||||
mean_tpr += np.interp(all_fpr, fpr_dict[c], tpr_dict[c])
|
||||
mean_tpr /= NUM_CLASSES
|
||||
macro_auc = auc(all_fpr, mean_tpr)
|
||||
return all_fpr, mean_tpr, macro_auc
|
||||
|
||||
|
||||
def compute_macro_pr(y_true, y_probs):
|
||||
one_hot = np.eye(NUM_CLASSES)[y_true]
|
||||
prec_dict, rec_dict = {}, {}
|
||||
for c in range(NUM_CLASSES):
|
||||
prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c])
|
||||
all_rec = np.linspace(0, 1, 200)
|
||||
mean_prec = np.zeros_like(all_rec)
|
||||
for c in range(NUM_CLASSES):
|
||||
mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1])
|
||||
mean_prec /= NUM_CLASSES
|
||||
macro_ap = average_precision_score(one_hot, y_probs, average='macro')
|
||||
return all_rec, mean_prec, macro_ap
|
||||
|
||||
|
||||
def sanitize_filename(name):
|
||||
return re.sub(r'[^\w\-_]', '_', name).strip('_')
|
||||
|
||||
|
||||
def preds_csv_path(out_dir, model_name):
|
||||
safe = sanitize_filename(model_name)
|
||||
return os.path.join(out_dir, f'{safe}_preds.csv')
|
||||
|
||||
|
||||
def save_preds_csv(path, y_true, y_preds, y_probs):
|
||||
header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES))
|
||||
data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs])
|
||||
np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f')
|
||||
|
||||
|
||||
def load_preds_csv(path):
|
||||
data = np.loadtxt(path, delimiter=',', skiprows=1)
|
||||
y_true = data[:, 0].astype(int)
|
||||
y_preds = data[:, 1].astype(int)
|
||||
y_probs = data[:, 2:2 + NUM_CLASSES]
|
||||
return y_true, y_preds, y_probs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
out_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available()
|
||||
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
|
||||
else 'cpu')
|
||||
print(f"Device: {device}")
|
||||
|
||||
val_transform = transforms.Compose([
|
||||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
train_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'train'),
|
||||
transform=val_transform)
|
||||
val_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'val'),
|
||||
transform=val_transform)
|
||||
print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}")
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False,
|
||||
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
|
||||
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
|
||||
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
|
||||
|
||||
# ———— 评估所有模型(有缓存则跳过)————
|
||||
results = {}
|
||||
for name, func in MODELS:
|
||||
print(f"\n{'='*50}")
|
||||
csv_path = preds_csv_path(out_dir, name)
|
||||
if os.path.exists(csv_path):
|
||||
print(f"加载缓存: {os.path.basename(csv_path)}")
|
||||
y_true, y_preds, y_probs = load_preds_csv(csv_path)
|
||||
else:
|
||||
print(f"评估: {name}")
|
||||
y_true, y_preds, y_probs = func(train_loader, val_loader, device)
|
||||
save_preds_csv(csv_path, y_true, y_preds, y_probs)
|
||||
print(f" 预测数据已保存: {os.path.basename(csv_path)}")
|
||||
acc = accuracy_score(y_true, y_preds)
|
||||
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
|
||||
rec, prec, macro_ap = compute_macro_pr(y_true, y_probs)
|
||||
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
|
||||
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc,
|
||||
'rec': rec, 'prec': prec, 'ap': macro_ap}
|
||||
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}")
|
||||
|
||||
# ———— ROC 对比图 ————
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
for i, (name, r) in enumerate(results.items()):
|
||||
color = COLORS[i % len(COLORS)]
|
||||
ax.plot(r['fpr'], r['tpr'], color=color, lw=2,
|
||||
label=f"{name} (AUC={r['auc']:.4f})")
|
||||
ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
|
||||
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
|
||||
ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
|
||||
ax.set_title('ROC Curve Comparison (Macro-Average)', fontsize=14)
|
||||
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
roc_path = os.path.join(out_dir, 'roc_comparison.png')
|
||||
plt.savefig(roc_path, dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print(f"\nROC 对比图已保存: {roc_path}")
|
||||
|
||||
# ———— PR 对比图 ————
|
||||
fig, ax = plt.subplots(figsize=(8, 7))
|
||||
for i, (name, r) in enumerate(results.items()):
|
||||
color = COLORS[i % len(COLORS)]
|
||||
ax.plot(r['rec'], r['prec'], color=color, lw=2,
|
||||
label=f"{name} (AP={r['ap']:.4f})")
|
||||
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
|
||||
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
|
||||
ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14)
|
||||
ax.legend(loc='lower left'); ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
pr_path = os.path.join(out_dir, 'pr_comparison.png')
|
||||
plt.savefig(pr_path, dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print(f"PR 对比图已保存: {pr_path}")
|
||||
|
||||
# ———— 准确率柱状图 ————
|
||||
names = list(results.keys())
|
||||
accs = [results[n]['acc'] for n in names]
|
||||
fig, ax = plt.subplots(figsize=(8, 5))
|
||||
bar_colors = [COLORS[i % len(COLORS)] for i in range(len(names))]
|
||||
bars = ax.bar(names, accs, color=bar_colors, edgecolor='white', linewidth=1.2)
|
||||
for bar, acc in zip(bars, accs):
|
||||
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
|
||||
f'{acc:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
|
||||
ax.set_ylim(min(accs) - 0.03, max(accs) * 1.08)
|
||||
ax.set_ylabel('Accuracy'); ax.set_title('Accuracy Comparison', fontsize=14)
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
plt.tight_layout()
|
||||
bar_path = os.path.join(out_dir, 'accuracy_bar.png')
|
||||
plt.savefig(bar_path, dpi=150, bbox_inches='tight')
|
||||
plt.show()
|
||||
print(f"准确率柱状图已保存: {bar_path}")
|
||||
BIN
baseline/pr_comparison.png
Normal file
|
After Width: | Height: | Size: 122 KiB |
BIN
baseline/roc_comparison.png
Normal file
|
After Width: | Height: | Size: 125 KiB |
BIN
confusion_matrix.png
Normal file
|
After Width: | Height: | Size: 51 KiB |
BIN
pr_curve.png
Normal file
|
After Width: | Height: | Size: 86 KiB |
10
requirements.txt
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
torch>=1.10
|
||||
torchvision>=0.11
|
||||
tqdm
|
||||
matplotlib
|
||||
pandas
|
||||
Pillow
|
||||
scikit-learn
|
||||
scikit-image
|
||||
numpy
|
||||
torchsummary
|
||||
BIN
roc_curve.png
Normal file
|
After Width: | Height: | Size: 105 KiB |
BIN
training_curves.png
Normal file
|
After Width: | Height: | Size: 296 KiB |
81
training_log.csv
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
epoch,train_loss,train_f1,train_acc,val_loss,val_f1,val_acc,lr,best
|
||||
1,1.0409312975676923,0.4329540729522705,48.04254100337675,1.1043149583345566,0.4398210048675537,48.66548042704626,0.004998072590601808,best
|
||||
2,0.9862563783744079,0.4695238769054413,52.5943680656054,0.9867177669753319,0.5062971115112305,58.397207774431976,0.00499229333433282,best
|
||||
3,0.9462850892451784,0.49421826004981995,55.40279787747226,1.0144445673589866,0.4907984733581543,53.15494114426499,0.004982671142387316,
|
||||
4,0.910117163585685,0.514958381652832,57.832097202122526,0.8787865286946395,0.5453917980194092,62.544483985765126,0.004969220851487844,best
|
||||
5,0.8786031692946986,0.5320333242416382,59.74282440906898,1.0686318878927787,0.4803737998008728,52.73063235696688,0.004951963201008076,
|
||||
6,0.8518873820889128,0.5481140613555908,61.51938615533044,0.7650798693964196,0.6073676347732544,68.98439638653161,0.004930924800994191,best
|
||||
7,0.8256270786701512,0.5604796409606934,62.90249638205499,0.8796401012116773,0.5789190530776978,62.63345195729537,0.004906138091134118,
|
||||
8,0.8003699506646013,0.5742803812026978,64.3014351181862,0.9246643470014378,0.5521833896636963,60.88831097727895,0.004877641290737884,
|
||||
9,0.780536473588097,0.5827116966247559,65.25642185238785,0.8404132533719564,0.5876226425170898,65.89789214344374,0.00484547833980621,
|
||||
10,0.7604798049209087,0.595557451248169,66.54079232995659,0.9228097118810533,0.564703643321991,60.77881193539557,0.004809698831278217,
|
||||
11,0.7410275131047088,0.6043155789375305,67.35784491075735,0.7576604621266131,0.6295210123062134,69.83985765124555,0.0047703579345627035,best
|
||||
12,0.7195374228732343,0.6127941608428955,68.07766521948867,0.9624476881507583,0.5630610585212708,61.59321105940323,0.00472751631047092,
|
||||
13,0.6997139808122973,0.6210756301879883,68.85175470332851,0.7615812296349155,0.6177672147750854,69.12127018888584,0.004681240017681994,
|
||||
14,0.6824904908837182,0.630592942237854,69.65448625180898,0.6715762626299165,0.6534035205841064,73.48754448398577,0.004631600410885231,best
|
||||
15,0.6653590450468583,0.6379610300064087,70.39088880849012,0.694461440047988,0.6517682075500488,73.0906104571585,0.004578674030756364,
|
||||
16,0.6514209577758935,0.6478185653686523,71.27879281234925,0.7036816360785346,0.6470745801925659,71.76977826444019,0.004522542485937369,
|
||||
17,0.6330186040776395,0.6530008316040039,71.7935962373372,0.7222367905930823,0.6418735980987549,71.07856556255133,0.004463292327201863,
|
||||
18,0.6166394593717968,0.6634106040000916,72.6038651712494,0.6067476719332303,0.6886636018753052,77.49110320284697,0.004401014914000078,best
|
||||
19,0.5973944908975692,0.6721534729003906,73.47367945007235,0.6952472055509845,0.6622275114059448,71.79715302491103,0.004335806273589214,
|
||||
20,0.5820678306183721,0.6758297681808472,73.73145803183792,0.7708474785342401,0.6217234134674072,68.74486723241172,0.004267766952966369,
|
||||
21,0.5650806297110982,0.6851130723953247,74.5741377231066,0.7461620579141478,0.6384793519973755,71.35231316725978,0.004197001863832355,
|
||||
22,0.5500074683958588,0.6915749311447144,75.099493487699,0.6420613189380593,0.672810435295105,74.9452504790583,0.00412362012082546,
|
||||
23,0.5367840825001858,0.6979560852050781,75.66102870236372,0.6252713002082977,0.6949211359024048,75.32849712565014,0.0040477348732745845,best
|
||||
24,0.5234906795055925,0.7052106857299805,76.26025084418717,0.7471277477021352,0.6447888016700745,69.70982753900904,0.003969463130731182,
|
||||
25,0.5044557179049829,0.7132176160812378,76.91148094548963,0.6325626891507145,0.6857903003692627,75.52012044894607,0.0038889255825490052,
|
||||
26,0.4938347232885195,0.7174828052520752,77.23784973468403,0.5635758755127437,0.70375657081604,78.50396934026827,0.003806246411789872,best
|
||||
27,0.4793313278116239,0.7242900133132935,77.87475880366618,0.5505193975847648,0.7201660871505737,78.89405967697783,0.003721553103742388,best
|
||||
28,0.46573570758837524,0.7336312532424927,78.60437771345876,0.640248272807638,0.6859503984451294,75.13003011223651,0.003634976249348867,
|
||||
29,0.44927708754289913,0.737967312335968,78.9352689339122,0.6151526539644867,0.7065733075141907,74.69887763482069,0.00354664934384357,
|
||||
30,0.4373503708129221,0.7443608045578003,79.40409430776653,0.5578661908719627,0.7272701263427734,78.20969066520668,0.0034567085809127244,best
|
||||
31,0.42717206794400175,0.7488712072372437,79.7779486251809,0.58909761693554,0.7034546136856079,76.88201478237066,0.003365292642693732,
|
||||
32,0.4100124511779706,0.7580570578575134,80.60630728412929,0.6458172935624336,0.6865078210830688,75.32849712565014,0.0032725424859373683,
|
||||
33,0.3993677339991451,0.763430118560791,80.9876989869754,0.47995706558097007,0.754202127456665,81.65891048453327,0.003178601124662685,best
|
||||
34,0.3858378949555808,0.7697042226791382,81.54772672455378,0.6427663844838523,0.6931804418563843,74.28141253764029,0.0030836134096397633,
|
||||
35,0.37397055771404913,0.7764154672622681,81.9992161119151,0.6085299244000101,0.7046636343002319,77.08048179578428,0.0029877258050403205,
|
||||
36,0.3597575335889638,0.7818952798843384,82.53587795465509,0.5254679805051781,0.7415529489517212,78.89405967697783,0.002891086162600577,
|
||||
37,0.3487578573732404,0.7871347665786743,82.8999336710082,0.5125140355052995,0.748577356338501,80.1875171092253,0.002793843493644594,
|
||||
38,0.3325814358527052,0.7965956330299377,83.59035817655571,0.5408413317834649,0.7290798425674438,79.87270736381056,0.002696147739319612,
|
||||
39,0.3261546248608721,0.7988470792770386,83.75316570188133,0.5301555857539602,0.7376729249954224,80.06433068710649,0.002598149539397671,
|
||||
40,0.30964642827472305,0.8070269823074341,84.52725518572117,0.5468305750249544,0.7365171313285828,78.73665480427046,0.0024999999999999996,
|
||||
41,0.3009217412674191,0.8119726777076721,84.96140858658949,0.46898612490165015,0.7599539756774902,82.18587462359704,0.002401850460602329,best
|
||||
42,0.28925693789887874,0.8200639486312866,85.6458031837916,0.5167866427621677,0.7465909123420715,80.83766767040788,0.0023038522606803878,
|
||||
43,0.2707157838268379,0.8313596248626709,86.46360950313556,0.5156203284349763,0.7596548199653625,80.6049822064057,0.0022061565063554063,
|
||||
44,0.2580273799019566,0.836384654045105,86.8412325132658,0.5318487190707494,0.746901273727417,80.27648508075555,0.0021089138373994237,
|
||||
45,0.2504911308580703,0.8410984873771667,87.30779667149059,0.49164087495639364,0.763725221157074,81.82315904735833,0.00201227419495968,best
|
||||
46,0.24104372076995695,0.8451772928237915,87.64094910757356,0.5290114263981752,0.7580969333648682,80.94032302217356,0.0019163865903602372,
|
||||
47,0.22337641519549614,0.8570870161056519,88.56201760733236,0.43634469677838694,0.7913081049919128,85.10128661374213,0.0018213988753373142,best
|
||||
48,0.2128122905210861,0.8645581603050232,89.1152616980222,0.43183545479456625,0.7972898483276367,85.49137695045168,0.001727457514062632,best
|
||||
49,0.2003101470318182,0.8717849254608154,89.69187168355042,0.4289672785945178,0.806715726852417,85.7993430057487,0.0016347073573062686,best
|
||||
50,0.1888495338707803,0.8796613216400146,90.34687047756874,0.4568697272132202,0.7956517338752747,84.64275937585546,0.0015432914190872762,
|
||||
51,0.1756466486088274,0.886497437953949,90.97096599131693,0.4541556305781541,0.7938134670257568,84.45113605255953,0.001453350656156431,
|
||||
52,0.16742963044469736,0.8907681703567505,91.3101483357453,0.4230425570913775,0.8147625923156738,86.100465370928,0.0013650237506511336,best
|
||||
53,0.15311133117022804,0.9007841944694519,92.11212614568258,0.4146759752586969,0.8208259344100952,86.83274021352314,0.0012784468962576128,best
|
||||
54,0.1423091164071722,0.9078108668327332,92.6714001447178,0.4709351422719488,0.8029968738555908,85.71721872433616,0.0011937535882101285,
|
||||
55,0.13160189816137902,0.9135022163391113,93.10932223830197,0.40829240685941226,0.8264325857162476,87.42814125376403,0.0011110744174509947,best
|
||||
56,0.12707359487800943,0.9174070358276367,93.4409671972986,0.42565728100299705,0.8230471611022949,87.65398302764851,0.0010305368692688178,
|
||||
57,0.11291898237482914,0.925153374671936,94.10953328509407,0.43774206922898184,0.8247347474098206,87.5376402956474,0.0009522651267254161,
|
||||
58,0.10370979833767516,0.9329074025154114,94.62584418716835,0.4068256767521223,0.8333848118782043,88.05091705447578,0.0008763798791745416,best
|
||||
59,0.0946220946482491,0.9380815029144287,95.09768451519537,0.41103389083751746,0.8357677459716797,88.4889132220093,0.0008029981361676465,best
|
||||
60,0.08804238645213414,0.9423004388809204,95.45194163048721,0.4207381762007377,0.8308929204940796,88.17410347659458,0.0007322330470336316,
|
||||
61,0.07913849578165794,0.9495129585266113,95.9916184273999,0.4083157278823748,0.8420299291610718,89.00218998083767,0.0006641937264107861,best
|
||||
62,0.06981624146470565,0.9559226036071777,96.49662325132658,0.41332166066585485,0.843511700630188,88.98850260060225,0.0005989850859999229,best
|
||||
63,0.06394793773276639,0.9602090120315552,96.79811866859623,0.41052334102789995,0.8489691019058228,89.35121817684096,0.0005367076727981376,best
|
||||
64,0.057007493751794744,0.9636315107345581,97.11016642547034,0.3970402057488879,0.8546013832092285,90.04927456884752,0.00047745751406263185,best
|
||||
65,0.05285091761448427,0.967146635055542,97.40035576459238,0.4247235585867853,0.8433754444122314,89.34437448672324,0.0004213259692436376,
|
||||
66,0.04614944799553407,0.9710246324539185,97.6912988422576,0.4035414461747053,0.8538572788238525,89.85080755543389,0.00036839958911476966,
|
||||
67,0.042909558690492254,0.9727644920349121,97.8428002894356,0.41578795896298,0.8530543446540833,89.91924445661101,0.0003187599823180077,
|
||||
68,0.03640206224887977,0.9769999980926514,98.15710926193921,0.42073161891477256,0.8551151156425476,89.94661921708185,0.0002724836895290806,best
|
||||
69,0.034432517173010414,0.9790080785751343,98.34328268210324,0.4133664407223553,0.8576940298080444,90.28880372296743,0.00022964206543729668,best
|
||||
70,0.030669637766023813,0.9817556142807007,98.5249336710082,0.418308269951463,0.8569411039352417,90.12455516014235,0.00019030116872178321,
|
||||
71,0.028112305183924133,0.9827903509140015,98.606337433671,0.4151474575991667,0.8596312999725342,90.37092800437996,0.00015452166019378966,best
|
||||
72,0.024704152367817256,0.9853801727294922,98.81135431741437,0.4153705465811558,0.8635820746421814,90.68573774979468,0.0001223587092621162,best
|
||||
73,0.024846541488804174,0.9855506420135498,98.8369814278823,0.4177400436290088,0.8632140159606934,90.59676977826444,9.38619088658821e-05,
|
||||
74,0.022639600746625622,0.9868491888046265,98.94702725518572,0.41732572613841307,0.8648342490196228,90.78154941144265,6.907519900580863e-05,best
|
||||
75,0.02120214593173326,0.9878177642822266,99.0231548480463,0.4163925825270714,0.866214394569397,90.89789214344374,4.803679899192394e-05,best
|
||||
76,0.019741657997631577,0.9883521795272827,99.06385672937772,0.42005763620917286,0.8647006750106812,90.82945524226663,3.077914851215586e-05,
|
||||
77,0.019116416042495393,0.9889511466026306,99.10003617945007,0.4159400745789841,0.8657370805740356,90.84998631261976,1.7328857612684272e-05,
|
||||
78,0.019259902796210714,0.9888157844543457,99.0962674867342,0.4192042892654481,0.8641382455825806,90.69942513003011,7.706665667180091e-06,
|
||||
79,0.01933925595445387,0.9887675046920776,99.0759165460685,0.4180937778044573,0.8662786483764648,90.84998631261976,1.9274093981927482e-06,best
|
||||
80,0.01922732148408437,0.9889604449272156,99.10078991799324,0.41794140280912484,0.864332914352417,90.82261155214891,0.0,
|
||||
|