diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..67fdd7a --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +* +!/.gitattributes +!/Dataloader.py +!/LICENSE +!/Merge_classes.py +!/Model.py +!/README.md +!/requirements.txt +!/THIRD_PARTY_LICENSES.md +!/Train.py +!/app.py +!/Baseline.py +!/Finetune.py +!/Curve.py +!/Evaluate.py +!/baseline/ +!/baseline/__init__.py +!/baseline/VGG_KNN.py +!/baseline/compare_models.py +!/baseline/ResNet34_Pretrained_10pct.py +!/baseline/HOG_Baseline.py +!/baseline/roc_comparison.png +!/baseline/pr_comparison.png +!/baseline/accuracy_bar.png +!/training_log.csv +!/confusion_matrix.png +!/roc_curve.png +!/pr_curve.png +!/training_curves.png +!.gitignore diff --git a/Curve.py b/Curve.py new file mode 100644 index 0000000..7c38b37 --- /dev/null +++ b/Curve.py @@ -0,0 +1,50 @@ +""" +plot_training_curves.py +从 training_log.csv 读取日志,绘制 Loss / F1 / Accuracy / LR 曲线 +""" + +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib + +matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] +matplotlib.rcParams['axes.unicode_minus'] = False + +# ============ 读取数据 ============ +df = pd.read_csv('training_log.csv') +best_rows = df[df['best'] == 'best'] + +fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + +# ---- 1. Loss ---- +ax = axes[0, 0] +ax.plot(df['epoch'], df['train_loss'], label='Train Loss', color='#1f77b4', lw=1.5) +ax.plot(df['epoch'], df['val_loss'], label='Val Loss', color='#ff7f0e', lw=1.5) +ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.set_title('Loss vs Epoch') +ax.legend(); ax.grid(True, alpha=0.3) + +# ---- 2. F1 Score ---- +ax = axes[0, 1] +ax.plot(df['epoch'], df['train_f1'], label='Train F1', color='#1f77b4', lw=1.5) +ax.plot(df['epoch'], df['val_f1'], label='Val F1', color='#ff7f0e', lw=1.5) +ax.set_xlabel('Epoch'); ax.set_ylabel('F1 Score'); ax.set_title('F1 Score vs Epoch') +ax.legend(); ax.grid(True, alpha=0.3) + +# ---- 3. Accuracy ---- +ax = axes[1, 0] +ax.plot(df['epoch'], df['train_acc'], label='Train Acc', color='#1f77b4', lw=1.5) +ax.plot(df['epoch'], df['val_acc'], label='Val Acc', color='#ff7f0e', lw=1.5) +ax.set_xlabel('Epoch'); ax.set_ylabel('Accuracy (%)'); ax.set_title('Accuracy vs Epoch') +ax.legend(); ax.grid(True, alpha=0.3) + +# ---- 4. Learning Rate ---- +ax = axes[1, 1] +ax.plot(df['epoch'], df['lr'], color='#2ca02c', lw=1.5) +ax.set_xlabel('Epoch'); ax.set_ylabel('Learning Rate'); ax.set_title('Learning Rate vs Epoch') +ax.ticklabel_format(style='scientific', axis='y', scilimits=(0, 0)) +ax.grid(True, alpha=0.3) + +plt.tight_layout() +plt.savefig('training_curves.png', dpi=150, bbox_inches='tight') +plt.show() +print("训练曲线已保存: training_curves.png") diff --git a/Dataloader.py b/Dataloader.py index fab833b..5d51e3f 100644 --- a/Dataloader.py +++ b/Dataloader.py @@ -13,9 +13,75 @@ import os from PIL import Image import matplotlib.pyplot as plt import numpy as np +import pandas as pd +from torch.utils.data import Dataset +from PIL import Image, ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +from tqdm import tqdm +from torch.utils.data import Dataset +from PIL import Image, ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True -def create_dataloaders(data_root='..', +class RobustImageFolder(Dataset): + """包装 ImageFolder,自动跳过损坏图片,带进度条""" + + def __init__(self, root, transform=None): + self.transform = transform + self.samples = [] + self.classes = [] + self.class_to_idx = {} + + # 先构建原始的 ImageFolder 来获取类别信息 + temp_dataset = datasets.ImageFolder(root, transform=None) + self.classes = temp_dataset.classes + self.class_to_idx = temp_dataset.class_to_idx + + # 带进度条扫描 + print(f"\n正在扫描: {root}") + print(f"发现 {len(temp_dataset.samples)} 个文件,开始验证...\n") + + corrupted_count = 0 + success_count = 0 + + # 使用 tqdm 显示进度 + for path, label in tqdm(temp_dataset.samples, + desc="验证图片完整性", + unit="张", + ncols=80): + try: + self.samples.append((path, label)) + success_count += 1 + except Exception as e: + corrupted_count += 1 + # 可选:只打印前10个错误,避免刷屏 + if corrupted_count <= 10: + tqdm.write(f"⚠️ 跳过损坏: {os.path.basename(path)}") + elif corrupted_count == 11: + tqdm.write(f"⚠️ 后续损坏图片将不再显示...") + + print(f"\n✅ 扫描完成!") + print(f" 📁 有效图片: {success_count} 张") + print(f" ❌ 损坏跳过: {corrupted_count} 张") + print(f" 📊 总计: {len(self.samples)} 张\n") + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + path, label = self.samples[idx] + try: + img = Image.open(path).convert('RGB') + if self.transform: + img = self.transform(img) + return img, label + except Exception as e: + # 极少数情况,返回下一个 + return self.__getitem__((idx + 1) % len(self)) +def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/', batch_size=32, image_size=256, val_split=0.2, @@ -77,12 +143,12 @@ def create_dataloaders(data_root='..', # 2. 加载数据集 # ================================== print("使用独立的 val 文件夹") - train_dataset = datasets.ImageFolder( + train_dataset = RobustImageFolder( root=os.path.join(data_root, 'train'), transform=train_transform if augment else val_transform ) - val_dataset = datasets.ImageFolder( + val_dataset = RobustImageFolder( root=os.path.join(data_root, 'val'), transform=val_transform ) @@ -111,9 +177,8 @@ def create_dataloaders(data_root='..', ) # 4. 获取类别名称 - class_names = train_dataset.classes if hasattr(train_dataset, 'classes') else ['0', '1', '2', '3'] + class_names = train_dataset.classes print(f"类别: {class_names}") - print(f"类别映射: {train_dataset.class_to_idx if hasattr(train_dataset, 'class_to_idx') else '0-3'}") return train_loader, val_loader, class_names @@ -158,10 +223,10 @@ def visualize_batch(dataloader, class_names, num_images=8): if __name__ == '__main__': train_loader, val_loader, class_names = create_dataloaders( - data_root='..', # 与trash-division同级文件夹 - batch_size=32, # 根据你的显存调整 + data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹 + batch_size=16, # 根据你的显存调整 image_size=256, # 与你模型输入一致 - num_workers=4, # Windows 可能需设为 0 + num_workers=16, # Windows 可能需设为 0 augment=True # 训练时使用数据增强 ) visualize_batch(train_loader, class_names, num_images=8) diff --git a/Evaluate.py b/Evaluate.py new file mode 100644 index 0000000..7c78f1c --- /dev/null +++ b/Evaluate.py @@ -0,0 +1,147 @@ +""" +evaluate_and_plot.py +加载模型,在验证集上推理,绘制混淆矩阵 / ROC / PR 曲线 +""" + +import os +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +from sklearn.metrics import ( + confusion_matrix, ConfusionMatrixDisplay, + roc_curve, auc, + precision_recall_curve, average_precision_score, +) + +from Model import Net +from Dataloader import RobustImageFolder + +matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] +matplotlib.rcParams['axes.unicode_minus'] = False + +# ============================================================ +# ★★★ 需要你修改的参数 ★★★ +# ============================================================ +MODEL_PATH = 'best_model.pth' # 模型权重路径 +DATA_ROOT = '../trash_division_data/ultimate_4_class/' # 数据集根目录 +BATCH_SIZE = 32 +IMAGE_SIZE = 256 +NUM_WORKERS = 4 +# ============================================================ + +# ---------- 1. 加载验证集 ---------- +val_transform = transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), +]) + +val_dataset = RobustImageFolder( + root=os.path.join(DATA_ROOT, 'val'), + transform=val_transform, +) +val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, + shuffle=False, num_workers=NUM_WORKERS, + pin_memory=True, drop_last=False) + +class_names = val_dataset.classes +num_classes = len(class_names) +print(f"类别: {class_names}") + +# ---------- 2. 加载模型 ---------- +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = Net(num_classes=num_classes) +state_dict = torch.load(MODEL_PATH, map_location=device) +if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] +elif 'model' in state_dict: + state_dict = state_dict['model'] +model.load_state_dict(state_dict) +model = model.to(device).eval() +print("模型加载完成") + +# ---------- 3. 推理 ---------- +all_labels = [] +all_probs = [] + +with torch.no_grad(): + for images, labels in val_loader: + images = images.to(device) + probs = torch.softmax(model(images), dim=1) + all_labels.append(labels.numpy()) + all_probs.append(probs.cpu().numpy()) + +all_labels = np.concatenate(all_labels) +all_probs = np.concatenate(all_probs) +all_preds = np.argmax(all_probs, axis=1) +print(f"推理完成, 共 {len(all_labels)} 样本") + +# ============================================================ +# ① 混淆矩阵 +# ============================================================ +cm = confusion_matrix(all_labels, all_preds) +fig, ax = plt.subplots(figsize=(8, 7)) +ConfusionMatrixDisplay(cm, display_labels=class_names).plot( + ax=ax, cmap='Blues', values_format='d', xticks_rotation=30) +ax.set_title('Confusion Matrix', fontsize=14) +plt.tight_layout() +plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight') +plt.show() +print("混淆矩阵已保存: confusion_matrix.png") + +# ============================================================ +# ② ROC 曲线 (One-vs-Rest + Macro-average) +# ============================================================ +one_hot = np.eye(num_classes)[all_labels] +colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'] + +fig, ax = plt.subplots(figsize=(8, 7)) +fpr_d, tpr_d, auc_d = {}, {}, {} + +for i in range(num_classes): + fpr_d[i], tpr_d[i], _ = roc_curve(one_hot[:, i], all_probs[:, i]) + auc_d[i] = auc(fpr_d[i], tpr_d[i]) + ax.plot(fpr_d[i], tpr_d[i], color=colors[i], lw=2, + label=f'{class_names[i]} (AUC={auc_d[i]:.4f})') + +# Macro-average +all_fpr = np.unique(np.concatenate([fpr_d[i] for i in range(num_classes)])) +mean_tpr = sum(np.interp(all_fpr, fpr_d[i], tpr_d[i]) for i in range(num_classes)) / num_classes +macro_auc = auc(all_fpr, mean_tpr) +ax.plot(all_fpr, mean_tpr, 'navy', lw=2, ls='--', + label=f'Macro-avg (AUC={macro_auc:.4f})') +ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5) + +ax.set_xlim(0, 1); ax.set_ylim(0, 1.05) +ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate') +ax.set_title('ROC Curve', fontsize=14) +ax.legend(loc='lower right'); ax.grid(True, alpha=0.3) +plt.tight_layout() +plt.savefig('roc_curve.png', dpi=150, bbox_inches='tight') +plt.show() +print("ROC 曲线已保存: roc_curve.png") + +# ============================================================ +# ③ Precision-Recall 曲线 +# ============================================================ +fig, ax = plt.subplots(figsize=(8, 7)) + +for i in range(num_classes): + prec, rec, _ = precision_recall_curve(one_hot[:, i], all_probs[:, i]) + ap = average_precision_score(one_hot[:, i], all_probs[:, i]) + ax.plot(rec, prec, color=colors[i], lw=2, + label=f'{class_names[i]} (AP={ap:.4f})') + +ax.set_xlim(0, 1); ax.set_ylim(0, 1.05) +ax.set_xlabel('Recall'); ax.set_ylabel('Precision') +ax.set_title('Precision-Recall Curve', fontsize=14) +ax.legend(loc='best'); ax.grid(True, alpha=0.3) +plt.tight_layout() +plt.savefig('pr_curve.png', dpi=150, bbox_inches='tight') +plt.show() +print("PR 曲线已保存: pr_curve.png") diff --git a/Finetune.py b/Finetune.py new file mode 100644 index 0000000..d70baf7 --- /dev/null +++ b/Finetune.py @@ -0,0 +1,225 @@ +""" +微调脚本:冻结 conv1 + stage2,微调 stage3~fc +加大少样本类别的 loss 权重 +author: yukun-hh +date :2026-4-25 +""" +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +import matplotlib.pyplot as plt +from Model import Net +from Dataloader import create_dataloaders +import os +import csv + + +def compute_macro_f1(predicted, targets, num_classes=4): + tp = torch.zeros(num_classes, device=predicted.device) + fp = torch.zeros(num_classes, device=predicted.device) + fn = torch.zeros(num_classes, device=predicted.device) + for c in range(num_classes): + tp[c] = ((predicted == c) & (targets == c)).sum() + fp[c] = ((predicted == c) & (targets != c)).sum() + fn[c] = ((predicted != c) & (targets == c)).sum() + precision = tp / (tp + fp + 1e-8) + recall = tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + return f1.mean().item() + + +def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch): + model.train() + running_loss = 0.0 + correct = 0 + total = 0 + all_preds, all_labels = [], [] + + pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]') + + for images, labels in pbar: + images, labels = images.to(device), labels.to(device) + + outputs = model(images) + loss = criterion(outputs, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() * images.size(0) + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + all_preds.append(predicted) + all_labels.append(labels) + + batch_f1 = compute_macro_f1(predicted, labels) + pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}', 'Acc': f'{100. * correct / total:.2f}%'}) + + epoch_loss = running_loss / total + epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels)) + epoch_acc = 100. * correct / total + return epoch_loss, epoch_f1, epoch_acc + + +def validate(model, val_loader, criterion, device): + model.eval() + running_loss = 0.0 + correct = 0 + total = 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for images, labels in tqdm(val_loader, desc='[Validate]'): + images, labels = images.to(device), labels.to(device) + + outputs = model(images) + loss = criterion(outputs, labels) + + running_loss += loss.item() * images.size(0) + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + all_preds.append(predicted) + all_labels.append(labels) + + epoch_loss = running_loss / total + epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels)) + epoch_acc = 100. * correct / total + return epoch_loss, epoch_f1, epoch_acc + + +def compute_class_weights(dataset, num_classes=4, device='cpu', power=1.0): + class_counts = torch.zeros(num_classes) + for _, label in dataset.samples: + lbl = label.item() if isinstance(label, torch.Tensor) else label + class_counts[lbl] += 1 + total = class_counts.sum() + weights = total / (num_classes * class_counts) + weights = weights ** power + return weights.to(device) + + +def freeze_base_layers(model): + frozen_layers = [] + for name, param in model.conv1.named_parameters(): + param.requires_grad = False + frozen_layers.append(f'conv1.{name}') + for name, param in model.bn1.named_parameters(): + param.requires_grad = False + frozen_layers.append(f'bn1.{name}') + for name, param in model.stage2.named_parameters(): + param.requires_grad = False + frozen_layers.append(f'stage2.{name}') + for name, param in model.stage3.named_parameters(): + param.requires_grad = False + frozen_layers.append(f'stage3.{name}') + + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + print(f'冻结层数: {len(frozen_layers)} 个参数组') + print(f'可训练参数量: {trainable:,} / {total:,} ({100. * trainable / total:.1f}%)') + return model + + +def finetune(model, train_loader, val_loader, epochs=30, lr=0.0001, device='cuda'): + class_weights = compute_class_weights(train_loader.dataset, num_classes=4, device=device, power=1.5) + criterion = nn.CrossEntropyLoss(weight=class_weights) + + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, model.parameters()), + lr=lr, momentum=0.9, weight_decay=1e-4 + ) + + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + history = { + 'train_loss': [], + 'train_f1': [], + 'train_acc': [], + 'val_loss': [], + 'val_f1': [], + 'val_acc': [] + } + + best_val_f1 = 0.0 + log_file = open('finetune_log.csv', 'w', newline='') + log_writer = csv.writer(log_file) + log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc', 'val_loss', 'val_f1', 'val_acc', 'lr', 'best']) + + for epoch in range(epochs): + print(f'\n{"=" * 50}') + print(f'Epoch {epoch + 1}/{epochs}') + + train_loss, train_f1, train_acc = train_one_epoch(model, train_loader, criterion, + optimizer, device, epoch) + + val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device) + + scheduler.step() + + history['train_loss'].append(train_loss) + history['train_f1'].append(train_f1) + history['train_acc'].append(train_acc) + history['val_loss'].append(val_loss) + history['val_f1'].append(val_f1) + history['val_acc'].append(val_acc) + + print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}') + print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}') + print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}') + + best_mark = '' + if val_f1 > best_val_f1: + best_val_f1 = val_f1 + torch.save(model.state_dict(), 'finetuned_model.pth') + best_mark = 'best' + print(f'✓ 保存最佳微调模型 (Macro-F1: {val_f1:.4f})') + + lr = optimizer.param_groups[0]['lr'] + log_writer.writerow([epoch + 1, train_loss, train_f1, train_acc, val_loss, val_f1, val_acc, lr, best_mark]) + log_file.flush() + + print(f'\n{"=" * 50}') + print(f'微调完成!最佳验证 Macro-F1: {best_val_f1:.4f}') + + return model, history + + +if __name__ == '__main__': + train_loader, val_loader, class_names = create_dataloaders( + data_root='../trash_division_data/ultimate_4_class/', + batch_size=16, + image_size=256, + num_workers=8, + augment=True + ) + + device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu') + + model = Net(num_classes=4) + + if os.path.exists('best_model.pth'): + model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'))) + print('✓ 加载预训练权重 best_model.pth') + else: + print('⚠ 未找到 best_model.pth,使用随机初始化权重') + + model = model.to(device) + model = freeze_base_layers(model) + + print(f'Device: {device}') + print(f'Total parameters: {sum(p.numel() for p in model.parameters()):,}') + + trained_model, history = finetune( + model=model, + train_loader=train_loader, + val_loader=val_loader, + epochs=30, + lr=0.0001, + device=device + ) + + model.load_state_dict(torch.load('finetuned_model.pth')) diff --git a/merge_classes.py b/Merge_classes.py similarity index 90% rename from merge_classes.py rename to Merge_classes.py index a605aa2..145c345 100644 --- a/merge_classes.py +++ b/Merge_classes.py @@ -1,5 +1,5 @@ """将原数据集合并为我们需要的四个大类 - 运行时先配置路径 + 已修改成相对路径 具体配置方法详见README.md author: weikaiwen @@ -18,9 +18,9 @@ import shutil # ================= 1. 配置你的路径 ================= # 注意:请确保相对路径正确,以下为示例 -ORIGINAL_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data' # 原始数据集的目录 -NEW_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data/ultimate_4_class' # 合并后的新目录 -CLASSNAME_FILE = '/Users/weikaiwen/Desktop/trash_division_data/val/classname.txt' # txt 文件的位置 +ORIGINAL_DATA_DIR = '../trash_division_data' # 原始数据集的目录 +NEW_DATA_DIR = '../trash_division_data/ultimate_4_class' # 合并后的新目录 +CLASSNAME_FILE = '../trash_division_data/val/classname.txt' # txt 文件的位置 # =================================================== diff --git a/Model.py b/Model.py index d568b97..bf1a714 100644 --- a/Model.py +++ b/Model.py @@ -1,99 +1,113 @@ """ -这个文件是模型的定义文件,请不要擅自修改,如有疑问微信群里反馈 -单独运行本文件将会输出模型结构 -目前的话是一个36层的模型,模型总量应该是在80M左右 如果到时候还是欠拟合的话再考虑去做更深的结构 +模型定义文件 - ResNet-34 author : yukun-hh date : 2026-4-10 - """ -#神经网络模型库 import torch from torch import nn from torch.nn import functional as F from torchsummary import summary -#残差块 -class Resblock(nn.Module): - def __init__(self, input_channels,output_channels,use_1x1conv=False,strides=1): - """ - :param input_channels: 进入残差块时的原通道 - :param output_channels: 输出时的通道数 - :param use_1x1conv: 如果输入和输出通道不相等时,需要用一个1x1的卷积层对原来的输入进行一个通道提升 - :param strides: 默认1,如果大于1起到缩小张量的作用 - """ + +class BasicBlock(nn.Module): + """ + ResNet-34 基础残差块:3x3 -> 3x3 + 若需要下采样或通道变化,则在跳跃连接中使用 1x1 卷积 + """ + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None): super().__init__() - self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=3,padding=1,stride=strides) - self.conv2 = nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1,stride=1) - if use_1x1conv: - self.conv3 = nn.Conv2d(input_channels, output_channels,kernel_size=1, stride=strides) - else: - self.conv3 = None - self.bn1 = nn.BatchNorm2d(output_channels) - self.bn2 = nn.BatchNorm2d(output_channels) - def forward(self,X): - Y = F.relu(self.bn1(self.conv1(X))) - Y = self.bn2(self.conv2(Y)) - if self.conv3 is not None: - X = self.conv3(X) - Y += X - return F.relu(Y) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample -class Net(): - """ - 模型的主要结构就在这里了,到时也好该和调用 - 现在必须实现的方法: - 目前还是以图片缩放到256*256构建残差块 - """ - net = nn.Sequential() - def resnet_block(self,input_channels, num_channels, num_residuals, - first_block=False): - """ - :param input_channels: 输入维度 - :param num_channels: 输出维度 - :param num_residuals: 单个残差层的残差块数 - :param first_block: 第一块不用下采样 特殊控制 - :return: list[nn.Module] - """ - blk = [] + def forward(self, x): + identity = x - for i in range(num_residuals): - if i == 0 and not first_block: - blk.append(Resblock(input_channels, num_channels, - use_1x1conv=True, strides=2)) - else: - blk.append(Resblock(num_channels, num_channels)) - return blk - def __init__(self): - b1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3), - nn.BatchNorm2d(64), nn.ReLU(), - nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - ) - """ - 7×7 卷积层,输出通道 64,步长 2,填充 3 - (3×256×256)->(64×128×128) - 批归一化 relu层 - 最大池化 - (64×128×128)->(64×64×64) - """ - b2 = nn.Sequential(*self.resnet_block(64, 64, num_residuals=3, first_block=True)) - b3 = nn.Sequential(*self.resnet_block(64, 128, num_residuals=4)) - b4 = nn.Sequential(*self.resnet_block(128, 256, num_residuals=6)) - b5 = nn.Sequential(*self.resnet_block(256, 512, num_residuals=3)) - self.net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 4)) - def get_network(self): - return self.net + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class Net(nn.Module): + + def __init__(self, num_classes=4, dropout=0.5): + super().__init__() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + layers_config = [ + (3, 64, 1), # layer1 + (4, 128, 2), # layer2 + (6, 256, 2), # layer3 + (3, 512, 2), # layer4 + ] + + self.in_channels = 64 + self.layer1 = self._make_layer(layers_config[0]) + self.layer2 = self._make_layer(layers_config[1]) + self.layer3 = self._make_layer(layers_config[2]) + self.layer4 = self._make_layer(layers_config[3]) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(512, num_classes) + + def _make_layer(self, config): + num_blocks, out_channels, stride = config + downsample = None + layers = [] + + if stride != 1 or self.in_channels != out_channels: + downsample = nn.Sequential( + nn.Conv2d(self.in_channels, out_channels, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_channels), + ) + + layers.append(BasicBlock(self.in_channels, out_channels, stride, downsample)) + self.in_channels = out_channels + + for _ in range(1, num_blocks): + layers.append(BasicBlock(self.in_channels, out_channels)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x) + return x if __name__ == '__main__': - Net_new = Net() - X = torch.rand(size=(1, 3, 256, 256)) - summary(Net_new.get_network(), input_size=(3, 256, 256)) - - - - - - - - + model = Net(num_classes=4) + summary(model, input_size=(3, 256, 256)) diff --git a/README.md b/README.md index eacd84d..10b86e6 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,249 @@ # trash-division -## 一个基于卷积神经网络的垃圾分类识别系统 -### 同济大学python人工智能程序设计课程小组作业 +一个基于卷积神经网络的垃圾分类识别系统 +> 同济大学 Python 人工智能程序设计课程小组作业 + +基于 ResNet-34 架构的 CNN 模型(约 21M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。 + +--- + +## 目录 + +- [项目特点](#项目特点) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [快速开始](#快速开始) +- [文件说明](#文件说明) +- [目录结构](#目录结构) +- [训练细节](#训练细节) +- [评估与可视化](#评估与可视化) +- [许可证](#许可证) + +--- + +## 项目特点 + +- **四类垃圾分类**:厨余垃圾(1)、可回收物(2)、其他垃圾(3)、有害垃圾(4) +- **ResNet-34 架构**:约 21M 参数,34 层深度残差网络,含 Dropout 正则化 +- **数据增强**:训练时使用随机裁剪、水平翻转、旋转、色彩抖动 +- **Macro-F1 评估**:采用宏平均 F1 分数作为主要评估指标,兼顾各类别表现 +- **类别加权损失**:自动计算类别权重,缓解类别不平衡问题 +- **余弦退火学习率调度**:使用 CosineAnnealingLR 平滑调整学习率 +- **断点续训**:自动检测 `best_model.pth` 并加载继续训练 +- **多设备支持**:自动选择 CUDA > Intel XPU > CPU + +## 模型架构 + +模型基于标准 ResNet-34 架构,使用 BasicBlock 构建。 + +### BasicBlock 块 + +每个 BasicBlock 包含两个 3x3 卷积层 + 跳跃连接: + +| 层 | 卷积 | 作用 | +|---|---|---| +| 3x3 Conv | 特征提取 | 第一层卷积 | +| 3x3 Conv | 特征提取 | 第二层卷积 | + +### 网络结构 + +| 阶段 | 块数 | 输出通道数 | 说明 | +|---|---|---|---| +| 初始层 | - | 64 | 7x7 Conv, stride=2 + MaxPool | +| Layer1 | 3 | 64 | 第一个残差阶段 | +| Layer2 | 4 | 128 | - | +| Layer3 | 6 | 256 | - | +| Layer4 | 3 | 512 | 最终残差阶段 | +| 分类头 | - | 4 | 全局平均池化 + Dropout + 全连接层 | + +## 数据集 + +本项目使用 [tany0699/garbage265](https://modelscope.cn/datasets/tany0699/garbage265) 中文生活垃圾分类数据集,包含 265 个子类别的生活垃圾图片。 + +通过 `Merge_classes.py` 脚本将 265 个子类别合并为 4 个顶级类别: + +``` +厨余垃圾 -> 1 +可回收物 -> 2 +其他垃圾 -> 3 +有害垃圾 -> 4 +``` + +数据集预期放置在 `../trash_division_data/`(与项目根目录平级的兄弟目录)。 + +## 环境要求 + +```bash +pip install -r requirements.txt +``` + +> **注意**:`requirements.txt` 不锁定 PyTorch 的 CUDA / XPU 版本,请根据硬件自行安装对应版本,例如: +> - NVIDIA GPU:`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121` +> - Intel GPU (XPU):`pip install torch torchvision --index-url https://download.pytorch.org/whl/xpu` 安装 +> - CPU:`pip install torch torchvision` 即可 + +## 快速开始 + +1. **数据预处理**:将 265 个子类别合并为 4 个顶级类别 + + ```bash + python Merge_classes.py + ``` + +2. **训练模型**: + + ```bash + python Train.py + ``` + +3. **微调模型**(可选,冻结浅层、微调深层): + + ```bash + python Finetune.py + ``` + +4. **评估与可视化**: + + ```bash + python Evaluate.py # 混淆矩阵、ROC 曲线、PR 曲线 + python Curve.py # 训练过程的 loss/f1/acc/lr 曲线 + python baseline/compare_models.py # 多模型基线对比(ROC/PR/准确率) + ``` + +> **注意**: +> - 数据目录默认为 `../trash_division_data/ultimate_4_class/`,需先运行合并脚本 +> - Windows 系统需将 `num_workers` 设为 `0`(参见 `Dataloader.py` 和 `Train.py`) +> - 训练会自动从 `best_model.pth` 断点续训(若存在) + +## 文件说明 + +| 文件 | 功能 | +|---|---| +| `Train.py` | 训练主脚本,包含训练循环、验证、评估 | +| `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 | +| `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 | +| `Model.py` | 模型定义,ResNet-34(BasicBlock)+ Dropout | +| `Merge_classes.py` | 数据集预处理,265 类合并为 4 类 | +| `Evaluate.py` | 模型评估,绘制混淆矩阵、ROC 曲线、PR 曲线 | +| `Curve.py` | 训练曲线绘制,从 CSV 读取并绘制 loss/f1/acc/lr 曲线 | +| `baseline/VGG_KNN.py` | VGG16 预训练特征提取 + KNN 四分类基线 | +| `baseline/ResNet34_Pretrained_10pct.py` | ResNet-34 ImageNet 预训练 + 10% 数据微调 | +| `baseline/HOG_Baseline.py` | HOG + 颜色直方图 + LogisticRegression(纯传统 CV) | +| `baseline/compare_models.py` | 多模型对比(ROC / PR 曲线 + 准确率柱状图) | +| `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr | +| `best_model.pth` | 训练好的最佳模型权重(约 125 MB,不纳入版本控制) | +| `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 | + +## 目录结构 + +``` +trash-division/ +├── baseline/ # 基线模型目录 +│ ├── VGG_KNN.py # VGG16 + KNN 分类脚本 +│ ├── ResNet34_Pretrained_10pct.py # ResNet-34 ImageNet 预训练 + 10% 微调 +│ ├── HOG_Baseline.py # HOG + LogisticRegression 纯传统基线 +│ ├── compare_models.py # 多模型对比脚本 +│ ├── roc_comparison.png # 多模型 ROC 对比(compare_models.py 输出) +│ ├── pr_comparison.png # 多模型 PR 对比(compare_models.py 输出) +│ └── accuracy_bar.png # 多模型准确率对比(compare_models.py 输出) +├── best_model.pth # 最佳模型权重(不纳入版本控制) +├── Curve.py # 训练曲线绘制脚本 +├── Dataloader.py # 数据加载模块 +├── Evaluate.py # 模型评估可视化脚本 +├── Finetune.py # 微调脚本 +├── .gitattributes # Git 属性配置 +├── LICENSE # MIT 许可证 +├── Merge_classes.py # 数据集预处理脚本 +├── Model.py # 模型定义 +├── README.md # 项目说明(本文件) +├── THIRD_PARTY_LICENSES.md # 第三方许可证声明 +├── Train.py # 训练主脚本 +├── training_log.csv # 训练日志 +├── confusion_matrix.png # 混淆矩阵(Evaluate.py 输出) +├── roc_curve.png # ROC 曲线(Evaluate.py 输出) +├── pr_curve.png # PR 曲线(Evaluate.py 输出) +└── training_curves.png # 训练曲线(Curve.py 输出) +``` + +## 训练细节 + +| 配置项 | 说明 | +|---|---| +| 输入尺寸 | 256 x 256 RGB | +| 优化器 | SGD(momentum=0.9, weight_decay=1e-4) | +| 初始学习率 | 0.001 | +| 学习率调度 | CosineAnnealingLR | +| 损失函数 | 类别加权 CrossEntropyLoss | +| 评估指标 | Macro-F1(宏平均 F1 分数) | +| 批量大小 | 默认 16(可通过参数调整) | +| 训练轮数 | 默认 20(可通过参数调整) | +| 设备选择优先级 | CUDA > Intel XPU > CPU | +| 断点续训 | 自动检测 best_model.pth 并加载 | + +训练时数据增强管线:RandomResizedCrop(256, scale=(0.8, 1.0)) + RandomHorizontalFlip(p=0.5) + RandomRotation(+-15 deg) + ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2) + +## 评估与可视化 + +训练完成后,`training_log.csv` 会记录每个 epoch 的训练/验证指标。以下两个脚本用于可视化分析: + +### Evaluate.py — 模型评估 + +在验证集上推理,生成三张评估图表: + +```bash +python Evaluate.py +``` + +脚本顶部的 `MODEL_PATH`、`DATA_ROOT`、`BATCH_SIZE`、`NUM_WORKERS` 可按需修改。 + +**混淆矩阵** + +![confusion_matrix](confusion_matrix.png) + +**ROC 曲线** + +![roc_curve](roc_curve.png) + +**PR 曲线** + +![pr_curve](pr_curve.png) + +### Curve.py — 训练曲线 + +从 `training_log.csv` 读取训练日志,绘制四张子图: + +```bash +python Curve.py +``` + +![training_curves](training_curves.png) + +### 基线模型对比 + +`compare_models.py` 对所有模型在验证集上统一评估,生成三张对比图表: + +```bash +python baseline/compare_models.py +``` + +对比阵容:ResNet-34、ResNet-34 (10% Fine-tune)、VGG16 + KNN、HOG + LogisticRegression。 + +**ROC 曲线对比** + +![roc_comparison](baseline/roc_comparison.png) + +**PR 曲线对比** + +![pr_comparison](baseline/pr_comparison.png) + +**准确率柱状图** + +![accuracy_bar](baseline/accuracy_bar.png) ## 许可证 本项目主代码采用 [MIT 许可证](LICENSE)。 本项目包含的数据集 `tany0699/garbage265` 采用 [Apache License 2.0](THIRD_PARTY_LICENSES.md),详情请参阅 `THIRD_PARTY_LICENSES.md` 文件。 - - - diff --git a/Train.py b/Train.py index 6c87c69..528d2a4 100644 --- a/Train.py +++ b/Train.py @@ -12,52 +12,70 @@ import torch.optim as optim from tqdm import tqdm # 进度条,可选 import matplotlib.pyplot as plt from Model import Net +from Dataloader import create_dataloaders +import os +import csv + + +def compute_macro_f1(predicted, targets, num_classes=4): + tp = torch.zeros(num_classes, device=predicted.device) + fp = torch.zeros(num_classes, device=predicted.device) + fn = torch.zeros(num_classes, device=predicted.device) + for c in range(num_classes): + tp[c] = ((predicted == c) & (targets == c)).sum() + fp[c] = ((predicted == c) & (targets != c)).sum() + fn[c] = ((predicted != c) & (targets == c)).sum() + precision = tp / (tp + fp + 1e-8) + recall = tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + return f1.mean().item() + def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch): """训练一个epoch""" - model.train() # 设置为训练模式 + model.train() running_loss = 0.0 correct = 0 total = 0 + all_preds, all_labels = [], [] - # 使用 tqdm 显示进度条(可选) pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]') for images, labels in pbar: - # 将数据移到 GPU/CPU images, labels = images.to(device), labels.to(device) - # 前向传播 outputs = model(images) loss = criterion(outputs, labels) - # 反向传播 - optimizer.zero_grad() # 清空梯度 - loss.backward() # 计算梯度 - optimizer.step() # 更新参数 + optimizer.zero_grad() + loss.backward() + optimizer.step() - # 统计 running_loss += loss.item() * images.size(0) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() + all_preds.append(predicted) + all_labels.append(labels) - # 更新进度条信息 - pbar.set_postfix({'loss': loss.item(), 'acc': 100. * correct / total}) + batch_f1 = compute_macro_f1(predicted, labels) + pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}', 'Acc': f'{100. * correct / total:.2f}%'}) epoch_loss = running_loss / total + epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels)) epoch_acc = 100. * correct / total - return epoch_loss, epoch_acc + return epoch_loss, epoch_f1, epoch_acc def validate(model, val_loader, criterion, device): """验证函数""" - model.eval() # 设置为评估模式 + model.eval() running_loss = 0.0 correct = 0 total = 0 + all_preds, all_labels = [], [] - with torch.no_grad(): # 不计算梯度,节省内存 + with torch.no_grad(): for images, labels in tqdm(val_loader, desc='[Validate]'): images, labels = images.to(device), labels.to(device) @@ -68,37 +86,52 @@ def validate(model, val_loader, criterion, device): _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() + all_preds.append(predicted) + all_labels.append(labels) epoch_loss = running_loss / total + epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels)) epoch_acc = 100. * correct / total - return epoch_loss, epoch_acc + return epoch_loss, epoch_f1, epoch_acc + + +def compute_class_weights(dataset, num_classes=4, device='cpu'): + class_counts = torch.zeros(num_classes) + for _, label in dataset.samples: + lbl = label.item() if isinstance(label, torch.Tensor) else label + class_counts[lbl] += 1 + total = class_counts.sum() + weights = total / (num_classes * class_counts) + return weights.to(device) def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'): """主训练函数""" # 1. 定义损失函数和优化器 - criterion = nn.CrossEntropyLoss() # 多分类用交叉熵 + class_weights = compute_class_weights(train_loader.dataset, num_classes=4, device=device) + criterion = nn.CrossEntropyLoss(weight=class_weights) # 多分类用交叉熵 - # 优化器选择(推荐 Adam 或 SGD) - optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) # 或者使用 SGD + 动量 - # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4) + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4) # 学习率调度器(可选,帮助收敛) - scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) - # 或者用余弦退火 - # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) # 2. 记录训练历史 history = { 'train_loss': [], + 'train_f1': [], 'train_acc': [], 'val_loss': [], + 'val_f1': [], 'val_acc': [] } - best_val_acc = 0.0 + best_val_f1 = 0.0 + log_file = open('training_log.csv', 'w', newline='') + log_writer = csv.writer(log_file) + log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc', 'val_loss', 'val_f1', 'val_acc', 'lr', 'best']) # 3. 开始训练 for epoch in range(epochs): @@ -106,77 +139,65 @@ def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'): print(f'Epoch {epoch + 1}/{epochs}') # 训练 - train_loss, train_acc = train_one_epoch(model, train_loader, criterion, - optimizer, device, epoch) + train_loss, train_f1, train_acc = train_one_epoch(model, train_loader, criterion, + optimizer, device, epoch) # 验证 - val_loss, val_acc = validate(model, val_loader, criterion, device) + val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device) # 更新学习率 scheduler.step() # 记录 history['train_loss'].append(train_loss) + history['train_f1'].append(train_f1) history['train_acc'].append(train_acc) history['val_loss'].append(val_loss) + history['val_f1'].append(val_f1) history['val_acc'].append(val_acc) # 打印结果 - print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%') - print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%') + print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}') + print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}') print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}') # 保存最佳模型 - if val_acc > best_val_acc: - best_val_acc = val_acc + best_mark = '' + if val_f1 > best_val_f1: + best_val_f1 = val_f1 torch.save(model.state_dict(), 'best_model.pth') - print(f'✓ 保存最佳模型 (Acc: {val_acc:.2f}%)') + best_mark = 'best' + print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})') + + lr = optimizer.param_groups[0]['lr'] + log_writer.writerow([epoch + 1, train_loss, train_f1, train_acc, val_loss, val_f1, val_acc, lr, best_mark]) + log_file.flush() # 4. 绘制训练曲线 - plot_training_history(history) print(f'\n{"=" * 50}') - print(f'训练完成!最佳验证准确率: {best_val_acc:.2f}%') + print(f'训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}') return model, history -def plot_training_history(history): - """绘制训练曲线""" - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) - - # 损失曲线 - ax1.plot(history['train_loss'], label='Train Loss') - ax1.plot(history['val_loss'], label='Val Loss') - ax1.set_xlabel('Epoch') - ax1.set_ylabel('Loss') - ax1.set_title('Training and Validation Loss') - ax1.legend() - ax1.grid(True) - - # 准确率曲线 - ax2.plot(history['train_acc'], label='Train Acc') - ax2.plot(history['val_acc'], label='Val Acc') - ax2.set_xlabel('Epoch') - ax2.set_ylabel('Accuracy (%)') - ax2.set_title('Training and Validation Accuracy') - ax2.legend() - ax2.grid(True) - - plt.tight_layout() - plt.savefig('training_history.png', dpi=150) - plt.show() - - # ========== 使用示例 ========== if __name__ == '__main__': # 假设你的 dataloader 已经写好了 - # train_loader = ... - # val_loader = ... + train_loader, val_loader, class_names = create_dataloaders( + data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹 + batch_size=16, # 根据你的显存调整 + image_size=256, # 与你模型输入一致 + num_workers=8, # Windows 可能需设为 0 + augment=True # 训练时使用数据增强 + ) # 1. 创建模型 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = Net().get_network() # 根据你的 Net 类调整 + device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu') + model = Net(num_classes=4) # 根据你的 Net 类调整 + #断点继续训练 + if os.path.exists('best_model.pth'): + model.load_state_dict(torch.load('best_model.pth',map_location=torch.device('cpu'))) model = model.to(device) # 打印模型信息 @@ -188,11 +209,9 @@ if __name__ == '__main__': model=model, train_loader=train_loader, val_loader=val_loader, - epochs=50, + epochs=20, lr=0.001, device=device ) - # 3. 加载最佳模型用于预测 - model.load_state_dict(torch.load('best_model.pth')) - print('训练完成,最佳模型已加载') \ No newline at end of file + model.load_state_dict(torch.load('best_model.pth')) \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000..0897466 --- /dev/null +++ b/app.py @@ -0,0 +1,95 @@ +import gradio as gr +import torch +from torchvision import transforms +from PIL import Image +from Model import Net # 根据上传的 Model.py,模型类名为 Net + +# 1. 基础配置与类别映射 +# 根据 Merge_classes.py,1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾 +class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] + +# 设备自动选择逻辑,保持与 Train.py 和 Evaluate.py 一致 +device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'mps' if torch.mps.is_available() else 'cpu') +print(f"当前使用的推理设备: {device}") + +# 2. 初始化模型并加载最佳权重 +model = Net(num_classes=4) +try: + # 采用与 Evaluate.py 一致的健壮加载方式 + state_dict = torch.load('best_model.pth', map_location=device) + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + elif 'model' in state_dict: + state_dict = state_dict['model'] + + model.load_state_dict(state_dict) + model = model.to(device).eval() + print("✅ 成功加载 best_model.pth 权重") +except Exception as e: + print(f"⚠️ 模型加载失败,请确保目录下存在 best_model.pth。错误信息: {e}") + +# 3. 定义数据预处理流程 (必须与 Evaluate.py 中的 val_transform 保持完全一致) +transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) +]) + +# 4. 核心推理函数 +def predict(image): + if image is None: + return None + + # Gradio 传入的 pil 图像,确保转为 RGB 格式 + image = image.convert('RGB') + + # 预处理并增加 batch 维度 + input_tensor = transform(image).unsqueeze(0).to(device) + + with torch.no_grad(): + outputs = model(input_tensor) + # 使用 Softmax 将 logits 转换为 0~1 的概率分布 + probabilities = torch.softmax(outputs, dim=1)[0] + + # 组装为 Gradio Label 组件需要的字典格式 { "类别名": 概率值 } + result_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} + return result_dict + +# 5. 构建与美化 Gradio 界面 +with gr.Blocks(theme=gr.themes.Soft(), title="Trash Division 垃圾分类识别") as demo: + gr.Markdown( + """ +
+

