diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..67fdd7a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,30 @@
+*
+!/.gitattributes
+!/Dataloader.py
+!/LICENSE
+!/Merge_classes.py
+!/Model.py
+!/README.md
+!/requirements.txt
+!/THIRD_PARTY_LICENSES.md
+!/Train.py
+!/app.py
+!/Baseline.py
+!/Finetune.py
+!/Curve.py
+!/Evaluate.py
+!/baseline/
+!/baseline/__init__.py
+!/baseline/VGG_KNN.py
+!/baseline/compare_models.py
+!/baseline/ResNet34_Pretrained_10pct.py
+!/baseline/HOG_Baseline.py
+!/baseline/roc_comparison.png
+!/baseline/pr_comparison.png
+!/baseline/accuracy_bar.png
+!/training_log.csv
+!/confusion_matrix.png
+!/roc_curve.png
+!/pr_curve.png
+!/training_curves.png
+!.gitignore
diff --git a/Curve.py b/Curve.py
new file mode 100644
index 0000000..7c38b37
--- /dev/null
+++ b/Curve.py
@@ -0,0 +1,50 @@
+"""
+plot_training_curves.py
+从 training_log.csv 读取日志,绘制 Loss / F1 / Accuracy / LR 曲线
+"""
+
+import pandas as pd
+import matplotlib.pyplot as plt
+import matplotlib
+
+matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
+matplotlib.rcParams['axes.unicode_minus'] = False
+
+# ============ 读取数据 ============
+df = pd.read_csv('training_log.csv')
+best_rows = df[df['best'] == 'best']
+
+fig, axes = plt.subplots(2, 2, figsize=(14, 10))
+
+# ---- 1. Loss ----
+ax = axes[0, 0]
+ax.plot(df['epoch'], df['train_loss'], label='Train Loss', color='#1f77b4', lw=1.5)
+ax.plot(df['epoch'], df['val_loss'], label='Val Loss', color='#ff7f0e', lw=1.5)
+ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.set_title('Loss vs Epoch')
+ax.legend(); ax.grid(True, alpha=0.3)
+
+# ---- 2. F1 Score ----
+ax = axes[0, 1]
+ax.plot(df['epoch'], df['train_f1'], label='Train F1', color='#1f77b4', lw=1.5)
+ax.plot(df['epoch'], df['val_f1'], label='Val F1', color='#ff7f0e', lw=1.5)
+ax.set_xlabel('Epoch'); ax.set_ylabel('F1 Score'); ax.set_title('F1 Score vs Epoch')
+ax.legend(); ax.grid(True, alpha=0.3)
+
+# ---- 3. Accuracy ----
+ax = axes[1, 0]
+ax.plot(df['epoch'], df['train_acc'], label='Train Acc', color='#1f77b4', lw=1.5)
+ax.plot(df['epoch'], df['val_acc'], label='Val Acc', color='#ff7f0e', lw=1.5)
+ax.set_xlabel('Epoch'); ax.set_ylabel('Accuracy (%)'); ax.set_title('Accuracy vs Epoch')
+ax.legend(); ax.grid(True, alpha=0.3)
+
+# ---- 4. Learning Rate ----
+ax = axes[1, 1]
+ax.plot(df['epoch'], df['lr'], color='#2ca02c', lw=1.5)
+ax.set_xlabel('Epoch'); ax.set_ylabel('Learning Rate'); ax.set_title('Learning Rate vs Epoch')
+ax.ticklabel_format(style='scientific', axis='y', scilimits=(0, 0))
+ax.grid(True, alpha=0.3)
+
+plt.tight_layout()
+plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
+plt.show()
+print("训练曲线已保存: training_curves.png")
diff --git a/Dataloader.py b/Dataloader.py
index fab833b..5d51e3f 100644
--- a/Dataloader.py
+++ b/Dataloader.py
@@ -13,9 +13,75 @@ import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
+import pandas as pd
+from torch.utils.data import Dataset
+from PIL import Image, ImageFile
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+from tqdm import tqdm
+from torch.utils.data import Dataset
+from PIL import Image, ImageFile
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
-def create_dataloaders(data_root='..',
+class RobustImageFolder(Dataset):
+ """包装 ImageFolder,自动跳过损坏图片,带进度条"""
+
+ def __init__(self, root, transform=None):
+ self.transform = transform
+ self.samples = []
+ self.classes = []
+ self.class_to_idx = {}
+
+ # 先构建原始的 ImageFolder 来获取类别信息
+ temp_dataset = datasets.ImageFolder(root, transform=None)
+ self.classes = temp_dataset.classes
+ self.class_to_idx = temp_dataset.class_to_idx
+
+ # 带进度条扫描
+ print(f"\n正在扫描: {root}")
+ print(f"发现 {len(temp_dataset.samples)} 个文件,开始验证...\n")
+
+ corrupted_count = 0
+ success_count = 0
+
+ # 使用 tqdm 显示进度
+ for path, label in tqdm(temp_dataset.samples,
+ desc="验证图片完整性",
+ unit="张",
+ ncols=80):
+ try:
+ self.samples.append((path, label))
+ success_count += 1
+ except Exception as e:
+ corrupted_count += 1
+ # 可选:只打印前10个错误,避免刷屏
+ if corrupted_count <= 10:
+ tqdm.write(f"⚠️ 跳过损坏: {os.path.basename(path)}")
+ elif corrupted_count == 11:
+ tqdm.write(f"⚠️ 后续损坏图片将不再显示...")
+
+ print(f"\n✅ 扫描完成!")
+ print(f" 📁 有效图片: {success_count} 张")
+ print(f" ❌ 损坏跳过: {corrupted_count} 张")
+ print(f" 📊 总计: {len(self.samples)} 张\n")
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, idx):
+ path, label = self.samples[idx]
+ try:
+ img = Image.open(path).convert('RGB')
+ if self.transform:
+ img = self.transform(img)
+ return img, label
+ except Exception as e:
+ # 极少数情况,返回下一个
+ return self.__getitem__((idx + 1) % len(self))
+def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
batch_size=32,
image_size=256,
val_split=0.2,
@@ -77,12 +143,12 @@ def create_dataloaders(data_root='..',
# 2. 加载数据集
# ==================================
print("使用独立的 val 文件夹")
- train_dataset = datasets.ImageFolder(
+ train_dataset = RobustImageFolder(
root=os.path.join(data_root, 'train'),
transform=train_transform if augment else val_transform
)
- val_dataset = datasets.ImageFolder(
+ val_dataset = RobustImageFolder(
root=os.path.join(data_root, 'val'),
transform=val_transform
)
@@ -111,9 +177,8 @@ def create_dataloaders(data_root='..',
)
# 4. 获取类别名称
- class_names = train_dataset.classes if hasattr(train_dataset, 'classes') else ['0', '1', '2', '3']
+ class_names = train_dataset.classes
print(f"类别: {class_names}")
- print(f"类别映射: {train_dataset.class_to_idx if hasattr(train_dataset, 'class_to_idx') else '0-3'}")
return train_loader, val_loader, class_names
@@ -158,10 +223,10 @@ def visualize_batch(dataloader, class_names, num_images=8):
if __name__ == '__main__':
train_loader, val_loader, class_names = create_dataloaders(
- data_root='..', # 与trash-division同级文件夹
- batch_size=32, # 根据你的显存调整
+ data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
+ batch_size=16, # 根据你的显存调整
image_size=256, # 与你模型输入一致
- num_workers=4, # Windows 可能需设为 0
+ num_workers=16, # Windows 可能需设为 0
augment=True # 训练时使用数据增强
)
visualize_batch(train_loader, class_names, num_images=8)
diff --git a/Evaluate.py b/Evaluate.py
new file mode 100644
index 0000000..7c78f1c
--- /dev/null
+++ b/Evaluate.py
@@ -0,0 +1,147 @@
+"""
+evaluate_and_plot.py
+加载模型,在验证集上推理,绘制混淆矩阵 / ROC / PR 曲线
+"""
+
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+
+import torch
+from torch.utils.data import DataLoader
+from torchvision import transforms
+from sklearn.metrics import (
+ confusion_matrix, ConfusionMatrixDisplay,
+ roc_curve, auc,
+ precision_recall_curve, average_precision_score,
+)
+
+from Model import Net
+from Dataloader import RobustImageFolder
+
+matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
+matplotlib.rcParams['axes.unicode_minus'] = False
+
+# ============================================================
+# ★★★ 需要你修改的参数 ★★★
+# ============================================================
+MODEL_PATH = 'best_model.pth' # 模型权重路径
+DATA_ROOT = '../trash_division_data/ultimate_4_class/' # 数据集根目录
+BATCH_SIZE = 32
+IMAGE_SIZE = 256
+NUM_WORKERS = 4
+# ============================================================
+
+# ---------- 1. 加载验证集 ----------
+val_transform = transforms.Compose([
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+])
+
+val_dataset = RobustImageFolder(
+ root=os.path.join(DATA_ROOT, 'val'),
+ transform=val_transform,
+)
+val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
+ shuffle=False, num_workers=NUM_WORKERS,
+ pin_memory=True, drop_last=False)
+
+class_names = val_dataset.classes
+num_classes = len(class_names)
+print(f"类别: {class_names}")
+
+# ---------- 2. 加载模型 ----------
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+model = Net(num_classes=num_classes)
+state_dict = torch.load(MODEL_PATH, map_location=device)
+if 'model_state_dict' in state_dict:
+ state_dict = state_dict['model_state_dict']
+elif 'model' in state_dict:
+ state_dict = state_dict['model']
+model.load_state_dict(state_dict)
+model = model.to(device).eval()
+print("模型加载完成")
+
+# ---------- 3. 推理 ----------
+all_labels = []
+all_probs = []
+
+with torch.no_grad():
+ for images, labels in val_loader:
+ images = images.to(device)
+ probs = torch.softmax(model(images), dim=1)
+ all_labels.append(labels.numpy())
+ all_probs.append(probs.cpu().numpy())
+
+all_labels = np.concatenate(all_labels)
+all_probs = np.concatenate(all_probs)
+all_preds = np.argmax(all_probs, axis=1)
+print(f"推理完成, 共 {len(all_labels)} 样本")
+
+# ============================================================
+# ① 混淆矩阵
+# ============================================================
+cm = confusion_matrix(all_labels, all_preds)
+fig, ax = plt.subplots(figsize=(8, 7))
+ConfusionMatrixDisplay(cm, display_labels=class_names).plot(
+ ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
+ax.set_title('Confusion Matrix', fontsize=14)
+plt.tight_layout()
+plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
+plt.show()
+print("混淆矩阵已保存: confusion_matrix.png")
+
+# ============================================================
+# ② ROC 曲线 (One-vs-Rest + Macro-average)
+# ============================================================
+one_hot = np.eye(num_classes)[all_labels]
+colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
+
+fig, ax = plt.subplots(figsize=(8, 7))
+fpr_d, tpr_d, auc_d = {}, {}, {}
+
+for i in range(num_classes):
+ fpr_d[i], tpr_d[i], _ = roc_curve(one_hot[:, i], all_probs[:, i])
+ auc_d[i] = auc(fpr_d[i], tpr_d[i])
+ ax.plot(fpr_d[i], tpr_d[i], color=colors[i], lw=2,
+ label=f'{class_names[i]} (AUC={auc_d[i]:.4f})')
+
+# Macro-average
+all_fpr = np.unique(np.concatenate([fpr_d[i] for i in range(num_classes)]))
+mean_tpr = sum(np.interp(all_fpr, fpr_d[i], tpr_d[i]) for i in range(num_classes)) / num_classes
+macro_auc = auc(all_fpr, mean_tpr)
+ax.plot(all_fpr, mean_tpr, 'navy', lw=2, ls='--',
+ label=f'Macro-avg (AUC={macro_auc:.4f})')
+ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
+
+ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
+ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
+ax.set_title('ROC Curve', fontsize=14)
+ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)
+plt.tight_layout()
+plt.savefig('roc_curve.png', dpi=150, bbox_inches='tight')
+plt.show()
+print("ROC 曲线已保存: roc_curve.png")
+
+# ============================================================
+# ③ Precision-Recall 曲线
+# ============================================================
+fig, ax = plt.subplots(figsize=(8, 7))
+
+for i in range(num_classes):
+ prec, rec, _ = precision_recall_curve(one_hot[:, i], all_probs[:, i])
+ ap = average_precision_score(one_hot[:, i], all_probs[:, i])
+ ax.plot(rec, prec, color=colors[i], lw=2,
+ label=f'{class_names[i]} (AP={ap:.4f})')
+
+ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
+ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
+ax.set_title('Precision-Recall Curve', fontsize=14)
+ax.legend(loc='best'); ax.grid(True, alpha=0.3)
+plt.tight_layout()
+plt.savefig('pr_curve.png', dpi=150, bbox_inches='tight')
+plt.show()
+print("PR 曲线已保存: pr_curve.png")
diff --git a/Finetune.py b/Finetune.py
new file mode 100644
index 0000000..d70baf7
--- /dev/null
+++ b/Finetune.py
@@ -0,0 +1,225 @@
+"""
+微调脚本:冻结 conv1 + stage2,微调 stage3~fc
+加大少样本类别的 loss 权重
+author: yukun-hh
+date :2026-4-25
+"""
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+from Model import Net
+from Dataloader import create_dataloaders
+import os
+import csv
+
+
+def compute_macro_f1(predicted, targets, num_classes=4):
+ tp = torch.zeros(num_classes, device=predicted.device)
+ fp = torch.zeros(num_classes, device=predicted.device)
+ fn = torch.zeros(num_classes, device=predicted.device)
+ for c in range(num_classes):
+ tp[c] = ((predicted == c) & (targets == c)).sum()
+ fp[c] = ((predicted == c) & (targets != c)).sum()
+ fn[c] = ((predicted != c) & (targets == c)).sum()
+ precision = tp / (tp + fp + 1e-8)
+ recall = tp / (tp + fn + 1e-8)
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
+ return f1.mean().item()
+
+
+def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
+ model.train()
+ running_loss = 0.0
+ correct = 0
+ total = 0
+ all_preds, all_labels = [], []
+
+ pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]')
+
+ for images, labels in pbar:
+ images, labels = images.to(device), labels.to(device)
+
+ outputs = model(images)
+ loss = criterion(outputs, labels)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ running_loss += loss.item() * images.size(0)
+ _, predicted = outputs.max(1)
+ total += labels.size(0)
+ correct += predicted.eq(labels).sum().item()
+ all_preds.append(predicted)
+ all_labels.append(labels)
+
+ batch_f1 = compute_macro_f1(predicted, labels)
+ pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}', 'Acc': f'{100. * correct / total:.2f}%'})
+
+ epoch_loss = running_loss / total
+ epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
+ epoch_acc = 100. * correct / total
+ return epoch_loss, epoch_f1, epoch_acc
+
+
+def validate(model, val_loader, criterion, device):
+ model.eval()
+ running_loss = 0.0
+ correct = 0
+ total = 0
+ all_preds, all_labels = [], []
+
+ with torch.no_grad():
+ for images, labels in tqdm(val_loader, desc='[Validate]'):
+ images, labels = images.to(device), labels.to(device)
+
+ outputs = model(images)
+ loss = criterion(outputs, labels)
+
+ running_loss += loss.item() * images.size(0)
+ _, predicted = outputs.max(1)
+ total += labels.size(0)
+ correct += predicted.eq(labels).sum().item()
+ all_preds.append(predicted)
+ all_labels.append(labels)
+
+ epoch_loss = running_loss / total
+ epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
+ epoch_acc = 100. * correct / total
+ return epoch_loss, epoch_f1, epoch_acc
+
+
+def compute_class_weights(dataset, num_classes=4, device='cpu', power=1.0):
+ class_counts = torch.zeros(num_classes)
+ for _, label in dataset.samples:
+ lbl = label.item() if isinstance(label, torch.Tensor) else label
+ class_counts[lbl] += 1
+ total = class_counts.sum()
+ weights = total / (num_classes * class_counts)
+ weights = weights ** power
+ return weights.to(device)
+
+
+def freeze_base_layers(model):
+ frozen_layers = []
+ for name, param in model.conv1.named_parameters():
+ param.requires_grad = False
+ frozen_layers.append(f'conv1.{name}')
+ for name, param in model.bn1.named_parameters():
+ param.requires_grad = False
+ frozen_layers.append(f'bn1.{name}')
+ for name, param in model.stage2.named_parameters():
+ param.requires_grad = False
+ frozen_layers.append(f'stage2.{name}')
+ for name, param in model.stage3.named_parameters():
+ param.requires_grad = False
+ frozen_layers.append(f'stage3.{name}')
+
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ total = sum(p.numel() for p in model.parameters())
+ print(f'冻结层数: {len(frozen_layers)} 个参数组')
+ print(f'可训练参数量: {trainable:,} / {total:,} ({100. * trainable / total:.1f}%)')
+ return model
+
+
+def finetune(model, train_loader, val_loader, epochs=30, lr=0.0001, device='cuda'):
+ class_weights = compute_class_weights(train_loader.dataset, num_classes=4, device=device, power=1.5)
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
+
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, model.parameters()),
+ lr=lr, momentum=0.9, weight_decay=1e-4
+ )
+
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+
+ history = {
+ 'train_loss': [],
+ 'train_f1': [],
+ 'train_acc': [],
+ 'val_loss': [],
+ 'val_f1': [],
+ 'val_acc': []
+ }
+
+ best_val_f1 = 0.0
+ log_file = open('finetune_log.csv', 'w', newline='')
+ log_writer = csv.writer(log_file)
+ log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc', 'val_loss', 'val_f1', 'val_acc', 'lr', 'best'])
+
+ for epoch in range(epochs):
+ print(f'\n{"=" * 50}')
+ print(f'Epoch {epoch + 1}/{epochs}')
+
+ train_loss, train_f1, train_acc = train_one_epoch(model, train_loader, criterion,
+ optimizer, device, epoch)
+
+ val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device)
+
+ scheduler.step()
+
+ history['train_loss'].append(train_loss)
+ history['train_f1'].append(train_f1)
+ history['train_acc'].append(train_acc)
+ history['val_loss'].append(val_loss)
+ history['val_f1'].append(val_f1)
+ history['val_acc'].append(val_acc)
+
+ print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}')
+ print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}')
+ print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
+
+ best_mark = ''
+ if val_f1 > best_val_f1:
+ best_val_f1 = val_f1
+ torch.save(model.state_dict(), 'finetuned_model.pth')
+ best_mark = 'best'
+ print(f'✓ 保存最佳微调模型 (Macro-F1: {val_f1:.4f})')
+
+ lr = optimizer.param_groups[0]['lr']
+ log_writer.writerow([epoch + 1, train_loss, train_f1, train_acc, val_loss, val_f1, val_acc, lr, best_mark])
+ log_file.flush()
+
+ print(f'\n{"=" * 50}')
+ print(f'微调完成!最佳验证 Macro-F1: {best_val_f1:.4f}')
+
+ return model, history
+
+
+if __name__ == '__main__':
+ train_loader, val_loader, class_names = create_dataloaders(
+ data_root='../trash_division_data/ultimate_4_class/',
+ batch_size=16,
+ image_size=256,
+ num_workers=8,
+ augment=True
+ )
+
+ device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
+
+ model = Net(num_classes=4)
+
+ if os.path.exists('best_model.pth'):
+ model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
+ print('✓ 加载预训练权重 best_model.pth')
+ else:
+ print('⚠ 未找到 best_model.pth,使用随机初始化权重')
+
+ model = model.to(device)
+ model = freeze_base_layers(model)
+
+ print(f'Device: {device}')
+ print(f'Total parameters: {sum(p.numel() for p in model.parameters()):,}')
+
+ trained_model, history = finetune(
+ model=model,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ epochs=30,
+ lr=0.0001,
+ device=device
+ )
+
+ model.load_state_dict(torch.load('finetuned_model.pth'))
diff --git a/merge_classes.py b/Merge_classes.py
similarity index 90%
rename from merge_classes.py
rename to Merge_classes.py
index a605aa2..145c345 100644
--- a/merge_classes.py
+++ b/Merge_classes.py
@@ -1,5 +1,5 @@
"""将原数据集合并为我们需要的四个大类
- 运行时先配置路径
+ 已修改成相对路径 具体配置方法详见README.md
author:
weikaiwen
@@ -18,9 +18,9 @@ import shutil
# ================= 1. 配置你的路径 =================
# 注意:请确保相对路径正确,以下为示例
-ORIGINAL_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data' # 原始数据集的目录
-NEW_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data/ultimate_4_class' # 合并后的新目录
-CLASSNAME_FILE = '/Users/weikaiwen/Desktop/trash_division_data/val/classname.txt' # txt 文件的位置
+ORIGINAL_DATA_DIR = '../trash_division_data' # 原始数据集的目录
+NEW_DATA_DIR = '../trash_division_data/ultimate_4_class' # 合并后的新目录
+CLASSNAME_FILE = '../trash_division_data/val/classname.txt' # txt 文件的位置
# ===================================================
diff --git a/Model.py b/Model.py
index d568b97..bf1a714 100644
--- a/Model.py
+++ b/Model.py
@@ -1,99 +1,113 @@
"""
-这个文件是模型的定义文件,请不要擅自修改,如有疑问微信群里反馈
-单独运行本文件将会输出模型结构
-目前的话是一个36层的模型,模型总量应该是在80M左右 如果到时候还是欠拟合的话再考虑去做更深的结构
+模型定义文件 - ResNet-34
author : yukun-hh
date : 2026-4-10
-
"""
-#神经网络模型库
import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
-#残差块
-class Resblock(nn.Module):
- def __init__(self, input_channels,output_channels,use_1x1conv=False,strides=1):
- """
- :param input_channels: 进入残差块时的原通道
- :param output_channels: 输出时的通道数
- :param use_1x1conv: 如果输入和输出通道不相等时,需要用一个1x1的卷积层对原来的输入进行一个通道提升
- :param strides: 默认1,如果大于1起到缩小张量的作用
- """
+
+class BasicBlock(nn.Module):
+ """
+ ResNet-34 基础残差块:3x3 -> 3x3
+ 若需要下采样或通道变化,则在跳跃连接中使用 1x1 卷积
+ """
+ expansion = 1
+
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__()
- self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=3,padding=1,stride=strides)
- self.conv2 = nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1,stride=1)
- if use_1x1conv:
- self.conv3 = nn.Conv2d(input_channels, output_channels,kernel_size=1, stride=strides)
- else:
- self.conv3 = None
- self.bn1 = nn.BatchNorm2d(output_channels)
- self.bn2 = nn.BatchNorm2d(output_channels)
- def forward(self,X):
- Y = F.relu(self.bn1(self.conv1(X)))
- Y = self.bn2(self.conv2(Y))
- if self.conv3 is not None:
- X = self.conv3(X)
- Y += X
- return F.relu(Y)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
-class Net():
- """
- 模型的主要结构就在这里了,到时也好该和调用
- 现在必须实现的方法:
- 目前还是以图片缩放到256*256构建残差块
- """
- net = nn.Sequential()
- def resnet_block(self,input_channels, num_channels, num_residuals,
- first_block=False):
- """
- :param input_channels: 输入维度
- :param num_channels: 输出维度
- :param num_residuals: 单个残差层的残差块数
- :param first_block: 第一块不用下采样 特殊控制
- :return: list[nn.Module]
- """
- blk = []
+ def forward(self, x):
+ identity = x
- for i in range(num_residuals):
- if i == 0 and not first_block:
- blk.append(Resblock(input_channels, num_channels,
- use_1x1conv=True, strides=2))
- else:
- blk.append(Resblock(num_channels, num_channels))
- return blk
- def __init__(self):
- b1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
- nn.BatchNorm2d(64), nn.ReLU(),
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- )
- """
- 7×7 卷积层,输出通道 64,步长 2,填充 3
- (3×256×256)->(64×128×128)
- 批归一化 relu层
- 最大池化
- (64×128×128)->(64×64×64)
- """
- b2 = nn.Sequential(*self.resnet_block(64, 64, num_residuals=3, first_block=True))
- b3 = nn.Sequential(*self.resnet_block(64, 128, num_residuals=4))
- b4 = nn.Sequential(*self.resnet_block(128, 256, num_residuals=6))
- b5 = nn.Sequential(*self.resnet_block(256, 512, num_residuals=3))
- self.net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 4))
- def get_network(self):
- return self.net
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class Net(nn.Module):
+
+ def __init__(self, num_classes=4, dropout=0.5):
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ layers_config = [
+ (3, 64, 1), # layer1
+ (4, 128, 2), # layer2
+ (6, 256, 2), # layer3
+ (3, 512, 2), # layer4
+ ]
+
+ self.in_channels = 64
+ self.layer1 = self._make_layer(layers_config[0])
+ self.layer2 = self._make_layer(layers_config[1])
+ self.layer3 = self._make_layer(layers_config[2])
+ self.layer4 = self._make_layer(layers_config[3])
+
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(512, num_classes)
+
+ def _make_layer(self, config):
+ num_blocks, out_channels, stride = config
+ downsample = None
+ layers = []
+
+ if stride != 1 or self.in_channels != out_channels:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.in_channels, out_channels,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_channels),
+ )
+
+ layers.append(BasicBlock(self.in_channels, out_channels, stride, downsample))
+ self.in_channels = out_channels
+
+ for _ in range(1, num_blocks):
+ layers.append(BasicBlock(self.in_channels, out_channels))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x)
+ return x
if __name__ == '__main__':
- Net_new = Net()
- X = torch.rand(size=(1, 3, 256, 256))
- summary(Net_new.get_network(), input_size=(3, 256, 256))
-
-
-
-
-
-
-
-
+ model = Net(num_classes=4)
+ summary(model, input_size=(3, 256, 256))
diff --git a/README.md b/README.md
index eacd84d..10b86e6 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,249 @@
# trash-division
-## 一个基于卷积神经网络的垃圾分类识别系统
-### 同济大学python人工智能程序设计课程小组作业
+一个基于卷积神经网络的垃圾分类识别系统
+> 同济大学 Python 人工智能程序设计课程小组作业
+
+基于 ResNet-34 架构的 CNN 模型(约 21M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。
+
+---
+
+## 目录
+
+- [项目特点](#项目特点)
+- [模型架构](#模型架构)
+- [数据集](#数据集)
+- [环境要求](#环境要求)
+- [快速开始](#快速开始)
+- [文件说明](#文件说明)
+- [目录结构](#目录结构)
+- [训练细节](#训练细节)
+- [评估与可视化](#评估与可视化)
+- [许可证](#许可证)
+
+---
+
+## 项目特点
+
+- **四类垃圾分类**:厨余垃圾(1)、可回收物(2)、其他垃圾(3)、有害垃圾(4)
+- **ResNet-34 架构**:约 21M 参数,34 层深度残差网络,含 Dropout 正则化
+- **数据增强**:训练时使用随机裁剪、水平翻转、旋转、色彩抖动
+- **Macro-F1 评估**:采用宏平均 F1 分数作为主要评估指标,兼顾各类别表现
+- **类别加权损失**:自动计算类别权重,缓解类别不平衡问题
+- **余弦退火学习率调度**:使用 CosineAnnealingLR 平滑调整学习率
+- **断点续训**:自动检测 `best_model.pth` 并加载继续训练
+- **多设备支持**:自动选择 CUDA > Intel XPU > CPU
+
+## 模型架构
+
+模型基于标准 ResNet-34 架构,使用 BasicBlock 构建。
+
+### BasicBlock 块
+
+每个 BasicBlock 包含两个 3x3 卷积层 + 跳跃连接:
+
+| 层 | 卷积 | 作用 |
+|---|---|---|
+| 3x3 Conv | 特征提取 | 第一层卷积 |
+| 3x3 Conv | 特征提取 | 第二层卷积 |
+
+### 网络结构
+
+| 阶段 | 块数 | 输出通道数 | 说明 |
+|---|---|---|---|
+| 初始层 | - | 64 | 7x7 Conv, stride=2 + MaxPool |
+| Layer1 | 3 | 64 | 第一个残差阶段 |
+| Layer2 | 4 | 128 | - |
+| Layer3 | 6 | 256 | - |
+| Layer4 | 3 | 512 | 最终残差阶段 |
+| 分类头 | - | 4 | 全局平均池化 + Dropout + 全连接层 |
+
+## 数据集
+
+本项目使用 [tany0699/garbage265](https://modelscope.cn/datasets/tany0699/garbage265) 中文生活垃圾分类数据集,包含 265 个子类别的生活垃圾图片。
+
+通过 `Merge_classes.py` 脚本将 265 个子类别合并为 4 个顶级类别:
+
+```
+厨余垃圾 -> 1
+可回收物 -> 2
+其他垃圾 -> 3
+有害垃圾 -> 4
+```
+
+数据集预期放置在 `../trash_division_data/`(与项目根目录平级的兄弟目录)。
+
+## 环境要求
+
+```bash
+pip install -r requirements.txt
+```
+
+> **注意**:`requirements.txt` 不锁定 PyTorch 的 CUDA / XPU 版本,请根据硬件自行安装对应版本,例如:
+> - NVIDIA GPU:`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121`
+> - Intel GPU (XPU):`pip install torch torchvision --index-url https://download.pytorch.org/whl/xpu` 安装
+> - CPU:`pip install torch torchvision` 即可
+
+## 快速开始
+
+1. **数据预处理**:将 265 个子类别合并为 4 个顶级类别
+
+ ```bash
+ python Merge_classes.py
+ ```
+
+2. **训练模型**:
+
+ ```bash
+ python Train.py
+ ```
+
+3. **微调模型**(可选,冻结浅层、微调深层):
+
+ ```bash
+ python Finetune.py
+ ```
+
+4. **评估与可视化**:
+
+ ```bash
+ python Evaluate.py # 混淆矩阵、ROC 曲线、PR 曲线
+ python Curve.py # 训练过程的 loss/f1/acc/lr 曲线
+ python baseline/compare_models.py # 多模型基线对比(ROC/PR/准确率)
+ ```
+
+> **注意**:
+> - 数据目录默认为 `../trash_division_data/ultimate_4_class/`,需先运行合并脚本
+> - Windows 系统需将 `num_workers` 设为 `0`(参见 `Dataloader.py` 和 `Train.py`)
+> - 训练会自动从 `best_model.pth` 断点续训(若存在)
+
+## 文件说明
+
+| 文件 | 功能 |
+|---|---|
+| `Train.py` | 训练主脚本,包含训练循环、验证、评估 |
+| `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 |
+| `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 |
+| `Model.py` | 模型定义,ResNet-34(BasicBlock)+ Dropout |
+| `Merge_classes.py` | 数据集预处理,265 类合并为 4 类 |
+| `Evaluate.py` | 模型评估,绘制混淆矩阵、ROC 曲线、PR 曲线 |
+| `Curve.py` | 训练曲线绘制,从 CSV 读取并绘制 loss/f1/acc/lr 曲线 |
+| `baseline/VGG_KNN.py` | VGG16 预训练特征提取 + KNN 四分类基线 |
+| `baseline/ResNet34_Pretrained_10pct.py` | ResNet-34 ImageNet 预训练 + 10% 数据微调 |
+| `baseline/HOG_Baseline.py` | HOG + 颜色直方图 + LogisticRegression(纯传统 CV) |
+| `baseline/compare_models.py` | 多模型对比(ROC / PR 曲线 + 准确率柱状图) |
+| `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr |
+| `best_model.pth` | 训练好的最佳模型权重(约 125 MB,不纳入版本控制) |
+| `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 |
+
+## 目录结构
+
+```
+trash-division/
+├── baseline/ # 基线模型目录
+│ ├── VGG_KNN.py # VGG16 + KNN 分类脚本
+│ ├── ResNet34_Pretrained_10pct.py # ResNet-34 ImageNet 预训练 + 10% 微调
+│ ├── HOG_Baseline.py # HOG + LogisticRegression 纯传统基线
+│ ├── compare_models.py # 多模型对比脚本
+│ ├── roc_comparison.png # 多模型 ROC 对比(compare_models.py 输出)
+│ ├── pr_comparison.png # 多模型 PR 对比(compare_models.py 输出)
+│ └── accuracy_bar.png # 多模型准确率对比(compare_models.py 输出)
+├── best_model.pth # 最佳模型权重(不纳入版本控制)
+├── Curve.py # 训练曲线绘制脚本
+├── Dataloader.py # 数据加载模块
+├── Evaluate.py # 模型评估可视化脚本
+├── Finetune.py # 微调脚本
+├── .gitattributes # Git 属性配置
+├── LICENSE # MIT 许可证
+├── Merge_classes.py # 数据集预处理脚本
+├── Model.py # 模型定义
+├── README.md # 项目说明(本文件)
+├── THIRD_PARTY_LICENSES.md # 第三方许可证声明
+├── Train.py # 训练主脚本
+├── training_log.csv # 训练日志
+├── confusion_matrix.png # 混淆矩阵(Evaluate.py 输出)
+├── roc_curve.png # ROC 曲线(Evaluate.py 输出)
+├── pr_curve.png # PR 曲线(Evaluate.py 输出)
+└── training_curves.png # 训练曲线(Curve.py 输出)
+```
+
+## 训练细节
+
+| 配置项 | 说明 |
+|---|---|
+| 输入尺寸 | 256 x 256 RGB |
+| 优化器 | SGD(momentum=0.9, weight_decay=1e-4) |
+| 初始学习率 | 0.001 |
+| 学习率调度 | CosineAnnealingLR |
+| 损失函数 | 类别加权 CrossEntropyLoss |
+| 评估指标 | Macro-F1(宏平均 F1 分数) |
+| 批量大小 | 默认 16(可通过参数调整) |
+| 训练轮数 | 默认 20(可通过参数调整) |
+| 设备选择优先级 | CUDA > Intel XPU > CPU |
+| 断点续训 | 自动检测 best_model.pth 并加载 |
+
+训练时数据增强管线:RandomResizedCrop(256, scale=(0.8, 1.0)) + RandomHorizontalFlip(p=0.5) + RandomRotation(+-15 deg) + ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)
+
+## 评估与可视化
+
+训练完成后,`training_log.csv` 会记录每个 epoch 的训练/验证指标。以下两个脚本用于可视化分析:
+
+### Evaluate.py — 模型评估
+
+在验证集上推理,生成三张评估图表:
+
+```bash
+python Evaluate.py
+```
+
+脚本顶部的 `MODEL_PATH`、`DATA_ROOT`、`BATCH_SIZE`、`NUM_WORKERS` 可按需修改。
+
+**混淆矩阵**
+
+
+
+**ROC 曲线**
+
+
+
+**PR 曲线**
+
+
+
+### Curve.py — 训练曲线
+
+从 `training_log.csv` 读取训练日志,绘制四张子图:
+
+```bash
+python Curve.py
+```
+
+
+
+### 基线模型对比
+
+`compare_models.py` 对所有模型在验证集上统一评估,生成三张对比图表:
+
+```bash
+python baseline/compare_models.py
+```
+
+对比阵容:ResNet-34、ResNet-34 (10% Fine-tune)、VGG16 + KNN、HOG + LogisticRegression。
+
+**ROC 曲线对比**
+
+
+
+**PR 曲线对比**
+
+
+
+**准确率柱状图**
+
+
## 许可证
本项目主代码采用 [MIT 许可证](LICENSE)。
本项目包含的数据集 `tany0699/garbage265` 采用 [Apache License 2.0](THIRD_PARTY_LICENSES.md),详情请参阅 `THIRD_PARTY_LICENSES.md` 文件。
-
-
-
diff --git a/Train.py b/Train.py
index 6c87c69..528d2a4 100644
--- a/Train.py
+++ b/Train.py
@@ -12,52 +12,70 @@ import torch.optim as optim
from tqdm import tqdm # 进度条,可选
import matplotlib.pyplot as plt
from Model import Net
+from Dataloader import create_dataloaders
+import os
+import csv
+
+
+def compute_macro_f1(predicted, targets, num_classes=4):
+ tp = torch.zeros(num_classes, device=predicted.device)
+ fp = torch.zeros(num_classes, device=predicted.device)
+ fn = torch.zeros(num_classes, device=predicted.device)
+ for c in range(num_classes):
+ tp[c] = ((predicted == c) & (targets == c)).sum()
+ fp[c] = ((predicted == c) & (targets != c)).sum()
+ fn[c] = ((predicted != c) & (targets == c)).sum()
+ precision = tp / (tp + fp + 1e-8)
+ recall = tp / (tp + fn + 1e-8)
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
+ return f1.mean().item()
+
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
"""训练一个epoch"""
- model.train() # 设置为训练模式
+ model.train()
running_loss = 0.0
correct = 0
total = 0
+ all_preds, all_labels = [], []
- # 使用 tqdm 显示进度条(可选)
pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]')
for images, labels in pbar:
- # 将数据移到 GPU/CPU
images, labels = images.to(device), labels.to(device)
- # 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
- # 反向传播
- optimizer.zero_grad() # 清空梯度
- loss.backward() # 计算梯度
- optimizer.step() # 更新参数
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
- # 统计
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
+ all_preds.append(predicted)
+ all_labels.append(labels)
- # 更新进度条信息
- pbar.set_postfix({'loss': loss.item(), 'acc': 100. * correct / total})
+ batch_f1 = compute_macro_f1(predicted, labels)
+ pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}', 'Acc': f'{100. * correct / total:.2f}%'})
epoch_loss = running_loss / total
+ epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total
- return epoch_loss, epoch_acc
+ return epoch_loss, epoch_f1, epoch_acc
def validate(model, val_loader, criterion, device):
"""验证函数"""
- model.eval() # 设置为评估模式
+ model.eval()
running_loss = 0.0
correct = 0
total = 0
+ all_preds, all_labels = [], []
- with torch.no_grad(): # 不计算梯度,节省内存
+ with torch.no_grad():
for images, labels in tqdm(val_loader, desc='[Validate]'):
images, labels = images.to(device), labels.to(device)
@@ -68,37 +86,52 @@ def validate(model, val_loader, criterion, device):
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
+ all_preds.append(predicted)
+ all_labels.append(labels)
epoch_loss = running_loss / total
+ epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total
- return epoch_loss, epoch_acc
+ return epoch_loss, epoch_f1, epoch_acc
+
+
+def compute_class_weights(dataset, num_classes=4, device='cpu'):
+ class_counts = torch.zeros(num_classes)
+ for _, label in dataset.samples:
+ lbl = label.item() if isinstance(label, torch.Tensor) else label
+ class_counts[lbl] += 1
+ total = class_counts.sum()
+ weights = total / (num_classes * class_counts)
+ return weights.to(device)
def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'):
"""主训练函数"""
# 1. 定义损失函数和优化器
- criterion = nn.CrossEntropyLoss() # 多分类用交叉熵
+ class_weights = compute_class_weights(train_loader.dataset, num_classes=4, device=device)
+ criterion = nn.CrossEntropyLoss(weight=class_weights) # 多分类用交叉熵
- # 优化器选择(推荐 Adam 或 SGD)
- optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
# 或者使用 SGD + 动量
- # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
+ optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
# 学习率调度器(可选,帮助收敛)
- scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
- # 或者用余弦退火
- # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# 2. 记录训练历史
history = {
'train_loss': [],
+ 'train_f1': [],
'train_acc': [],
'val_loss': [],
+ 'val_f1': [],
'val_acc': []
}
- best_val_acc = 0.0
+ best_val_f1 = 0.0
+ log_file = open('training_log.csv', 'w', newline='')
+ log_writer = csv.writer(log_file)
+ log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc', 'val_loss', 'val_f1', 'val_acc', 'lr', 'best'])
# 3. 开始训练
for epoch in range(epochs):
@@ -106,77 +139,65 @@ def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'):
print(f'Epoch {epoch + 1}/{epochs}')
# 训练
- train_loss, train_acc = train_one_epoch(model, train_loader, criterion,
- optimizer, device, epoch)
+ train_loss, train_f1, train_acc = train_one_epoch(model, train_loader, criterion,
+ optimizer, device, epoch)
# 验证
- val_loss, val_acc = validate(model, val_loader, criterion, device)
+ val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device)
# 更新学习率
scheduler.step()
# 记录
history['train_loss'].append(train_loss)
+ history['train_f1'].append(train_f1)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
+ history['val_f1'].append(val_f1)
history['val_acc'].append(val_acc)
# 打印结果
- print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
- print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
+ print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}')
+ print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}')
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
# 保存最佳模型
- if val_acc > best_val_acc:
- best_val_acc = val_acc
+ best_mark = ''
+ if val_f1 > best_val_f1:
+ best_val_f1 = val_f1
torch.save(model.state_dict(), 'best_model.pth')
- print(f'✓ 保存最佳模型 (Acc: {val_acc:.2f}%)')
+ best_mark = 'best'
+ print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})')
+
+ lr = optimizer.param_groups[0]['lr']
+ log_writer.writerow([epoch + 1, train_loss, train_f1, train_acc, val_loss, val_f1, val_acc, lr, best_mark])
+ log_file.flush()
# 4. 绘制训练曲线
- plot_training_history(history)
print(f'\n{"=" * 50}')
- print(f'训练完成!最佳验证准确率: {best_val_acc:.2f}%')
+ print(f'训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}')
return model, history
-def plot_training_history(history):
- """绘制训练曲线"""
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
-
- # 损失曲线
- ax1.plot(history['train_loss'], label='Train Loss')
- ax1.plot(history['val_loss'], label='Val Loss')
- ax1.set_xlabel('Epoch')
- ax1.set_ylabel('Loss')
- ax1.set_title('Training and Validation Loss')
- ax1.legend()
- ax1.grid(True)
-
- # 准确率曲线
- ax2.plot(history['train_acc'], label='Train Acc')
- ax2.plot(history['val_acc'], label='Val Acc')
- ax2.set_xlabel('Epoch')
- ax2.set_ylabel('Accuracy (%)')
- ax2.set_title('Training and Validation Accuracy')
- ax2.legend()
- ax2.grid(True)
-
- plt.tight_layout()
- plt.savefig('training_history.png', dpi=150)
- plt.show()
-
-
# ========== 使用示例 ==========
if __name__ == '__main__':
# 假设你的 dataloader 已经写好了
- # train_loader = ...
- # val_loader = ...
+ train_loader, val_loader, class_names = create_dataloaders(
+ data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
+ batch_size=16, # 根据你的显存调整
+ image_size=256, # 与你模型输入一致
+ num_workers=8, # Windows 可能需设为 0
+ augment=True # 训练时使用数据增强
+ )
# 1. 创建模型
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- model = Net().get_network() # 根据你的 Net 类调整
+ device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
+ model = Net(num_classes=4) # 根据你的 Net 类调整
+ #断点继续训练
+ if os.path.exists('best_model.pth'):
+ model.load_state_dict(torch.load('best_model.pth',map_location=torch.device('cpu')))
model = model.to(device)
# 打印模型信息
@@ -188,11 +209,9 @@ if __name__ == '__main__':
model=model,
train_loader=train_loader,
val_loader=val_loader,
- epochs=50,
+ epochs=20,
lr=0.001,
device=device
)
-
# 3. 加载最佳模型用于预测
- model.load_state_dict(torch.load('best_model.pth'))
- print('训练完成,最佳模型已加载')
\ No newline at end of file
+ model.load_state_dict(torch.load('best_model.pth'))
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..0897466
--- /dev/null
+++ b/app.py
@@ -0,0 +1,95 @@
+import gradio as gr
+import torch
+from torchvision import transforms
+from PIL import Image
+from Model import Net # 根据上传的 Model.py,模型类名为 Net
+
+# 1. 基础配置与类别映射
+# 根据 Merge_classes.py,1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾
+class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
+
+# 设备自动选择逻辑,保持与 Train.py 和 Evaluate.py 一致
+device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
+print(f"当前使用的推理设备: {device}")
+
+# 2. 初始化模型并加载最佳权重
+model = Net(num_classes=4)
+try:
+ # 采用与 Evaluate.py 一致的健壮加载方式
+ state_dict = torch.load('best_model.pth', map_location=device)
+ if 'model_state_dict' in state_dict:
+ state_dict = state_dict['model_state_dict']
+ elif 'model' in state_dict:
+ state_dict = state_dict['model']
+
+ model.load_state_dict(state_dict)
+ model = model.to(device).eval()
+ print("✅ 成功加载 best_model.pth 权重")
+except Exception as e:
+ print(f"⚠️ 模型加载失败,请确保目录下存在 best_model.pth。错误信息: {e}")
+
+# 3. 定义数据预处理流程 (必须与 Evaluate.py 中的 val_transform 保持完全一致)
+transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+])
+
+# 4. 核心推理函数
+def predict(image):
+ if image is None:
+ return None
+
+ # Gradio 传入的 pil 图像,确保转为 RGB 格式
+ image = image.convert('RGB')
+
+ # 预处理并增加 batch 维度
+ input_tensor = transform(image).unsqueeze(0).to(device)
+
+ with torch.no_grad():
+ outputs = model(input_tensor)
+ # 使用 Softmax 将 logits 转换为 0~1 的概率分布
+ probabilities = torch.softmax(outputs, dim=1)[0]
+
+ # 组装为 Gradio Label 组件需要的字典格式 { "类别名": 概率值 }
+ result_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
+ return result_dict
+
+# 5. 构建与美化 Gradio 界面
+with gr.Blocks(theme=gr.themes.Soft(), title="Trash Division 垃圾分类识别") as demo:
+ gr.Markdown(
+ """
+
+
🗑️ Trash Division - 智能垃圾分类系统
+
基于 ResNet-34 架构,支持精准识别:厨余垃圾、可回收物、其他垃圾、有害垃圾。
+
同济大学 Python 人工智能程序设计课程小组作业
+
+ """
+ )
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ # type="pil" 让 Gradio 直接传 PIL Image 对象给预测函数,配合 torchvision 最方便
+ image_input = gr.Image(type="pil", label="上传垃圾图片 (支持拍照)")
+ with gr.Row():
+ clear_btn = gr.Button("清空", variant="secondary")
+ submit_btn = gr.Button("开始识别", variant="primary")
+
+ with gr.Column(scale=1):
+ label_output = gr.Label(num_top_classes=4, label="预测结果与置信度")
+
+ # 绑定点击事件
+ submit_btn.click(fn=predict, inputs=image_input, outputs=label_output)
+ clear_btn.click(lambda: (None, None), inputs=None, outputs=[image_input, label_output])
+
+if __name__ == "__main__":
+ # 启动 Web 界面
+ demo.launch(
+ server_name="127.0.0.1",
+ server_port=7860,
+ share=False, # 如果你想生成临时公网链接分享给同学测试,改为 True
+ inbrowser=True # 运行后自动在默认浏览器中打开
+ )
\ No newline at end of file
diff --git a/baseline/HOG_Baseline.py b/baseline/HOG_Baseline.py
new file mode 100644
index 0000000..c72e405
--- /dev/null
+++ b/baseline/HOG_Baseline.py
@@ -0,0 +1,145 @@
+"""
+baseline/HOG_Baseline.py
+HOG + 颜色直方图特征提取 + LogisticRegression 四分类
+纯传统 CV/ML 基线,零神经网络依赖
+可独立运行,也可被 compare_models.py 导入
+author: yukun-hh
+date: 2026-5-14
+"""
+import sys, os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+import matplotlib
+
+from skimage.feature import hog
+from sklearn.linear_model import LogisticRegression
+from sklearn.metrics import (
+ accuracy_score, f1_score,
+ confusion_matrix, ConfusionMatrixDisplay,
+ classification_report,
+)
+
+matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
+matplotlib.rcParams['axes.unicode_minus'] = False
+
+# ============================================================
+# ★★★ 可配置参数 ★★★
+# ============================================================
+DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
+IMAGE_SIZE = 128
+HOG_ORIENTATIONS = 9
+HOG_PIXELS_PER_CELL = (8, 8)
+HOG_CELLS_PER_BLOCK = (2, 2)
+COLOR_BINS = 32
+# ============================================================
+
+CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
+NUM_CLASSES = 4
+
+
+def extract_hog_color(image):
+ img = image.convert('RGB').resize((IMAGE_SIZE, IMAGE_SIZE))
+ arr = np.array(img, dtype=np.float64) / 255.0
+
+ hog_feat = hog(arr, orientations=HOG_ORIENTATIONS,
+ pixels_per_cell=HOG_PIXELS_PER_CELL,
+ cells_per_block=HOG_CELLS_PER_BLOCK,
+ channel_axis=2, feature_vector=True)
+
+ color_feat = []
+ for c in range(3):
+ hist, _ = np.histogram(arr[:, :, c], bins=COLOR_BINS, range=(0, 1))
+ color_feat.append(hist)
+ color_feat = np.concatenate(color_feat)
+
+ return np.concatenate([hog_feat, color_feat])
+
+
+class HOGLRBaseline:
+ def __init__(self, data_root=DATA_ROOT, image_size=IMAGE_SIZE):
+ self.data_root = data_root
+ self.image_size = image_size
+ self.clf = LogisticRegression(
+ C=1.0, max_iter=1000, solver='lbfgs', n_jobs=-1,
+ )
+
+ def _load_data(self, split):
+ dir_path = os.path.join(self.data_root, split)
+ features, labels = [], []
+ for class_id in range(1, NUM_CLASSES + 1):
+ class_dir = os.path.join(dir_path, str(class_id))
+ if not os.path.isdir(class_dir):
+ continue
+ files = sorted(os.listdir(class_dir))
+ for fname in tqdm(files, desc=f'{split}/class_{class_id}'):
+ fpath = os.path.join(class_dir, fname)
+ try:
+ with Image.open(fpath) as img:
+ feat = extract_hog_color(img)
+ features.append(feat)
+ labels.append(class_id - 1)
+ except Exception:
+ pass
+ print(f" {split}: {len(features)} 张")
+ return np.array(features, dtype=np.float32), np.array(labels)
+
+ def fit(self, train_dir=None):
+ if train_dir is None:
+ train_dir = 'train'
+ print(" 提取训练集 HOG 特征 ...")
+ X, y = self._load_data(train_dir)
+ self.clf.fit(X, y)
+
+ def predict(self, val_dir=None):
+ if val_dir is None:
+ val_dir = 'val'
+ print(" 提取验证集 HOG 特征 ...")
+ X, y = self._load_data(val_dir)
+ preds = self.clf.predict(X)
+ probs = self.clf.predict_proba(X)
+ return y, preds, probs
+
+
+# ============================================================
+# compare_models.py 导入接口
+# ============================================================
+
+def get_hog_lr_preds(train_loader, val_loader, device):
+ baseline = HOGLRBaseline()
+ baseline.fit('train')
+ return baseline.predict('val')
+
+
+# ============================================================
+# 独立运行入口
+# ============================================================
+
+if __name__ == '__main__':
+ out_dir = os.path.dirname(os.path.abspath(__file__))
+
+ print("HOG + LogisticRegression 基线")
+ baseline = HOGLRBaseline()
+
+ baseline.fit('train')
+ y_true, y_preds, y_probs = baseline.predict('val')
+
+ acc = accuracy_score(y_true, y_preds)
+ macro_f1 = f1_score(y_true, y_preds, average='macro')
+ print(f"\n验证集 Accuracy: {acc:.4f}")
+ print(f"验证集 Macro-F1: {macro_f1:.4f}")
+ print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}")
+
+ cm = confusion_matrix(y_true, y_preds)
+ fig, ax = plt.subplots(figsize=(8, 7))
+ ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
+ ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
+ ax.set_title('HOG + LogisticRegression 混淆矩阵', fontsize=14)
+ plt.tight_layout()
+ cm_path = os.path.join(out_dir, 'hog_lr_confusion_matrix.png')
+ plt.savefig(cm_path, dpi=150, bbox_inches='tight')
+ plt.show()
+ print(f"混淆矩阵已保存: {cm_path}")
diff --git a/baseline/ResNet34_Pretrained_10pct.py b/baseline/ResNet34_Pretrained_10pct.py
new file mode 100644
index 0000000..b7988a8
--- /dev/null
+++ b/baseline/ResNet34_Pretrained_10pct.py
@@ -0,0 +1,278 @@
+"""
+baseline/ResNet34_Pretrained_10pct.py
+ResNet-34 ImageNet 预训练权重 + 10% 训练集微调
+可独立运行训练,也可被 compare_models.py 导入
+author: yukun-hh
+date: 2026-5-14
+"""
+import sys, os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import random
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, Subset
+from torchvision import models, transforms
+from tqdm import tqdm
+import csv
+import matplotlib.pyplot as plt
+import matplotlib
+
+from Dataloader import RobustImageFolder
+
+matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
+matplotlib.rcParams['axes.unicode_minus'] = False
+
+# ============================================================
+# ★★★ 可配置参数 ★★★
+# ============================================================
+DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
+BATCH_SIZE = 32
+IMAGE_SIZE = 256
+NUM_WORKERS = 4
+EPOCHS = 30
+LR = 0.001
+TRAIN_PCT = 0.1
+SEED = 42
+DROPOUT = 0.3
+MODEL_SAVE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct.pth')
+LOG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct_log.csv')
+# ============================================================
+
+NUM_CLASSES = 4
+CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
+
+
+class PretrainedResNet34(nn.Module):
+ def __init__(self, num_classes=NUM_CLASSES, dropout=DROPOUT):
+ super().__init__()
+ self.backbone = models.resnet34(weights='IMAGENET1K_V1')
+ in_features = self.backbone.fc.in_features
+ self.backbone.fc = nn.Identity()
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(in_features, num_classes)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ x = self.dropout(x)
+ x = self.fc(x)
+ return x
+
+ def freeze_early_layers(self):
+ for param in self.backbone.conv1.parameters():
+ param.requires_grad = False
+ for param in self.backbone.bn1.parameters():
+ param.requires_grad = False
+ for param in self.backbone.layer1.parameters():
+ param.requires_grad = False
+ for param in self.backbone.layer2.parameters():
+ param.requires_grad = False
+
+ def print_trainable_info(self):
+ frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
+ total = frozen + trainable
+ print(f" 冻结参数: {frozen:,} 可训练参数: {trainable:,} ({100.*trainable/total:.1f}%)")
+
+
+def compute_macro_f1(predicted, targets, num_classes=NUM_CLASSES):
+ tp = torch.zeros(num_classes, device=predicted.device)
+ fp = torch.zeros(num_classes, device=predicted.device)
+ fn = torch.zeros(num_classes, device=predicted.device)
+ for c in range(num_classes):
+ tp[c] = ((predicted == c) & (targets == c)).sum()
+ fp[c] = ((predicted == c) & (targets != c)).sum()
+ fn[c] = ((predicted != c) & (targets == c)).sum()
+ precision = tp / (tp + fp + 1e-8)
+ recall = tp / (tp + fn + 1e-8)
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
+ return f1.mean().item()
+
+
+def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
+ model.train()
+ running_loss, correct, total = 0.0, 0, 0
+ all_preds, all_labels = [], []
+ pbar = tqdm(loader, desc=f'Epoch {epoch+1} [Train]')
+ for images, labels in pbar:
+ images, labels = images.to(device), labels.to(device)
+ outputs = model(images)
+ loss = criterion(outputs, labels)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ running_loss += loss.item() * images.size(0)
+ _, predicted = outputs.max(1)
+ total += labels.size(0)
+ correct += predicted.eq(labels).sum().item()
+ all_preds.append(predicted)
+ all_labels.append(labels)
+ batch_f1 = compute_macro_f1(predicted, labels)
+ pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}',
+ 'Acc': f'{100.*correct/total:.2f}%'})
+ epoch_loss = running_loss / total
+ epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
+ epoch_acc = 100. * correct / total
+ return epoch_loss, epoch_f1, epoch_acc
+
+
+@torch.no_grad()
+def validate(model, loader, criterion, device):
+ model.eval()
+ running_loss, correct, total = 0.0, 0, 0
+ all_preds, all_labels = [], []
+ for images, labels in tqdm(loader, desc='[Validate]'):
+ images, labels = images.to(device), labels.to(device)
+ outputs = model(images)
+ loss = criterion(outputs, labels)
+ running_loss += loss.item() * images.size(0)
+ _, predicted = outputs.max(1)
+ total += labels.size(0)
+ correct += predicted.eq(labels).sum().item()
+ all_preds.append(predicted)
+ all_labels.append(labels)
+ epoch_loss = running_loss / total
+ epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
+ epoch_acc = 100. * correct / total
+ return epoch_loss, epoch_f1, epoch_acc
+
+
+def train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR):
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
+ lr=lr, momentum=0.9, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+
+ history = {'train_loss': [], 'train_f1': [], 'train_acc': [],
+ 'val_loss': [], 'val_f1': [], 'val_acc': []}
+ best_val_f1 = 0.0
+
+ log_file = open(LOG_PATH, 'w', newline='')
+ log_writer = csv.writer(log_file)
+ log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc',
+ 'val_loss', 'val_f1', 'val_acc', 'lr', 'best'])
+
+ for epoch in range(epochs):
+ print(f'\n{"="*50}')
+ print(f'Epoch {epoch+1}/{epochs}')
+
+ train_loss, train_f1, train_acc = train_one_epoch(
+ model, train_loader, criterion, optimizer, device, epoch)
+ val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device)
+ scheduler.step()
+
+ history['train_loss'].append(train_loss)
+ history['train_f1'].append(train_f1)
+ history['train_acc'].append(train_acc)
+ history['val_loss'].append(val_loss)
+ history['val_f1'].append(val_f1)
+ history['val_acc'].append(val_acc)
+
+ print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}')
+ print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}')
+ print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
+
+ best_mark = ''
+ if val_f1 > best_val_f1:
+ best_val_f1 = val_f1
+ torch.save(model.state_dict(), MODEL_SAVE_PATH)
+ best_mark = 'best'
+ print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})')
+
+ lr_val = optimizer.param_groups[0]['lr']
+ log_writer.writerow([epoch+1, train_loss, train_f1, train_acc,
+ val_loss, val_f1, val_acc, lr_val, best_mark])
+ log_file.flush()
+
+ log_file.close()
+ print(f'\n训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}')
+ return history
+
+
+# ============================================================
+# compare_models.py 导入接口
+# ============================================================
+
+def get_resnet34_10pct_preds(train_loader, val_loader, device):
+ model = PretrainedResNet34(num_classes=NUM_CLASSES)
+ model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu'))
+ model = model.to(device).eval()
+
+ y_true, y_preds, y_probs = [], [], []
+ with torch.no_grad():
+ for images, labels in tqdm(val_loader, desc='ResNet-34 (10%)'):
+ images, labels = images.to(device), labels
+ logits = model(images)
+ probs = torch.softmax(logits, dim=1)
+ preds = probs.argmax(dim=1)
+ y_true.append(labels.numpy())
+ y_preds.append(preds.cpu().numpy())
+ y_probs.append(probs.cpu().numpy())
+ return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs)
+
+
+# ============================================================
+# 独立训练入口
+# ============================================================
+
+if __name__ == '__main__':
+ random.seed(SEED)
+ np.random.seed(SEED)
+ torch.manual_seed(SEED)
+
+ device = torch.device('cuda' if torch.cuda.is_available()
+ else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
+ else 'cpu')
+ print(f"Device: {device}")
+
+ val_transform = transforms.Compose([
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+ train_transform = transforms.Compose([
+ transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.RandomRotation(degrees=15),
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+
+ full_train_dataset = RobustImageFolder(
+ root=os.path.join(DATA_ROOT, 'train'),
+ transform=train_transform,
+ )
+ val_dataset = RobustImageFolder(
+ root=os.path.join(DATA_ROOT, 'val'),
+ transform=val_transform,
+ )
+
+ n_train = len(full_train_dataset)
+ n_subset = max(1, int(n_train * TRAIN_PCT))
+ indices = random.sample(range(n_train), n_subset)
+ train_dataset = Subset(full_train_dataset, indices)
+ print(f"训练集: {len(train_dataset)} / {n_train} ({TRAIN_PCT*100:.0f}%)")
+ print(f"验证集: {len(val_dataset)}")
+
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
+ shuffle=True, num_workers=NUM_WORKERS,
+ pin_memory=True, drop_last=True)
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
+ shuffle=False, num_workers=NUM_WORKERS,
+ pin_memory=True, drop_last=False)
+
+ model = PretrainedResNet34(num_classes=NUM_CLASSES, dropout=DROPOUT)
+ model.freeze_early_layers()
+ model.print_trainable_info()
+ model = model.to(device)
+
+ history = train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR)
+
+ model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu'))
+ print(f"模型已保存: {MODEL_SAVE_PATH}")
+ print(f"训练日志已保存: {LOG_PATH}")
diff --git a/baseline/VGG_KNN.py b/baseline/VGG_KNN.py
new file mode 100644
index 0000000..2e9a466
--- /dev/null
+++ b/baseline/VGG_KNN.py
@@ -0,0 +1,145 @@
+"""
+baseline/VGG_KNN.py
+VGG16 预训练模型特征提取 + KNN 四分类基线
+可独立运行,也可被 compare_models.py 导入复用
+author: yukun-hh
+date: 2026-5-14
+"""
+import sys, os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from torchvision import models, transforms
+from tqdm import tqdm
+
+from sklearn.neighbors import KNeighborsClassifier
+from sklearn.metrics import (
+ accuracy_score, f1_score,
+ confusion_matrix, ConfusionMatrixDisplay,
+ classification_report,
+)
+
+from Dataloader import RobustImageFolder
+
+matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
+matplotlib.rcParams['axes.unicode_minus'] = False
+
+
+CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
+
+
+def load_vgg16_extractor(device):
+ try:
+ model = models.vgg16(weights='IMAGENET1K_V1')
+ except TypeError:
+ model = models.vgg16(pretrained=True)
+ model.classifier = nn.Identity()
+ model = model.to(device).eval()
+ for param in model.parameters():
+ param.requires_grad = False
+ return model
+
+
+def extract_features(model, loader, device):
+ model.eval()
+ all_features = []
+ all_labels = []
+ with torch.no_grad():
+ for images, labels in tqdm(loader, desc='Extracting features'):
+ images = images.to(device)
+ feats = model(images)
+ all_features.append(feats.cpu().numpy())
+ all_labels.append(labels.numpy())
+ return np.concatenate(all_features), np.concatenate(all_labels)
+
+
+class VGGKNNBaseline:
+ def __init__(self, k=5, device='cpu',
+ data_root='../trash_division_data/ultimate_4_class/',
+ image_size=256, batch_size=32, num_workers=4):
+ self.k = k
+ self.device = device
+ self.data_root = data_root
+ self.image_size = image_size
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.extractor = load_vgg16_extractor(device)
+ self.knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
+
+ def _get_loader(self, split):
+ transform = transforms.Compose([
+ transforms.Resize((self.image_size, self.image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+ dataset = RobustImageFolder(
+ root=os.path.join(self.data_root, split),
+ transform=transform,
+ )
+ print(f" {split}: {len(dataset)} 张")
+ return DataLoader(dataset, batch_size=self.batch_size,
+ shuffle=False, num_workers=self.num_workers,
+ pin_memory=True, drop_last=False)
+
+ def fit(self, train_loader=None):
+ if train_loader is None:
+ train_loader = self._get_loader('train')
+ print(" 提取训练集特征 ...")
+ train_feats, train_labels = extract_features(self.extractor, train_loader, self.device)
+ self.knn.fit(train_feats, train_labels)
+
+ def predict(self, val_loader=None):
+ if val_loader is None:
+ val_loader = self._get_loader('val')
+ print(" 提取验证集特征 ...")
+ val_feats, val_labels = extract_features(self.extractor, val_loader, self.device)
+ preds = self.knn.predict(val_feats)
+ probs = self.knn.predict_proba(val_feats)
+ return val_labels, preds, probs
+
+
+if __name__ == '__main__':
+ DATA_ROOT = '../trash_division_data/ultimate_4_class/'
+ BATCH_SIZE = 32
+ IMAGE_SIZE = 256
+ NUM_WORKERS = 4
+ K = 5
+
+ device = torch.device('cuda' if torch.cuda.is_available()
+ else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
+ else 'cpu')
+ print(f"Device: {device}")
+
+ baseline = VGGKNNBaseline(k=K, device=device,
+ data_root=DATA_ROOT, image_size=IMAGE_SIZE,
+ batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
+
+ train_loader = baseline._get_loader('train')
+ val_loader = baseline._get_loader('val')
+
+ baseline.fit(train_loader)
+ y_true, y_preds, y_probs = baseline.predict(val_loader)
+
+ acc = accuracy_score(y_true, y_preds)
+ macro_f1 = f1_score(y_true, y_preds, average='macro')
+ print(f"\n验证集 Accuracy: {acc:.4f}")
+ print(f"验证集 Macro-F1: {macro_f1:.4f}")
+ print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}")
+
+ cm = confusion_matrix(y_true, y_preds)
+ fig, ax = plt.subplots(figsize=(8, 7))
+ ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
+ ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
+ ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14)
+ plt.tight_layout()
+ out_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vgg_knn_confusion_matrix.png')
+ plt.savefig(out_path, dpi=150, bbox_inches='tight')
+ plt.show()
+ print(f"混淆矩阵已保存: {out_path}")
diff --git a/baseline/__init__.py b/baseline/__init__.py
new file mode 100644
index 0000000..7471dfe
--- /dev/null
+++ b/baseline/__init__.py
@@ -0,0 +1 @@
+# baseline package
diff --git a/baseline/accuracy_bar.png b/baseline/accuracy_bar.png
new file mode 100644
index 0000000..786494c
Binary files /dev/null and b/baseline/accuracy_bar.png differ
diff --git a/baseline/compare_models.py b/baseline/compare_models.py
new file mode 100644
index 0000000..5917fe0
--- /dev/null
+++ b/baseline/compare_models.py
@@ -0,0 +1,249 @@
+"""
+baseline/compare_models.py
+多模型对比:ROC 曲线 + 准确率柱状图
+添加新模型只需在 MODELS 列表加一行,无需修改绘图代码
+author: yukun-hh
+date: 2026-5-14
+"""
+import sys, os, re
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+
+import torch
+from torch.utils.data import DataLoader
+from torchvision import transforms
+from tqdm import tqdm
+
+from sklearn.metrics import (
+ roc_curve, auc, accuracy_score,
+ precision_recall_curve, average_precision_score,
+)
+
+from Model import Net
+from Dataloader import RobustImageFolder
+from baseline.VGG_KNN import VGGKNNBaseline
+from baseline.ResNet34_Pretrained_10pct import get_resnet34_10pct_preds
+from baseline.HOG_Baseline import get_hog_lr_preds
+
+matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
+matplotlib.rcParams['axes.unicode_minus'] = False
+
+# ============================================================
+# ★★★ 可配置参数 ★★★
+# ============================================================
+DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
+BATCH_SIZE = 32
+IMAGE_SIZE = 256
+NUM_WORKERS = 4
+K_KNN = 5
+# ============================================================
+
+CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
+NUM_CLASSES = 4
+
+# ============================================================
+# 预测函数 — 每个函数签名: (train_loader, val_loader, device) -> (y_true, y_preds, y_probs)
+# ============================================================
+
+def get_resnet34_preds(train_loader, val_loader, device):
+ model = Net(num_classes=NUM_CLASSES)
+ state_dict = torch.load('../best_model.pth', map_location='cpu')
+ if 'model_state_dict' in state_dict:
+ state_dict = state_dict['model_state_dict']
+ elif 'model' in state_dict:
+ state_dict = state_dict['model']
+ model.load_state_dict(state_dict)
+ model = model.to(device).eval()
+
+ y_true, y_preds, y_probs = [], [], []
+ with torch.no_grad():
+ for images, labels in tqdm(val_loader, desc='ResNet-34'):
+ images, labels = images.to(device), labels
+ logits = model(images)
+ probs = torch.softmax(logits, dim=1)
+ preds = probs.argmax(dim=1)
+ y_true.append(labels.numpy())
+ y_preds.append(preds.cpu().numpy())
+ y_probs.append(probs.cpu().numpy())
+ return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs)
+
+
+def get_vgg_knn_preds(train_loader, val_loader, device):
+ baseline = VGGKNNBaseline(k=K_KNN, device=device)
+ baseline.fit(train_loader)
+ return baseline.predict(val_loader)
+
+
+# ============================================================
+# ★ 模型注册表 — 添加新模型只需在这里加一行 ★
+# ============================================================
+
+MODELS = [
+ ('ResNet-34', get_resnet34_preds),
+ ('ResNet-34 (10% Fine-tune)', get_resnet34_10pct_preds),
+ ('VGG16 + KNN (K=5)', get_vgg_knn_preds),
+ ('HOG + LogisticRegression', get_hog_lr_preds),
+ # 未来轻松扩展示例:
+ # ('ResNet-18 (pretrained)', get_resnet18_preds),
+ # ('ResNet-50 (pretrained)', get_resnet50_preds),
+ # ('ResNet-34 (finetuned)', get_finetuned_preds),
+]
+
+# ============================================================
+# 调色板 (扩展时无需修改)
+# ============================================================
+COLORS = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b',
+ '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
+
+
+def compute_macro_roc(y_true, y_probs):
+ one_hot = np.eye(NUM_CLASSES)[y_true]
+ fpr_dict, tpr_dict = {}, {}
+ for c in range(NUM_CLASSES):
+ fpr_dict[c], tpr_dict[c], _ = roc_curve(one_hot[:, c], y_probs[:, c])
+ all_fpr = np.unique(np.concatenate([fpr_dict[c] for c in range(NUM_CLASSES)]))
+ mean_tpr = np.zeros_like(all_fpr)
+ for c in range(NUM_CLASSES):
+ mean_tpr += np.interp(all_fpr, fpr_dict[c], tpr_dict[c])
+ mean_tpr /= NUM_CLASSES
+ macro_auc = auc(all_fpr, mean_tpr)
+ return all_fpr, mean_tpr, macro_auc
+
+
+def compute_macro_pr(y_true, y_probs):
+ one_hot = np.eye(NUM_CLASSES)[y_true]
+ prec_dict, rec_dict = {}, {}
+ for c in range(NUM_CLASSES):
+ prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c])
+ all_rec = np.linspace(0, 1, 200)
+ mean_prec = np.zeros_like(all_rec)
+ for c in range(NUM_CLASSES):
+ mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1])
+ mean_prec /= NUM_CLASSES
+ macro_ap = average_precision_score(one_hot, y_probs, average='macro')
+ return all_rec, mean_prec, macro_ap
+
+
+def sanitize_filename(name):
+ return re.sub(r'[^\w\-_]', '_', name).strip('_')
+
+
+def preds_csv_path(out_dir, model_name):
+ safe = sanitize_filename(model_name)
+ return os.path.join(out_dir, f'{safe}_preds.csv')
+
+
+def save_preds_csv(path, y_true, y_preds, y_probs):
+ header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES))
+ data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs])
+ np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f')
+
+
+def load_preds_csv(path):
+ data = np.loadtxt(path, delimiter=',', skiprows=1)
+ y_true = data[:, 0].astype(int)
+ y_preds = data[:, 1].astype(int)
+ y_probs = data[:, 2:2 + NUM_CLASSES]
+ return y_true, y_preds, y_probs
+
+
+if __name__ == '__main__':
+ out_dir = os.path.dirname(os.path.abspath(__file__))
+
+ device = torch.device('cuda' if torch.cuda.is_available()
+ else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
+ else 'cpu')
+ print(f"Device: {device}")
+
+ val_transform = transforms.Compose([
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+
+ train_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'train'),
+ transform=val_transform)
+ val_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'val'),
+ transform=val_transform)
+ print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}")
+
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False,
+ num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
+ num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
+
+ # ———— 评估所有模型(有缓存则跳过)————
+ results = {}
+ for name, func in MODELS:
+ print(f"\n{'='*50}")
+ csv_path = preds_csv_path(out_dir, name)
+ if os.path.exists(csv_path):
+ print(f"加载缓存: {os.path.basename(csv_path)}")
+ y_true, y_preds, y_probs = load_preds_csv(csv_path)
+ else:
+ print(f"评估: {name}")
+ y_true, y_preds, y_probs = func(train_loader, val_loader, device)
+ save_preds_csv(csv_path, y_true, y_preds, y_probs)
+ print(f" 预测数据已保存: {os.path.basename(csv_path)}")
+ acc = accuracy_score(y_true, y_preds)
+ fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
+ rec, prec, macro_ap = compute_macro_pr(y_true, y_probs)
+ results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
+ 'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc,
+ 'rec': rec, 'prec': prec, 'ap': macro_ap}
+ print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}")
+
+ # ———— ROC 对比图 ————
+ fig, ax = plt.subplots(figsize=(8, 7))
+ for i, (name, r) in enumerate(results.items()):
+ color = COLORS[i % len(COLORS)]
+ ax.plot(r['fpr'], r['tpr'], color=color, lw=2,
+ label=f"{name} (AUC={r['auc']:.4f})")
+ ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
+ ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
+ ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
+ ax.set_title('ROC Curve Comparison (Macro-Average)', fontsize=14)
+ ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)
+ plt.tight_layout()
+ roc_path = os.path.join(out_dir, 'roc_comparison.png')
+ plt.savefig(roc_path, dpi=150, bbox_inches='tight')
+ plt.show()
+ print(f"\nROC 对比图已保存: {roc_path}")
+
+ # ———— PR 对比图 ————
+ fig, ax = plt.subplots(figsize=(8, 7))
+ for i, (name, r) in enumerate(results.items()):
+ color = COLORS[i % len(COLORS)]
+ ax.plot(r['rec'], r['prec'], color=color, lw=2,
+ label=f"{name} (AP={r['ap']:.4f})")
+ ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
+ ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
+ ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14)
+ ax.legend(loc='lower left'); ax.grid(True, alpha=0.3)
+ plt.tight_layout()
+ pr_path = os.path.join(out_dir, 'pr_comparison.png')
+ plt.savefig(pr_path, dpi=150, bbox_inches='tight')
+ plt.show()
+ print(f"PR 对比图已保存: {pr_path}")
+
+ # ———— 准确率柱状图 ————
+ names = list(results.keys())
+ accs = [results[n]['acc'] for n in names]
+ fig, ax = plt.subplots(figsize=(8, 5))
+ bar_colors = [COLORS[i % len(COLORS)] for i in range(len(names))]
+ bars = ax.bar(names, accs, color=bar_colors, edgecolor='white', linewidth=1.2)
+ for bar, acc in zip(bars, accs):
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
+ f'{acc:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
+ ax.set_ylim(min(accs) - 0.03, max(accs) * 1.08)
+ ax.set_ylabel('Accuracy'); ax.set_title('Accuracy Comparison', fontsize=14)
+ ax.grid(True, alpha=0.3, axis='y')
+ plt.tight_layout()
+ bar_path = os.path.join(out_dir, 'accuracy_bar.png')
+ plt.savefig(bar_path, dpi=150, bbox_inches='tight')
+ plt.show()
+ print(f"准确率柱状图已保存: {bar_path}")
diff --git a/baseline/pr_comparison.png b/baseline/pr_comparison.png
new file mode 100644
index 0000000..03f80ef
Binary files /dev/null and b/baseline/pr_comparison.png differ
diff --git a/baseline/roc_comparison.png b/baseline/roc_comparison.png
new file mode 100644
index 0000000..66ae8cc
Binary files /dev/null and b/baseline/roc_comparison.png differ
diff --git a/confusion_matrix.png b/confusion_matrix.png
new file mode 100644
index 0000000..765f7c9
Binary files /dev/null and b/confusion_matrix.png differ
diff --git a/pr_curve.png b/pr_curve.png
new file mode 100644
index 0000000..f24c81d
Binary files /dev/null and b/pr_curve.png differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..f612349
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+torch>=1.10
+torchvision>=0.11
+tqdm
+matplotlib
+pandas
+Pillow
+scikit-learn
+scikit-image
+numpy
+torchsummary
diff --git a/roc_curve.png b/roc_curve.png
new file mode 100644
index 0000000..05cdfbd
Binary files /dev/null and b/roc_curve.png differ
diff --git a/training_curves.png b/training_curves.png
new file mode 100644
index 0000000..482cf5e
Binary files /dev/null and b/training_curves.png differ
diff --git a/training_log.csv b/training_log.csv
new file mode 100644
index 0000000..9795ca2
--- /dev/null
+++ b/training_log.csv
@@ -0,0 +1,81 @@
+epoch,train_loss,train_f1,train_acc,val_loss,val_f1,val_acc,lr,best
+1,1.0409312975676923,0.4329540729522705,48.04254100337675,1.1043149583345566,0.4398210048675537,48.66548042704626,0.004998072590601808,best
+2,0.9862563783744079,0.4695238769054413,52.5943680656054,0.9867177669753319,0.5062971115112305,58.397207774431976,0.00499229333433282,best
+3,0.9462850892451784,0.49421826004981995,55.40279787747226,1.0144445673589866,0.4907984733581543,53.15494114426499,0.004982671142387316,
+4,0.910117163585685,0.514958381652832,57.832097202122526,0.8787865286946395,0.5453917980194092,62.544483985765126,0.004969220851487844,best
+5,0.8786031692946986,0.5320333242416382,59.74282440906898,1.0686318878927787,0.4803737998008728,52.73063235696688,0.004951963201008076,
+6,0.8518873820889128,0.5481140613555908,61.51938615533044,0.7650798693964196,0.6073676347732544,68.98439638653161,0.004930924800994191,best
+7,0.8256270786701512,0.5604796409606934,62.90249638205499,0.8796401012116773,0.5789190530776978,62.63345195729537,0.004906138091134118,
+8,0.8003699506646013,0.5742803812026978,64.3014351181862,0.9246643470014378,0.5521833896636963,60.88831097727895,0.004877641290737884,
+9,0.780536473588097,0.5827116966247559,65.25642185238785,0.8404132533719564,0.5876226425170898,65.89789214344374,0.00484547833980621,
+10,0.7604798049209087,0.595557451248169,66.54079232995659,0.9228097118810533,0.564703643321991,60.77881193539557,0.004809698831278217,
+11,0.7410275131047088,0.6043155789375305,67.35784491075735,0.7576604621266131,0.6295210123062134,69.83985765124555,0.0047703579345627035,best
+12,0.7195374228732343,0.6127941608428955,68.07766521948867,0.9624476881507583,0.5630610585212708,61.59321105940323,0.00472751631047092,
+13,0.6997139808122973,0.6210756301879883,68.85175470332851,0.7615812296349155,0.6177672147750854,69.12127018888584,0.004681240017681994,
+14,0.6824904908837182,0.630592942237854,69.65448625180898,0.6715762626299165,0.6534035205841064,73.48754448398577,0.004631600410885231,best
+15,0.6653590450468583,0.6379610300064087,70.39088880849012,0.694461440047988,0.6517682075500488,73.0906104571585,0.004578674030756364,
+16,0.6514209577758935,0.6478185653686523,71.27879281234925,0.7036816360785346,0.6470745801925659,71.76977826444019,0.004522542485937369,
+17,0.6330186040776395,0.6530008316040039,71.7935962373372,0.7222367905930823,0.6418735980987549,71.07856556255133,0.004463292327201863,
+18,0.6166394593717968,0.6634106040000916,72.6038651712494,0.6067476719332303,0.6886636018753052,77.49110320284697,0.004401014914000078,best
+19,0.5973944908975692,0.6721534729003906,73.47367945007235,0.6952472055509845,0.6622275114059448,71.79715302491103,0.004335806273589214,
+20,0.5820678306183721,0.6758297681808472,73.73145803183792,0.7708474785342401,0.6217234134674072,68.74486723241172,0.004267766952966369,
+21,0.5650806297110982,0.6851130723953247,74.5741377231066,0.7461620579141478,0.6384793519973755,71.35231316725978,0.004197001863832355,
+22,0.5500074683958588,0.6915749311447144,75.099493487699,0.6420613189380593,0.672810435295105,74.9452504790583,0.00412362012082546,
+23,0.5367840825001858,0.6979560852050781,75.66102870236372,0.6252713002082977,0.6949211359024048,75.32849712565014,0.0040477348732745845,best
+24,0.5234906795055925,0.7052106857299805,76.26025084418717,0.7471277477021352,0.6447888016700745,69.70982753900904,0.003969463130731182,
+25,0.5044557179049829,0.7132176160812378,76.91148094548963,0.6325626891507145,0.6857903003692627,75.52012044894607,0.0038889255825490052,
+26,0.4938347232885195,0.7174828052520752,77.23784973468403,0.5635758755127437,0.70375657081604,78.50396934026827,0.003806246411789872,best
+27,0.4793313278116239,0.7242900133132935,77.87475880366618,0.5505193975847648,0.7201660871505737,78.89405967697783,0.003721553103742388,best
+28,0.46573570758837524,0.7336312532424927,78.60437771345876,0.640248272807638,0.6859503984451294,75.13003011223651,0.003634976249348867,
+29,0.44927708754289913,0.737967312335968,78.9352689339122,0.6151526539644867,0.7065733075141907,74.69887763482069,0.00354664934384357,
+30,0.4373503708129221,0.7443608045578003,79.40409430776653,0.5578661908719627,0.7272701263427734,78.20969066520668,0.0034567085809127244,best
+31,0.42717206794400175,0.7488712072372437,79.7779486251809,0.58909761693554,0.7034546136856079,76.88201478237066,0.003365292642693732,
+32,0.4100124511779706,0.7580570578575134,80.60630728412929,0.6458172935624336,0.6865078210830688,75.32849712565014,0.0032725424859373683,
+33,0.3993677339991451,0.763430118560791,80.9876989869754,0.47995706558097007,0.754202127456665,81.65891048453327,0.003178601124662685,best
+34,0.3858378949555808,0.7697042226791382,81.54772672455378,0.6427663844838523,0.6931804418563843,74.28141253764029,0.0030836134096397633,
+35,0.37397055771404913,0.7764154672622681,81.9992161119151,0.6085299244000101,0.7046636343002319,77.08048179578428,0.0029877258050403205,
+36,0.3597575335889638,0.7818952798843384,82.53587795465509,0.5254679805051781,0.7415529489517212,78.89405967697783,0.002891086162600577,
+37,0.3487578573732404,0.7871347665786743,82.8999336710082,0.5125140355052995,0.748577356338501,80.1875171092253,0.002793843493644594,
+38,0.3325814358527052,0.7965956330299377,83.59035817655571,0.5408413317834649,0.7290798425674438,79.87270736381056,0.002696147739319612,
+39,0.3261546248608721,0.7988470792770386,83.75316570188133,0.5301555857539602,0.7376729249954224,80.06433068710649,0.002598149539397671,
+40,0.30964642827472305,0.8070269823074341,84.52725518572117,0.5468305750249544,0.7365171313285828,78.73665480427046,0.0024999999999999996,
+41,0.3009217412674191,0.8119726777076721,84.96140858658949,0.46898612490165015,0.7599539756774902,82.18587462359704,0.002401850460602329,best
+42,0.28925693789887874,0.8200639486312866,85.6458031837916,0.5167866427621677,0.7465909123420715,80.83766767040788,0.0023038522606803878,
+43,0.2707157838268379,0.8313596248626709,86.46360950313556,0.5156203284349763,0.7596548199653625,80.6049822064057,0.0022061565063554063,
+44,0.2580273799019566,0.836384654045105,86.8412325132658,0.5318487190707494,0.746901273727417,80.27648508075555,0.0021089138373994237,
+45,0.2504911308580703,0.8410984873771667,87.30779667149059,0.49164087495639364,0.763725221157074,81.82315904735833,0.00201227419495968,best
+46,0.24104372076995695,0.8451772928237915,87.64094910757356,0.5290114263981752,0.7580969333648682,80.94032302217356,0.0019163865903602372,
+47,0.22337641519549614,0.8570870161056519,88.56201760733236,0.43634469677838694,0.7913081049919128,85.10128661374213,0.0018213988753373142,best
+48,0.2128122905210861,0.8645581603050232,89.1152616980222,0.43183545479456625,0.7972898483276367,85.49137695045168,0.001727457514062632,best
+49,0.2003101470318182,0.8717849254608154,89.69187168355042,0.4289672785945178,0.806715726852417,85.7993430057487,0.0016347073573062686,best
+50,0.1888495338707803,0.8796613216400146,90.34687047756874,0.4568697272132202,0.7956517338752747,84.64275937585546,0.0015432914190872762,
+51,0.1756466486088274,0.886497437953949,90.97096599131693,0.4541556305781541,0.7938134670257568,84.45113605255953,0.001453350656156431,
+52,0.16742963044469736,0.8907681703567505,91.3101483357453,0.4230425570913775,0.8147625923156738,86.100465370928,0.0013650237506511336,best
+53,0.15311133117022804,0.9007841944694519,92.11212614568258,0.4146759752586969,0.8208259344100952,86.83274021352314,0.0012784468962576128,best
+54,0.1423091164071722,0.9078108668327332,92.6714001447178,0.4709351422719488,0.8029968738555908,85.71721872433616,0.0011937535882101285,
+55,0.13160189816137902,0.9135022163391113,93.10932223830197,0.40829240685941226,0.8264325857162476,87.42814125376403,0.0011110744174509947,best
+56,0.12707359487800943,0.9174070358276367,93.4409671972986,0.42565728100299705,0.8230471611022949,87.65398302764851,0.0010305368692688178,
+57,0.11291898237482914,0.925153374671936,94.10953328509407,0.43774206922898184,0.8247347474098206,87.5376402956474,0.0009522651267254161,
+58,0.10370979833767516,0.9329074025154114,94.62584418716835,0.4068256767521223,0.8333848118782043,88.05091705447578,0.0008763798791745416,best
+59,0.0946220946482491,0.9380815029144287,95.09768451519537,0.41103389083751746,0.8357677459716797,88.4889132220093,0.0008029981361676465,best
+60,0.08804238645213414,0.9423004388809204,95.45194163048721,0.4207381762007377,0.8308929204940796,88.17410347659458,0.0007322330470336316,
+61,0.07913849578165794,0.9495129585266113,95.9916184273999,0.4083157278823748,0.8420299291610718,89.00218998083767,0.0006641937264107861,best
+62,0.06981624146470565,0.9559226036071777,96.49662325132658,0.41332166066585485,0.843511700630188,88.98850260060225,0.0005989850859999229,best
+63,0.06394793773276639,0.9602090120315552,96.79811866859623,0.41052334102789995,0.8489691019058228,89.35121817684096,0.0005367076727981376,best
+64,0.057007493751794744,0.9636315107345581,97.11016642547034,0.3970402057488879,0.8546013832092285,90.04927456884752,0.00047745751406263185,best
+65,0.05285091761448427,0.967146635055542,97.40035576459238,0.4247235585867853,0.8433754444122314,89.34437448672324,0.0004213259692436376,
+66,0.04614944799553407,0.9710246324539185,97.6912988422576,0.4035414461747053,0.8538572788238525,89.85080755543389,0.00036839958911476966,
+67,0.042909558690492254,0.9727644920349121,97.8428002894356,0.41578795896298,0.8530543446540833,89.91924445661101,0.0003187599823180077,
+68,0.03640206224887977,0.9769999980926514,98.15710926193921,0.42073161891477256,0.8551151156425476,89.94661921708185,0.0002724836895290806,best
+69,0.034432517173010414,0.9790080785751343,98.34328268210324,0.4133664407223553,0.8576940298080444,90.28880372296743,0.00022964206543729668,best
+70,0.030669637766023813,0.9817556142807007,98.5249336710082,0.418308269951463,0.8569411039352417,90.12455516014235,0.00019030116872178321,
+71,0.028112305183924133,0.9827903509140015,98.606337433671,0.4151474575991667,0.8596312999725342,90.37092800437996,0.00015452166019378966,best
+72,0.024704152367817256,0.9853801727294922,98.81135431741437,0.4153705465811558,0.8635820746421814,90.68573774979468,0.0001223587092621162,best
+73,0.024846541488804174,0.9855506420135498,98.8369814278823,0.4177400436290088,0.8632140159606934,90.59676977826444,9.38619088658821e-05,
+74,0.022639600746625622,0.9868491888046265,98.94702725518572,0.41732572613841307,0.8648342490196228,90.78154941144265,6.907519900580863e-05,best
+75,0.02120214593173326,0.9878177642822266,99.0231548480463,0.4163925825270714,0.866214394569397,90.89789214344374,4.803679899192394e-05,best
+76,0.019741657997631577,0.9883521795272827,99.06385672937772,0.42005763620917286,0.8647006750106812,90.82945524226663,3.077914851215586e-05,
+77,0.019116416042495393,0.9889511466026306,99.10003617945007,0.4159400745789841,0.8657370805740356,90.84998631261976,1.7328857612684272e-05,
+78,0.019259902796210714,0.9888157844543457,99.0962674867342,0.4192042892654481,0.8641382455825806,90.69942513003011,7.706665667180091e-06,
+79,0.01933925595445387,0.9887675046920776,99.0759165460685,0.4180937778044573,0.8662786483764648,90.84998631261976,1.9274093981927482e-06,best
+80,0.01922732148408437,0.9889604449272156,99.10078991799324,0.41794140280912484,0.864332914352417,90.82261155214891,0.0,