2026-04-10 09:16:49 +00:00
|
|
|
|
"""
|
|
|
|
|
|
目前是由AI先生成了一份训练用代码,没有调整,因为现在还没有设计好数据迭代器
|
|
|
|
|
|
这个文件目前还不能运行!!!
|
|
|
|
|
|
|
|
|
|
|
|
最佳模型将会保存在根目录下
|
|
|
|
|
|
author:yukun-hh
|
|
|
|
|
|
date :2026-4-10
|
|
|
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
|
from tqdm import tqdm # 进度条,可选
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
2026-04-12 06:39:58 +00:00
|
|
|
|
from Model import Net
|
2026-06-03 12:33:22 +00:00
|
|
|
|
from Dataloader import create_dataloaders
|
|
|
|
|
|
import os
|
|
|
|
|
|
import csv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_macro_f1(predicted, targets, num_classes=4):
|
|
|
|
|
|
tp = torch.zeros(num_classes, device=predicted.device)
|
|
|
|
|
|
fp = torch.zeros(num_classes, device=predicted.device)
|
|
|
|
|
|
fn = torch.zeros(num_classes, device=predicted.device)
|
|
|
|
|
|
for c in range(num_classes):
|
|
|
|
|
|
tp[c] = ((predicted == c) & (targets == c)).sum()
|
|
|
|
|
|
fp[c] = ((predicted == c) & (targets != c)).sum()
|
|
|
|
|
|
fn[c] = ((predicted != c) & (targets == c)).sum()
|
|
|
|
|
|
precision = tp / (tp + fp + 1e-8)
|
|
|
|
|
|
recall = tp / (tp + fn + 1e-8)
|
|
|
|
|
|
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
|
|
|
|
|
return f1.mean().item()
|
|
|
|
|
|
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
|
|
|
|
|
|
"""训练一个epoch"""
|
2026-06-03 12:33:22 +00:00
|
|
|
|
model.train()
|
2026-04-10 09:16:49 +00:00
|
|
|
|
running_loss = 0.0
|
|
|
|
|
|
correct = 0
|
|
|
|
|
|
total = 0
|
2026-06-03 12:33:22 +00:00
|
|
|
|
all_preds, all_labels = [], []
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]')
|
|
|
|
|
|
|
|
|
|
|
|
for images, labels in pbar:
|
|
|
|
|
|
images, labels = images.to(device), labels.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
outputs = model(images)
|
|
|
|
|
|
loss = criterion(outputs, labels)
|
|
|
|
|
|
|
2026-06-03 12:33:22 +00:00
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
optimizer.step()
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
running_loss += loss.item() * images.size(0)
|
|
|
|
|
|
_, predicted = outputs.max(1)
|
|
|
|
|
|
total += labels.size(0)
|
|
|
|
|
|
correct += predicted.eq(labels).sum().item()
|
2026-06-03 12:33:22 +00:00
|
|
|
|
all_preds.append(predicted)
|
|
|
|
|
|
all_labels.append(labels)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
2026-06-03 12:33:22 +00:00
|
|
|
|
batch_f1 = compute_macro_f1(predicted, labels)
|
|
|
|
|
|
pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}', 'Acc': f'{100. * correct / total:.2f}%'})
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
epoch_loss = running_loss / total
|
2026-06-03 12:33:22 +00:00
|
|
|
|
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
|
2026-04-10 09:16:49 +00:00
|
|
|
|
epoch_acc = 100. * correct / total
|
2026-06-03 12:33:22 +00:00
|
|
|
|
return epoch_loss, epoch_f1, epoch_acc
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate(model, val_loader, criterion, device):
|
|
|
|
|
|
"""验证函数"""
|
2026-06-03 12:33:22 +00:00
|
|
|
|
model.eval()
|
2026-04-10 09:16:49 +00:00
|
|
|
|
running_loss = 0.0
|
|
|
|
|
|
correct = 0
|
|
|
|
|
|
total = 0
|
2026-06-03 12:33:22 +00:00
|
|
|
|
all_preds, all_labels = [], []
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
2026-06-03 12:33:22 +00:00
|
|
|
|
with torch.no_grad():
|
2026-04-10 09:16:49 +00:00
|
|
|
|
for images, labels in tqdm(val_loader, desc='[Validate]'):
|
|
|
|
|
|
images, labels = images.to(device), labels.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
outputs = model(images)
|
|
|
|
|
|
loss = criterion(outputs, labels)
|
|
|
|
|
|
|
|
|
|
|
|
running_loss += loss.item() * images.size(0)
|
|
|
|
|
|
_, predicted = outputs.max(1)
|
|
|
|
|
|
total += labels.size(0)
|
|
|
|
|
|
correct += predicted.eq(labels).sum().item()
|
2026-06-03 12:33:22 +00:00
|
|
|
|
all_preds.append(predicted)
|
|
|
|
|
|
all_labels.append(labels)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
epoch_loss = running_loss / total
|
2026-06-03 12:33:22 +00:00
|
|
|
|
epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels))
|
2026-04-10 09:16:49 +00:00
|
|
|
|
epoch_acc = 100. * correct / total
|
2026-06-03 12:33:22 +00:00
|
|
|
|
return epoch_loss, epoch_f1, epoch_acc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_class_weights(dataset, num_classes=4, device='cpu'):
|
|
|
|
|
|
class_counts = torch.zeros(num_classes)
|
|
|
|
|
|
for _, label in dataset.samples:
|
|
|
|
|
|
lbl = label.item() if isinstance(label, torch.Tensor) else label
|
|
|
|
|
|
class_counts[lbl] += 1
|
|
|
|
|
|
total = class_counts.sum()
|
|
|
|
|
|
weights = total / (num_classes * class_counts)
|
|
|
|
|
|
return weights.to(device)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'):
|
|
|
|
|
|
"""主训练函数"""
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 定义损失函数和优化器
|
2026-06-03 12:33:22 +00:00
|
|
|
|
class_weights = compute_class_weights(train_loader.dataset, num_classes=4, device=device)
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss(weight=class_weights) # 多分类用交叉熵
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
# 或者使用 SGD + 动量
|
2026-06-03 12:33:22 +00:00
|
|
|
|
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
# 学习率调度器(可选,帮助收敛)
|
2026-06-03 12:33:22 +00:00
|
|
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
# 2. 记录训练历史
|
|
|
|
|
|
history = {
|
|
|
|
|
|
'train_loss': [],
|
2026-06-03 12:33:22 +00:00
|
|
|
|
'train_f1': [],
|
2026-04-10 09:16:49 +00:00
|
|
|
|
'train_acc': [],
|
|
|
|
|
|
'val_loss': [],
|
2026-06-03 12:33:22 +00:00
|
|
|
|
'val_f1': [],
|
2026-04-10 09:16:49 +00:00
|
|
|
|
'val_acc': []
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-03 12:33:22 +00:00
|
|
|
|
best_val_f1 = 0.0
|
|
|
|
|
|
log_file = open('training_log.csv', 'w', newline='')
|
|
|
|
|
|
log_writer = csv.writer(log_file)
|
|
|
|
|
|
log_writer.writerow(['epoch', 'train_loss', 'train_f1', 'train_acc', 'val_loss', 'val_f1', 'val_acc', 'lr', 'best'])
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
# 3. 开始训练
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
|
|
|
print(f'\n{"=" * 50}')
|
|
|
|
|
|
print(f'Epoch {epoch + 1}/{epochs}')
|
|
|
|
|
|
|
|
|
|
|
|
# 训练
|
2026-06-03 12:33:22 +00:00
|
|
|
|
train_loss, train_f1, train_acc = train_one_epoch(model, train_loader, criterion,
|
|
|
|
|
|
optimizer, device, epoch)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
# 验证
|
2026-06-03 12:33:22 +00:00
|
|
|
|
val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
# 更新学习率
|
|
|
|
|
|
scheduler.step()
|
|
|
|
|
|
|
|
|
|
|
|
# 记录
|
|
|
|
|
|
history['train_loss'].append(train_loss)
|
2026-06-03 12:33:22 +00:00
|
|
|
|
history['train_f1'].append(train_f1)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
history['train_acc'].append(train_acc)
|
|
|
|
|
|
history['val_loss'].append(val_loss)
|
2026-06-03 12:33:22 +00:00
|
|
|
|
history['val_f1'].append(val_f1)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
history['val_acc'].append(val_acc)
|
|
|
|
|
|
|
|
|
|
|
|
# 打印结果
|
2026-06-03 12:33:22 +00:00
|
|
|
|
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}')
|
|
|
|
|
|
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}')
|
2026-04-10 09:16:49 +00:00
|
|
|
|
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
|
|
|
|
|
|
|
|
|
|
|
|
# 保存最佳模型
|
2026-06-03 12:33:22 +00:00
|
|
|
|
best_mark = ''
|
|
|
|
|
|
if val_f1 > best_val_f1:
|
|
|
|
|
|
best_val_f1 = val_f1
|
2026-04-10 09:16:49 +00:00
|
|
|
|
torch.save(model.state_dict(), 'best_model.pth')
|
2026-06-03 12:33:22 +00:00
|
|
|
|
best_mark = 'best'
|
|
|
|
|
|
print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})')
|
|
|
|
|
|
|
|
|
|
|
|
lr = optimizer.param_groups[0]['lr']
|
|
|
|
|
|
log_writer.writerow([epoch + 1, train_loss, train_f1, train_acc, val_loss, val_f1, val_acc, lr, best_mark])
|
|
|
|
|
|
log_file.flush()
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
# 4. 绘制训练曲线
|
|
|
|
|
|
|
|
|
|
|
|
print(f'\n{"=" * 50}')
|
2026-06-03 12:33:22 +00:00
|
|
|
|
print(f'训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}')
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
return model, history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 使用示例 ==========
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
# 假设你的 dataloader 已经写好了
|
2026-06-03 12:33:22 +00:00
|
|
|
|
train_loader, val_loader, class_names = create_dataloaders(
|
|
|
|
|
|
data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
|
|
|
|
|
|
batch_size=16, # 根据你的显存调整
|
|
|
|
|
|
image_size=256, # 与你模型输入一致
|
|
|
|
|
|
num_workers=8, # Windows 可能需设为 0
|
|
|
|
|
|
augment=True # 训练时使用数据增强
|
|
|
|
|
|
)
|
2026-04-10 09:16:49 +00:00
|
|
|
|
|
|
|
|
|
|
# 1. 创建模型
|
2026-06-03 12:33:22 +00:00
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
|
|
|
|
|
|
model = Net(num_classes=4) # 根据你的 Net 类调整
|
|
|
|
|
|
#断点继续训练
|
|
|
|
|
|
if os.path.exists('best_model.pth'):
|
|
|
|
|
|
model.load_state_dict(torch.load('best_model.pth',map_location=torch.device('cpu')))
|
2026-04-10 09:16:49 +00:00
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
# 打印模型信息
|
|
|
|
|
|
print(f'Device: {device}')
|
|
|
|
|
|
print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 开始训练
|
|
|
|
|
|
trained_model, history = train(
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
train_loader=train_loader,
|
|
|
|
|
|
val_loader=val_loader,
|
2026-06-03 12:33:22 +00:00
|
|
|
|
epochs=20,
|
2026-04-10 09:16:49 +00:00
|
|
|
|
lr=0.001,
|
|
|
|
|
|
device=device
|
|
|
|
|
|
)
|
|
|
|
|
|
# 3. 加载最佳模型用于预测
|
2026-06-03 12:33:22 +00:00
|
|
|
|
model.load_state_dict(torch.load('best_model.pth'))
|