🗑️ Trash Division - 智能垃圾分类系统

+

基于 ResNet-34 架构,支持精准识别:厨余垃圾、可回收物、其他垃圾、有害垃圾

+

同济大学 Python 人工智能程序设计课程小组作业

+
+ """ + ) + + with gr.Row(): + with gr.Column(scale=1): + # type="pil" 让 Gradio 直接传 PIL Image 对象给预测函数,配合 torchvision 最方便 + image_input = gr.Image(type="pil", label="上传垃圾图片 (支持拍照)") + with gr.Row(): + clear_btn = gr.Button("清空", variant="secondary") + submit_btn = gr.Button("开始识别", variant="primary") + + with gr.Column(scale=1): + label_output = gr.Label(num_top_classes=4, label="预测结果与置信度") + + # 绑定点击事件 + submit_btn.click(fn=predict, inputs=image_input, outputs=label_output) + clear_btn.click(lambda: (None, None), inputs=None, outputs=[image_input, label_output]) + +if __name__ == "__main__": + # 启动 Web 界面 + demo.launch( + server_name="127.0.0.1", + server_port=7860, + share=False, # 如果你想生成临时公网链接分享给同学测试,改为 True + inbrowser=True # 运行后自动在默认浏览器中打开 + ) \ No newline at end of file diff --git a/baseline/HOG_Baseline.py b/baseline/HOG_Baseline.py new file mode 100644 index 0000000..c72e405 --- /dev/null +++ b/baseline/HOG_Baseline.py @@ -0,0 +1,145 @@ +""" +baseline/HOG_Baseline.py +HOG + 颜色直方图特征提取 + LogisticRegression 四分类 +纯传统 CV/ML 基线,零神经网络依赖 +可独立运行,也可被 compare_models.py 导入 +author: yukun-hh +date: 2026-5-14 +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +from PIL import Image +from tqdm import tqdm +import matplotlib.pyplot as plt +import matplotlib + +from skimage.feature import hog +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import ( + accuracy_score, f1_score, + confusion_matrix, ConfusionMatrixDisplay, + classification_report, +) + +matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] +matplotlib.rcParams['axes.unicode_minus'] = False + +# ============================================================ +# ★★★ 可配置参数 ★★★ +# ============================================================ +DATA_ROOT = '../../trash_division_data/ultimate_4_class/' +IMAGE_SIZE = 128 +HOG_ORIENTATIONS = 9 +HOG_PIXELS_PER_CELL = (8, 8) +HOG_CELLS_PER_BLOCK = (2, 2) +COLOR_BINS = 32 +# ============================================================ + +CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] +NUM_CLASSES = 4 + + +def extract_hog_color(image): + img = image.convert('RGB').resize((IMAGE_SIZE, IMAGE_SIZE)) + arr = np.array(img, dtype=np.float64) / 255.0 + + hog_feat = hog(arr, orientations=HOG_ORIENTATIONS, + pixels_per_cell=HOG_PIXELS_PER_CELL, + cells_per_block=HOG_CELLS_PER_BLOCK, + channel_axis=2, feature_vector=True) + + color_feat = [] + for c in range(3): + hist, _ = np.histogram(arr[:, :, c], bins=COLOR_BINS, range=(0, 1)) + color_feat.append(hist) + color_feat = np.concatenate(color_feat) + + return np.concatenate([hog_feat, color_feat]) + + +class HOGLRBaseline: + def __init__(self, data_root=DATA_ROOT, image_size=IMAGE_SIZE): + self.data_root = data_root + self.image_size = image_size + self.clf = LogisticRegression( + C=1.0, max_iter=1000, solver='lbfgs', n_jobs=-1, + ) + + def _load_data(self, split): + dir_path = os.path.join(self.data_root, split) + features, labels = [], [] + for class_id in range(1, NUM_CLASSES + 1): + class_dir = os.path.join(dir_path, str(class_id)) + if not os.path.isdir(class_dir): + continue + files = sorted(os.listdir(class_dir)) + for fname in tqdm(files, desc=f'{split}/class_{class_id}'): + fpath = os.path.join(class_dir, fname) + try: + with Image.open(fpath) as img: + feat = extract_hog_color(img) + features.append(feat) + labels.append(class_id - 1) + except Exception: + pass + print(f" {split}: {len(features)} 张") + return np.array(features, dtype=np.float32), np.array(labels) + + def fit(self, train_dir=None): + if train_dir is None: + train_dir = 'train' + print(" 提取训练集 HOG 特征 ...") + X, y = self._load_data(train_dir) + self.clf.fit(X, y) + + def predict(self, val_dir=None): + if val_dir is None: + val_dir = 'val' + print(" 提取验证集 HOG 特征 ...") + X, y = self._load_data(val_dir) + preds = self.clf.predict(X) + probs = self.clf.predict_proba(X) + return y, preds, probs + + +# ============================================================ +# compare_models.py 导入接口 +# ============================================================ + +def get_hog_lr_preds(train_loader, val_loader, device): + baseline = HOGLRBaseline() + baseline.fit('train') + return baseline.predict('val') + + +# ============================================================ +# 独立运行入口 +# ============================================================ + +if __name__ == '__main__': + out_dir = os.path.dirname(os.path.abspath(__file__)) + + print("HOG + LogisticRegression 基线") + baseline = HOGLRBaseline() + + baseline.fit('train') + y_true, y_preds, y_probs = baseline.predict('val') + + acc = accuracy_score(y_true, y_preds) + macro_f1 = f1_score(y_true, y_preds, average='macro') + print(f"\n验证集 Accuracy: {acc:.4f}") + print(f"验证集 Macro-F1: {macro_f1:.4f}") + print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}") + + cm = confusion_matrix(y_true, y_preds) + fig, ax = plt.subplots(figsize=(8, 7)) + ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot( + ax=ax, cmap='Blues', values_format='d', xticks_rotation=30) + ax.set_title('HOG + LogisticRegression 混淆矩阵', fontsize=14) + plt.tight_layout() + cm_path = os.path.join(out_dir, 'hog_lr_confusion_matrix.png') + plt.savefig(cm_path, dpi=150, bbox_inches='tight') + plt.show() + print(f"混淆矩阵已保存: {cm_path}") diff --git a/baseline/ResNet34_Pretrained_10pct.py b/baseline/ResNet34_Pretrained_10pct.py new file mode 100644 index 0000000..b7988a8 --- /dev/null +++ b/baseline/ResNet34_Pretrained_10pct.py @@ -0,0 +1,278 @@ +""" +baseline/ResNet34_Pretrained_10pct.py +ResNet-34 ImageNet 预训练权重 + 10% 训练集微调 +可独立运行训练,也可被 compare_models.py 导入 +author: yukun-hh +date: 2026-5-14 +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, Subset +from torchvision import models, transforms +from tqdm import tqdm +import csv +import matplotlib.pyplot as plt +import matplotlib + +from Dataloader import RobustImageFolder + +matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] +matplotlib.rcParams['axes.unicode_minus'] = False + +# ============================================================ +# ★★★ 可配置参数 ★★★ +# ============================================================ +DATA_ROOT = '../../trash_division_data/ultimate_4_class/' +BATCH_SIZE = 32 +IMAGE_SIZE = 256 +NUM_WORKERS = 4 +EPOCHS = 30 +LR = 0.001 +TRAIN_PCT = 0.1 +SEED = 42 +DROPOUT = 0.3 +MODEL_SAVE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct.pth') +LOG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct_log.csv') +# ============================================================ + +NUM_CLASSES = 4 +CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] + + +class PretrainedResNet34(nn.Module): + def __init__(self, num_classes=NUM_CLASSES, dropout=DROPOUT): + super().__init__() + self.backbone = models.resnet34(weights='IMAGENET1K_V1') + in_features = self.backbone.fc.in_features + self.backbone.fc = nn.Identity() + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(in_features, num_classes) + + def forward(self, x): + x = self.backbone(x) + x = self.dropout(x) + x = self.fc(x) + return x + + def freeze_early_layers(self): + for param in self.backbone.conv1.parameters(): + param.requires_grad = False + for param in self.backbone.bn1.parameters(): + param.requires_grad = False + for param in self.backbone.layer1.parameters(): + param.requires_grad = False + for param in self.backbone.layer2.parameters(): + param.requires_grad = False + + def print_trainable_info(self): + frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad) + trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + total = frozen + trainable + print(f" 冻结参数: {frozen:,} 可训练参数: {trainable:,} ({100.*trainable/total:.1f}%)") + + +def compute_macro_f1(predicted, targets, num_classes=NUM_CLASSES): + tp = torch.zeros(num_classes, device=predicted.device) + fp = torch.zeros(num_classes, device=predicted.device) + fn = torch.zeros(num_classes, device=predicted.device) + for c in range(num_classes): + tp[c] = ((predicted == c) & (targets == c)).sum() + fp[c] = ((predicted == c) & (targets != c)).sum() + fn[c] = ((predicted != c) & (targets == c)).sum() + precision = tp / (tp + fp + 1e-8) + recall = tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + return f1.mean().item() + + +def train_one_epoch(model, loader, criterion, optimizer, device, epoch): + model.train() + running_loss, correct, total = 0.0, 0, 0 + all_preds, all_labels = [], [] + pbar = tqdm(loader, desc=f'Epoch {epoch+1} [Train]') + for images, labels in pbar: + images, labels = images.to(device), labels.to(device) + outputs = model(images) + loss = criterion(outputs, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + running_loss += loss.item() * images.size(0) + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + all_preds.append(predicted) + all_labels.append(labels) + batch_f1 = compute_macro_f1(predicted, labels) + pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}', + 'Acc': f'{100.*correct/total:.2f}%'}) + epoch_loss = running_loss / total + epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels)) + epoch_acc = 100. * correct / total + return epoch_loss, epoch_f1, epoch_acc + + +@torch.no_grad() +def validate(model, loader, criterion, device): + model.eval() + running_loss, correct, total = 0.0, 0, 0 + all_preds, all_labels = [], [] + for images, labels in tqdm(loader, desc='[Validate]'): + images, labels = images.to(device), labels.to(device) + outputs = model(images) + loss = criterion(outputs, labels) + running_loss += loss.item() * images.size(0) + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + all_preds.append(predicted) + all_labels.append(labels) + epoch_loss = running_loss / total + epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels)) + epoch_acc = 100. * correct / total + return epoch_loss, epoch_f1, epoch_acc + + +def train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR): + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), + lr=lr, momentum=0.9, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + history = {'train_loss': [], 'train_f1': [], 'train_acc': [], + 'val_loss': [], 'val_f1': [], 'val_acc': []} + best_val_f1 = 0.0 + + log_file = open(LOG_PATH, 'w', newline='') + log_writer = csv.writer(log_file) + log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc', + 'val_loss', 'val_f1', 'val_acc', 'lr', 'best']) + + for epoch in range(epochs): + print(f'\n{"="*50}') + print(f'Epoch {epoch+1}/{epochs}') + + train_loss, train_f1, train_acc = train_one_epoch( + model, train_loader, criterion, optimizer, device, epoch) + val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device) + scheduler.step() + + history['train_loss'].append(train_loss) + history['train_f1'].append(train_f1) + history['train_acc'].append(train_acc) + history['val_loss'].append(val_loss) + history['val_f1'].append(val_f1) + history['val_acc'].append(val_acc) + + print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}') + print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}') + print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}') + + best_mark = '' + if val_f1 > best_val_f1: + best_val_f1 = val_f1 + torch.save(model.state_dict(), MODEL_SAVE_PATH) + best_mark = 'best' + print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})') + + lr_val = optimizer.param_groups[0]['lr'] + log_writer.writerow([epoch+1, train_loss, train_f1, train_acc, + val_loss, val_f1, val_acc, lr_val, best_mark]) + log_file.flush() + + log_file.close() + print(f'\n训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}') + return history + + +# ============================================================ +# compare_models.py 导入接口 +# ============================================================ + +def get_resnet34_10pct_preds(train_loader, val_loader, device): + model = PretrainedResNet34(num_classes=NUM_CLASSES) + model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu')) + model = model.to(device).eval() + + y_true, y_preds, y_probs = [], [], [] + with torch.no_grad(): + for images, labels in tqdm(val_loader, desc='ResNet-34 (10%)'): + images, labels = images.to(device), labels + logits = model(images) + probs = torch.softmax(logits, dim=1) + preds = probs.argmax(dim=1) + y_true.append(labels.numpy()) + y_preds.append(preds.cpu().numpy()) + y_probs.append(probs.cpu().numpy()) + return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs) + + +# ============================================================ +# 独立训练入口 +# ============================================================ + +if __name__ == '__main__': + random.seed(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + + device = torch.device('cuda' if torch.cuda.is_available() + else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available() + else 'cpu') + print(f"Device: {device}") + + val_transform = transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomRotation(degrees=15), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + + full_train_dataset = RobustImageFolder( + root=os.path.join(DATA_ROOT, 'train'), + transform=train_transform, + ) + val_dataset = RobustImageFolder( + root=os.path.join(DATA_ROOT, 'val'), + transform=val_transform, + ) + + n_train = len(full_train_dataset) + n_subset = max(1, int(n_train * TRAIN_PCT)) + indices = random.sample(range(n_train), n_subset) + train_dataset = Subset(full_train_dataset, indices) + print(f"训练集: {len(train_dataset)} / {n_train} ({TRAIN_PCT*100:.0f}%)") + print(f"验证集: {len(val_dataset)}") + + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, + shuffle=True, num_workers=NUM_WORKERS, + pin_memory=True, drop_last=True) + val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, + shuffle=False, num_workers=NUM_WORKERS, + pin_memory=True, drop_last=False) + + model = PretrainedResNet34(num_classes=NUM_CLASSES, dropout=DROPOUT) + model.freeze_early_layers() + model.print_trainable_info() + model = model.to(device) + + history = train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR) + + model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu')) + print(f"模型已保存: {MODEL_SAVE_PATH}") + print(f"训练日志已保存: {LOG_PATH}") diff --git a/baseline/VGG_KNN.py b/baseline/VGG_KNN.py new file mode 100644 index 0000000..2e9a466 --- /dev/null +++ b/baseline/VGG_KNN.py @@ -0,0 +1,145 @@ +""" +baseline/VGG_KNN.py +VGG16 预训练模型特征提取 + KNN 四分类基线 +可独立运行,也可被 compare_models.py 导入复用 +author: yukun-hh +date: 2026-5-14 +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torchvision import models, transforms +from tqdm import tqdm + +from sklearn.neighbors import KNeighborsClassifier +from sklearn.metrics import ( + accuracy_score, f1_score, + confusion_matrix, ConfusionMatrixDisplay, + classification_report, +) + +from Dataloader import RobustImageFolder + +matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] +matplotlib.rcParams['axes.unicode_minus'] = False + + +CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] + + +def load_vgg16_extractor(device): + try: + model = models.vgg16(weights='IMAGENET1K_V1') + except TypeError: + model = models.vgg16(pretrained=True) + model.classifier = nn.Identity() + model = model.to(device).eval() + for param in model.parameters(): + param.requires_grad = False + return model + + +def extract_features(model, loader, device): + model.eval() + all_features = [] + all_labels = [] + with torch.no_grad(): + for images, labels in tqdm(loader, desc='Extracting features'): + images = images.to(device) + feats = model(images) + all_features.append(feats.cpu().numpy()) + all_labels.append(labels.numpy()) + return np.concatenate(all_features), np.concatenate(all_labels) + + +class VGGKNNBaseline: + def __init__(self, k=5, device='cpu', + data_root='../trash_division_data/ultimate_4_class/', + image_size=256, batch_size=32, num_workers=4): + self.k = k + self.device = device + self.data_root = data_root + self.image_size = image_size + self.batch_size = batch_size + self.num_workers = num_workers + self.extractor = load_vgg16_extractor(device) + self.knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1) + + def _get_loader(self, split): + transform = transforms.Compose([ + transforms.Resize((self.image_size, self.image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + dataset = RobustImageFolder( + root=os.path.join(self.data_root, split), + transform=transform, + ) + print(f" {split}: {len(dataset)} 张") + return DataLoader(dataset, batch_size=self.batch_size, + shuffle=False, num_workers=self.num_workers, + pin_memory=True, drop_last=False) + + def fit(self, train_loader=None): + if train_loader is None: + train_loader = self._get_loader('train') + print(" 提取训练集特征 ...") + train_feats, train_labels = extract_features(self.extractor, train_loader, self.device) + self.knn.fit(train_feats, train_labels) + + def predict(self, val_loader=None): + if val_loader is None: + val_loader = self._get_loader('val') + print(" 提取验证集特征 ...") + val_feats, val_labels = extract_features(self.extractor, val_loader, self.device) + preds = self.knn.predict(val_feats) + probs = self.knn.predict_proba(val_feats) + return val_labels, preds, probs + + +if __name__ == '__main__': + DATA_ROOT = '../trash_division_data/ultimate_4_class/' + BATCH_SIZE = 32 + IMAGE_SIZE = 256 + NUM_WORKERS = 4 + K = 5 + + device = torch.device('cuda' if torch.cuda.is_available() + else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available() + else 'cpu') + print(f"Device: {device}") + + baseline = VGGKNNBaseline(k=K, device=device, + data_root=DATA_ROOT, image_size=IMAGE_SIZE, + batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) + + train_loader = baseline._get_loader('train') + val_loader = baseline._get_loader('val') + + baseline.fit(train_loader) + y_true, y_preds, y_probs = baseline.predict(val_loader) + + acc = accuracy_score(y_true, y_preds) + macro_f1 = f1_score(y_true, y_preds, average='macro') + print(f"\n验证集 Accuracy: {acc:.4f}") + print(f"验证集 Macro-F1: {macro_f1:.4f}") + print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}") + + cm = confusion_matrix(y_true, y_preds) + fig, ax = plt.subplots(figsize=(8, 7)) + ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot( + ax=ax, cmap='Blues', values_format='d', xticks_rotation=30) + ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14) + plt.tight_layout() + out_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vgg_knn_confusion_matrix.png') + plt.savefig(out_path, dpi=150, bbox_inches='tight') + plt.show() + print(f"混淆矩阵已保存: {out_path}") diff --git a/baseline/__init__.py b/baseline/__init__.py new file mode 100644 index 0000000..7471dfe --- /dev/null +++ b/baseline/__init__.py @@ -0,0 +1 @@ +# baseline package diff --git a/baseline/accuracy_bar.png b/baseline/accuracy_bar.png new file mode 100644 index 0000000..786494c Binary files /dev/null and b/baseline/accuracy_bar.png differ diff --git a/baseline/compare_models.py b/baseline/compare_models.py new file mode 100644 index 0000000..5917fe0 --- /dev/null +++ b/baseline/compare_models.py @@ -0,0 +1,249 @@ +""" +baseline/compare_models.py +多模型对比:ROC 曲线 + 准确率柱状图 +添加新模型只需在 MODELS 列表加一行,无需修改绘图代码 +author: yukun-hh +date: 2026-5-14 +""" +import sys, os, re +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +from tqdm import tqdm + +from sklearn.metrics import ( + roc_curve, auc, accuracy_score, + precision_recall_curve, average_precision_score, +) + +from Model import Net +from Dataloader import RobustImageFolder +from baseline.VGG_KNN import VGGKNNBaseline +from baseline.ResNet34_Pretrained_10pct import get_resnet34_10pct_preds +from baseline.HOG_Baseline import get_hog_lr_preds + +matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] +matplotlib.rcParams['axes.unicode_minus'] = False + +# ============================================================ +# ★★★ 可配置参数 ★★★ +# ============================================================ +DATA_ROOT = '../../trash_division_data/ultimate_4_class/' +BATCH_SIZE = 32 +IMAGE_SIZE = 256 +NUM_WORKERS = 4 +K_KNN = 5 +# ============================================================ + +CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] +NUM_CLASSES = 4 + +# ============================================================ +# 预测函数 — 每个函数签名: (train_loader, val_loader, device) -> (y_true, y_preds, y_probs) +# ============================================================ + +def get_resnet34_preds(train_loader, val_loader, device): + model = Net(num_classes=NUM_CLASSES) + state_dict = torch.load('../best_model.pth', map_location='cpu') + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + elif 'model' in state_dict: + state_dict = state_dict['model'] + model.load_state_dict(state_dict) + model = model.to(device).eval() + + y_true, y_preds, y_probs = [], [], [] + with torch.no_grad(): + for images, labels in tqdm(val_loader, desc='ResNet-34'): + images, labels = images.to(device), labels + logits = model(images) + probs = torch.softmax(logits, dim=1) + preds = probs.argmax(dim=1) + y_true.append(labels.numpy()) + y_preds.append(preds.cpu().numpy()) + y_probs.append(probs.cpu().numpy()) + return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs) + + +def get_vgg_knn_preds(train_loader, val_loader, device): + baseline = VGGKNNBaseline(k=K_KNN, device=device) + baseline.fit(train_loader) + return baseline.predict(val_loader) + + +# ============================================================ +# ★ 模型注册表 — 添加新模型只需在这里加一行 ★ +# ============================================================ + +MODELS = [ + ('ResNet-34', get_resnet34_preds), + ('ResNet-34 (10% Fine-tune)', get_resnet34_10pct_preds), + ('VGG16 + KNN (K=5)', get_vgg_knn_preds), + ('HOG + LogisticRegression', get_hog_lr_preds), + # 未来轻松扩展示例: + # ('ResNet-18 (pretrained)', get_resnet18_preds), + # ('ResNet-50 (pretrained)', get_resnet50_preds), + # ('ResNet-34 (finetuned)', get_finetuned_preds), +] + +# ============================================================ +# 调色板 (扩展时无需修改) +# ============================================================ +COLORS = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', + '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] + + +def compute_macro_roc(y_true, y_probs): + one_hot = np.eye(NUM_CLASSES)[y_true] + fpr_dict, tpr_dict = {}, {} + for c in range(NUM_CLASSES): + fpr_dict[c], tpr_dict[c], _ = roc_curve(one_hot[:, c], y_probs[:, c]) + all_fpr = np.unique(np.concatenate([fpr_dict[c] for c in range(NUM_CLASSES)])) + mean_tpr = np.zeros_like(all_fpr) + for c in range(NUM_CLASSES): + mean_tpr += np.interp(all_fpr, fpr_dict[c], tpr_dict[c]) + mean_tpr /= NUM_CLASSES + macro_auc = auc(all_fpr, mean_tpr) + return all_fpr, mean_tpr, macro_auc + + +def compute_macro_pr(y_true, y_probs): + one_hot = np.eye(NUM_CLASSES)[y_true] + prec_dict, rec_dict = {}, {} + for c in range(NUM_CLASSES): + prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c]) + all_rec = np.linspace(0, 1, 200) + mean_prec = np.zeros_like(all_rec) + for c in range(NUM_CLASSES): + mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1]) + mean_prec /= NUM_CLASSES + macro_ap = average_precision_score(one_hot, y_probs, average='macro') + return all_rec, mean_prec, macro_ap + + +def sanitize_filename(name): + return re.sub(r'[^\w\-_]', '_', name).strip('_') + + +def preds_csv_path(out_dir, model_name): + safe = sanitize_filename(model_name) + return os.path.join(out_dir, f'{safe}_preds.csv') + + +def save_preds_csv(path, y_true, y_preds, y_probs): + header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES)) + data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs]) + np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f') + + +def load_preds_csv(path): + data = np.loadtxt(path, delimiter=',', skiprows=1) + y_true = data[:, 0].astype(int) + y_preds = data[:, 1].astype(int) + y_probs = data[:, 2:2 + NUM_CLASSES] + return y_true, y_preds, y_probs + + +if __name__ == '__main__': + out_dir = os.path.dirname(os.path.abspath(__file__)) + + device = torch.device('cuda' if torch.cuda.is_available() + else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available() + else 'cpu') + print(f"Device: {device}") + + val_transform = transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + + train_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'train'), + transform=val_transform) + val_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'val'), + transform=val_transform) + print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}") + + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, + num_workers=NUM_WORKERS, pin_memory=True, drop_last=False) + val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, + num_workers=NUM_WORKERS, pin_memory=True, drop_last=False) + + # ———— 评估所有模型(有缓存则跳过)———— + results = {} + for name, func in MODELS: + print(f"\n{'='*50}") + csv_path = preds_csv_path(out_dir, name) + if os.path.exists(csv_path): + print(f"加载缓存: {os.path.basename(csv_path)}") + y_true, y_preds, y_probs = load_preds_csv(csv_path) + else: + print(f"评估: {name}") + y_true, y_preds, y_probs = func(train_loader, val_loader, device) + save_preds_csv(csv_path, y_true, y_preds, y_probs) + print(f" 预测数据已保存: {os.path.basename(csv_path)}") + acc = accuracy_score(y_true, y_preds) + fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs) + rec, prec, macro_ap = compute_macro_pr(y_true, y_probs) + results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs, + 'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc, + 'rec': rec, 'prec': prec, 'ap': macro_ap} + print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}") + + # ———— ROC 对比图 ———— + fig, ax = plt.subplots(figsize=(8, 7)) + for i, (name, r) in enumerate(results.items()): + color = COLORS[i % len(COLORS)] + ax.plot(r['fpr'], r['tpr'], color=color, lw=2, + label=f"{name} (AUC={r['auc']:.4f})") + ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5) + ax.set_xlim(0, 1); ax.set_ylim(0, 1.05) + ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate') + ax.set_title('ROC Curve Comparison (Macro-Average)', fontsize=14) + ax.legend(loc='lower right'); ax.grid(True, alpha=0.3) + plt.tight_layout() + roc_path = os.path.join(out_dir, 'roc_comparison.png') + plt.savefig(roc_path, dpi=150, bbox_inches='tight') + plt.show() + print(f"\nROC 对比图已保存: {roc_path}") + + # ———— PR 对比图 ———— + fig, ax = plt.subplots(figsize=(8, 7)) + for i, (name, r) in enumerate(results.items()): + color = COLORS[i % len(COLORS)] + ax.plot(r['rec'], r['prec'], color=color, lw=2, + label=f"{name} (AP={r['ap']:.4f})") + ax.set_xlim(0, 1); ax.set_ylim(0, 1.05) + ax.set_xlabel('Recall'); ax.set_ylabel('Precision') + ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14) + ax.legend(loc='lower left'); ax.grid(True, alpha=0.3) + plt.tight_layout() + pr_path = os.path.join(out_dir, 'pr_comparison.png') + plt.savefig(pr_path, dpi=150, bbox_inches='tight') + plt.show() + print(f"PR 对比图已保存: {pr_path}") + + # ———— 准确率柱状图 ———— + names = list(results.keys()) + accs = [results[n]['acc'] for n in names] + fig, ax = plt.subplots(figsize=(8, 5)) + bar_colors = [COLORS[i % len(COLORS)] for i in range(len(names))] + bars = ax.bar(names, accs, color=bar_colors, edgecolor='white', linewidth=1.2) + for bar, acc in zip(bars, accs): + ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, + f'{acc:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold') + ax.set_ylim(min(accs) - 0.03, max(accs) * 1.08) + ax.set_ylabel('Accuracy'); ax.set_title('Accuracy Comparison', fontsize=14) + ax.grid(True, alpha=0.3, axis='y') + plt.tight_layout() + bar_path = os.path.join(out_dir, 'accuracy_bar.png') + plt.savefig(bar_path, dpi=150, bbox_inches='tight') + plt.show() + print(f"准确率柱状图已保存: {bar_path}") diff --git a/baseline/pr_comparison.png b/baseline/pr_comparison.png new file mode 100644 index 0000000..03f80ef Binary files /dev/null and b/baseline/pr_comparison.png differ diff --git a/baseline/roc_comparison.png b/baseline/roc_comparison.png new file mode 100644 index 0000000..66ae8cc Binary files /dev/null and b/baseline/roc_comparison.png differ diff --git a/confusion_matrix.png b/confusion_matrix.png new file mode 100644 index 0000000..765f7c9 Binary files /dev/null and b/confusion_matrix.png differ diff --git a/pr_curve.png b/pr_curve.png new file mode 100644 index 0000000..f24c81d Binary files /dev/null and b/pr_curve.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f612349 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +torch>=1.10 +torchvision>=0.11 +tqdm +matplotlib +pandas +Pillow +scikit-learn +scikit-image +numpy +torchsummary diff --git a/roc_curve.png b/roc_curve.png new file mode 100644 index 0000000..05cdfbd Binary files /dev/null and b/roc_curve.png differ diff --git a/training_curves.png b/training_curves.png new file mode 100644 index 0000000..482cf5e Binary files /dev/null and b/training_curves.png differ diff --git a/training_log.csv b/training_log.csv new file mode 100644 index 0000000..9795ca2 --- /dev/null +++ b/training_log.csv @@ -0,0 +1,81 @@ +epoch,train_loss,train_f1,train_acc,val_loss,val_f1,val_acc,lr,best +1,1.0409312975676923,0.4329540729522705,48.04254100337675,1.1043149583345566,0.4398210048675537,48.66548042704626,0.004998072590601808,best +2,0.9862563783744079,0.4695238769054413,52.5943680656054,0.9867177669753319,0.5062971115112305,58.397207774431976,0.00499229333433282,best +3,0.9462850892451784,0.49421826004981995,55.40279787747226,1.0144445673589866,0.4907984733581543,53.15494114426499,0.004982671142387316, +4,0.910117163585685,0.514958381652832,57.832097202122526,0.8787865286946395,0.5453917980194092,62.544483985765126,0.004969220851487844,best +5,0.8786031692946986,0.5320333242416382,59.74282440906898,1.0686318878927787,0.4803737998008728,52.73063235696688,0.004951963201008076, +6,0.8518873820889128,0.5481140613555908,61.51938615533044,0.7650798693964196,0.6073676347732544,68.98439638653161,0.004930924800994191,best +7,0.8256270786701512,0.5604796409606934,62.90249638205499,0.8796401012116773,0.5789190530776978,62.63345195729537,0.004906138091134118, +8,0.8003699506646013,0.5742803812026978,64.3014351181862,0.9246643470014378,0.5521833896636963,60.88831097727895,0.004877641290737884, +9,0.780536473588097,0.5827116966247559,65.25642185238785,0.8404132533719564,0.5876226425170898,65.89789214344374,0.00484547833980621, +10,0.7604798049209087,0.595557451248169,66.54079232995659,0.9228097118810533,0.564703643321991,60.77881193539557,0.004809698831278217, +11,0.7410275131047088,0.6043155789375305,67.35784491075735,0.7576604621266131,0.6295210123062134,69.83985765124555,0.0047703579345627035,best +12,0.7195374228732343,0.6127941608428955,68.07766521948867,0.9624476881507583,0.5630610585212708,61.59321105940323,0.00472751631047092, +13,0.6997139808122973,0.6210756301879883,68.85175470332851,0.7615812296349155,0.6177672147750854,69.12127018888584,0.004681240017681994, +14,0.6824904908837182,0.630592942237854,69.65448625180898,0.6715762626299165,0.6534035205841064,73.48754448398577,0.004631600410885231,best +15,0.6653590450468583,0.6379610300064087,70.39088880849012,0.694461440047988,0.6517682075500488,73.0906104571585,0.004578674030756364, +16,0.6514209577758935,0.6478185653686523,71.27879281234925,0.7036816360785346,0.6470745801925659,71.76977826444019,0.004522542485937369, +17,0.6330186040776395,0.6530008316040039,71.7935962373372,0.7222367905930823,0.6418735980987549,71.07856556255133,0.004463292327201863, +18,0.6166394593717968,0.6634106040000916,72.6038651712494,0.6067476719332303,0.6886636018753052,77.49110320284697,0.004401014914000078,best +19,0.5973944908975692,0.6721534729003906,73.47367945007235,0.6952472055509845,0.6622275114059448,71.79715302491103,0.004335806273589214, +20,0.5820678306183721,0.6758297681808472,73.73145803183792,0.7708474785342401,0.6217234134674072,68.74486723241172,0.004267766952966369, +21,0.5650806297110982,0.6851130723953247,74.5741377231066,0.7461620579141478,0.6384793519973755,71.35231316725978,0.004197001863832355, +22,0.5500074683958588,0.6915749311447144,75.099493487699,0.6420613189380593,0.672810435295105,74.9452504790583,0.00412362012082546, +23,0.5367840825001858,0.6979560852050781,75.66102870236372,0.6252713002082977,0.6949211359024048,75.32849712565014,0.0040477348732745845,best +24,0.5234906795055925,0.7052106857299805,76.26025084418717,0.7471277477021352,0.6447888016700745,69.70982753900904,0.003969463130731182, +25,0.5044557179049829,0.7132176160812378,76.91148094548963,0.6325626891507145,0.6857903003692627,75.52012044894607,0.0038889255825490052, +26,0.4938347232885195,0.7174828052520752,77.23784973468403,0.5635758755127437,0.70375657081604,78.50396934026827,0.003806246411789872,best +27,0.4793313278116239,0.7242900133132935,77.87475880366618,0.5505193975847648,0.7201660871505737,78.89405967697783,0.003721553103742388,best +28,0.46573570758837524,0.7336312532424927,78.60437771345876,0.640248272807638,0.6859503984451294,75.13003011223651,0.003634976249348867, +29,0.44927708754289913,0.737967312335968,78.9352689339122,0.6151526539644867,0.7065733075141907,74.69887763482069,0.00354664934384357, +30,0.4373503708129221,0.7443608045578003,79.40409430776653,0.5578661908719627,0.7272701263427734,78.20969066520668,0.0034567085809127244,best +31,0.42717206794400175,0.7488712072372437,79.7779486251809,0.58909761693554,0.7034546136856079,76.88201478237066,0.003365292642693732, +32,0.4100124511779706,0.7580570578575134,80.60630728412929,0.6458172935624336,0.6865078210830688,75.32849712565014,0.0032725424859373683, +33,0.3993677339991451,0.763430118560791,80.9876989869754,0.47995706558097007,0.754202127456665,81.65891048453327,0.003178601124662685,best +34,0.3858378949555808,0.7697042226791382,81.54772672455378,0.6427663844838523,0.6931804418563843,74.28141253764029,0.0030836134096397633, +35,0.37397055771404913,0.7764154672622681,81.9992161119151,0.6085299244000101,0.7046636343002319,77.08048179578428,0.0029877258050403205, +36,0.3597575335889638,0.7818952798843384,82.53587795465509,0.5254679805051781,0.7415529489517212,78.89405967697783,0.002891086162600577, +37,0.3487578573732404,0.7871347665786743,82.8999336710082,0.5125140355052995,0.748577356338501,80.1875171092253,0.002793843493644594, +38,0.3325814358527052,0.7965956330299377,83.59035817655571,0.5408413317834649,0.7290798425674438,79.87270736381056,0.002696147739319612, +39,0.3261546248608721,0.7988470792770386,83.75316570188133,0.5301555857539602,0.7376729249954224,80.06433068710649,0.002598149539397671, +40,0.30964642827472305,0.8070269823074341,84.52725518572117,0.5468305750249544,0.7365171313285828,78.73665480427046,0.0024999999999999996, +41,0.3009217412674191,0.8119726777076721,84.96140858658949,0.46898612490165015,0.7599539756774902,82.18587462359704,0.002401850460602329,best +42,0.28925693789887874,0.8200639486312866,85.6458031837916,0.5167866427621677,0.7465909123420715,80.83766767040788,0.0023038522606803878, +43,0.2707157838268379,0.8313596248626709,86.46360950313556,0.5156203284349763,0.7596548199653625,80.6049822064057,0.0022061565063554063, +44,0.2580273799019566,0.836384654045105,86.8412325132658,0.5318487190707494,0.746901273727417,80.27648508075555,0.0021089138373994237, +45,0.2504911308580703,0.8410984873771667,87.30779667149059,0.49164087495639364,0.763725221157074,81.82315904735833,0.00201227419495968,best +46,0.24104372076995695,0.8451772928237915,87.64094910757356,0.5290114263981752,0.7580969333648682,80.94032302217356,0.0019163865903602372, +47,0.22337641519549614,0.8570870161056519,88.56201760733236,0.43634469677838694,0.7913081049919128,85.10128661374213,0.0018213988753373142,best +48,0.2128122905210861,0.8645581603050232,89.1152616980222,0.43183545479456625,0.7972898483276367,85.49137695045168,0.001727457514062632,best +49,0.2003101470318182,0.8717849254608154,89.69187168355042,0.4289672785945178,0.806715726852417,85.7993430057487,0.0016347073573062686,best +50,0.1888495338707803,0.8796613216400146,90.34687047756874,0.4568697272132202,0.7956517338752747,84.64275937585546,0.0015432914190872762, +51,0.1756466486088274,0.886497437953949,90.97096599131693,0.4541556305781541,0.7938134670257568,84.45113605255953,0.001453350656156431, +52,0.16742963044469736,0.8907681703567505,91.3101483357453,0.4230425570913775,0.8147625923156738,86.100465370928,0.0013650237506511336,best +53,0.15311133117022804,0.9007841944694519,92.11212614568258,0.4146759752586969,0.8208259344100952,86.83274021352314,0.0012784468962576128,best +54,0.1423091164071722,0.9078108668327332,92.6714001447178,0.4709351422719488,0.8029968738555908,85.71721872433616,0.0011937535882101285, +55,0.13160189816137902,0.9135022163391113,93.10932223830197,0.40829240685941226,0.8264325857162476,87.42814125376403,0.0011110744174509947,best +56,0.12707359487800943,0.9174070358276367,93.4409671972986,0.42565728100299705,0.8230471611022949,87.65398302764851,0.0010305368692688178, +57,0.11291898237482914,0.925153374671936,94.10953328509407,0.43774206922898184,0.8247347474098206,87.5376402956474,0.0009522651267254161, +58,0.10370979833767516,0.9329074025154114,94.62584418716835,0.4068256767521223,0.8333848118782043,88.05091705447578,0.0008763798791745416,best +59,0.0946220946482491,0.9380815029144287,95.09768451519537,0.41103389083751746,0.8357677459716797,88.4889132220093,0.0008029981361676465,best +60,0.08804238645213414,0.9423004388809204,95.45194163048721,0.4207381762007377,0.8308929204940796,88.17410347659458,0.0007322330470336316, +61,0.07913849578165794,0.9495129585266113,95.9916184273999,0.4083157278823748,0.8420299291610718,89.00218998083767,0.0006641937264107861,best +62,0.06981624146470565,0.9559226036071777,96.49662325132658,0.41332166066585485,0.843511700630188,88.98850260060225,0.0005989850859999229,best +63,0.06394793773276639,0.9602090120315552,96.79811866859623,0.41052334102789995,0.8489691019058228,89.35121817684096,0.0005367076727981376,best +64,0.057007493751794744,0.9636315107345581,97.11016642547034,0.3970402057488879,0.8546013832092285,90.04927456884752,0.00047745751406263185,best +65,0.05285091761448427,0.967146635055542,97.40035576459238,0.4247235585867853,0.8433754444122314,89.34437448672324,0.0004213259692436376, +66,0.04614944799553407,0.9710246324539185,97.6912988422576,0.4035414461747053,0.8538572788238525,89.85080755543389,0.00036839958911476966, +67,0.042909558690492254,0.9727644920349121,97.8428002894356,0.41578795896298,0.8530543446540833,89.91924445661101,0.0003187599823180077, +68,0.03640206224887977,0.9769999980926514,98.15710926193921,0.42073161891477256,0.8551151156425476,89.94661921708185,0.0002724836895290806,best +69,0.034432517173010414,0.9790080785751343,98.34328268210324,0.4133664407223553,0.8576940298080444,90.28880372296743,0.00022964206543729668,best +70,0.030669637766023813,0.9817556142807007,98.5249336710082,0.418308269951463,0.8569411039352417,90.12455516014235,0.00019030116872178321, +71,0.028112305183924133,0.9827903509140015,98.606337433671,0.4151474575991667,0.8596312999725342,90.37092800437996,0.00015452166019378966,best +72,0.024704152367817256,0.9853801727294922,98.81135431741437,0.4153705465811558,0.8635820746421814,90.68573774979468,0.0001223587092621162,best +73,0.024846541488804174,0.9855506420135498,98.8369814278823,0.4177400436290088,0.8632140159606934,90.59676977826444,9.38619088658821e-05, +74,0.022639600746625622,0.9868491888046265,98.94702725518572,0.41732572613841307,0.8648342490196228,90.78154941144265,6.907519900580863e-05,best +75,0.02120214593173326,0.9878177642822266,99.0231548480463,0.4163925825270714,0.866214394569397,90.89789214344374,4.803679899192394e-05,best +76,0.019741657997631577,0.9883521795272827,99.06385672937772,0.42005763620917286,0.8647006750106812,90.82945524226663,3.077914851215586e-05, +77,0.019116416042495393,0.9889511466026306,99.10003617945007,0.4159400745789841,0.8657370805740356,90.84998631261976,1.7328857612684272e-05, +78,0.019259902796210714,0.9888157844543457,99.0962674867342,0.4192042892654481,0.8641382455825806,90.69942513003011,7.706665667180091e-06, +79,0.01933925595445387,0.9887675046920776,99.0759165460685,0.4180937778044573,0.8662786483764648,90.84998631261976,1.9274093981927482e-06,best +80,0.01922732148408437,0.9889604449272156,99.10078991799324,0.41794140280912484,0.864332914352417,90.82261155214891,0.0,