重置所有文件与main一致,新增app.py(Gradio推理前端)

This commit is contained in:
weikaiwen348-code 2026-06-03 20:33:22 +08:00
parent 543e833fd0
commit 0e9bd58dce
24 changed files with 1957 additions and 168 deletions

30
.gitignore vendored Normal file
View file

@ -0,0 +1,30 @@
*
!/.gitattributes
!/Dataloader.py
!/LICENSE
!/Merge_classes.py
!/Model.py
!/README.md
!/requirements.txt
!/THIRD_PARTY_LICENSES.md
!/Train.py
!/app.py
!/Baseline.py
!/Finetune.py
!/Curve.py
!/Evaluate.py
!/baseline/
!/baseline/__init__.py
!/baseline/VGG_KNN.py
!/baseline/compare_models.py
!/baseline/ResNet34_Pretrained_10pct.py
!/baseline/HOG_Baseline.py
!/baseline/roc_comparison.png
!/baseline/pr_comparison.png
!/baseline/accuracy_bar.png
!/training_log.csv
!/confusion_matrix.png
!/roc_curve.png
!/pr_curve.png
!/training_curves.png
!.gitignore

50
Curve.py Normal file
View file

@ -0,0 +1,50 @@
"""
plot_training_curves.py
training_log.csv 读取日志绘制 Loss / F1 / Accuracy / LR 曲线
"""
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============ 读取数据 ============
df = pd.read_csv('training_log.csv')
best_rows = df[df['best'] == 'best']
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# ---- 1. Loss ----
ax = axes[0, 0]
ax.plot(df['epoch'], df['train_loss'], label='Train Loss', color='#1f77b4', lw=1.5)
ax.plot(df['epoch'], df['val_loss'], label='Val Loss', color='#ff7f0e', lw=1.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.set_title('Loss vs Epoch')
ax.legend(); ax.grid(True, alpha=0.3)
# ---- 2. F1 Score ----
ax = axes[0, 1]
ax.plot(df['epoch'], df['train_f1'], label='Train F1', color='#1f77b4', lw=1.5)
ax.plot(df['epoch'], df['val_f1'], label='Val F1', color='#ff7f0e', lw=1.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('F1 Score'); ax.set_title('F1 Score vs Epoch')
ax.legend(); ax.grid(True, alpha=0.3)
# ---- 3. Accuracy ----
ax = axes[1, 0]
ax.plot(df['epoch'], df['train_acc'], label='Train Acc', color='#1f77b4', lw=1.5)
ax.plot(df['epoch'], df['val_acc'], label='Val Acc', color='#ff7f0e', lw=1.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Accuracy (%)'); ax.set_title('Accuracy vs Epoch')
ax.legend(); ax.grid(True, alpha=0.3)
# ---- 4. Learning Rate ----
ax = axes[1, 1]
ax.plot(df['epoch'], df['lr'], color='#2ca02c', lw=1.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Learning Rate'); ax.set_title('Learning Rate vs Epoch')
ax.ticklabel_format(style='scientific', axis='y', scilimits=(0, 0))
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print("训练曲线已保存: training_curves.png")

View file

@ -13,9 +13,75 @@ import os
from PIL import Image from PIL import Image
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from tqdm import tqdm
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def create_dataloaders(data_root='..', class RobustImageFolder(Dataset):
"""包装 ImageFolder自动跳过损坏图片带进度条"""
def __init__(self, root, transform=None):
self.transform = transform
self.samples = []
self.classes = []
self.class_to_idx = {}
# 先构建原始的 ImageFolder 来获取类别信息
temp_dataset = datasets.ImageFolder(root, transform=None)
self.classes = temp_dataset.classes
self.class_to_idx = temp_dataset.class_to_idx
# 带进度条扫描
print(f"\n正在扫描: {root}")
print(f"发现 {len(temp_dataset.samples)} 个文件,开始验证...\n")
corrupted_count = 0
success_count = 0
# 使用 tqdm 显示进度
for path, label in tqdm(temp_dataset.samples,
desc="验证图片完整性",
unit="",
ncols=80):
try:
self.samples.append((path, label))
success_count += 1
except Exception as e:
corrupted_count += 1
# 可选只打印前10个错误避免刷屏
if corrupted_count <= 10:
tqdm.write(f"⚠️ 跳过损坏: {os.path.basename(path)}")
elif corrupted_count == 11:
tqdm.write(f"⚠️ 后续损坏图片将不再显示...")
print(f"\n✅ 扫描完成!")
print(f" 📁 有效图片: {success_count}")
print(f" ❌ 损坏跳过: {corrupted_count}")
print(f" 📊 总计: {len(self.samples)}\n")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, label = self.samples[idx]
try:
img = Image.open(path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
except Exception as e:
# 极少数情况,返回下一个
return self.__getitem__((idx + 1) % len(self))
def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
batch_size=32, batch_size=32,
image_size=256, image_size=256,
val_split=0.2, val_split=0.2,
@ -77,12 +143,12 @@ def create_dataloaders(data_root='..',
# 2. 加载数据集 # 2. 加载数据集
# ================================== # ==================================
print("使用独立的 val 文件夹") print("使用独立的 val 文件夹")
train_dataset = datasets.ImageFolder( train_dataset = RobustImageFolder(
root=os.path.join(data_root, 'train'), root=os.path.join(data_root, 'train'),
transform=train_transform if augment else val_transform transform=train_transform if augment else val_transform
) )
val_dataset = datasets.ImageFolder( val_dataset = RobustImageFolder(
root=os.path.join(data_root, 'val'), root=os.path.join(data_root, 'val'),
transform=val_transform transform=val_transform
) )
@ -111,9 +177,8 @@ def create_dataloaders(data_root='..',
) )
# 4. 获取类别名称 # 4. 获取类别名称
class_names = train_dataset.classes if hasattr(train_dataset, 'classes') else ['0', '1', '2', '3'] class_names = train_dataset.classes
print(f"类别: {class_names}") print(f"类别: {class_names}")
print(f"类别映射: {train_dataset.class_to_idx if hasattr(train_dataset, 'class_to_idx') else '0-3'}")
return train_loader, val_loader, class_names return train_loader, val_loader, class_names
@ -158,10 +223,10 @@ def visualize_batch(dataloader, class_names, num_images=8):
if __name__ == '__main__': if __name__ == '__main__':
train_loader, val_loader, class_names = create_dataloaders( train_loader, val_loader, class_names = create_dataloaders(
data_root='..', # 与trash-division同级文件夹 data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
batch_size=32, # 根据你的显存调整 batch_size=16, # 根据你的显存调整
image_size=256, # 与你模型输入一致 image_size=256, # 与你模型输入一致
num_workers=4, # Windows 可能需设为 0 num_workers=16, # Windows 可能需设为 0
augment=True # 训练时使用数据增强 augment=True # 训练时使用数据增强
) )
visualize_batch(train_loader, class_names, num_images=8) visualize_batch(train_loader, class_names, num_images=8)

147
Evaluate.py Normal file
View file

@ -0,0 +1,147 @@
"""
evaluate_and_plot.py
加载模型在验证集上推理绘制混淆矩阵 / ROC / PR 曲线
"""
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import (
confusion_matrix, ConfusionMatrixDisplay,
roc_curve, auc,
precision_recall_curve, average_precision_score,
)
from Model import Net
from Dataloader import RobustImageFolder
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 需要你修改的参数 ★★★
# ============================================================
MODEL_PATH = 'best_model.pth' # 模型权重路径
DATA_ROOT = '../trash_division_data/ultimate_4_class/' # 数据集根目录
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
# ============================================================
# ---------- 1. 加载验证集 ----------
val_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
val_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'val'),
transform=val_transform,
)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=False)
class_names = val_dataset.classes
num_classes = len(class_names)
print(f"类别: {class_names}")
# ---------- 2. 加载模型 ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(num_classes=num_classes)
state_dict = torch.load(MODEL_PATH, map_location=device)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict)
model = model.to(device).eval()
print("模型加载完成")
# ---------- 3. 推理 ----------
all_labels = []
all_probs = []
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
probs = torch.softmax(model(images), dim=1)
all_labels.append(labels.numpy())
all_probs.append(probs.cpu().numpy())
all_labels = np.concatenate(all_labels)
all_probs = np.concatenate(all_probs)
all_preds = np.argmax(all_probs, axis=1)
print(f"推理完成, 共 {len(all_labels)} 样本")
# ============================================================
# ① 混淆矩阵
# ============================================================
cm = confusion_matrix(all_labels, all_preds)
fig, ax = plt.subplots(figsize=(8, 7))
ConfusionMatrixDisplay(cm, display_labels=class_names).plot(
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
ax.set_title('Confusion Matrix', fontsize=14)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()
print("混淆矩阵已保存: confusion_matrix.png")
# ============================================================
# ② ROC 曲线 (One-vs-Rest + Macro-average)
# ============================================================
one_hot = np.eye(num_classes)[all_labels]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
fig, ax = plt.subplots(figsize=(8, 7))
fpr_d, tpr_d, auc_d = {}, {}, {}
for i in range(num_classes):
fpr_d[i], tpr_d[i], _ = roc_curve(one_hot[:, i], all_probs[:, i])
auc_d[i] = auc(fpr_d[i], tpr_d[i])
ax.plot(fpr_d[i], tpr_d[i], color=colors[i], lw=2,
label=f'{class_names[i]} (AUC={auc_d[i]:.4f})')
# Macro-average
all_fpr = np.unique(np.concatenate([fpr_d[i] for i in range(num_classes)]))
mean_tpr = sum(np.interp(all_fpr, fpr_d[i], tpr_d[i]) for i in range(num_classes)) / num_classes
macro_auc = auc(all_fpr, mean_tpr)
ax.plot(all_fpr, mean_tpr, 'navy', lw=2, ls='--',
label=f'Macro-avg (AUC={macro_auc:.4f})')
ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve', fontsize=14)
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('roc_curve.png', dpi=150, bbox_inches='tight')
plt.show()
print("ROC 曲线已保存: roc_curve.png")
# ============================================================
# ③ Precision-Recall 曲线
# ============================================================
fig, ax = plt.subplots(figsize=(8, 7))
for i in range(num_classes):
prec, rec, _ = precision_recall_curve(one_hot[:, i], all_probs[:, i])
ap = average_precision_score(one_hot[:, i], all_probs[:, i])
ax.plot(rec, prec, color=colors[i], lw=2,
label=f'{class_names[i]} (AP={ap:.4f})')
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curve', fontsize=14)
ax.legend(loc='best'); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('pr_curve.png', dpi=150, bbox_inches='tight')
plt.show()
print("PR 曲线已保存: pr_curve.png")

225
Finetune.py Normal file
View file

@ -0,0 +1,225 @@
"""
微调脚本冻结 conv1 + stage2微调 stage3~fc
加大少样本类别的 loss 权重
author: yukun-hh
date 2026-4-25
"""
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from Model import Net
from Dataloader import create_dataloaders
import os
import csv
def compute_macro_f1(predicted, targets, num_classes=4):
tp = torch.zeros(num_classes, device=predicted.device)
fp = torch.zeros(num_classes, device=predicted.device)
fn = torch.zeros(num_classes, device=predicted.device)
for c in range(num_classes):
tp[c] = ((predicted == c) & (targets == c)).sum()
fp[c] = ((predicted == c) & (targets != c)).sum()
fn[c] = ((predicted != c) & (targets == c)).sum()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return f1.mean().item()
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
model.train()
running_loss = 0.0
correct = 0
total = 0
all_preds, all_labels = [], []
pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
all_preds.append(predicted)
all_labels.append(labels)
batch_f1 = compute_macro_f1(predicted, labels)
pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}', 'Acc': f'{100. * correct / total:.2f}%'})
epoch_loss = running_loss / total
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total
return epoch_loss, epoch_f1, epoch_acc
def validate(model, val_loader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
all_preds, all_labels = [], []
with torch.no_grad():
for images, labels in tqdm(val_loader, desc='[Validate]'):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
all_preds.append(predicted)
all_labels.append(labels)
epoch_loss = running_loss / total
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total
return epoch_loss, epoch_f1, epoch_acc
def compute_class_weights(dataset, num_classes=4, device='cpu', power=1.0):
class_counts = torch.zeros(num_classes)
for _, label in dataset.samples:
lbl = label.item() if isinstance(label, torch.Tensor) else label
class_counts[lbl] += 1
total = class_counts.sum()
weights = total / (num_classes * class_counts)
weights = weights ** power
return weights.to(device)
def freeze_base_layers(model):
frozen_layers = []
for name, param in model.conv1.named_parameters():
param.requires_grad = False
frozen_layers.append(f'conv1.{name}')
for name, param in model.bn1.named_parameters():
param.requires_grad = False
frozen_layers.append(f'bn1.{name}')
for name, param in model.stage2.named_parameters():
param.requires_grad = False
frozen_layers.append(f'stage2.{name}')
for name, param in model.stage3.named_parameters():
param.requires_grad = False
frozen_layers.append(f'stage3.{name}')
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'冻结层数: {len(frozen_layers)} 个参数组')
print(f'可训练参数量: {trainable:,} / {total:,} ({100. * trainable / total:.1f}%)')
return model
def finetune(model, train_loader, val_loader, epochs=30, lr=0.0001, device='cuda'):
class_weights = compute_class_weights(train_loader.dataset, num_classes=4, device=device, power=1.5)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=lr, momentum=0.9, weight_decay=1e-4
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
history = {
'train_loss': [],
'train_f1': [],
'train_acc': [],
'val_loss': [],
'val_f1': [],
'val_acc': []
}
best_val_f1 = 0.0
log_file = open('finetune_log.csv', 'w', newline='')
log_writer = csv.writer(log_file)
log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc', 'val_loss', 'val_f1', 'val_acc', 'lr', 'best'])
for epoch in range(epochs):
print(f'\n{"=" * 50}')
print(f'Epoch {epoch + 1}/{epochs}')
train_loss, train_f1, train_acc = train_one_epoch(model, train_loader, criterion,
optimizer, device, epoch)
val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device)
scheduler.step()
history['train_loss'].append(train_loss)
history['train_f1'].append(train_f1)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_f1'].append(val_f1)
history['val_acc'].append(val_acc)
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}')
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}')
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
best_mark = ''
if val_f1 > best_val_f1:
best_val_f1 = val_f1
torch.save(model.state_dict(), 'finetuned_model.pth')
best_mark = 'best'
print(f'✓ 保存最佳微调模型 (Macro-F1: {val_f1:.4f})')
lr = optimizer.param_groups[0]['lr']
log_writer.writerow([epoch + 1, train_loss, train_f1, train_acc, val_loss, val_f1, val_acc, lr, best_mark])
log_file.flush()
print(f'\n{"=" * 50}')
print(f'微调完成!最佳验证 Macro-F1: {best_val_f1:.4f}')
return model, history
if __name__ == '__main__':
train_loader, val_loader, class_names = create_dataloaders(
data_root='../trash_division_data/ultimate_4_class/',
batch_size=16,
image_size=256,
num_workers=8,
augment=True
)
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
model = Net(num_classes=4)
if os.path.exists('best_model.pth'):
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
print('✓ 加载预训练权重 best_model.pth')
else:
print('⚠ 未找到 best_model.pth使用随机初始化权重')
model = model.to(device)
model = freeze_base_layers(model)
print(f'Device: {device}')
print(f'Total parameters: {sum(p.numel() for p in model.parameters()):,}')
trained_model, history = finetune(
model=model,
train_loader=train_loader,
val_loader=val_loader,
epochs=30,
lr=0.0001,
device=device
)
model.load_state_dict(torch.load('finetuned_model.pth'))

