Compare commits

..

20 commits

Author SHA1 Message Date
1557a2c6a0 Update training_log.csv
不小心上传空白csv,现已恢复
2026-05-17 20:22:53 +08:00
8008f2da8f Update README.md
修改Readme.md
2026-05-17 19:57:19 +08:00
52b03e981c README: 补充 PyTorch CUDA/XPU 版本安装提示 2026-05-17 19:52:38 +08:00
ed2778972b 添加精简版 requirements.txt,更新 README 环境要求 2026-05-17 19:51:50 +08:00
65c9742a1c 修改accuracy_bar.png 2026-05-17 19:46:31 +08:00
25e11c1914 完善 README:新增基线对比章节,嵌入 baseline 对比图,更新文件说明和目录结构 2026-05-17 19:41:30 +08:00
b7341c746f 调整准确率柱状图 y 轴范围,放大模型间差异 2026-05-17 19:34:46 +08:00
307b66e9db fix: 移除旧版 sklearn 不兼容的 multi_class 参数 2026-05-17 18:47:14 +08:00
562fc4142a 新增 HOG + LogisticRegression 纯传统基线模型(零神经网络依赖) 2026-05-17 18:43:28 +08:00
3fee1c82ab compare_models.py: 新增 PR 曲线对比图 2026-05-17 18:31:32 +08:00
010dacb533 新增 ResNet34 ImageNet 预训练 + 10% 数据微调模型 (baseline) 2026-05-17 17:20:09 +08:00
547d96cfa9 compare_models.py: 添加 CSV 缓存机制,已有预测数据时跳过重复计算 2026-05-17 16:55:27 +08:00
818d98d06c 重构:将 Baseline.py 迁移至 baseline/ 目录,新增多模型对比脚本 2026-05-17 16:14:48 +08:00
3624f058c2 添加 VGG16 + KNN 基线模型 Baseline.py 2026-05-14 19:50:05 +08:00
76b56dd64b 在 README 中直接展示评估与训练曲线图片 2026-05-14 16:04:31 +08:00
44c70250e3 更新 AGENTS.md、README.md,添加可视化脚本与输出文件到版本控制 2026-05-14 16:02:47 +08:00
2f4e9df26e Merge branch 'Resnet34-test' 2026-05-14 15:51:52 +08:00
b4c99489b3 Update Evaluate.py
修改曲线绘画函数
2026-05-14 15:49:52 +08:00
ywd09
ce0c6da36a 模型评估程序 2026-05-14 00:37:04 +08:00
4575f3390f refactor: replace custom Bottleneck model with standard ResNet-34 + Dropout 2026-05-12 15:56:28 +08:00
19 changed files with 1289 additions and 93 deletions

29
.gitignore vendored Normal file
View file

@ -0,0 +1,29 @@
*
!/.gitattributes
!/Dataloader.py
!/LICENSE
!/Merge_classes.py
!/Model.py
!/README.md
!/requirements.txt
!/THIRD_PARTY_LICENSES.md
!/Train.py
!/Baseline.py
!/Finetune.py
!/Curve.py
!/Evaluate.py
!/baseline/
!/baseline/__init__.py
!/baseline/VGG_KNN.py
!/baseline/compare_models.py
!/baseline/ResNet34_Pretrained_10pct.py
!/baseline/HOG_Baseline.py
!/baseline/roc_comparison.png
!/baseline/pr_comparison.png
!/baseline/accuracy_bar.png
!/training_log.csv
!/confusion_matrix.png
!/roc_curve.png
!/pr_curve.png
!/training_curves.png
!.gitignore

50
Curve.py Normal file
View file

