From 6f54c8b13ef0368af5df8f1164576d14f3fb7024 Mon Sep 17 00:00:00 2001 From: yukun-hh Date: Sat, 25 Apr 2026 12:17:26 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E7=94=A8=E5=AE=8FF1=E8=AF=84=E4=BC=B0?= =?UTF-8?q?=E4=B8=8E=E7=B1=BB=E5=88=AB=E5=8A=A0=E6=9D=83=E6=8D=9F=E5=A4=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Train.py | 90 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 30 deletions(-) diff --git a/Train.py b/Train.py index 157af01..7c98d5a 100644 --- a/Train.py +++ b/Train.py @@ -14,51 +14,67 @@ import matplotlib.pyplot as plt from Model import Net from Dataloader import create_dataloaders import os + + +def compute_macro_f1(predicted, targets, num_classes=4): + tp = torch.zeros(num_classes, device=predicted.device) + fp = torch.zeros(num_classes, device=predicted.device) + fn = torch.zeros(num_classes, device=predicted.device) + for c in range(num_classes): + tp[c] = ((predicted == c) & (targets == c)).sum() + fp[c] = ((predicted == c) & (targets != c)).sum() + fn[c] = ((predicted != c) & (targets == c)).sum() + precision = tp / (tp + fp + 1e-8) + recall = tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + return f1.mean().item() + + def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch): """训练一个epoch""" - model.train() # 设置为训练模式 + model.train() running_loss = 0.0 correct = 0 total = 0 + all_preds, all_labels = [], [] - # 使用 tqdm 显示进度条(可选) pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]') for images, labels in pbar: - # 将数据移到 GPU/CPU images, labels = images.to(device), labels.to(device) - # 前向传播 outputs = model(images) loss = criterion(outputs, labels) - # 反向传播 - optimizer.zero_grad() # 清空梯度 - loss.backward() # 计算梯度 - optimizer.step() # 更新参数 + optimizer.zero_grad() + loss.backward() + optimizer.step() - # 统计 running_loss += loss.item() * images.size(0) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() + all_preds.append(predicted) + all_labels.append(labels) - # 更新进度条信息 - pbar.set_postfix({'loss': loss.item(), 'acc': 100. * correct / total}) + batch_f1 = compute_macro_f1(predicted, labels) + pbar.set_postfix({'loss': loss.item(), 'F1': f'{batch_f1:.4f}', 'Acc': f'{100. * correct / total:.2f}%'}) epoch_loss = running_loss / total + epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels)) epoch_acc = 100. * correct / total - return epoch_loss, epoch_acc + return epoch_loss, epoch_f1, epoch_acc def validate(model, val_loader, criterion, device): """验证函数""" - model.eval() # 设置为评估模式 + model.eval() running_loss = 0.0 correct = 0 total = 0 + all_preds, all_labels = [], [] - with torch.no_grad(): # 不计算梯度,节省内存 + with torch.no_grad(): for images, labels in tqdm(val_loader, desc='[Validate]'): images, labels = images.to(device), labels.to(device) @@ -69,17 +85,31 @@ def validate(model, val_loader, criterion, device): _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() + all_preds.append(predicted) + all_labels.append(labels) epoch_loss = running_loss / total + epoch_f1 = compute_macro_f1(torch.cat(all_preds), torch.cat(all_labels)) epoch_acc = 100. * correct / total - return epoch_loss, epoch_acc + return epoch_loss, epoch_f1, epoch_acc + + +def compute_class_weights(dataset, num_classes=4, device='cpu'): + class_counts = torch.zeros(num_classes) + for _, label in dataset.samples: + lbl = label.item() if isinstance(label, torch.Tensor) else label + class_counts[lbl] += 1 + total = class_counts.sum() + weights = total / (num_classes * class_counts) + return weights.to(device) def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'): """主训练函数""" # 1. 定义损失函数和优化器 - criterion = nn.CrossEntropyLoss() # 多分类用交叉熵 + class_weights = compute_class_weights(train_loader.dataset, num_classes=4, device=device) + criterion = nn.CrossEntropyLoss(weight=class_weights) # 多分类用交叉熵 # 或者使用 SGD + 动量 optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4) @@ -90,12 +120,12 @@ def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'): # 2. 记录训练历史 history = { 'train_loss': [], - 'train_acc': [], + 'train_f1': [], 'val_loss': [], - 'val_acc': [] + 'val_f1': [] } - best_val_acc = 0.0 + best_val_f1 = 0.0 # 3. 开始训练 for epoch in range(epochs): @@ -103,36 +133,36 @@ def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'): print(f'Epoch {epoch + 1}/{epochs}') # 训练 - train_loss, train_acc = train_one_epoch(model, train_loader, criterion, - optimizer, device, epoch) + train_loss, train_f1, train_acc = train_one_epoch(model, train_loader, criterion, + optimizer, device, epoch) # 验证 - val_loss, val_acc = validate(model, val_loader, criterion, device) + val_loss, val_f1, val_acc = validate(model, val_loader, criterion, device) # 更新学习率 scheduler.step() # 记录 history['train_loss'].append(train_loss) - history['train_acc'].append(train_acc) + history['train_f1'].append(train_f1) history['val_loss'].append(val_loss) - history['val_acc'].append(val_acc) + history['val_f1'].append(val_f1) # 打印结果 - print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%') - print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%') + print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Train Macro-F1: {train_f1:.4f}') + print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Macro-F1: {val_f1:.4f}') print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}') # 保存最佳模型 - if val_acc > best_val_acc: - best_val_acc = val_acc + if val_f1 > best_val_f1: + best_val_f1 = val_f1 torch.save(model.state_dict(), 'best_model.pth') - print(f'✓ 保存最佳模型 (Acc: {val_acc:.2f}%)') + print(f'✓ 保存最佳模型 (Macro-F1: {val_f1:.4f})') # 4. 绘制训练曲线 print(f'\n{"=" * 50}') - print(f'训练完成!最佳验证准确率: {best_val_acc:.2f}%') + print(f'训练完成!最佳验证 Macro-F1: {best_val_f1:.4f}') return model, history