View file

@ -1,5 +1,5 @@
"""将原数据集合并为我们需要的四个大类 """将原数据集合并为我们需要的四个大类
运行时先配置路径 已修改成相对路径 具体配置方法详见README.md
author author
weikaiwen weikaiwen
@ -18,9 +18,9 @@ import shutil
# ================= 1. 配置你的路径 ================= # ================= 1. 配置你的路径 =================
# 注意:请确保相对路径正确,以下为示例 # 注意:请确保相对路径正确,以下为示例
ORIGINAL_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data' # 原始数据集的目录 ORIGINAL_DATA_DIR = '../trash_division_data' # 原始数据集的目录
NEW_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data/ultimate_4_class' # 合并后的新目录 NEW_DATA_DIR = '../trash_division_data/ultimate_4_class' # 合并后的新目录
CLASSNAME_FILE = '/Users/weikaiwen/Desktop/trash_division_data/val/classname.txt' # txt 文件的位置 CLASSNAME_FILE = '../trash_division_data/val/classname.txt' # txt 文件的位置
# =================================================== # ===================================================

180
Model.py
View file

@ -1,99 +1,113 @@
""" """
这个文件是模型的定义文件请不要擅自修改如有疑问微信群里反馈 模型定义文件 - ResNet-34
单独运行本文件将会输出模型结构
目前的话是一个36层的模型模型总量应该是在80M左右 如果到时候还是欠拟合的话再考虑去做更深的结构
author : yukun-hh author : yukun-hh
date : 2026-4-10 date : 2026-4-10
""" """
#神经网络模型库
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torchsummary import summary from torchsummary import summary
#残差块
class Resblock(nn.Module):
def __init__(self, input_channels,output_channels,use_1x1conv=False,strides=1):
"""
:param input_channels: 进入残差块时的原通道
:param output_channels: 输出时的通道数 class BasicBlock(nn.Module):
:param use_1x1conv: 如果输入和输出通道不相等时需要用一个1x1的卷积层对原来的输入进行一个通道提升 """
:param strides: 默认1如果大于1起到缩小张量的作用 ResNet-34 基础残差块3x3 -> 3x3
""" 若需要下采样或通道变化则在跳跃连接中使用 1x1 卷积
"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=3,padding=1,stride=strides) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1,stride=1) self.bn1 = nn.BatchNorm2d(out_channels)
if use_1x1conv: self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.conv3 = nn.Conv2d(input_channels, output_channels,kernel_size=1, stride=strides) self.bn2 = nn.BatchNorm2d(out_channels)
else: self.relu = nn.ReLU(inplace=True)
self.conv3 = None self.downsample = downsample
self.bn1 = nn.BatchNorm2d(output_channels)
self.bn2 = nn.BatchNorm2d(output_channels)
def forward(self,X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3 is not None:
X = self.conv3(X)
Y += X
return F.relu(Y)
class Net(): def forward(self, x):
""" identity = x
模型的主要结构就在这里了到时也好该和调用
现在必须实现的方法
目前还是以图片缩放到256256构建残差块
"""
net = nn.Sequential()
def resnet_block(self,input_channels, num_channels, num_residuals,
first_block=False):
"""
:param input_channels: 输入维度
:param num_channels: 输出维度
:param num_residuals: 单个残差层的残差块数
:param first_block: 第一块不用下采样 特殊控制
:return: list[nn.Module]
"""
blk = []
for i in range(num_residuals): out = self.conv1(x)
if i == 0 and not first_block: out = self.bn1(out)
blk.append(Resblock(input_channels, num_channels, out = self.relu(out)
use_1x1conv=True, strides=2))
else:
blk.append(Resblock(num_channels, num_channels))
return blk
def __init__(self):
b1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
"""
7×7 卷积层输出通道 64步长 2填充 3
(3×256×256)->(64×128×128)
批归一化 relu层
最大池化
(64×128×128)->(64×64×64)
"""
b2 = nn.Sequential(*self.resnet_block(64, 64, num_residuals=3, first_block=True))
b3 = nn.Sequential(*self.resnet_block(64, 128, num_residuals=4))
b4 = nn.Sequential(*self.resnet_block(128, 256, num_residuals=6))
b5 = nn.Sequential(*self.resnet_block(256, 512, num_residuals=3))
self.net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 4))
def get_network(self):
return self.net
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Net(nn.Module):
def __init__(self, num_classes=4, dropout=0.5):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
layers_config = [
(3, 64, 1), # layer1
(4, 128, 2), # layer2
(6, 256, 2), # layer3
(3, 512, 2), # layer4
]
self.in_channels = 64
self.layer1 = self._make_layer(layers_config[0])
self.layer2 = self._make_layer(layers_config[1])
self.layer3 = self._make_layer(layers_config[2])
self.layer4 = self._make_layer(layers_config[3])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, config):
num_blocks, out_channels, stride = config
downsample = None
layers = []
if stride != 1 or self.in_channels != out_channels:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels),
)
layers.append(BasicBlock(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for _ in range(1, num_blocks):
layers.append(BasicBlock(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x)
return x
if __name__ == '__main__': if __name__ == '__main__':
Net_new = Net() model = Net(num_classes=4)
X = torch.rand(size=(1, 3, 256, 256)) summary(model, input_size=(3, 256, 256))
summary(Net_new.get_network(), input_size=(3, 256, 256))

245
README.md
View file

@ -1,14 +1,249 @@
# trash-division # trash-division
## 一个基于卷积神经网络的垃圾分类识别系统
### 同济大学python人工智能程序设计课程小组作业 一个基于卷积神经网络的垃圾分类识别系统
> 同济大学 Python 人工智能程序设计课程小组作业
基于 ResNet-34 架构的 CNN 模型(约 21M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。
---
## 目录
- [项目特点](#项目特点)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速开始](#快速开始)
- [文件说明](#文件说明)
- [目录结构](#目录结构)
- [训练细节](#训练细节)
- [评估与可视化](#评估与可视化)
- [许可证](#许可证)
---
## 项目特点
- **四类垃圾分类**厨余垃圾1、可回收物2、其他垃圾3、有害垃圾4
- **ResNet-34 架构**:约 21M 参数34 层深度残差网络,含 Dropout 正则化
- **数据增强**:训练时使用随机裁剪、水平翻转、旋转、色彩抖动
- **Macro-F1 评估**:采用宏平均 F1 分数作为主要评估指标,兼顾各类别表现
- **类别加权损失**:自动计算类别权重,缓解类别不平衡问题
- **余弦退火学习率调度**:使用 CosineAnnealingLR 平滑调整学习率
- **断点续训**:自动检测 `best_model.pth` 并加载继续训练
- **多设备支持**:自动选择 CUDA > Intel XPU > CPU
## 模型架构
模型基于标准 ResNet-34 架构,使用 BasicBlock 构建。
### BasicBlock 块
每个 BasicBlock 包含两个 3x3 卷积层 + 跳跃连接:
| 层 | 卷积 | 作用 |
|---|---|---|
| 3x3 Conv | 特征提取 | 第一层卷积 |
| 3x3 Conv | 特征提取 | 第二层卷积 |
### 网络结构
| 阶段 | 块数 | 输出通道数 | 说明 |
|---|---|---|---|
| 初始层 | - | 64 | 7x7 Conv, stride=2 + MaxPool |
| Layer1 | 3 | 64 | 第一个残差阶段 |
| Layer2 | 4 | 128 | - |
| Layer3 | 6 | 256 | - |
| Layer4 | 3 | 512 | 最终残差阶段 |
| 分类头 | - | 4 | 全局平均池化 + Dropout + 全连接层 |
## 数据集
本项目使用 [tany0699/garbage265](https://modelscope.cn/datasets/tany0699/garbage265) 中文生活垃圾分类数据集,包含 265 个子类别的生活垃圾图片。
通过 `Merge_classes.py` 脚本将 265 个子类别合并为 4 个顶级类别:
```
厨余垃圾 -> 1
可回收物 -> 2
其他垃圾 -> 3
有害垃圾 -> 4
```
数据集预期放置在 `../trash_division_data/`(与项目根目录平级的兄弟目录)。
## 环境要求
```bash
pip install -r requirements.txt
```
> **注意**`requirements.txt` 不锁定 PyTorch 的 CUDA / XPU 版本,请根据硬件自行安装对应版本,例如:
> - NVIDIA GPU`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121`
> - Intel GPU (XPU)`pip install torch torchvision --index-url https://download.pytorch.org/whl/xpu` 安装
> - CPU`pip install torch torchvision` 即可
## 快速开始
1. **数据预处理**:将 265 个子类别合并为 4 个顶级类别
```bash
python Merge_classes.py
```
2. **训练模型**
```bash
python Train.py
```
3. **微调模型**(可选,冻结浅层、微调深层):
```bash
python Finetune.py
```
4. **评估与可视化**
```bash
python Evaluate.py # 混淆矩阵、ROC 曲线、PR 曲线
python Curve.py # 训练过程的 loss/f1/acc/lr 曲线
python baseline/compare_models.py # 多模型基线对比ROC/PR/准确率)
```
> **注意**
> - 数据目录默认为 `../trash_division_data/ultimate_4_class/`,需先运行合并脚本
> - Windows 系统需将 `num_workers` 设为 `0`(参见 `Dataloader.py``Train.py`
> - 训练会自动从 `best_model.pth` 断点续训(若存在)
## 文件说明
| 文件 | 功能 |
|---|---|
| `Train.py` | 训练主脚本,包含训练循环、验证、评估 |
| `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 |
| `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 |
| `Model.py` | 模型定义ResNet-34BasicBlock+ Dropout |
| `Merge_classes.py` | 数据集预处理265 类合并为 4 类 |
| `Evaluate.py` | 模型评估绘制混淆矩阵、ROC 曲线、PR 曲线 |
| `Curve.py` | 训练曲线绘制,从 CSV 读取并绘制 loss/f1/acc/lr 曲线 |
| `baseline/VGG_KNN.py` | VGG16 预训练特征提取 + KNN 四分类基线 |
| `baseline/ResNet34_Pretrained_10pct.py` | ResNet-34 ImageNet 预训练 + 10% 数据微调 |
| `baseline/HOG_Baseline.py` | HOG + 颜色直方图 + LogisticRegression纯传统 CV |
| `baseline/compare_models.py` | 多模型对比ROC / PR 曲线 + 准确率柱状图) |
| `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr |
| `best_model.pth` | 训练好的最佳模型权重(约 125 MB不纳入版本控制 |
| `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 |
## 目录结构
```
trash-division/
├── baseline/ # 基线模型目录
│ ├── VGG_KNN.py # VGG16 + KNN 分类脚本
│ ├── ResNet34_Pretrained_10pct.py # ResNet-34 ImageNet 预训练 + 10% 微调
│ ├── HOG_Baseline.py # HOG + LogisticRegression 纯传统基线
│ ├── compare_models.py # 多模型对比脚本
│ ├── roc_comparison.png # 多模型 ROC 对比compare_models.py 输出)
│ ├── pr_comparison.png # 多模型 PR 对比compare_models.py 输出)
│ └── accuracy_bar.png # 多模型准确率对比compare_models.py 输出)
├── best_model.pth # 最佳模型权重(不纳入版本控制)
├── Curve.py # 训练曲线绘制脚本
├── Dataloader.py # 数据加载模块
├── Evaluate.py # 模型评估可视化脚本
├── Finetune.py # 微调脚本
├── .gitattributes # Git 属性配置
├── LICENSE # MIT 许可证
├── Merge_classes.py # 数据集预处理脚本
├── Model.py # 模型定义
├── README.md # 项目说明(本文件)
├── THIRD_PARTY_LICENSES.md # 第三方许可证声明
├── Train.py # 训练主脚本
├── training_log.csv # 训练日志
├── confusion_matrix.png # 混淆矩阵Evaluate.py 输出)
├── roc_curve.png # ROC 曲线Evaluate.py 输出)
├── pr_curve.png # PR 曲线Evaluate.py 输出)
└── training_curves.png # 训练曲线Curve.py 输出)
```
## 训练细节
| 配置项 | 说明 |
|---|---|
| 输入尺寸 | 256 x 256 RGB |
| 优化器 | SGDmomentum=0.9, weight_decay=1e-4 |
| 初始学习率 | 0.001 |
| 学习率调度 | CosineAnnealingLR |
| 损失函数 | 类别加权 CrossEntropyLoss |
| 评估指标 | Macro-F1宏平均 F1 分数) |
| 批量大小 | 默认 16可通过参数调整 |
| 训练轮数 | 默认 20可通过参数调整 |
| 设备选择优先级 | CUDA > Intel XPU > CPU |
| 断点续训 | 自动检测 best_model.pth 并加载 |
训练时数据增强管线RandomResizedCrop(256, scale=(0.8, 1.0)) + RandomHorizontalFlip(p=0.5) + RandomRotation(+-15 deg) + ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)
## 评估与可视化
训练完成后,`training_log.csv` 会记录每个 epoch 的训练/验证指标。以下两个脚本用于可视化分析:
### Evaluate.py — 模型评估
在验证集上推理,生成三张评估图表:
```bash
python Evaluate.py
```
脚本顶部的 `MODEL_PATH`、`DATA_ROOT`、`BATCH_SIZE`、`NUM_WORKERS` 可按需修改。
**混淆矩阵**
![confusion_matrix](confusion_matrix.png)
**ROC 曲线**
![roc_curve](roc_curve.png)
**PR 曲线**
![pr_curve](pr_curve.png)
### Curve.py — 训练曲线
`training_log.csv` 读取训练日志,绘制四张子图:
```bash
python Curve.py
```
![training_curves](training_curves.png)
### 基线模型对比
`compare_models.py` 对所有模型在验证集上统一评估,生成三张对比图表:
```bash
python baseline/compare_models.py
```
对比阵容ResNet-34、ResNet-34 (10% Fine-tune)、VGG16 + KNN、HOG + LogisticRegression。
**ROC 曲线对比**
![roc_comparison](baseline/roc_comparison.png)
**PR 曲线对比**
![pr_comparison](baseline/pr_comparison.png)
**准确率柱状图**
![accuracy_bar](baseline/accuracy_bar.png)
## 许可证 ## 许可证
本项目主代码采用 [MIT 许可证](LICENSE)。 本项目主代码采用 [MIT 许可证](LICENSE)。
本项目包含的数据集 `tany0699/garbage265` 采用 [Apache License 2.0](THIRD_PARTY_LICENSES.md),详情请参阅 `THIRD_PARTY_LICENSES.md` 文件。 本项目包含的数据集 `tany0699/garbage265` 采用 [Apache License 2.0](THIRD_PARTY_LICENSES.md),详情请参阅 `THIRD_PARTY_LICENSES.md` 文件。

155
Train.py
View file

@ -12,52 +12,70 @@ import torch.optim as optim
from tqdm import tqdm # 进度条,可选 from tqdm import tqdm # 进度条,可选
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from Model import Net from Model import Net
from Dataloader import create_dataloaders
import os
import csv
def compute_macro_f1(predicted, targets, num_classes=4):
tp = torch.zeros(num_classes, device=predicted.device)
fp = torch.zeros(num_classes, device=predicted.device)
fn = torch.zeros(num_classes, device=predicted.device)
for c in range(num_classes):
tp[c] = ((predicted == c) & (targets == c)).sum()
fp[c] = ((predicted == c) & (targets != c)).sum()
fn[c] = ((predicted != c) & (targets == c)).sum()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return f1.mean().item()
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch): def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
"""训练一个epoch""" """训练一个epoch"""
model.train() # 设置为训练模式 model.train()
running_loss = 0.0 running_loss = 0.0
correct = 0 correct = 0
total = 0 total = 0
all_preds, all_labels = [], []
# 使用 tqdm 显示进度条(可选)
pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]') pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]')
for images, labels in pbar: for images, labels in pbar:
# 将数据移到 GPU/CPU
images, labels = images.to(device), labels.to(device) images, labels = images.to(device), labels.to(device)
# 前向传播
outputs = model(images) outputs = model(images)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
# 反向传播 optimizer.zero_grad()
optimizer.zero_grad() # 清空梯度 loss.backward()
loss.backward() # 计算梯度 optimizer.step()
optimizer.step() # 更新参数
# 统计
running_loss += loss.item() * images.size(0) running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1) _, predicted = outputs.max(1)
total += labels.size(0) total += labels.size(0)
correct += predicted.eq(labels).sum().item() correct += predicted.eq(labels).sum().item()
all_preds.append(predicted)
all_labels.append(labels)
# 更新进度条信息 batch_f1 = compute_macro_f1(predicted, labels)
pbar.set_postfix({'loss': loss.item(), 'acc': 100. * correct / total}) pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}', 'Acc': f'{100. * correct / total:.2f}%'})
epoch_loss = running_loss / total epoch_loss = running_loss / total
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc return epoch_loss, epoch_f1, epoch_acc
def validate(model, val_loader, criterion, device): def validate(model, val_loader, criterion, device):
"""验证函数""" """验证函数"""
model.eval() # 设置为评估模式 model.eval()
running_loss = 0.0 running_loss = 0.0
correct = 0 correct = 0
total = 0 total = 0
all_preds, all_labels = [], []
with torch.no_grad(): # 不计算梯度,节省内存 with torch.no_grad():
for images, labels in tqdm(val_loader, desc='[Validate]'): for images, labels in tqdm(val_loader, desc='[Validate]'):
images, labels = images.to(device), labels.to(device) images, labels = images.to(device), labels.to(device)
@ -68,37 +86,52 @@ def validate(model, val_loader, criterion, device):
_, predicted = outputs.max(1) _, predicted = outputs.max(1)
total += labels.size(0) total += labels.size(0)
correct += predicted.eq(labels).sum().item() correct += predicted.eq(labels).sum().item()
all_preds.append(predicted)
all_labels.append(labels)
epoch_loss = running_loss / total epoch_loss = running_loss / total
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc return epoch_loss, epoch_f1, epoch_acc
def compute_class_weights(dataset, num_classes=4, device='cpu'):
class_counts = torch.zeros(num_classes)
for _, label in dataset.samples:
lbl = label.item() if isinstance(label, torch.Tensor) else label
class_counts[lbl] += 1
total = class_counts.sum()
weights = total / (num_classes * class_counts)
return weights.to(device)
def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'): def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'):
"""主训练函数""" """主训练函数"""
# 1. 定义损失函数和优化器 # 1. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 多分类用交叉熵 class_weights = compute_class_weights(train_loader.dataset, num_classes=4, device=device)
criterion = nn.CrossEntropyLoss(weight=class_weights) # 多分类用交叉熵
# 优化器选择(推荐 Adam 或 SGD
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
# 或者使用 SGD + 动量 # 或者使用 SGD + 动量
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
# 学习率调度器(可选,帮助收敛) # 学习率调度器(可选,帮助收敛)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# 或者用余弦退火
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# 2. 记录训练历史 # 2. 记录训练历史
history = { history = {
'train_loss': [], 'train_loss': [],
'train_f1': [],
'train_acc': [], 'train_acc': [],
'val_loss': [], 'val_loss': [],
'val_f1': [],
'val_acc': [] 'val_acc': []
} }
best_val_acc = 0.0 best_val_f1 = 0.0
log_file = open('training_log.csv', 'w', newline='')
log_writer = csv.writer(log_file)
log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc', 'val_loss', 'val_f1', 'val_acc', 'lr', 'best'])
# 3. 开始训练 # 3. 开始训练
for epoch in range(epochs): for epoch in range(epochs):
@ -106,77 +139,65 @@ def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'):
print(f'Epoch {epoch + 1}/{epochs}') print(f'Epoch {epoch + 1}/{epochs}')
# 训练 # 训练
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, train_loss, train_f1, train_acc = train_one_epoch(model, train_loader, criterion,
optimizer, device, epoch) optimizer, device, epoch)
# 验证 # 验证
val_loss, val_acc = validate(model, val_loader, criterion, device) val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device)
# 更新学习率 # 更新学习率
scheduler.step() scheduler.step()
# 记录 # 记录
history['train_loss'].append(train_loss) history['train_loss'].append(train_loss)
history['train_f1'].append(train_f1)
history['train_acc'].append(train_acc) history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss) history['val_loss'].append(val_loss)
history['val_f1'].append(val_f1)
history['val_acc'].append(val_acc) history['val_acc'].append(val_acc)
# 打印结果 # 打印结果
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%') print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}')
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%') print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}')
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}') print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
# 保存最佳模型 # 保存最佳模型
if val_acc > best_val_acc: best_mark = ''
best_val_acc = val_acc if val_f1 > best_val_f1:
best_val_f1 = val_f1
torch.save(model.state_dict(), 'best_model.pth') torch.save(model.state_dict(), 'best_model.pth')
print(f'✓ 保存最佳模型 (Acc: {val_acc:.2f}%)') best_mark = 'best'
print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})')
lr = optimizer.param_groups[0]['lr']
log_writer.writerow([epoch + 1, train_loss, train_f1, train_acc, val_loss, val_f1, val_acc, lr, best_mark])
log_file.flush()
# 4. 绘制训练曲线 # 4. 绘制训练曲线
plot_training_history(history)
print(f'\n{"=" * 50}') print(f'\n{"=" * 50}')
print(f'训练完成!最佳验证准确率: {best_val_acc:.2f}%') print(f'训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}')
return model, history return model, history
def plot_training_history(history):
"""绘制训练曲线"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# 损失曲线
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)
# 准确率曲线
ax2.plot(history['train_acc'], label='Train Acc')
ax2.plot(history['val_acc'], label='Val Acc')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()
# ========== 使用示例 ========== # ========== 使用示例 ==========
if __name__ == '__main__': if __name__ == '__main__':
# 假设你的 dataloader 已经写好了 # 假设你的 dataloader 已经写好了
# train_loader = ... train_loader, val_loader, class_names = create_dataloaders(
# val_loader = ... data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
batch_size=16, # 根据你的显存调整
image_size=256, # 与你模型输入一致
num_workers=8, # Windows 可能需设为 0
augment=True # 训练时使用数据增强
)
# 1. 创建模型 # 1. 创建模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
model = Net().get_network() # 根据你的 Net 类调整 model = Net(num_classes=4) # 根据你的 Net 类调整
#断点继续训练
if os.path.exists('best_model.pth'):
model.load_state_dict(torch.load('best_model.pth',map_location=torch.device('cpu')))
model = model.to(device) model = model.to(device)
# 打印模型信息 # 打印模型信息
@ -188,11 +209,9 @@ if __name__ == '__main__':
model=model, model=model,
train_loader=train_loader, train_loader=train_loader,
val_loader=val_loader, val_loader=val_loader,
epochs=50, epochs=20,
lr=0.001, lr=0.001,
device=device device=device
) )
# 3. 加载最佳模型用于预测 # 3. 加载最佳模型用于预测
model.load_state_dict(torch.load('best_model.pth')) model.load_state_dict(torch.load('best_model.pth'))
print('训练完成,最佳模型已加载')