@ -0,0 +1,50 @@
"""
plot_training_curves.py
training_log.csv 读取日志绘制 Loss / F1 / Accuracy / LR 曲线
"""
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============ 读取数据 ============
df = pd.read_csv('training_log.csv')
best_rows = df[df['best'] == 'best']
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# ---- 1. Loss ----
ax = axes[0, 0]
ax.plot(df['epoch'], df['train_loss'], label='Train Loss', color='#1f77b4', lw=1.5)
ax.plot(df['epoch'], df['val_loss'], label='Val Loss', color='#ff7f0e', lw=1.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.set_title('Loss vs Epoch')
ax.legend(); ax.grid(True, alpha=0.3)
# ---- 2. F1 Score ----
ax = axes[0, 1]
ax.plot(df['epoch'], df['train_f1'], label='Train F1', color='#1f77b4', lw=1.5)
ax.plot(df['epoch'], df['val_f1'], label='Val F1', color='#ff7f0e', lw=1.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('F1 Score'); ax.set_title('F1 Score vs Epoch')
ax.legend(); ax.grid(True, alpha=0.3)
# ---- 3. Accuracy ----
ax = axes[1, 0]
ax.plot(df['epoch'], df['train_acc'], label='Train Acc', color='#1f77b4', lw=1.5)
ax.plot(df['epoch'], df['val_acc'], label='Val Acc', color='#ff7f0e', lw=1.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Accuracy (%)'); ax.set_title('Accuracy vs Epoch')
ax.legend(); ax.grid(True, alpha=0.3)
# ---- 4. Learning Rate ----
ax = axes[1, 1]
ax.plot(df['epoch'], df['lr'], color='#2ca02c', lw=1.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Learning Rate'); ax.set_title('Learning Rate vs Epoch')
ax.ticklabel_format(style='scientific', axis='y', scilimits=(0, 0))
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print("训练曲线已保存: training_curves.png")

147
Evaluate.py Normal file
View file

@ -0,0 +1,147 @@
"""
evaluate_and_plot.py
加载模型在验证集上推理绘制混淆矩阵 / ROC / PR 曲线
"""
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import (
confusion_matrix, ConfusionMatrixDisplay,
roc_curve, auc,
precision_recall_curve, average_precision_score,
)
from Model import Net
from Dataloader import RobustImageFolder
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 需要你修改的参数 ★★★
# ============================================================
MODEL_PATH = 'best_model.pth' # 模型权重路径
DATA_ROOT = '../trash_division_data/ultimate_4_class/' # 数据集根目录
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
# ============================================================
# ---------- 1. 加载验证集 ----------
val_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
val_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'val'),
transform=val_transform,
)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=False)
class_names = val_dataset.classes
num_classes = len(class_names)
print(f"类别: {class_names}")
# ---------- 2. 加载模型 ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(num_classes=num_classes)
state_dict = torch.load(MODEL_PATH, map_location=device)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict)
model = model.to(device).eval()
print("模型加载完成")
# ---------- 3. 推理 ----------
all_labels = []
all_probs = []
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
probs = torch.softmax(model(images), dim=1)
all_labels.append(labels.numpy())
all_probs.append(probs.cpu().numpy())
all_labels = np.concatenate(all_labels)
all_probs = np.concatenate(all_probs)
all_preds = np.argmax(all_probs, axis=1)
print(f"推理完成, 共 {len(all_labels)} 样本")
# ============================================================
# ① 混淆矩阵
# ============================================================
cm = confusion_matrix(all_labels, all_preds)
fig, ax = plt.subplots(figsize=(8, 7))
ConfusionMatrixDisplay(cm, display_labels=class_names).plot(
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
ax.set_title('Confusion Matrix', fontsize=14)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()
print("混淆矩阵已保存: confusion_matrix.png")
# ============================================================
# ② ROC 曲线 (One-vs-Rest + Macro-average)
# ============================================================
one_hot = np.eye(num_classes)[all_labels]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
fig, ax = plt.subplots(figsize=(8, 7))
fpr_d, tpr_d, auc_d = {}, {}, {}
for i in range(num_classes):
fpr_d[i], tpr_d[i], _ = roc_curve(one_hot[:, i], all_probs[:, i])
auc_d[i] = auc(fpr_d[i], tpr_d[i])
ax.plot(fpr_d[i], tpr_d[i], color=colors[i], lw=2,
label=f'{class_names[i]} (AUC={auc_d[i]:.4f})')
# Macro-average
all_fpr = np.unique(np.concatenate([fpr_d[i] for i in range(num_classes)]))
mean_tpr = sum(np.interp(all_fpr, fpr_d[i], tpr_d[i]) for i in range(num_classes)) / num_classes
macro_auc = auc(all_fpr, mean_tpr)
ax.plot(all_fpr, mean_tpr, 'navy', lw=2, ls='--',
label=f'Macro-avg (AUC={macro_auc:.4f})')
ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve', fontsize=14)
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('roc_curve.png', dpi=150, bbox_inches='tight')
plt.show()
print("ROC 曲线已保存: roc_curve.png")
# ============================================================
# ③ Precision-Recall 曲线
# ============================================================
fig, ax = plt.subplots(figsize=(8, 7))
for i in range(num_classes):
prec, rec, _ = precision_recall_curve(one_hot[:, i], all_probs[:, i])
ap = average_precision_score(one_hot[:, i], all_probs[:, i])
ax.plot(rec, prec, color=colors[i], lw=2,
label=f'{class_names[i]} (AP={ap:.4f})')
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curve', fontsize=14)
ax.legend(loc='best'); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('pr_curve.png', dpi=150, bbox_inches='tight')
plt.show()
print("PR 曲线已保存: pr_curve.png")

View file

@ -1,6 +1,5 @@
"""
模型定义文件 - 使用瓶颈结构 (Bottleneck) 的深度残差网络
目标约50层参数量约80M
模型定义文件 - ResNet-34
author : yukun-hh
date : 2026-4-10
"""
@ -10,27 +9,19 @@ from torch.nn import functional as F
from torchsummary import summary
class Bottleneck(nn.Module):
class BasicBlock(nn.Module):
"""
瓶颈残差块1x1(降维) -> 3x3 -> 1x1(升维)
若需要下采样或通道变化则在跳跃连接中使用1x1卷积
ResNet-34 基础残差块3x3 -> 3x3
若需要下采样或通道变化则在跳跃连接中使用 1x1 卷积
"""
expansion = 4 # 输出通道是中间通道的4倍
expansion = 1
def __init__(self, in_channels, mid_channels, stride=1, downsample=None):
"""
:param in_channels: 输入通道数
:param mid_channels: 中间层通道数1x1降维后的通道数
:param stride: 步长用于下采样
:param downsample: 下采样模块当stride1或通道变化时使用
"""
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(mid_channels)
self.conv3 = nn.Conv2d(mid_channels, mid_channels * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(mid_channels * self.expansion)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
@ -43,10 +34,6 @@ class Bottleneck(nn.Module):
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
@ -57,68 +44,49 @@ class Bottleneck(nn.Module):
class Net(nn.Module):
"""
基于 Bottleneck ResNet 风格模型
各阶段配置仿照 ResNet-50适当调整宽度以达到约80M参数
"""
def __init__(self, num_classes=4):
def __init__(self, num_classes=4, dropout=0.5):
super().__init__()
# 第一阶段7x7卷积 + 最大池化
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 残差阶段定义
# 每个阶段的参数:[块数, 中间通道数, 步长]
# 为了达到80M参数我们略微加宽网络相比标准ResNet-50
layers_config = [
(3, 64, 1), # stage2: 3个瓶颈块输出通道 64*4=256
(4, 128, 2), # stage3: 4个瓶颈块输出通道 128*4=512
(14, 256, 2), # stage4: 14个瓶颈块输出通道 256*4=1024加深至此阶段
(3, 512, 2) # stage5: 3个瓶颈块输出通道 512*4=2048
(3, 64, 1), # layer1
(4, 128, 2), # layer2
(6, 256, 2), # layer3
(3, 512, 2), # layer4
]
self.in_channels = 64
self.stage2 = self._make_layer(layers_config[0])
self.stage3 = self._make_layer(layers_config[1])
self.stage4 = self._make_layer(layers_config[2])
self.stage5 = self._make_layer(layers_config[3])
self.layer1 = self._make_layer(layers_config[0])
self.layer2 = self._make_layer(layers_config[1])
self.layer3 = self._make_layer(layers_config[2])
self.layer4 = self._make_layer(layers_config[3])
# 全局池化与分类层
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(2048, num_classes)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, config):
"""
构建一个残差阶段
:param config: (块数, 中间通道数, 第一阶段步长)
:return: nn.Sequential
"""
num_blocks, mid_channels, stride = config
num_blocks, out_channels, stride = config
downsample = None
layers = []
# 第一个块可能需要下采样和通道匹配
if stride != 1 or self.in_channels != mid_channels * Bottleneck.expansion:
if stride != 1 or self.in_channels != out_channels:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, mid_channels * Bottleneck.expansion,
nn.Conv2d(self.in_channels, out_channels,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(mid_channels * Bottleneck.expansion),
nn.BatchNorm2d(out_channels),
)
layers.append(
Bottleneck(self.in_channels, mid_channels, stride, downsample)
)
self.in_channels = mid_channels * Bottleneck.expansion
layers.append(BasicBlock(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
# 后续块
for _ in range(1, num_blocks):
layers.append(
Bottleneck(self.in_channels, mid_channels)
)
layers.append(BasicBlock(self.in_channels, out_channels))
return nn.Sequential(*layers)
@ -128,13 +96,14 @@ class Net(nn.Module):
x = self.relu(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x)
return x

148
README.md
View file

@ -4,7 +4,7 @@
> 同济大学 Python 人工智能程序设计课程小组作业
基于自定义 ResNet 风格 Bottleneck 架构的 CNN 模型(约 80M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。
基于 ResNet-34 架构的 CNN 模型(约 21M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。
---
@ -18,6 +18,7 @@
- [文件说明](#文件说明)
- [目录结构](#目录结构)
- [训练细节](#训练细节)
- [评估与可视化](#评估与可视化)
- [许可证](#许可证)
---
@ -25,7 +26,7 @@
## 项目特点
- **四类垃圾分类**厨余垃圾1、可回收物2、其他垃圾3、有害垃圾4
- **自定义 ResNet Bottleneck 架构**:约 80M 参数50 层深度残差网络
- **ResNet-34 架构**:约 21M 参数34 层深度残差网络,含 Dropout 正则化
- **数据增强**:训练时使用随机裁剪、水平翻转、旋转、色彩抖动
- **Macro-F1 评估**:采用宏平均 F1 分数作为主要评估指标,兼顾各类别表现
- **类别加权损失**:自动计算类别权重,缓解类别不平衡问题
@ -35,28 +36,27 @@
## 模型架构
模型基于残差网络ResNet的 Bottleneck 构建块设计
模型基于标准 ResNet-34 架构,使用 BasicBlock 构建
### Bottleneck 块
### BasicBlock 块
每个 Bottleneck 块包含三个卷积层
每个 BasicBlock 包含两个 3x3 卷积层 + 跳跃连接
| 层 | 卷积 | 作用 |
|---|---|---|
| 1x1 Conv | 降维 | 减少通道数,降低计算量 |
| 3x3 Conv | 特征提取 | 核心卷积操作 |
| 1x1 Conv | 升维 (x4) | 恢复通道数至输入的 4 倍 |
| 3x3 Conv | 特征提取 | 第一层卷积 |
| 3x3 Conv | 特征提取 | 第二层卷积 |
### 网络结构
| 阶段 | 块数 | 输出通道数 | 说明 |
|---|---|---|---|
| 初始层 | - | 64 | 7x7 Conv, stride=2 + MaxPool |
| Stage 1 | 3 | 256 | 第一个残差阶段 |
| Stage 2 | 4 | 512 | - |
| Stage 3 | 14 | 1024 | 最深阶段(比 ResNet-50 加深) |
| Stage 4 | 3 | 2048 | 最终残差阶段 |
| 分类头 | - | 4 | 全局平均池化 + 全连接层 |
| Layer1 | 3 | 64 | 第一个残差阶段 |
| Layer2 | 4 | 128 | - |
| Layer3 | 6 | 256 | - |
| Layer4 | 3 | 512 | 最终残差阶段 |
| 分类头 | - | 4 | 全局平均池化 + Dropout + 全连接层 |
## 数据集
@ -75,16 +75,14 @@
## 环境要求
本项目无 `requirements.txt`,需手动安装以下依赖:
```bash
pip install -r requirements.txt
```
- Python 3.8+
- PyTorch推荐 1.10+
- torchvision
- tqdm
- matplotlib
- pandas
- Pillow
- torchsummary
> **注意**`requirements.txt` 不锁定 PyTorch 的 CUDA / XPU 版本,请根据硬件自行安装对应版本,例如:
> - NVIDIA GPU`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121`
> - Intel GPU (XPU)`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121` 安装
> - CPU`pip install torch torchvision` 即可
## 快速开始
@ -100,6 +98,20 @@
python Train.py
```
3. **微调模型**(可选,冻结浅层、微调深层):
```bash
python Finetune.py
```
4. **评估与可视化**
```bash
python Evaluate.py # 混淆矩阵、ROC 曲线、PR 曲线
python Curve.py # 训练过程的 loss/f1/acc/lr 曲线
python baseline/compare_models.py # 多模型基线对比ROC/PR/准确率)
```
> **注意**
> - 数据目录默认为 `../trash_division_data/ultimate_4_class/`,需先运行合并脚本
> - Windows 系统需将 `num_workers` 设为 `0`(参见 `Dataloader.py``Train.py`
@ -110,27 +122,49 @@
| 文件 | 功能 |
|---|---|
| `Train.py` | 训练主脚本,包含训练循环、验证、评估 |
| `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 |
| `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 |
| `Model.py` | 模型定义,Bottleneck 残差块 + Net 主模型 |
| `Model.py` | 模型定义,ResNet-34BasicBlock+ Dropout |
| `Merge_classes.py` | 数据集预处理265 类合并为 4 类 |
| `best_model.pth` | 训练好的最佳模型权重(约 125 MB |
| `AGENTS.md` | AI 助手指南(开发辅助) |
| `Evaluate.py` | 模型评估绘制混淆矩阵、ROC 曲线、PR 曲线 |
| `Curve.py` | 训练曲线绘制,从 CSV 读取并绘制 loss/f1/acc/lr 曲线 |
| `baseline/VGG_KNN.py` | VGG16 预训练特征提取 + KNN 四分类基线 |
| `baseline/ResNet34_Pretrained_10pct.py` | ResNet-34 ImageNet 预训练 + 10% 数据微调 |
| `baseline/HOG_Baseline.py` | HOG + 颜色直方图 + LogisticRegression纯传统 CV |
| `baseline/compare_models.py` | 多模型对比ROC / PR 曲线 + 准确率柱状图) |
| `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr |
| `best_model.pth` | 训练好的最佳模型权重(约 125 MB不纳入版本控制 |
| `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 |
## 目录结构
```
trash-division/
├── AGENTS.md # AI 助手指南
├── best_model.pth # 最佳模型权重
├── baseline/ # 基线模型目录
│ ├── VGG_KNN.py # VGG16 + KNN 分类脚本
│ ├── ResNet34_Pretrained_10pct.py # ResNet-34 ImageNet 预训练 + 10% 微调
│ ├── HOG_Baseline.py # HOG + LogisticRegression 纯传统基线
│ ├── compare_models.py # 多模型对比脚本
│ ├── roc_comparison.png # 多模型 ROC 对比compare_models.py 输出)
│ ├── pr_comparison.png # 多模型 PR 对比compare_models.py 输出)
│ └── accuracy_bar.png # 多模型准确率对比compare_models.py 输出)
├── best_model.pth # 最佳模型权重(不纳入版本控制)
├── Curve.py # 训练曲线绘制脚本
├── Dataloader.py # 数据加载模块
├── Evaluate.py # 模型评估可视化脚本
├── Finetune.py # 微调脚本
├── .gitattributes # Git 属性配置
├── LICENSE # MIT 许可证
├── Merge_classes.py # 数据集预处理脚本
├── Model.py # 模型定义
├── README.md # 项目说明(本文件)
├── THIRD_PARTY_LICENSES.md # 第三方许可证声明
└── Train.py # 训练主脚本
├── Train.py # 训练主脚本
├── training_log.csv # 训练日志
├── confusion_matrix.png # 混淆矩阵Evaluate.py 输出)
├── roc_curve.png # ROC 曲线Evaluate.py 输出)
├── pr_curve.png # PR 曲线Evaluate.py 输出)
└── training_curves.png # 训练曲线Curve.py 输出)
```
## 训练细节
@ -150,6 +184,64 @@ trash-division/
训练时数据增强管线RandomResizedCrop(256, scale=(0.8, 1.0)) + RandomHorizontalFlip(p=0.5) + RandomRotation(+-15 deg) + ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)
## 评估与可视化
训练完成后,`training_log.csv` 会记录每个 epoch 的训练/验证指标。以下两个脚本用于可视化分析:
### Evaluate.py — 模型评估
在验证集上推理,生成三张评估图表:
```bash
python Evaluate.py
```
脚本顶部的 `MODEL_PATH`、`DATA_ROOT`、`BATCH_SIZE`、`NUM_WORKERS` 可按需修改。
**混淆矩阵**
![confusion_matrix](confusion_matrix.png)
**ROC 曲线**
![roc_curve](roc_curve.png)
**PR 曲线**
![pr_curve](pr_curve.png)
### Curve.py — 训练曲线
`training_log.csv` 读取训练日志,绘制四张子图:
```bash
python Curve.py
```
![training_curves](training_curves.png)
### 基线模型对比
`compare_models.py` 对所有模型在验证集上统一评估,生成三张对比图表:
```bash
python baseline/compare_models.py
```
对比阵容ResNet-34、ResNet-34 (10% Fine-tune)、VGG16 + KNN、HOG + LogisticRegression。
**ROC 曲线对比**
![roc_comparison](baseline/roc_comparison.png)
**PR 曲线对比**
![pr_comparison](baseline/pr_comparison.png)
**准确率柱状图**
![accuracy_bar](baseline/accuracy_bar.png)
## 许可证
本项目主代码采用 [MIT 许可证](LICENSE)。

145
baseline/HOG_Baseline.py Normal file
View file

@ -0,0 +1,145 @@
"""
baseline/HOG_Baseline.py
HOG + 颜色直方图特征提取 + LogisticRegression 四分类
纯传统 CV/ML 基线零神经网络依赖
可独立运行也可被 compare_models.py 导入
author: yukun-hh
date: 2026-5-14
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib
from skimage.feature import hog
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
accuracy_score, f1_score,
confusion_matrix, ConfusionMatrixDisplay,
classification_report,
)
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 可配置参数 ★★★
# ============================================================
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
IMAGE_SIZE = 128
HOG_ORIENTATIONS = 9
HOG_PIXELS_PER_CELL = (8, 8)
HOG_CELLS_PER_BLOCK = (2, 2)
COLOR_BINS = 32
# ============================================================
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
NUM_CLASSES = 4
def extract_hog_color(image):
img = image.convert('RGB').resize((IMAGE_SIZE, IMAGE_SIZE))
arr = np.array(img, dtype=np.float64) / 255.0
hog_feat = hog(arr, orientations=HOG_ORIENTATIONS,
pixels_per_cell=HOG_PIXELS_PER_CELL,
cells_per_block=HOG_CELLS_PER_BLOCK,
channel_axis=2, feature_vector=True)
color_feat = []
for c in range(3):
hist, _ = np.histogram(arr[:, :, c], bins=COLOR_BINS, range=(0, 1))
color_feat.append(hist)
color_feat = np.concatenate(color_feat)
return np.concatenate([hog_feat, color_feat])
class HOGLRBaseline:
def __init__(self, data_root=DATA_ROOT, image_size=IMAGE_SIZE):
self.data_root = data_root
self.image_size = image_size
self.clf = LogisticRegression(
C=1.0, max_iter=1000, solver='lbfgs', n_jobs=-1,
)
def _load_data(self, split):
dir_path = os.path.join(self.data_root, split)
features, labels = [], []
for class_id in range(1, NUM_CLASSES + 1):
class_dir = os.path.join(dir_path, str(class_id))
if not os.path.isdir(class_dir):
continue
files = sorted(os.listdir(class_dir))
for fname in tqdm(files, desc=f'{split}/class_{class_id}'):
fpath = os.path.join(class_dir, fname)
try:
with Image.open(fpath) as img:
feat = extract_hog_color(img)
features.append(feat)
labels.append(class_id - 1)
except Exception:
pass
print(f" {split}: {len(features)}")
return np.array(features, dtype=np.float32), np.array(labels)
def fit(self, train_dir=None):
if train_dir is None:
train_dir = 'train'
print(" 提取训练集 HOG 特征 ...")
X, y = self._load_data(train_dir)
self.clf.fit(X, y)
def predict(self, val_dir=None):
if val_dir is None:
val_dir = 'val'
print(" 提取验证集 HOG 特征 ...")
X, y = self._load_data(val_dir)
preds = self.clf.predict(X)
probs = self.clf.predict_proba(X)
return y, preds, probs
# ============================================================
# compare_models.py 导入接口
# ============================================================
def get_hog_lr_preds(train_loader, val_loader, device):
baseline = HOGLRBaseline()
baseline.fit('train')
return baseline.predict('val')
# ============================================================
# 独立运行入口
# ============================================================
if __name__ == '__main__':
out_dir = os.path.dirname(os.path.abspath(__file__))
print("HOG + LogisticRegression 基线")
baseline = HOGLRBaseline()
baseline.fit('train')
y_true, y_preds, y_probs = baseline.predict('val')
acc = accuracy_score(y_true, y_preds)
macro_f1 = f1_score(y_true, y_preds, average='macro')
print(f"\n验证集 Accuracy: {acc:.4f}")
print(f"验证集 Macro-F1: {macro_f1:.4f}")
print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}")
cm = confusion_matrix(y_true, y_preds)
fig, ax = plt.subplots(figsize=(8, 7))
ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
ax.set_title('HOG + LogisticRegression 混淆矩阵', fontsize=14)
plt.tight_layout()
cm_path = os.path.join(out_dir, 'hog_lr_confusion_matrix.png')
plt.savefig(cm_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"混淆矩阵已保存: {cm_path}")

View file

@ -0,0 +1,278 @@
"""
baseline/ResNet34_Pretrained_10pct.py
ResNet-34 ImageNet 预训练权重 + 10% 训练集微调
可独立运行训练也可被 compare_models.py 导入
author: yukun-hh
date: 2026-5-14
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import models, transforms
from tqdm import tqdm
import csv
import matplotlib.pyplot as plt
import matplotlib
from Dataloader import RobustImageFolder
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 可配置参数 ★★★
# ============================================================
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
EPOCHS = 30
LR = 0.001
TRAIN_PCT = 0.1
SEED = 42
DROPOUT = 0.3
MODEL_SAVE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct.pth')
LOG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct_log.csv')
# ============================================================
NUM_CLASSES = 4
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
class PretrainedResNet34(nn.Module):
def __init__(self, num_classes=NUM_CLASSES, dropout=DROPOUT):
super().__init__()
self.backbone = models.resnet34(weights='IMAGENET1K_V1')
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(in_features, num_classes)
def forward(self, x):
x = self.backbone(x)
x = self.dropout(x)
x = self.fc(x)
return x
def freeze_early_layers(self):
for param in self.backbone.conv1.parameters():
param.requires_grad = False
for param in self.backbone.bn1.parameters():
param.requires_grad = False
for param in self.backbone.layer1.parameters():
param.requires_grad = False
for param in self.backbone.layer2.parameters():
param.requires_grad = False
def print_trainable_info(self):
frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = frozen + trainable
print(f" 冻结参数: {frozen:,} 可训练参数: {trainable:,} ({100.*trainable/total:.1f}%)")
def compute_macro_f1(predicted, targets, num_classes=NUM_CLASSES):
tp = torch.zeros(num_classes, device=predicted.device)
fp = torch.zeros(num_classes, device=predicted.device)
fn = torch.zeros(num_classes, device=predicted.device)
for c in range(num_classes):
tp[c] = ((predicted == c) & (targets == c)).sum()
fp[c] = ((predicted == c) & (targets != c)).sum()
fn[c] = ((predicted != c) & (targets == c)).sum()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return f1.mean().item()
def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
model.train()
running_loss, correct, total = 0.0, 0, 0
all_preds, all_labels = [], []
pbar = tqdm(loader, desc=f'Epoch {epoch+1} [Train]')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
all_preds.append(predicted)
all_labels.append(labels)
batch_f1 = compute_macro_f1(predicted, labels)
pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}',
'Acc': f'{100.*correct/total:.2f}%'})
epoch_loss = running_loss / total
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total
return epoch_loss, epoch_f1, epoch_acc
@torch.no_grad()
def validate(model, loader, criterion, device):
model.eval()
running_loss, correct, total = 0.0, 0, 0
all_preds, all_labels = [], []
for images, labels in tqdm(loader, desc='[Validate]'):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
all_preds.append(predicted)
all_labels.append(labels)
epoch_loss = running_loss / total
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total
return epoch_loss, epoch_f1, epoch_acc
def train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
lr=lr, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
history = {'train_loss': [], 'train_f1': [], 'train_acc': [],
'val_loss': [], 'val_f1': [], 'val_acc': []}
best_val_f1 = 0.0
log_file = open(LOG_PATH, 'w', newline='')
log_writer = csv.writer(log_file)
log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc',
'val_loss', 'val_f1', 'val_acc', 'lr', 'best'])
for epoch in range(epochs):
print(f'\n{"="*50}')
print(f'Epoch {epoch+1}/{epochs}')
train_loss, train_f1, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device, epoch)
val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device)
scheduler.step()
history['train_loss'].append(train_loss)
history['train_f1'].append(train_f1)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_f1'].append(val_f1)
history['val_acc'].append(val_acc)
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}')
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}')
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
best_mark = ''
if val_f1 > best_val_f1:
best_val_f1 = val_f1
torch.save(model.state_dict(), MODEL_SAVE_PATH)
best_mark = 'best'
print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})')
lr_val = optimizer.param_groups[0]['lr']
log_writer.writerow([epoch+1, train_loss, train_f1, train_acc,
val_loss, val_f1, val_acc, lr_val, best_mark])
log_file.flush()
log_file.close()
print(f'\n训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}')
return history
# ============================================================
# compare_models.py 导入接口
# ============================================================
def get_resnet34_10pct_preds(train_loader, val_loader, device):
model = PretrainedResNet34(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu'))
model = model.to(device).eval()
y_true, y_preds, y_probs = [], [], []
with torch.no_grad():
for images, labels in tqdm(val_loader, desc='ResNet-34 (10%)'):
images, labels = images.to(device), labels
logits = model(images)
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
y_true.append(labels.numpy())
y_preds.append(preds.cpu().numpy())
y_probs.append(probs.cpu().numpy())
return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs)
# ============================================================
# 独立训练入口
# ============================================================
if __name__ == '__main__':
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available()
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
else 'cpu')
print(f"Device: {device}")
val_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
full_train_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'train'),
transform=train_transform,
)
val_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'val'),
transform=val_transform,
)
n_train = len(full_train_dataset)
n_subset = max(1, int(n_train * TRAIN_PCT))
indices = random.sample(range(n_train), n_subset)
train_dataset = Subset(full_train_dataset, indices)
print(f"训练集: {len(train_dataset)} / {n_train} ({TRAIN_PCT*100:.0f}%)")
print(f"验证集: {len(val_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=False)
model = PretrainedResNet34(num_classes=NUM_CLASSES, dropout=DROPOUT)
model.freeze_early_layers()
model.print_trainable_info()
model = model.to(device)
history = train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu'))
print(f"模型已保存: {MODEL_SAVE_PATH}")
print(f"训练日志已保存: {LOG_PATH}")

145
baseline/VGG_KNN.py Normal file
View file

@ -0,0 +1,145 @@
"""
baseline/VGG_KNN.py
VGG16 预训练模型特征提取 + KNN 四分类基线
可独立运行也可被 compare_models.py 导入复用
author: yukun-hh
date: 2026-5-14
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (
accuracy_score, f1_score,
confusion_matrix, ConfusionMatrixDisplay,
classification_report,
)
from Dataloader import RobustImageFolder
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
def load_vgg16_extractor(device):
try:
model = models.vgg16(weights='IMAGENET1K_V1')
except TypeError:
model = models.vgg16(pretrained=True)
model.classifier = nn.Identity()
model = model.to(device).eval()
for param in model.parameters():
param.requires_grad = False
return model
def extract_features(model, loader, device):
model.eval()
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(loader, desc='Extracting features'):
images = images.to(device)
feats = model(images)
all_features.append(feats.cpu().numpy())
all_labels.append(labels.numpy())
return np.concatenate(all_features), np.concatenate(all_labels)
class VGGKNNBaseline:
def __init__(self, k=5, device='cpu',
data_root='../trash_division_data/ultimate_4_class/',
image_size=256, batch_size=32, num_workers=4):
self.k = k
self.device = device
self.data_root = data_root
self.image_size = image_size
self.batch_size = batch_size
self.num_workers = num_workers
self.extractor = load_vgg16_extractor(device)
self.knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
def _get_loader(self, split):
transform = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
dataset = RobustImageFolder(
root=os.path.join(self.data_root, split),
transform=transform,
)
print(f" {split}: {len(dataset)}")
return DataLoader(dataset, batch_size=self.batch_size,
shuffle=False, num_workers=self.num_workers,
pin_memory=True, drop_last=False)
def fit(self, train_loader=None):
if train_loader is None:
train_loader = self._get_loader('train')
print(" 提取训练集特征 ...")
train_feats, train_labels = extract_features(self.extractor, train_loader, self.device)
self.knn.fit(train_feats, train_labels)
def predict(self, val_loader=None):
if val_loader is None:
val_loader = self._get_loader('val')
print(" 提取验证集特征 ...")
val_feats, val_labels = extract_features(self.extractor, val_loader, self.device)
preds = self.knn.predict(val_feats)
probs = self.knn.predict_proba(val_feats)
return val_labels, preds, probs
if __name__ == '__main__':
DATA_ROOT = '../trash_division_data/ultimate_4_class/'
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
K = 5
device = torch.device('cuda' if torch.cuda.is_available()
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
else 'cpu')
print(f"Device: {device}")
baseline = VGGKNNBaseline(k=K, device=device,
data_root=DATA_ROOT, image_size=IMAGE_SIZE,
batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
train_loader = baseline._get_loader('train')
val_loader = baseline._get_loader('val')
baseline.fit(train_loader)
y_true, y_preds, y_probs = baseline.predict(val_loader)
acc = accuracy_score(y_true, y_preds)
macro_f1 = f1_score(y_true, y_preds, average='macro')
print(f"\n验证集 Accuracy: {acc:.4f}")
print(f"验证集 Macro-F1: {macro_f1:.4f}")
print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}")
cm = confusion_matrix(y_true, y_preds)
fig, ax = plt.subplots(figsize=(8, 7))
ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14)
plt.tight_layout()
out_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vgg_knn_confusion_matrix.png')
plt.savefig(out_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"混淆矩阵已保存: {out_path}")

1
baseline/__init__.py Normal file
View file

@ -0,0 +1 @@
# baseline package

BIN
baseline/accuracy_bar.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

249
baseline/compare_models.py Normal file
View file

@ -0,0 +1,249 @@
"""
baseline/compare_models.py
多模型对比ROC 曲线 + 准确率柱状图
添加新模型只需在 MODELS 列表加一行无需修改绘图代码
author: yukun-hh
date: 2026-5-14
"""
import sys, os, re
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from sklearn.metrics import (
roc_curve, auc, accuracy_score,
precision_recall_curve, average_precision_score,
)
from Model import Net
from Dataloader import RobustImageFolder
from baseline.VGG_KNN import VGGKNNBaseline
from baseline.ResNet34_Pretrained_10pct import get_resnet34_10pct_preds
from baseline.HOG_Baseline import get_hog_lr_preds
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 可配置参数 ★★★
# ============================================================
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
K_KNN = 5
# ============================================================
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
NUM_CLASSES = 4
# ============================================================
# 预测函数 — 每个函数签名: (train_loader, val_loader, device) -> (y_true, y_preds, y_probs)
# ============================================================
def get_resnet34_preds(train_loader, val_loader, device):
model = Net(num_classes=NUM_CLASSES)
state_dict = torch.load('../best_model.pth', map_location='cpu')
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict)
model = model.to(device).eval()
y_true, y_preds, y_probs = [], [], []
with torch.no_grad():
for images, labels in tqdm(val_loader, desc='ResNet-34'):
images, labels = images.to(device), labels
logits = model(images)
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
y_true.append(labels.numpy())
y_preds.append(preds.cpu().numpy())
y_probs.append(probs.cpu().numpy())
return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs)
def get_vgg_knn_preds(train_loader, val_loader, device):
baseline = VGGKNNBaseline(k=K_KNN, device=device)
baseline.fit(train_loader)
return baseline.predict(val_loader)
# ============================================================
# ★ 模型注册表 — 添加新模型只需在这里加一行 ★
# ============================================================
MODELS = [
('ResNet-34', get_resnet34_preds),
('ResNet-34 (10% Fine-tune)', get_resnet34_10pct_preds),
('VGG16 + KNN (K=5)', get_vgg_knn_preds),
('HOG + LogisticRegression', get_hog_lr_preds),
# 未来轻松扩展示例:
# ('ResNet-18 (pretrained)', get_resnet18_preds),
# ('ResNet-50 (pretrained)', get_resnet50_preds),
# ('ResNet-34 (finetuned)', get_finetuned_preds),
]
# ============================================================
# 调色板 (扩展时无需修改)
# ============================================================
COLORS = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b',
'#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
def compute_macro_roc(y_true, y_probs):
one_hot = np.eye(NUM_CLASSES)[y_true]
fpr_dict, tpr_dict = {}, {}
for c in range(NUM_CLASSES):
fpr_dict[c], tpr_dict[c], _ = roc_curve(one_hot[:, c], y_probs[:, c])
all_fpr = np.unique(np.concatenate([fpr_dict[c] for c in range(NUM_CLASSES)]))
mean_tpr = np.zeros_like(all_fpr)
for c in range(NUM_CLASSES):
mean_tpr += np.interp(all_fpr, fpr_dict[c], tpr_dict[c])
mean_tpr /= NUM_CLASSES
macro_auc = auc(all_fpr, mean_tpr)
return all_fpr, mean_tpr, macro_auc
def compute_macro_pr(y_true, y_probs):
one_hot = np.eye(NUM_CLASSES)[y_true]
prec_dict, rec_dict = {}, {}
for c in range(NUM_CLASSES):
prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c])
all_rec = np.linspace(0, 1, 200)
mean_prec = np.zeros_like(all_rec)
for c in range(NUM_CLASSES):
mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1])
mean_prec /= NUM_CLASSES
macro_ap = average_precision_score(one_hot, y_probs, average='macro')
return all_rec, mean_prec, macro_ap
def sanitize_filename(name):
return re.sub(r'[^\w\-_]', '_', name).strip('_')
def preds_csv_path(out_dir, model_name):
safe = sanitize_filename(model_name)
return os.path.join(out_dir, f'{safe}_preds.csv')
def save_preds_csv(path, y_true, y_preds, y_probs):
header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES))
data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs])
np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f')
def load_preds_csv(path):
data = np.loadtxt(path, delimiter=',', skiprows=1)
y_true = data[:, 0].astype(int)
y_preds = data[:, 1].astype(int)
y_probs = data[:, 2:2 + NUM_CLASSES]
return y_true, y_preds, y_probs
if __name__ == '__main__':
out_dir = os.path.dirname(os.path.abspath(__file__))
device = torch.device('cuda' if torch.cuda.is_available()
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
else 'cpu')
print(f"Device: {device}")
val_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'train'),
transform=val_transform)
val_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'val'),
transform=val_transform)
print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
# ———— 评估所有模型(有缓存则跳过)————
results = {}
for name, func in MODELS:
print(f"\n{'='*50}")
csv_path = preds_csv_path(out_dir, name)
if os.path.exists(csv_path):
print(f"加载缓存: {os.path.basename(csv_path)}")
y_true, y_preds, y_probs = load_preds_csv(csv_path)
else:
print(f"评估: {name}")
y_true, y_preds, y_probs = func(train_loader, val_loader, device)
save_preds_csv(csv_path, y_true, y_preds, y_probs)
print(f" 预测数据已保存: {os.path.basename(csv_path)}")
acc = accuracy_score(y_true, y_preds)
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
rec, prec, macro_ap = compute_macro_pr(y_true, y_probs)
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc,
'rec': rec, 'prec': prec, 'ap': macro_ap}
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}")
# ———— ROC 对比图 ————
fig, ax = plt.subplots(figsize=(8, 7))
for i, (name, r) in enumerate(results.items()):
color = COLORS[i % len(COLORS)]
ax.plot(r['fpr'], r['tpr'], color=color, lw=2,
label=f"{name} (AUC={r['auc']:.4f})")
ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve Comparison (Macro-Average)', fontsize=14)
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)
plt.tight_layout()
roc_path = os.path.join(out_dir, 'roc_comparison.png')
plt.savefig(roc_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"\nROC 对比图已保存: {roc_path}")
# ———— PR 对比图 ————
fig, ax = plt.subplots(figsize=(8, 7))
for i, (name, r) in enumerate(results.items()):
color = COLORS[i % len(COLORS)]
ax.plot(r['rec'], r['prec'], color=color, lw=2,
label=f"{name} (AP={r['ap']:.4f})")
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14)
ax.legend(loc='lower left'); ax.grid(True, alpha=0.3)
plt.tight_layout()
pr_path = os.path.join(out_dir, 'pr_comparison.png')
plt.savefig(pr_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"PR 对比图已保存: {pr_path}")
# ———— 准确率柱状图 ————
names = list(results.keys())
accs = [results[n]['acc'] for n in names]
fig, ax = plt.subplots(figsize=(8, 5))
bar_colors = [COLORS[i % len(COLORS)] for i in range(len(names))]
bars = ax.bar(names, accs, color=bar_colors, edgecolor='white', linewidth=1.2)
for bar, acc in zip(bars, accs):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
f'{acc:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
ax.set_ylim(min(accs) - 0.03, max(accs) * 1.08)
ax.set_ylabel('Accuracy'); ax.set_title('Accuracy Comparison', fontsize=14)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
bar_path = os.path.join(out_dir, 'accuracy_bar.png')
plt.savefig(bar_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"准确率柱状图已保存: {bar_path}")

BIN
baseline/pr_comparison.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

BIN
baseline/roc_comparison.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 125 KiB

BIN
confusion_matrix.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

BIN
pr_curve.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

10
requirements.txt Normal file
View file

@ -0,0 +1,10 @@
torch>=1.10
torchvision>=0.11
tqdm
matplotlib
pandas
Pillow
scikit-learn
scikit-image
numpy
torchsummary

BIN
roc_curve.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

BIN
training_curves.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 296 KiB

81
training_log.csv Normal file
View file

@ -0,0 +1,81 @@
epoch,train_loss,train_f1,train_acc,val_loss,val_f1,val_acc,lr,best
1,1.0409312975676923,0.4329540729522705,48.04254100337675,1.1043149583345566,0.4398210048675537,48.66548042704626,0.004998072590601808,best
2,0.9862563783744079,0.4695238769054413,52.5943680656054,0.9867177669753319,0.5062971115112305,58.397207774431976,0.00499229333433282,best
3,0.9462850892451784,0.49421826004981995,55.40279787747226,1.0144445673589866,0.4907984733581543,53.15494114426499,0.004982671142387316,
4,0.910117163585685,0.514958381652832,57.832097202122526,0.8787865286946395,0.5453917980194092,62.544483985765126,0.004969220851487844,best
5,0.8786031692946986,0.5320333242416382,59.74282440906898,1.0686318878927787,0.4803737998008728,52.73063235696688,0.004951963201008076,
6,0.8518873820889128,0.5481140613555908,61.51938615533044,0.7650798693964196,0.6073676347732544,68.98439638653161,0.004930924800994191,best
7,0.8256270786701512,0.5604796409606934,62.90249638205499,0.8796401012116773,0.5789190530776978,62.63345195729537,0.004906138091134118,
8,0.8003699506646013,0.5742803812026978,64.3014351181862,0.9246643470014378,0.5521833896636963,60.88831097727895,0.004877641290737884,
9,0.780536473588097,0.5827116966247559,65.25642185238785,0.8404132533719564,0.5876226425170898,65.89789214344374,0.00484547833980621,
10,0.7604798049209087,0.595557451248169,66.54079232995659,0.9228097118810533,0.564703643321991,60.77881193539557,0.004809698831278217,
11,0.7410275131047088,0.6043155789375305,67.35784491075735,0.7576604621266131,0.6295210123062134,69.83985765124555,0.0047703579345627035,best
12,0.7195374228732343,0.6127941608428955,68.07766521948867,0.9624476881507583,0.5630610585212708,61.59321105940323,0.00472751631047092,
13,0.6997139808122973,0.6210756301879883,68.85175470332851,0.7615812296349155,0.6177672147750854,69.12127018888584,0.004681240017681994,
14,0.6824904908837182,0.630592942237854,69.65448625180898,0.6715762626299165,0.6534035205841064,73.48754448398577,0.004631600410885231,best
15,0.6653590450468583,0.6379610300064087,70.39088880849012,0.694461440047988,0.6517682075500488,73.0906104571585,0.004578674030756364,
16,0.6514209577758935,0.6478185653686523,71.27879281234925,0.7036816360785346,0.6470745801925659,71.76977826444019,0.004522542485937369,
17,0.6330186040776395,0.6530008316040039,71.7935962373372,0.7222367905930823,0.6418735980987549,71.07856556255133,0.004463292327201863,
18,0.6166394593717968,0.6634106040000916,72.6038651712494,0.6067476719332303,0.6886636018753052,77.49110320284697,0.004401014914000078,best
19,0.5973944908975692,0.6721534729003906,73.47367945007235,0.6952472055509845,0.6622275114059448,71.79715302491103,0.004335806273589214,
20,0.5820678306183721,0.6758297681808472,73.73145803183792,0.7708474785342401,0.6217234134674072,68.74486723241172,0.004267766952966369,
21,0.5650806297110982,0.6851130723953247,74.5741377231066,0.7461620579141478,0.6384793519973755,71.35231316725978,0.004197001863832355,
22,0.5500074683958588,0.6915749311447144,75.099493487699,0.6420613189380593,0.672810435295105,74.9452504790583,0.00412362012082546,
23,0.5367840825001858,0.6979560852050781,75.66102870236372,0.6252713002082977,0.6949211359024048,75.32849712565014,0.0040477348732745845,best
24,0.5234906795055925,0.7052106857299805,76.26025084418717,0.7471277477021352,0.6447888016700745,69.70982753900904,0.003969463130731182,
25,0.5044557179049829,0.7132176160812378,76.91148094548963,0.6325626891507145,0.6857903003692627,75.52012044894607,0.0038889255825490052,
26,0.4938347232885195,0.7174828052520752,77.23784973468403,0.5635758755127437,0.70375657081604,78.50396934026827,0.003806246411789872,best
27,0.4793313278116239,0.7242900133132935,77.87475880366618,0.5505193975847648,0.7201660871505737,78.89405967697783,0.003721553103742388,best
28,0.46573570758837524,0.7336312532424927,78.60437771345876,0.640248272807638,0.6859503984451294,75.13003011223651,0.003634976249348867,
29,0.44927708754289913,0.737967312335968,78.9352689339122,0.6151526539644867,0.7065733075141907,74.69887763482069,0.00354664934384357,
30,0.4373503708129221,0.7443608045578003,79.40409430776653,0.5578661908719627,0.7272701263427734,78.20969066520668,0.0034567085809127244,best
31,0.42717206794400175,0.7488712072372437,79.7779486251809,0.58909761693554,0.7034546136856079,76.88201478237066,0.003365292642693732,
32,0.4100124511779706,0.7580570578575134,80.60630728412929,0.6458172935624336,0.6865078210830688,75.32849712565014,0.0032725424859373683,
33,0.3993677339991451,0.763430118560791,80.9876989869754,0.47995706558097007,0.754202127456665,81.65891048453327,0.003178601124662685,best
34,0.3858378949555808,0.7697042226791382,81.54772672455378,0.6427663844838523,0.6931804418563843,74.28141253764029,0.0030836134096397633,
35,0.37397055771404913,0.7764154672622681,81.9992161119151,0.6085299244000101,0.7046636343002319,77.08048179578428,0.0029877258050403205,
36,0.3597575335889638,0.7818952798843384,82.53587795465509,0.5254679805051781,0.7415529489517212,78.89405967697783,0.002891086162600577,
37,0.3487578573732404,0.7871347665786743,82.8999336710082,0.5125140355052995,0.748577356338501,80.1875171092253,0.002793843493644594,
38,0.3325814358527052,0.7965956330299377,83.59035817655571,0.5408413317834649,0.7290798425674438,79.87270736381056,0.002696147739319612,
39,0.3261546248608721,0.7988470792770386,83.75316570188133,0.5301555857539602,0.7376729249954224,80.06433068710649,0.002598149539397671,
40,0.30964642827472305,0.8070269823074341,84.52725518572117,0.5468305750249544,0.7365171313285828,78.73665480427046,0.0024999999999999996,
41,0.3009217412674191,0.8119726777076721,84.96140858658949,0.46898612490165015,0.7599539756774902,82.18587462359704,0.002401850460602329,best
42,0.28925693789887874,0.8200639486312866,85.6458031837916,0.5167866427621677,0.7465909123420715,80.83766767040788,0.0023038522606803878,
43,0.2707157838268379,0.8313596248626709,86.46360950313556,0.5156203284349763,0.7596548199653625,80.6049822064057,0.0022061565063554063,
44,0.2580273799019566,0.836384654045105,86.8412325132658,0.5318487190707494,0.746901273727417,80.27648508075555,0.0021089138373994237,
45,0.2504911308580703,0.8410984873771667,87.30779667149059,0.49164087495639364,0.763725221157074,81.82315904735833,0.00201227419495968,best
46,0.24104372076995695,0.8451772928237915,87.64094910757356,0.5290114263981752,0.7580969333648682,80.94032302217356,0.0019163865903602372,
47,0.22337641519549614,0.8570870161056519,88.56201760733236,0.43634469677838694,0.7913081049919128,85.10128661374213,0.0018213988753373142,best
48,0.2128122905210861,0.8645581603050232,89.1152616980222,0.43183545479456625,0.7972898483276367,85.49137695045168,0.001727457514062632,best
49,0.2003101470318182,0.8717849254608154,89.69187168355042,0.4289672785945178,0.806715726852417,85.7993430057487,0.0016347073573062686,best
50,0.1888495338707803,0.8796613216400146,90.34687047756874,0.4568697272132202,0.7956517338752747,84.64275937585546,0.0015432914190872762,
51,0.1756466486088274,0.886497437953949,90.97096599131693,0.4541556305781541,0.7938134670257568,84.45113605255953,0.001453350656156431,
52,0.16742963044469736,0.8907681703567505,91.3101483357453,0.4230425570913775,0.8147625923156738,86.100465370928,0.0013650237506511336,best
53,0.15311133117022804,0.9007841944694519,92.11212614568258,0.4146759752586969,0.8208259344100952,86.83274021352314,0.0012784468962576128,best
54,0.1423091164071722,0.9078108668327332,92.6714001447178,0.4709351422719488,0.8029968738555908,85.71721872433616,0.0011937535882101285,
55,0.13160189816137902,0.9135022163391113,93.10932223830197,0.40829240685941226,0.8264325857162476,87.42814125376403,0.0011110744174509947,best
56,0.12707359487800943,0.9174070358276367,93.4409671972986,0.42565728100299705,0.8230471611022949,87.65398302764851,0.0010305368692688178,
57,0.11291898237482914,0.925153374671936,94.10953328509407,0.43774206922898184,0.8247347474098206,87.5376402956474,0.0009522651267254161,
58,0.10370979833767516,0.9329074025154114,94.62584418716835,0.4068256767521223,0.8333848118782043,88.05091705447578,0.0008763798791745416,best
59,0.0946220946482491,0.9380815029144287,95.09768451519537,0.41103389083751746,0.8357677459716797,88.4889132220093,0.0008029981361676465,best
60,0.08804238645213414,0.9423004388809204,95.45194163048721,0.4207381762007377,0.8308929204940796,88.17410347659458,0.0007322330470336316,
61,0.07913849578165794,0.9495129585266113,95.9916184273999,0.4083157278823748,0.8420299291610718,89.00218998083767,0.0006641937264107861,best
62,0.06981624146470565,0.9559226036071777,96.49662325132658,0.41332166066585485,0.843511700630188,88.98850260060225,0.0005989850859999229,best
63,0.06394793773276639,0.9602090120315552,96.79811866859623,0.41052334102789995,0.8489691019058228,89.35121817684096,0.0005367076727981376,best
64,0.057007493751794744,0.9636315107345581,97.11016642547034,0.3970402057488879,0.8546013832092285,90.04927456884752,0.00047745751406263185,best
65,0.05285091761448427,0.967146635055542,97.40035576459238,0.4247235585867853,0.8433754444122314,89.34437448672324,0.0004213259692436376,
66,0.04614944799553407,0.9710246324539185,97.6912988422576,0.4035414461747053,0.8538572788238525,89.85080755543389,0.00036839958911476966,
67,0.042909558690492254,0.9727644920349121,97.8428002894356,0.41578795896298,0.8530543446540833,89.91924445661101,0.0003187599823180077,
68,0.03640206224887977,0.9769999980926514,98.15710926193921,0.42073161891477256,0.8551151156425476,89.94661921708185,0.0002724836895290806,best
69,0.034432517173010414,0.9790080785751343,98.34328268210324,0.4133664407223553,0.8576940298080444,90.28880372296743,0.00022964206543729668,best
70,0.030669637766023813,0.9817556142807007,98.5249336710082,0.418308269951463,0.8569411039352417,90.12455516014235,0.00019030116872178321,
71,0.028112305183924133,0.9827903509140015,98.606337433671,0.4151474575991667,0.8596312999725342,90.37092800437996,0.00015452166019378966,best
72,0.024704152367817256,0.9853801727294922,98.81135431741437,0.4153705465811558,0.8635820746421814,90.68573774979468,0.0001223587092621162,best
73,0.024846541488804174,0.9855506420135498,98.8369814278823,0.4177400436290088,0.8632140159606934,90.59676977826444,9.38619088658821e-05,
74,0.022639600746625622,0.9868491888046265,98.94702725518572,0.41732572613841307,0.8648342490196228,90.78154941144265,6.907519900580863e-05,best
75,0.02120214593173326,0.9878177642822266,99.0231548480463,0.4163925825270714,0.866214394569397,90.89789214344374,4.803679899192394e-05,best
76,0.019741657997631577,0.9883521795272827,99.06385672937772,0.42005763620917286,0.8647006750106812,90.82945524226663,3.077914851215586e-05,
77,0.019116416042495393,0.9889511466026306,99.10003617945007,0.4159400745789841,0.8657370805740356,90.84998631261976,1.7328857612684272e-05,
78,0.019259902796210714,0.9888157844543457,99.0962674867342,0.4192042892654481,0.8641382455825806,90.69942513003011,7.706665667180091e-06,
79,0.01933925595445387,0.9887675046920776,99.0759165460685,0.4180937778044573,0.8662786483764648,90.84998631261976,1.9274093981927482e-06,best
80,0.01922732148408437,0.9889604449272156,99.10078991799324,0.41794140280912484,0.864332914352417,90.82261155214891,0.0,
1 epoch train_loss train_f1 train_acc val_loss val_f1 val_acc lr best
2 1 1.0409312975676923 0.4329540729522705 48.04254100337675 1.1043149583345566 0.4398210048675537 48.66548042704626 0.004998072590601808 best
3 2 0.9862563783744079 0.4695238769054413 52.5943680656054 0.9867177669753319 0.5062971115112305 58.397207774431976 0.00499229333433282 best
4 3 0.9462850892451784 0.49421826004981995 55.40279787747226 1.0144445673589866 0.4907984733581543 53.15494114426499 0.004982671142387316
5 4 0.910117163585685 0.514958381652832 57.832097202122526 0.8787865286946395 0.5453917980194092 62.544483985765126 0.004969220851487844 best
6 5 0.8786031692946986 0.5320333242416382 59.74282440906898 1.0686318878927787 0.4803737998008728 52.73063235696688 0.004951963201008076
7 6 0.8518873820889128 0.5481140613555908 61.51938615533044 0.7650798693964196 0.6073676347732544 68.98439638653161 0.004930924800994191 best
8 7 0.8256270786701512 0.5604796409606934 62.90249638205499 0.8796401012116773 0.5789190530776978 62.63345195729537 0.004906138091134118
9 8 0.8003699506646013 0.5742803812026978 64.3014351181862 0.9246643470014378 0.5521833896636963 60.88831097727895 0.004877641290737884
10 9 0.780536473588097 0.5827116966247559 65.25642185238785 0.8404132533719564 0.5876226425170898 65.89789214344374 0.00484547833980621
11 10 0.7604798049209087 0.595557451248169 66.54079232995659 0.9228097118810533 0.564703643321991 60.77881193539557 0.004809698831278217
12 11 0.7410275131047088 0.6043155789375305 67.35784491075735 0.7576604621266131 0.6295210123062134 69.83985765124555 0.0047703579345627035 best
13 12 0.7195374228732343 0.6127941608428955 68.07766521948867 0.9624476881507583 0.5630610585212708 61.59321105940323 0.00472751631047092
14 13 0.6997139808122973 0.6210756301879883 68.85175470332851 0.7615812296349155 0.6177672147750854 69.12127018888584 0.004681240017681994
15 14 0.6824904908837182 0.630592942237854 69.65448625180898 0.6715762626299165 0.6534035205841064 73.48754448398577 0.004631600410885231 best
16 15 0.6653590450468583 0.6379610300064087 70.39088880849012 0.694461440047988 0.6517682075500488 73.0906104571585 0.004578674030756364
17 16 0.6514209577758935 0.6478185653686523 71.27879281234925 0.7036816360785346 0.6470745801925659 71.76977826444019 0.004522542485937369
18 17 0.6330186040776395 0.6530008316040039 71.7935962373372 0.7222367905930823 0.6418735980987549 71.07856556255133 0.004463292327201863
19 18 0.6166394593717968 0.6634106040000916 72.6038651712494 0.6067476719332303 0.6886636018753052 77.49110320284697 0.004401014914000078 best
20 19 0.5973944908975692 0.6721534729003906 73.47367945007235 0.6952472055509845 0.6622275114059448 71.79715302491103 0.004335806273589214
21 20 0.5820678306183721 0.6758297681808472 73.73145803183792 0.7708474785342401 0.6217234134674072 68.74486723241172 0.004267766952966369
22 21 0.5650806297110982 0.6851130723953247 74.5741377231066 0.7461620579141478 0.6384793519973755 71.35231316725978 0.004197001863832355
23 22 0.5500074683958588 0.6915749311447144 75.099493487699 0.6420613189380593 0.672810435295105 74.9452504790583 0.00412362012082546
24 23 0.5367840825001858 0.6979560852050781 75.66102870236372 0.6252713002082977 0.6949211359024048 75.32849712565014 0.0040477348732745845 best
25 24 0.5234906795055925 0.7052106857299805 76.26025084418717 0.7471277477021352 0.6447888016700745 69.70982753900904 0.003969463130731182
26 25 0.5044557179049829 0.7132176160812378 76.91148094548963 0.6325626891507145 0.6857903003692627 75.52012044894607 0.0038889255825490052
27 26 0.4938347232885195 0.7174828052520752 77.23784973468403 0.5635758755127437 0.70375657081604 78.50396934026827 0.003806246411789872 best
28 27 0.4793313278116239 0.7242900133132935 77.87475880366618 0.5505193975847648 0.7201660871505737 78.89405967697783 0.003721553103742388 best
29 28 0.46573570758837524 0.7336312532424927 78.60437771345876 0.640248272807638 0.6859503984451294 75.13003011223651 0.003634976249348867
30 29 0.44927708754289913 0.737967312335968 78.9352689339122 0.6151526539644867 0.7065733075141907 74.69887763482069 0.00354664934384357
31 30 0.4373503708129221 0.7443608045578003 79.40409430776653 0.5578661908719627 0.7272701263427734 78.20969066520668 0.0034567085809127244 best
32 31 0.42717206794400175 0.7488712072372437 79.7779486251809 0.58909761693554 0.7034546136856079 76.88201478237066 0.003365292642693732
33 32 0.4100124511779706 0.7580570578575134 80.60630728412929 0.6458172935624336 0.6865078210830688 75.32849712565014 0.0032725424859373683
34 33 0.3993677339991451 0.763430118560791 80.9876989869754 0.47995706558097007 0.754202127456665 81.65891048453327 0.003178601124662685 best
35 34 0.3858378949555808 0.7697042226791382 81.54772672455378 0.6427663844838523 0.6931804418563843 74.28141253764029 0.0030836134096397633
36 35 0.37397055771404913 0.7764154672622681 81.9992161119151 0.6085299244000101 0.7046636343002319 77.08048179578428 0.0029877258050403205
37 36 0.3597575335889638 0.7818952798843384 82.53587795465509 0.5254679805051781 0.7415529489517212 78.89405967697783 0.002891086162600577
38 37 0.3487578573732404 0.7871347665786743 82.8999336710082 0.5125140355052995 0.748577356338501 80.1875171092253 0.002793843493644594
39 38 0.3325814358527052 0.7965956330299377 83.59035817655571 0.5408413317834649 0.7290798425674438 79.87270736381056 0.002696147739319612
40 39 0.3261546248608721 0.7988470792770386 83.75316570188133 0.5301555857539602 0.7376729249954224 80.06433068710649 0.002598149539397671
41 40 0.30964642827472305 0.8070269823074341 84.52725518572117 0.5468305750249544 0.7365171313285828 78.73665480427046 0.0024999999999999996
42 41 0.3009217412674191 0.8119726777076721 84.96140858658949 0.46898612490165015 0.7599539756774902 82.18587462359704 0.002401850460602329 best
43 42 0.28925693789887874 0.8200639486312866 85.6458031837916 0.5167866427621677 0.7465909123420715 80.83766767040788 0.0023038522606803878
44 43 0.2707157838268379 0.8313596248626709 86.46360950313556 0.5156203284349763 0.7596548199653625 80.6049822064057 0.0022061565063554063
45 44 0.2580273799019566 0.836384654045105 86.8412325132658 0.5318487190707494 0.746901273727417 80.27648508075555 0.0021089138373994237
46 45 0.2504911308580703 0.8410984873771667 87.30779667149059 0.49164087495639364 0.763725221157074 81.82315904735833 0.00201227419495968 best
47 46 0.24104372076995695 0.8451772928237915 87.64094910757356 0.5290114263981752 0.7580969333648682 80.94032302217356 0.0019163865903602372
48 47 0.22337641519549614 0.8570870161056519 88.56201760733236 0.43634469677838694 0.7913081049919128 85.10128661374213 0.0018213988753373142 best
49 48 0.2128122905210861 0.8645581603050232 89.1152616980222 0.43183545479456625 0.7972898483276367 85.49137695045168 0.001727457514062632 best
50 49 0.2003101470318182 0.8717849254608154 89.69187168355042 0.4289672785945178 0.806715726852417 85.7993430057487 0.0016347073573062686 best
51 50 0.1888495338707803 0.8796613216400146 90.34687047756874 0.4568697272132202 0.7956517338752747 84.64275937585546 0.0015432914190872762
52 51 0.1756466486088274 0.886497437953949 90.97096599131693 0.4541556305781541 0.7938134670257568 84.45113605255953 0.001453350656156431
53 52 0.16742963044469736 0.8907681703567505 91.3101483357453 0.4230425570913775 0.8147625923156738 86.100465370928 0.0013650237506511336 best
54 53 0.15311133117022804 0.9007841944694519 92.11212614568258 0.4146759752586969 0.8208259344100952 86.83274021352314 0.0012784468962576128 best
55 54 0.1423091164071722 0.9078108668327332 92.6714001447178 0.4709351422719488 0.8029968738555908 85.71721872433616 0.0011937535882101285
56 55 0.13160189816137902 0.9135022163391113 93.10932223830197 0.40829240685941226 0.8264325857162476 87.42814125376403 0.0011110744174509947 best
57 56 0.12707359487800943 0.9174070358276367 93.4409671972986 0.42565728100299705 0.8230471611022949 87.65398302764851 0.0010305368692688178
58 57 0.11291898237482914 0.925153374671936 94.10953328509407 0.43774206922898184 0.8247347474098206 87.5376402956474 0.0009522651267254161
59 58 0.10370979833767516 0.9329074025154114 94.62584418716835 0.4068256767521223 0.8333848118782043 88.05091705447578 0.0008763798791745416 best
60 59 0.0946220946482491 0.9380815029144287 95.09768451519537 0.41103389083751746 0.8357677459716797 88.4889132220093 0.0008029981361676465 best
61 60 0.08804238645213414 0.9423004388809204 95.45194163048721 0.4207381762007377 0.8308929204940796 88.17410347659458 0.0007322330470336316
62 61 0.07913849578165794 0.9495129585266113 95.9916184273999 0.4083157278823748 0.8420299291610718 89.00218998083767 0.0006641937264107861 best
63 62 0.06981624146470565 0.9559226036071777 96.49662325132658 0.41332166066585485 0.843511700630188 88.98850260060225 0.0005989850859999229 best
64 63 0.06394793773276639 0.9602090120315552 96.79811866859623 0.41052334102789995 0.8489691019058228 89.35121817684096 0.0005367076727981376 best
65 64 0.057007493751794744 0.9636315107345581 97.11016642547034 0.3970402057488879 0.8546013832092285 90.04927456884752 0.00047745751406263185 best
66 65 0.05285091761448427 0.967146635055542 97.40035576459238 0.4247235585867853 0.8433754444122314 89.34437448672324 0.0004213259692436376
67 66 0.04614944799553407 0.9710246324539185 97.6912988422576 0.4035414461747053 0.8538572788238525 89.85080755543389 0.00036839958911476966
68 67 0.042909558690492254 0.9727644920349121 97.8428002894356 0.41578795896298 0.8530543446540833 89.91924445661101 0.0003187599823180077
69 68 0.03640206224887977 0.9769999980926514 98.15710926193921 0.42073161891477256 0.8551151156425476 89.94661921708185 0.0002724836895290806 best
70 69 0.034432517173010414 0.9790080785751343 98.34328268210324 0.4133664407223553 0.8576940298080444 90.28880372296743 0.00022964206543729668 best
71 70 0.030669637766023813 0.9817556142807007 98.5249336710082 0.418308269951463 0.8569411039352417 90.12455516014235 0.00019030116872178321
72 71 0.028112305183924133 0.9827903509140015 98.606337433671 0.4151474575991667 0.8596312999725342 90.37092800437996 0.00015452166019378966 best
73 72 0.024704152367817256 0.9853801727294922 98.81135431741437 0.4153705465811558 0.8635820746421814 90.68573774979468 0.0001223587092621162 best
74 73 0.024846541488804174 0.9855506420135498 98.8369814278823 0.4177400436290088 0.8632140159606934 90.59676977826444 9.38619088658821e-05
75 74 0.022639600746625622 0.9868491888046265 98.94702725518572 0.41732572613841307 0.8648342490196228 90.78154941144265 6.907519900580863e-05 best
76 75 0.02120214593173326 0.9878177642822266 99.0231548480463 0.4163925825270714 0.866214394569397 90.89789214344374 4.803679899192394e-05 best
77 76 0.019741657997631577 0.9883521795272827 99.06385672937772 0.42005763620917286 0.8647006750106812 90.82945524226663 3.077914851215586e-05
78 77 0.019116416042495393 0.9889511466026306 99.10003617945007 0.4159400745789841 0.8657370805740356 90.84998631261976 1.7328857612684272e-05
79 78 0.019259902796210714 0.9888157844543457 99.0962674867342 0.4192042892654481 0.8641382455825806 90.69942513003011 7.706665667180091e-06
80 79 0.01933925595445387 0.9887675046920776 99.0759165460685 0.4180937778044573 0.8662786483764648 90.84998631261976 1.9274093981927482e-06 best
81 80 0.01922732148408437 0.9889604449272156 99.10078991799324 0.41794140280912484 0.864332914352417 90.82261155214891 0.0