95
app.py Normal file
View file

@ -0,0 +1,95 @@
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from Model import Net # 根据上传的 Model.py模型类名为 Net
# 1. 基础配置与类别映射
# 根据 Merge_classes.py1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾
class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
# 设备自动选择逻辑,保持与 Train.py 和 Evaluate.py 一致
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
print(f"当前使用的推理设备: {device}")
# 2. 初始化模型并加载最佳权重
model = Net(num_classes=4)
try:
# 采用与 Evaluate.py 一致的健壮加载方式
state_dict = torch.load('best_model.pth', map_location=device)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict)
model = model.to(device).eval()
print("✅ 成功加载 best_model.pth 权重")
except Exception as e:
print(f"⚠️ 模型加载失败,请确保目录下存在 best_model.pth。错误信息: {e}")
# 3. 定义数据预处理流程 (必须与 Evaluate.py 中的 val_transform 保持完全一致)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 4. 核心推理函数
def predict(image):
if image is None:
return None
# Gradio 传入的 pil 图像,确保转为 RGB 格式
image = image.convert('RGB')
# 预处理并增加 batch 维度
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
# 使用 Softmax 将 logits 转换为 0~1 的概率分布
probabilities = torch.softmax(outputs, dim=1)[0]
# 组装为 Gradio Label 组件需要的字典格式 { "类别名": 概率值 }
result_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
return result_dict
# 5. 构建与美化 Gradio 界面
with gr.Blocks(theme=gr.themes.Soft(), title="Trash Division 垃圾分类识别") as demo:
gr.Markdown(
"""
<div style="text-align: center; max-width: 800px; margin: 0 auto;">
<h1>🗑 Trash Division - 智能垃圾分类系统</h1>
<p>基于 <b>ResNet-34</b> 架构支持精准识别<b>厨余垃圾可回收物其他垃圾有害垃圾</b></p>
<p><i>同济大学 Python 人工智能程序设计课程小组作业</i></p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
# type="pil" 让 Gradio 直接传 PIL Image 对象给预测函数,配合 torchvision 最方便
image_input = gr.Image(type="pil", label="上传垃圾图片 (支持拍照)")
with gr.Row():
clear_btn = gr.Button("清空", variant="secondary")
submit_btn = gr.Button("开始识别", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(num_top_classes=4, label="预测结果与置信度")
# 绑定点击事件
submit_btn.click(fn=predict, inputs=image_input, outputs=label_output)
clear_btn.click(lambda: (None, None), inputs=None, outputs=[image_input, label_output])
if __name__ == "__main__":
# 启动 Web 界面
demo.launch(
server_name="127.0.0.1",
server_port=7860,
share=False, # 如果你想生成临时公网链接分享给同学测试,改为 True
inbrowser=True # 运行后自动在默认浏览器中打开
)

145
baseline/HOG_Baseline.py Normal file
View file

@ -0,0 +1,145 @@
"""
baseline/HOG_Baseline.py
HOG + 颜色直方图特征提取 + LogisticRegression 四分类
纯传统 CV/ML 基线零神经网络依赖
可独立运行也可被 compare_models.py 导入
author: yukun-hh
date: 2026-5-14
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib
from skimage.feature import hog
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
accuracy_score, f1_score,
confusion_matrix, ConfusionMatrixDisplay,
classification_report,
)
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 可配置参数 ★★★
# ============================================================
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
IMAGE_SIZE = 128
HOG_ORIENTATIONS = 9
HOG_PIXELS_PER_CELL = (8, 8)
HOG_CELLS_PER_BLOCK = (2, 2)
COLOR_BINS = 32
# ============================================================
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
NUM_CLASSES = 4
def extract_hog_color(image):
img = image.convert('RGB').resize((IMAGE_SIZE, IMAGE_SIZE))
arr = np.array(img, dtype=np.float64) / 255.0
hog_feat = hog(arr, orientations=HOG_ORIENTATIONS,
pixels_per_cell=HOG_PIXELS_PER_CELL,
cells_per_block=HOG_CELLS_PER_BLOCK,
channel_axis=2, feature_vector=True)
color_feat = []
for c in range(3):
hist, _ = np.histogram(arr[:, :, c], bins=COLOR_BINS, range=(0, 1))
color_feat.append(hist)
color_feat = np.concatenate(color_feat)
return np.concatenate([hog_feat, color_feat])
class HOGLRBaseline:
def __init__(self, data_root=DATA_ROOT, image_size=IMAGE_SIZE):
self.data_root = data_root
self.image_size = image_size
self.clf = LogisticRegression(
C=1.0, max_iter=1000, solver='lbfgs', n_jobs=-1,
)
def _load_data(self, split):
dir_path = os.path.join(self.data_root, split)
features, labels = [], []
for class_id in range(1, NUM_CLASSES + 1):
class_dir = os.path.join(dir_path, str(class_id))
if not os.path.isdir(class_dir):
continue
files = sorted(os.listdir(class_dir))
for fname in tqdm(files, desc=f'{split}/class_{class_id}'):
fpath = os.path.join(class_dir, fname)
try:
with Image.open(fpath) as img:
feat = extract_hog_color(img)
features.append(feat)
labels.append(class_id - 1)
except Exception:
pass
print(f" {split}: {len(features)}")
return np.array(features, dtype=np.float32), np.array(labels)
def fit(self, train_dir=None):
if train_dir is None:
train_dir = 'train'
print(" 提取训练集 HOG 特征 ...")
X, y = self._load_data(train_dir)
self.clf.fit(X, y)
def predict(self, val_dir=None):
if val_dir is None:
val_dir = 'val'
print(" 提取验证集 HOG 特征 ...")
X, y = self._load_data(val_dir)
preds = self.clf.predict(X)
probs = self.clf.predict_proba(X)
return y, preds, probs
# ============================================================
# compare_models.py 导入接口
# ============================================================
def get_hog_lr_preds(train_loader, val_loader, device):
baseline = HOGLRBaseline()
baseline.fit('train')
return baseline.predict('val')
# ============================================================
# 独立运行入口
# ============================================================
if __name__ == '__main__':
out_dir = os.path.dirname(os.path.abspath(__file__))
print("HOG + LogisticRegression 基线")
baseline = HOGLRBaseline()
baseline.fit('train')
y_true, y_preds, y_probs = baseline.predict('val')
acc = accuracy_score(y_true, y_preds)
macro_f1 = f1_score(y_true, y_preds, average='macro')
print(f"\n验证集 Accuracy: {acc:.4f}")
print(f"验证集 Macro-F1: {macro_f1:.4f}")
print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}")
cm = confusion_matrix(y_true, y_preds)
fig, ax = plt.subplots(figsize=(8, 7))
ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
ax.set_title('HOG + LogisticRegression 混淆矩阵', fontsize=14)
plt.tight_layout()
cm_path = os.path.join(out_dir, 'hog_lr_confusion_matrix.png')
plt.savefig(cm_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"混淆矩阵已保存: {cm_path}")

View file

@ -0,0 +1,278 @@
"""
baseline/ResNet34_Pretrained_10pct.py
ResNet-34 ImageNet 预训练权重 + 10% 训练集微调
可独立运行训练也可被 compare_models.py 导入
author: yukun-hh
date: 2026-5-14
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import models, transforms
from tqdm import tqdm
import csv
import matplotlib.pyplot as plt
import matplotlib
from Dataloader import RobustImageFolder
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 可配置参数 ★★★
# ============================================================
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
EPOCHS = 30
LR = 0.001
TRAIN_PCT = 0.1
SEED = 42
DROPOUT = 0.3
MODEL_SAVE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct.pth')
LOG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resnet34_10pct_log.csv')
# ============================================================
NUM_CLASSES = 4
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
class PretrainedResNet34(nn.Module):
def __init__(self, num_classes=NUM_CLASSES, dropout=DROPOUT):
super().__init__()
self.backbone = models.resnet34(weights='IMAGENET1K_V1')
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(in_features, num_classes)
def forward(self, x):
x = self.backbone(x)
x = self.dropout(x)
x = self.fc(x)
return x
def freeze_early_layers(self):
for param in self.backbone.conv1.parameters():
param.requires_grad = False
for param in self.backbone.bn1.parameters():
param.requires_grad = False
for param in self.backbone.layer1.parameters():
param.requires_grad = False
for param in self.backbone.layer2.parameters():
param.requires_grad = False
def print_trainable_info(self):
frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = frozen + trainable
print(f" 冻结参数: {frozen:,} 可训练参数: {trainable:,} ({100.*trainable/total:.1f}%)")
def compute_macro_f1(predicted, targets, num_classes=NUM_CLASSES):
tp = torch.zeros(num_classes, device=predicted.device)
fp = torch.zeros(num_classes, device=predicted.device)
fn = torch.zeros(num_classes, device=predicted.device)
for c in range(num_classes):
tp[c] = ((predicted == c) & (targets == c)).sum()
fp[c] = ((predicted == c) & (targets != c)).sum()
fn[c] = ((predicted != c) & (targets == c)).sum()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return f1.mean().item()
def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
model.train()
running_loss, correct, total = 0.0, 0, 0
all_preds, all_labels = [], []
pbar = tqdm(loader, desc=f'Epoch {epoch+1} [Train]')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
all_preds.append(predicted)
all_labels.append(labels)
batch_f1 = compute_macro_f1(predicted, labels)
pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}',
'Acc': f'{100.*correct/total:.2f}%'})
epoch_loss = running_loss / total
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total
return epoch_loss, epoch_f1, epoch_acc
@torch.no_grad()
def validate(model, loader, criterion, device):
model.eval()
running_loss, correct, total = 0.0, 0, 0
all_preds, all_labels = [], []
for images, labels in tqdm(loader, desc='[Validate]'):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
all_preds.append(predicted)
all_labels.append(labels)
epoch_loss = running_loss / total
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
epoch_acc = 100. * correct / total
return epoch_loss, epoch_f1, epoch_acc
def train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
lr=lr, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
history = {'train_loss': [], 'train_f1': [], 'train_acc': [],
'val_loss': [], 'val_f1': [], 'val_acc': []}
best_val_f1 = 0.0
log_file = open(LOG_PATH, 'w', newline='')
log_writer = csv.writer(log_file)
log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc',
'val_loss', 'val_f1', 'val_acc', 'lr', 'best'])
for epoch in range(epochs):
print(f'\n{"="*50}')
print(f'Epoch {epoch+1}/{epochs}')
train_loss, train_f1, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device, epoch)
val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device)
scheduler.step()
history['train_loss'].append(train_loss)
history['train_f1'].append(train_f1)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_f1'].append(val_f1)
history['val_acc'].append(val_acc)
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}')
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}')
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
best_mark = ''
if val_f1 > best_val_f1:
best_val_f1 = val_f1
torch.save(model.state_dict(), MODEL_SAVE_PATH)
best_mark = 'best'
print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})')
lr_val = optimizer.param_groups[0]['lr']
log_writer.writerow([epoch+1, train_loss, train_f1, train_acc,
val_loss, val_f1, val_acc, lr_val, best_mark])
log_file.flush()
log_file.close()
print(f'\n训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}')
return history
# ============================================================
# compare_models.py 导入接口
# ============================================================
def get_resnet34_10pct_preds(train_loader, val_loader, device):
model = PretrainedResNet34(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu'))
model = model.to(device).eval()
y_true, y_preds, y_probs = [], [], []
with torch.no_grad():
for images, labels in tqdm(val_loader, desc='ResNet-34 (10%)'):
images, labels = images.to(device), labels
logits = model(images)
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
y_true.append(labels.numpy())
y_preds.append(preds.cpu().numpy())
y_probs.append(probs.cpu().numpy())
return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs)
# ============================================================
# 独立训练入口
# ============================================================
if __name__ == '__main__':
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available()
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
else 'cpu')
print(f"Device: {device}")
val_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
full_train_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'train'),
transform=train_transform,
)
val_dataset = RobustImageFolder(
root=os.path.join(DATA_ROOT, 'val'),
transform=val_transform,
)
n_train = len(full_train_dataset)
n_subset = max(1, int(n_train * TRAIN_PCT))
indices = random.sample(range(n_train), n_subset)
train_dataset = Subset(full_train_dataset, indices)
print(f"训练集: {len(train_dataset)} / {n_train} ({TRAIN_PCT*100:.0f}%)")
print(f"验证集: {len(val_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True, drop_last=False)
model = PretrainedResNet34(num_classes=NUM_CLASSES, dropout=DROPOUT)
model.freeze_early_layers()
model.print_trainable_info()
model = model.to(device)
history = train_model(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location='cpu'))
print(f"模型已保存: {MODEL_SAVE_PATH}")
print(f"训练日志已保存: {LOG_PATH}")

145
baseline/VGG_KNN.py Normal file
View file

@ -0,0 +1,145 @@
"""
baseline/VGG_KNN.py
VGG16 预训练模型特征提取 + KNN 四分类基线
可独立运行也可被 compare_models.py 导入复用
author: yukun-hh
date: 2026-5-14
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (
accuracy_score, f1_score,
confusion_matrix, ConfusionMatrixDisplay,
classification_report,
)
from Dataloader import RobustImageFolder
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
def load_vgg16_extractor(device):
try:
model = models.vgg16(weights='IMAGENET1K_V1')
except TypeError:
model = models.vgg16(pretrained=True)
model.classifier = nn.Identity()
model = model.to(device).eval()
for param in model.parameters():
param.requires_grad = False
return model
def extract_features(model, loader, device):
model.eval()
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(loader, desc='Extracting features'):
images = images.to(device)
feats = model(images)
all_features.append(feats.cpu().numpy())
all_labels.append(labels.numpy())
return np.concatenate(all_features), np.concatenate(all_labels)
class VGGKNNBaseline:
def __init__(self, k=5, device='cpu',
data_root='../trash_division_data/ultimate_4_class/',
image_size=256, batch_size=32, num_workers=4):
self.k = k
self.device = device
self.data_root = data_root
self.image_size = image_size
self.batch_size = batch_size
self.num_workers = num_workers
self.extractor = load_vgg16_extractor(device)
self.knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
def _get_loader(self, split):
transform = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
dataset = RobustImageFolder(
root=os.path.join(self.data_root, split),
transform=transform,
)
print(f" {split}: {len(dataset)}")
return DataLoader(dataset, batch_size=self.batch_size,
shuffle=False, num_workers=self.num_workers,
pin_memory=True, drop_last=False)
def fit(self, train_loader=None):
if train_loader is None:
train_loader = self._get_loader('train')
print(" 提取训练集特征 ...")
train_feats, train_labels = extract_features(self.extractor, train_loader, self.device)
self.knn.fit(train_feats, train_labels)
def predict(self, val_loader=None):
if val_loader is None:
val_loader = self._get_loader('val')
print(" 提取验证集特征 ...")
val_feats, val_labels = extract_features(self.extractor, val_loader, self.device)
preds = self.knn.predict(val_feats)
probs = self.knn.predict_proba(val_feats)
return val_labels, preds, probs
if __name__ == '__main__':
DATA_ROOT = '../trash_division_data/ultimate_4_class/'
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
K = 5
device = torch.device('cuda' if torch.cuda.is_available()
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
else 'cpu')
print(f"Device: {device}")
baseline = VGGKNNBaseline(k=K, device=device,
data_root=DATA_ROOT, image_size=IMAGE_SIZE,
batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
train_loader = baseline._get_loader('train')
val_loader = baseline._get_loader('val')
baseline.fit(train_loader)
y_true, y_preds, y_probs = baseline.predict(val_loader)
acc = accuracy_score(y_true, y_preds)
macro_f1 = f1_score(y_true, y_preds, average='macro')
print(f"\n验证集 Accuracy: {acc:.4f}")
print(f"验证集 Macro-F1: {macro_f1:.4f}")
print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}")
cm = confusion_matrix(y_true, y_preds)
fig, ax = plt.subplots(figsize=(8, 7))
ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14)
plt.tight_layout()
out_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vgg_knn_confusion_matrix.png')
plt.savefig(out_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"混淆矩阵已保存: {out_path}")

1
baseline/__init__.py Normal file
View file

@ -0,0 +1 @@
# baseline package

BIN
baseline/accuracy_bar.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

249
baseline/compare_models.py Normal file
View file

@ -0,0 +1,249 @@
"""
baseline/compare_models.py
多模型对比ROC 曲线 + 准确率柱状图
添加新模型只需在 MODELS 列表加一行无需修改绘图代码
author: yukun-hh
date: 2026-5-14
"""
import sys, os, re
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from sklearn.metrics import (
roc_curve, auc, accuracy_score,
precision_recall_curve, average_precision_score,
)
from Model import Net
from Dataloader import RobustImageFolder
from baseline.VGG_KNN import VGGKNNBaseline
from baseline.ResNet34_Pretrained_10pct import get_resnet34_10pct_preds
from baseline.HOG_Baseline import get_hog_lr_preds
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
# ============================================================
# ★★★ 可配置参数 ★★★
# ============================================================
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
BATCH_SIZE = 32
IMAGE_SIZE = 256
NUM_WORKERS = 4
K_KNN = 5
# ============================================================
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
NUM_CLASSES = 4
# ============================================================
# 预测函数 — 每个函数签名: (train_loader, val_loader, device) -> (y_true, y_preds, y_probs)
# ============================================================
def get_resnet34_preds(train_loader, val_loader, device):
model = Net(num_classes=NUM_CLASSES)
state_dict = torch.load('../best_model.pth', map_location='cpu')
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict)
model = model.to(device).eval()
y_true, y_preds, y_probs = [], [], []
with torch.no_grad():
for images, labels in tqdm(val_loader, desc='ResNet-34'):
images, labels = images.to(device), labels
logits = model(images)
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
y_true.append(labels.numpy())
y_preds.append(preds.cpu().numpy())
y_probs.append(probs.cpu().numpy())
return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs)
def get_vgg_knn_preds(train_loader, val_loader, device):
baseline = VGGKNNBaseline(k=K_KNN, device=device)
baseline.fit(train_loader)
return baseline.predict(val_loader)
# ============================================================
# ★ 模型注册表 — 添加新模型只需在这里加一行 ★
# ============================================================
MODELS = [
('ResNet-34', get_resnet34_preds),
('ResNet-34 (10% Fine-tune)', get_resnet34_10pct_preds),
('VGG16 + KNN (K=5)', get_vgg_knn_preds),
('HOG + LogisticRegression', get_hog_lr_preds),
# 未来轻松扩展示例:
# ('ResNet-18 (pretrained)', get_resnet18_preds),
# ('ResNet-50 (pretrained)', get_resnet50_preds),
# ('ResNet-34 (finetuned)', get_finetuned_preds),
]
# ============================================================
# 调色板 (扩展时无需修改)
# ============================================================
COLORS = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b',
'#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
def compute_macro_roc(y_true, y_probs):
one_hot = np.eye(NUM_CLASSES)[y_true]
fpr_dict, tpr_dict = {}, {}
for c in range(NUM_CLASSES):
fpr_dict[c], tpr_dict[c], _ = roc_curve(one_hot[:, c], y_probs[:, c])
all_fpr = np.unique(np.concatenate([fpr_dict[c] for c in range(NUM_CLASSES)]))
mean_tpr = np.zeros_like(all_fpr)
for c in range(NUM_CLASSES):
mean_tpr += np.interp(all_fpr, fpr_dict[c], tpr_dict[c])
mean_tpr /= NUM_CLASSES
macro_auc = auc(all_fpr, mean_tpr)
return all_fpr, mean_tpr, macro_auc
def compute_macro_pr(y_true, y_probs):
one_hot = np.eye(NUM_CLASSES)[y_true]
prec_dict, rec_dict = {}, {}
for c in range(NUM_CLASSES):
prec_dict[c], rec_dict[c], _ = precision_recall_curve(one_hot[:, c], y_probs[:, c])
all_rec = np.linspace(0, 1, 200)
mean_prec = np.zeros_like(all_rec)
for c in range(NUM_CLASSES):
mean_prec += np.interp(all_rec, rec_dict[c][::-1], prec_dict[c][::-1])
mean_prec /= NUM_CLASSES
macro_ap = average_precision_score(one_hot, y_probs, average='macro')
return all_rec, mean_prec, macro_ap
def sanitize_filename(name):
return re.sub(r'[^\w\-_]', '_', name).strip('_')
def preds_csv_path(out_dir, model_name):
safe = sanitize_filename(model_name)
return os.path.join(out_dir, f'{safe}_preds.csv')
def save_preds_csv(path, y_true, y_preds, y_probs):
header = 'y_true,y_pred,' + ','.join(f'prob_{c}' for c in range(NUM_CLASSES))
data = np.column_stack([y_true.astype(float), y_preds.astype(float), y_probs])
np.savetxt(path, data, delimiter=',', header=header, comments='', fmt='%.6f')
def load_preds_csv(path):
data = np.loadtxt(path, delimiter=',', skiprows=1)
y_true = data[:, 0].astype(int)
y_preds = data[:, 1].astype(int)
y_probs = data[:, 2:2 + NUM_CLASSES]
return y_true, y_preds, y_probs
if __name__ == '__main__':
out_dir = os.path.dirname(os.path.abspath(__file__))
device = torch.device('cuda' if torch.cuda.is_available()
else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available()
else 'cpu')
print(f"Device: {device}")
val_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'train'),
transform=val_transform)
val_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'val'),
transform=val_transform)
print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)
# ———— 评估所有模型(有缓存则跳过)————
results = {}
for name, func in MODELS:
print(f"\n{'='*50}")
csv_path = preds_csv_path(out_dir, name)
if os.path.exists(csv_path):
print(f"加载缓存: {os.path.basename(csv_path)}")
y_true, y_preds, y_probs = load_preds_csv(csv_path)
else:
print(f"评估: {name}")
y_true, y_preds, y_probs = func(train_loader, val_loader, device)
save_preds_csv(csv_path, y_true, y_preds, y_probs)
print(f" 预测数据已保存: {os.path.basename(csv_path)}")
acc = accuracy_score(y_true, y_preds)
fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs)
rec, prec, macro_ap = compute_macro_pr(y_true, y_probs)
results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs,
'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc,
'rec': rec, 'prec': prec, 'ap': macro_ap}
print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f} | Macro-AP: {macro_ap:.4f}")
# ———— ROC 对比图 ————
fig, ax = plt.subplots(figsize=(8, 7))
for i, (name, r) in enumerate(results.items()):
color = COLORS[i % len(COLORS)]
ax.plot(r['fpr'], r['tpr'], color=color, lw=2,
label=f"{name} (AUC={r['auc']:.4f})")
ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve Comparison (Macro-Average)', fontsize=14)
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)
plt.tight_layout()
roc_path = os.path.join(out_dir, 'roc_comparison.png')
plt.savefig(roc_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"\nROC 对比图已保存: {roc_path}")
# ———— PR 对比图 ————
fig, ax = plt.subplots(figsize=(8, 7))
for i, (name, r) in enumerate(results.items()):
color = COLORS[i % len(COLORS)]
ax.plot(r['rec'], r['prec'], color=color, lw=2,
label=f"{name} (AP={r['ap']:.4f})")
ax.set_xlim(0, 1); ax.set_ylim(0, 1.05)
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
ax.set_title('PR Curve Comparison (Macro-Average)', fontsize=14)
ax.legend(loc='lower left'); ax.grid(True, alpha=0.3)
plt.tight_layout()
pr_path = os.path.join(out_dir, 'pr_comparison.png')
plt.savefig(pr_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"PR 对比图已保存: {pr_path}")
# ———— 准确率柱状图 ————
names = list(results.keys())
accs = [results[n]['acc'] for n in names]
fig, ax = plt.subplots(figsize=(8, 5))
bar_colors = [COLORS[i % len(COLORS)] for i in range(len(names))]
bars = ax.bar(names, accs, color=bar_colors, edgecolor='white', linewidth=1.2)
for bar, acc in zip(bars, accs):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
f'{acc:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
ax.set_ylim(min(accs) - 0.03, max(accs) * 1.08)
ax.set_ylabel('Accuracy'); ax.set_title('Accuracy Comparison', fontsize=14)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
bar_path = os.path.join(out_dir, 'accuracy_bar.png')
plt.savefig(bar_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"准确率柱状图已保存: {bar_path}")

BIN
baseline/pr_comparison.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

BIN
baseline/roc_comparison.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 125 KiB

BIN
confusion_matrix.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

BIN
pr_curve.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

10
requirements.txt Normal file
View file

@ -0,0 +1,10 @@
torch>=1.10
torchvision>=0.11
tqdm
matplotlib
pandas
Pillow
scikit-learn
scikit-image
numpy
torchsummary

BIN
roc_curve.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

BIN
training_curves.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 296 KiB

81
training_log.csv Normal file
View file

@ -0,0 +1,81 @@
epoch,train_loss,train_f1,train_acc,val_loss,val_f1,val_acc,lr,best
1,1.0409312975676923,0.4329540729522705,48.04254100337675,1.1043149583345566,0.4398210048675537,48.66548042704626,0.004998072590601808,best
2,0.9862563783744079,0.4695238769054413,52.5943680656054,0.9867177669753319,0.5062971115112305,58.397207774431976,0.00499229333433282,best
3,0.9462850892451784,0.49421826004981995,55.40279787747226,1.0144445673589866,0.4907984733581543,53.15494114426499,0.004982671142387316,
4,0.910117163585685,0.514958381652832,57.832097202122526,0.8787865286946395,0.5453917980194092,62.544483985765126,0.004969220851487844,best
5,0.8786031692946986,0.5320333242416382,59.74282440906898,1.0686318878927787,0.4803737998008728,52.73063235696688,0.004951963201008076,
6,0.8518873820889128,0.5481140613555908,61.51938615533044,0.7650798693964196,0.6073676347732544,68.98439638653161,0.004930924800994191,best
7,0.8256270786701512,0.5604796409606934,62.90249638205499,0.8796401012116773,0.5789190530776978,62.63345195729537,0.004906138091134118,
8,0.8003699506646013,0.5742803812026978,64.3014351181862,0.9246643470014378,0.5521833896636963,60.88831097727895,0.004877641290737884,
9,0.780536473588097,0.5827116966247559,65.25642185238785,0.8404132533719564,0.5876226425170898,65.89789214344374,0.00484547833980621,
10,0.7604798049209087,0.595557451248169,66.54079232995659,0.9228097118810533,0.564703643321991,60.77881193539557,0.004809698831278217,
11,0.7410275131047088,0.6043155789375305,67.35784491075735,0.7576604621266131,0.6295210123062134,69.83985765124555,0.0047703579345627035,best
12,0.7195374228732343,0.6127941608428955,68.07766521948867,0.9624476881507583,0.5630610585212708,61.59321105940323,0.00472751631047092,
13,0.6997139808122973,0.6210756301879883,68.85175470332851,0.7615812296349155,0.6177672147750854,69.12127018888584,0.004681240017681994,
14,0.6824904908837182,0.630592942237854,69.65448625180898,0.6715762626299165,0.6534035205841064,73.48754448398577,0.004631600410885231,best
15,0.6653590450468583,0.6379610300064087,70.39088880849012,0.694461440047988,0.6517682075500488,73.0906104571585,0.004578674030756364,
16,0.6514209577758935,0.6478185653686523,71.27879281234925,0.7036816360785346,0.6470745801925659,71.76977826444019,0.004522542485937369,
17,0.6330186040776395,0.6530008316040039,71.7935962373372,0.7222367905930823,0.6418735980987549,71.07856556255133,0.004463292327201863,
18,0.6166394593717968,0.6634106040000916,72.6038651712494,0.6067476719332303,0.6886636018753052,77.49110320284697,0.004401014914000078,best
19,0.5973944908975692,0.6721534729003906,73.47367945007235,0.6952472055509845,0.6622275114059448,71.79715302491103,0.004335806273589214,
20,0.5820678306183721,0.6758297681808472,73.73145803183792,0.7708474785342401,0.6217234134674072,68.74486723241172,0.004267766952966369,
21,0.5650806297110982,0.6851130723953247,74.5741377231066,0.7461620579141478,0.6384793519973755,71.35231316725978,0.004197001863832355,
22,0.5500074683958588,0.6915749311447144,75.099493487699,0.6420613189380593,0.672810435295105,74.9452504790583,0.00412362012082546,
23,0.5367840825001858,0.6979560852050781,75.66102870236372,0.6252713002082977,0.6949211359024048,75.32849712565014,0.0040477348732745845,best
24,0.5234906795055925,0.7052106857299805,76.26025084418717,0.7471277477021352,0.6447888016700745,69.70982753900904,0.003969463130731182,
25,0.5044557179049829,0.7132176160812378,76.91148094548963,0.6325626891507145,0.6857903003692627,75.52012044894607,0.0038889255825490052,
26,0.4938347232885195,0.7174828052520752,77.23784973468403,0.5635758755127437,0.70375657081604,78.50396934026827,0.003806246411789872,best
27,0.4793313278116239,0.7242900133132935,77.87475880366618,0.5505193975847648,0.7201660871505737,78.89405967697783,0.003721553103742388,best
28,0.46573570758837524,0.7336312532424927,78.60437771345876,0.640248272807638,0.6859503984451294,75.13003011223651,0.003634976249348867,
29,0.44927708754289913,0.737967312335968,78.9352689339122,0.6151526539644867,0.7065733075141907,74.69887763482069,0.00354664934384357,
30,0.4373503708129221,0.7443608045578003,79.40409430776653,0.5578661908719627,0.7272701263427734,78.20969066520668,0.0034567085809127244,best
31,0.42717206794400175,0.7488712072372437,79.7779486251809,0.58909761693554,0.7034546136856079,76.88201478237066,0.003365292642693732,
32,0.4100124511779706,0.7580570578575134,80.60630728412929,0.6458172935624336,0.6865078210830688,75.32849712565014,0.0032725424859373683,
33,0.3993677339991451,0.763430118560791,80.9876989869754,0.47995706558097007,0.754202127456665,81.65891048453327,0.003178601124662685,best
34,0.3858378949555808,0.7697042226791382,81.54772672455378,0.6427663844838523,0.6931804418563843,74.28141253764029,0.0030836134096397633,
35,0.37397055771404913,0.7764154672622681,81.9992161119151,0.6085299244000101,0.7046636343002319,77.08048179578428,0.0029877258050403205,
36,0.3597575335889638,0.7818952798843384,82.53587795465509,0.5254679805051781,0.7415529489517212,78.89405967697783,0.002891086162600577,
37,0.3487578573732404,0.7871347665786743,82.8999336710082,0.5125140355052995,0.748577356338501,80.1875171092253,0.002793843493644594,
38,0.3325814358527052,0.7965956330299377,83.59035817655571,0.5408413317834649,0.7290798425674438,79.87270736381056,0.002696147739319612,
39,0.3261546248608721,0.7988470792770386,83.75316570188133,0.5301555857539602,0.7376729249954224,80.06433068710649,0.002598149539397671,
40,0.30964642827472305,0.8070269823074341,84.52725518572117,0.5468305750249544,0.7365171313285828,78.73665480427046,0.0024999999999999996,
41,0.3009217412674191,0.8119726777076721,84.96140858658949,0.46898612490165015,0.7599539756774902,82.18587462359704,0.002401850460602329,best
42,0.28925693789887874,0.8200639486312866,85.6458031837916,0.5167866427621677,0.7465909123420715,80.83766767040788,0.0023038522606803878,
43,0.2707157838268379,0.8313596248626709,86.46360950313556,0.5156203284349763,0.7596548199653625,80.6049822064057,0.0022061565063554063,
44,0.2580273799019566,0.836384654045105,86.8412325132658,0.5318487190707494,0.746901273727417,80.27648508075555,0.0021089138373994237,
45,0.2504911308580703,0.8410984873771667,87.30779667149059,0.49164087495639364,0.763725221157074,81.82315904735833,0.00201227419495968,best
46,0.24104372076995695,0.8451772928237915,87.64094910757356,0.5290114263981752,0.7580969333648682,80.94032302217356,0.0019163865903602372,
47,0.22337641519549614,0.8570870161056519,88.56201760733236,0.43634469677838694,0.7913081049919128,85.10128661374213,0.0018213988753373142,best
48,0.2128122905210861,0.8645581603050232,89.1152616980222,0.43183545479456625,0.7972898483276367,85.49137695045168,0.001727457514062632,best
49,0.2003101470318182,0.8717849254608154,89.69187168355042,0.4289672785945178,0.806715726852417,85.7993430057487,0.0016347073573062686,best
50,0.1888495338707803,0.8796613216400146,90.34687047756874,0.4568697272132202,0.7956517338752747,84.64275937585546,0.0015432914190872762,
51,0.1756466486088274,0.886497437953949,90.97096599131693,0.4541556305781541,0.7938134670257568,84.45113605255953,0.001453350656156431,
52,0.16742963044469736,0.8907681703567505,91.3101483357453,0.4230425570913775,0.8147625923156738,86.100465370928,0.0013650237506511336,best
53,0.15311133117022804,0.9007841944694519,92.11212614568258,0.4146759752586969,0.8208259344100952,86.83274021352314,0.0012784468962576128,best
54,0.1423091164071722,0.9078108668327332,92.6714001447178,0.4709351422719488,0.8029968738555908,85.71721872433616,0.0011937535882101285,
55,0.13160189816137902,0.9135022163391113,93.10932223830197,0.40829240685941226,0.8264325857162476,87.42814125376403,0.0011110744174509947,best
56,0.12707359487800943,0.9174070358276367,93.4409671972986,0.42565728100299705,0.8230471611022949,87.65398302764851,0.0010305368692688178,
57,0.11291898237482914,0.925153374671936,94.10953328509407,0.43774206922898184,0.8247347474098206,87.5376402956474,0.0009522651267254161,
58,0.10370979833767516,0.9329074025154114,94.62584418716835,0.4068256767521223,0.8333848118782043,88.05091705447578,0.0008763798791745416,best
59,0.0946220946482491,0.9380815029144287,95.09768451519537,0.41103389083751746,0.8357677459716797,88.4889132220093,0.0008029981361676465,best
60,0.08804238645213414,0.9423004388809204,95.45194163048721,0.4207381762007377,0.8308929204940796,88.17410347659458,0.0007322330470336316,
61,0.07913849578165794,0.9495129585266113,95.9916184273999,0.4083157278823748,0.8420299291610718,89.00218998083767,0.0006641937264107861,best
62,0.06981624146470565,0.9559226036071777,96.49662325132658,0.41332166066585485,0.843511700630188,88.98850260060225,0.0005989850859999229,best
63,0.06394793773276639,0.9602090120315552,96.79811866859623,0.41052334102789995,0.8489691019058228,89.35121817684096,0.0005367076727981376,best
64,0.057007493751794744,0.9636315107345581,97.11016642547034,0.3970402057488879,0.8546013832092285,90.04927456884752,0.00047745751406263185,best
65,0.05285091761448427,0.967146635055542,97.40035576459238,0.4247235585867853,0.8433754444122314,89.34437448672324,0.0004213259692436376,
66,0.04614944799553407,0.9710246324539185,97.6912988422576,0.4035414461747053,0.8538572788238525,89.85080755543389,0.00036839958911476966,
67,0.042909558690492254,0.9727644920349121,97.8428002894356,0.41578795896298,0.8530543446540833,89.91924445661101,0.0003187599823180077,
68,0.03640206224887977,0.9769999980926514,98.15710926193921,0.42073161891477256,0.8551151156425476,89.94661921708185,0.0002724836895290806,best
69,0.034432517173010414,0.9790080785751343,98.34328268210324,0.4133664407223553,0.8576940298080444,90.28880372296743,0.00022964206543729668,best
70,0.030669637766023813,0.9817556142807007,98.5249336710082,0.418308269951463,0.8569411039352417,90.12455516014235,0.00019030116872178321,
71,0.028112305183924133,0.9827903509140015,98.606337433671,0.4151474575991667,0.8596312999725342,90.37092800437996,0.00015452166019378966,best
72,0.024704152367817256,0.9853801727294922,98.81135431741437,0.4153705465811558,0.8635820746421814,90.68573774979468,0.0001223587092621162,best
73,0.024846541488804174,0.9855506420135498,98.8369814278823,0.4177400436290088,0.8632140159606934,90.59676977826444,9.38619088658821e-05,
74,0.022639600746625622,0.9868491888046265,98.94702725518572,0.41732572613841307,0.8648342490196228,90.78154941144265,6.907519900580863e-05,best
75,0.02120214593173326,0.9878177642822266,99.0231548480463,0.4163925825270714,0.866214394569397,90.89789214344374,4.803679899192394e-05,best
76,0.019741657997631577,0.9883521795272827,99.06385672937772,0.42005763620917286,0.8647006750106812,90.82945524226663,3.077914851215586e-05,
77,0.019116416042495393,0.9889511466026306,99.10003617945007,0.4159400745789841,0.8657370805740356,90.84998631261976,1.7328857612684272e-05,
78,0.019259902796210714,0.9888157844543457,99.0962674867342,0.4192042892654481,0.8641382455825806,90.69942513003011,7.706665667180091e-06,
79,0.01933925595445387,0.9887675046920776,99.0759165460685,0.4180937778044573,0.8662786483764648,90.84998631261976,1.9274093981927482e-06,best
80,0.01922732148408437,0.9889604449272156,99.10078991799324,0.41794140280912484,0.864332914352417,90.82261155214891,0.0,
1 epoch train_loss train_f1 train_acc val_loss val_f1 val_acc lr best
2 1 1.0409312975676923 0.4329540729522705 48.04254100337675 1.1043149583345566 0.4398210048675537 48.66548042704626 0.004998072590601808 best
3 2 0.9862563783744079 0.4695238769054413 52.5943680656054 0.9867177669753319 0.5062971115112305 58.397207774431976 0.00499229333433282 best
4 3 0.9462850892451784 0.49421826004981995 55.40279787747226 1.0144445673589866 0.4907984733581543 53.15494114426499 0.004982671142387316
5 4 0.910117163585685 0.514958381652832 57.832097202122526 0.8787865286946395 0.5453917980194092 62.544483985765126 0.004969220851487844 best
6 5 0.8786031692946986 0.5320333242416382 59.74282440906898 1.0686318878927787 0.4803737998008728 52.73063235696688 0.004951963201008076
7 6 0.8518873820889128 0.5481140613555908 61.51938615533044 0.7650798693964196 0.6073676347732544 68.98439638653161 0.004930924800994191 best
8 7 0.8256270786701512 0.5604796409606934 62.90249638205499 0.8796401012116773 0.5789190530776978 62.63345195729537 0.004906138091134118
9 8 0.8003699506646013 0.5742803812026978 64.3014351181862 0.9246643470014378 0.5521833896636963 60.88831097727895 0.004877641290737884
10 9 0.780536473588097 0.5827116966247559 65.25642185238785 0.8404132533719564 0.5876226425170898 65.89789214344374 0.00484547833980621
11 10 0.7604798049209087 0.595557451248169 66.54079232995659 0.9228097118810533 0.564703643321991 60.77881193539557 0.004809698831278217
12 11 0.7410275131047088 0.6043155789375305 67.35784491075735 0.7576604621266131 0.6295210123062134 69.83985765124555 0.0047703579345627035 best
13 12 0.7195374228732343 0.6127941608428955 68.07766521948867 0.9624476881507583 0.5630610585212708 61.59321105940323 0.00472751631047092
14 13 0.6997139808122973 0.6210756301879883 68.85175470332851 0.7615812296349155 0.6177672147750854 69.12127018888584 0.004681240017681994
15 14 0.6824904908837182 0.630592942237854 69.65448625180898 0.6715762626299165 0.6534035205841064 73.48754448398577 0.004631600410885231 best
16 15 0.6653590450468583 0.6379610300064087 70.39088880849012 0.694461440047988 0.6517682075500488 73.0906104571585 0.004578674030756364
17 16 0.6514209577758935 0.6478185653686523 71.27879281234925 0.7036816360785346 0.6470745801925659 71.76977826444019 0.004522542485937369
18 17 0.6330186040776395 0.6530008316040039 71.7935962373372 0.7222367905930823 0.6418735980987549 71.07856556255133 0.004463292327201863
19 18 0.6166394593717968 0.6634106040000916 72.6038651712494 0.6067476719332303 0.6886636018753052 77.49110320284697 0.004401014914000078 best
20 19 0.5973944908975692 0.6721534729003906 73.47367945007235 0.6952472055509845 0.6622275114059448 71.79715302491103 0.004335806273589214
21 20 0.5820678306183721 0.6758297681808472 73.73145803183792 0.7708474785342401 0.6217234134674072 68.74486723241172 0.004267766952966369
22 21 0.5650806297110982 0.6851130723953247 74.5741377231066 0.7461620579141478 0.6384793519973755 71.35231316725978 0.004197001863832355
23 22 0.5500074683958588 0.6915749311447144 75.099493487699 0.6420613189380593 0.672810435295105 74.9452504790583 0.00412362012082546
24 23 0.5367840825001858 0.6979560852050781 75.66102870236372 0.6252713002082977 0.6949211359024048 75.32849712565014 0.0040477348732745845 best
25 24 0.5234906795055925 0.7052106857299805 76.26025084418717 0.7471277477021352 0.6447888016700745 69.70982753900904 0.003969463130731182
26 25 0.5044557179049829 0.7132176160812378 76.91148094548963 0.6325626891507145 0.6857903003692627 75.52012044894607 0.0038889255825490052
27 26 0.4938347232885195 0.7174828052520752 77.23784973468403 0.5635758755127437 0.70375657081604 78.50396934026827 0.003806246411789872 best
28 27 0.4793313278116239 0.7242900133132935 77.87475880366618 0.5505193975847648 0.7201660871505737 78.89405967697783 0.003721553103742388 best
29 28 0.46573570758837524 0.7336312532424927 78.60437771345876 0.640248272807638 0.6859503984451294 75.13003011223651 0.003634976249348867
30 29 0.44927708754289913 0.737967312335968 78.9352689339122 0.6151526539644867 0.7065733075141907 74.69887763482069 0.00354664934384357
31 30 0.4373503708129221 0.7443608045578003 79.40409430776653 0.5578661908719627 0.7272701263427734 78.20969066520668 0.0034567085809127244 best
32 31 0.42717206794400175 0.7488712072372437 79.7779486251809 0.58909761693554 0.7034546136856079 76.88201478237066 0.003365292642693732
33 32 0.4100124511779706 0.7580570578575134 80.60630728412929 0.6458172935624336 0.6865078210830688 75.32849712565014 0.0032725424859373683
34 33 0.3993677339991451 0.763430118560791 80.9876989869754 0.47995706558097007 0.754202127456665 81.65891048453327 0.003178601124662685 best
35 34 0.3858378949555808 0.7697042226791382 81.54772672455378 0.6427663844838523 0.6931804418563843 74.28141253764029 0.0030836134096397633
36 35 0.37397055771404913 0.7764154672622681 81.9992161119151 0.6085299244000101 0.7046636343002319 77.08048179578428 0.0029877258050403205
37 36 0.3597575335889638 0.7818952798843384 82.53587795465509 0.5254679805051781 0.7415529489517212 78.89405967697783 0.002891086162600577
38 37 0.3487578573732404 0.7871347665786743 82.8999336710082 0.5125140355052995 0.748577356338501 80.1875171092253 0.002793843493644594
39 38 0.3325814358527052 0.7965956330299377 83.59035817655571 0.5408413317834649 0.7290798425674438 79.87270736381056 0.002696147739319612
40 39 0.3261546248608721 0.7988470792770386 83.75316570188133 0.5301555857539602 0.7376729249954224 80.06433068710649 0.002598149539397671
41 40 0.30964642827472305 0.8070269823074341 84.52725518572117 0.5468305750249544 0.7365171313285828 78.73665480427046 0.0024999999999999996
42 41 0.3009217412674191 0.8119726777076721 84.96140858658949 0.46898612490165015 0.7599539756774902 82.18587462359704 0.002401850460602329 best
43 42 0.28925693789887874 0.8200639486312866 85.6458031837916 0.5167866427621677 0.7465909123420715 80.83766767040788 0.0023038522606803878
44 43 0.2707157838268379 0.8313596248626709 86.46360950313556 0.5156203284349763 0.7596548199653625 80.6049822064057 0.0022061565063554063
45 44 0.2580273799019566 0.836384654045105 86.8412325132658 0.5318487190707494 0.746901273727417 80.27648508075555 0.0021089138373994237
46 45 0.2504911308580703 0.8410984873771667 87.30779667149059 0.49164087495639364 0.763725221157074 81.82315904735833 0.00201227419495968 best
47 46 0.24104372076995695 0.8451772928237915 87.64094910757356 0.5290114263981752 0.7580969333648682 80.94032302217356 0.0019163865903602372
48 47 0.22337641519549614 0.8570870161056519 88.56201760733236 0.43634469677838694 0.7913081049919128 85.10128661374213 0.0018213988753373142 best
49 48 0.2128122905210861 0.8645581603050232 89.1152616980222 0.43183545479456625 0.7972898483276367 85.49137695045168 0.001727457514062632 best
50 49 0.2003101470318182 0.8717849254608154 89.69187168355042 0.4289672785945178 0.806715726852417 85.7993430057487 0.0016347073573062686 best
51 50 0.1888495338707803 0.8796613216400146 90.34687047756874 0.4568697272132202 0.7956517338752747 84.64275937585546 0.0015432914190872762
52 51 0.1756466486088274 0.886497437953949 90.97096599131693 0.4541556305781541 0.7938134670257568 84.45113605255953 0.001453350656156431
53 52 0.16742963044469736 0.8907681703567505 91.3101483357453 0.4230425570913775 0.8147625923156738 86.100465370928 0.0013650237506511336 best
54 53 0.15311133117022804 0.9007841944694519 92.11212614568258 0.4146759752586969 0.8208259344100952 86.83274021352314 0.0012784468962576128 best
55 54 0.1423091164071722 0.9078108668327332 92.6714001447178 0.4709351422719488 0.8029968738555908 85.71721872433616 0.0011937535882101285
56 55 0.13160189816137902 0.9135022163391113 93.10932223830197 0.40829240685941226 0.8264325857162476 87.42814125376403 0.0011110744174509947 best
57 56 0.12707359487800943 0.9174070358276367 93.4409671972986 0.42565728100299705 0.8230471611022949 87.65398302764851 0.0010305368692688178
58 57 0.11291898237482914 0.925153374671936 94.10953328509407 0.43774206922898184 0.8247347474098206 87.5376402956474 0.0009522651267254161
59 58 0.10370979833767516 0.9329074025154114 94.62584418716835 0.4068256767521223 0.8333848118782043 88.05091705447578 0.0008763798791745416 best
60 59 0.0946220946482491 0.9380815029144287 95.09768451519537 0.41103389083751746 0.8357677459716797 88.4889132220093 0.0008029981361676465 best
61 60 0.08804238645213414 0.9423004388809204 95.45194163048721 0.4207381762007377 0.8308929204940796 88.17410347659458 0.0007322330470336316
62 61 0.07913849578165794 0.9495129585266113 95.9916184273999 0.4083157278823748 0.8420299291610718 89.00218998083767 0.0006641937264107861 best
63 62 0.06981624146470565 0.9559226036071777 96.49662325132658 0.41332166066585485 0.843511700630188 88.98850260060225 0.0005989850859999229 best
64 63 0.06394793773276639 0.9602090120315552 96.79811866859623 0.41052334102789995 0.8489691019058228 89.35121817684096 0.0005367076727981376 best
65 64 0.057007493751794744 0.9636315107345581 97.11016642547034 0.3970402057488879 0.8546013832092285 90.04927456884752 0.00047745751406263185 best
66 65 0.05285091761448427 0.967146635055542 97.40035576459238 0.4247235585867853 0.8433754444122314 89.34437448672324 0.0004213259692436376
67 66 0.04614944799553407 0.9710246324539185 97.6912988422576 0.4035414461747053 0.8538572788238525 89.85080755543389 0.00036839958911476966
68 67 0.042909558690492254 0.9727644920349121 97.8428002894356 0.41578795896298 0.8530543446540833 89.91924445661101 0.0003187599823180077
69 68 0.03640206224887977 0.9769999980926514 98.15710926193921 0.42073161891477256 0.8551151156425476 89.94661921708185 0.0002724836895290806 best
70 69 0.034432517173010414 0.9790080785751343 98.34328268210324 0.4133664407223553 0.8576940298080444 90.28880372296743 0.00022964206543729668 best
71 70 0.030669637766023813 0.9817556142807007 98.5249336710082 0.418308269951463 0.8569411039352417 90.12455516014235 0.00019030116872178321
72 71 0.028112305183924133 0.9827903509140015 98.606337433671 0.4151474575991667 0.8596312999725342 90.37092800437996 0.00015452166019378966 best
73 72 0.024704152367817256 0.9853801727294922 98.81135431741437 0.4153705465811558 0.8635820746421814 90.68573774979468 0.0001223587092621162 best
74 73 0.024846541488804174 0.9855506420135498 98.8369814278823 0.4177400436290088 0.8632140159606934 90.59676977826444 9.38619088658821e-05
75 74 0.022639600746625622 0.9868491888046265 98.94702725518572 0.41732572613841307 0.8648342490196228 90.78154941144265 6.907519900580863e-05 best
76 75 0.02120214593173326 0.9878177642822266 99.0231548480463 0.4163925825270714 0.866214394569397 90.89789214344374 4.803679899192394e-05 best
77 76 0.019741657997631577 0.9883521795272827 99.06385672937772 0.42005763620917286 0.8647006750106812 90.82945524226663 3.077914851215586e-05
78 77 0.019116416042495393 0.9889511466026306 99.10003617945007 0.4159400745789841 0.8657370805740356 90.84998631261976 1.7328857612684272e-05
79 78 0.019259902796210714 0.9888157844543457 99.0962674867342 0.4192042892654481 0.8641382455825806 90.69942513003011 7.706665667180091e-06
80 79 0.01933925595445387 0.9887675046920776 99.0759165460685 0.4180937778044573 0.8662786483764648 90.84998631261976 1.9274093981927482e-06 best
81 80 0.01922732148408437 0.9889604449272156 99.10078991799324 0.41794140280912484 0.864332914352417 90.82261155214891 0.0