trash-division/Train.py
2026-04-12 14:39:58 +08:00

198 lines
No EOL
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
目前是由AI先生成了一份训练用代码没有调整因为现在还没有设计好数据迭代器
这个文件目前还不能运行!!!
最佳模型将会保存在根目录下
author:yukun-hh
date 2026-4-10
"""
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm # 进度条,可选
import matplotlib.pyplot as plt
from Model import Net
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
"""训练一个epoch"""
model.train() # 设置为训练模式
running_loss = 0.0
correct = 0
total = 0
# 使用 tqdm 显示进度条(可选)
pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]')
for images, labels in pbar:
# 将数据移到 GPU/CPU
images, labels = images.to(device), labels.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播
optimizer.zero_grad() # 清空梯度
loss.backward() # 计算梯度
optimizer.step() # 更新参数
# 统计
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# 更新进度条信息
pbar.set_postfix({'loss': loss.item(), 'acc': 100. * correct / total})
epoch_loss = running_loss / total
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
def validate(model, val_loader, criterion, device):
"""验证函数"""
model.eval() # 设置为评估模式
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad(): # 不计算梯度,节省内存
for images, labels in tqdm(val_loader, desc='[Validate]'):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / total
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'):
"""主训练函数"""
# 1. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 多分类用交叉熵
# 优化器选择(推荐 Adam 或 SGD
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
# 或者使用 SGD + 动量
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
# 学习率调度器(可选,帮助收敛)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# 或者用余弦退火
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# 2. 记录训练历史
history = {
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': []
}
best_val_acc = 0.0
# 3. 开始训练
for epoch in range(epochs):
print(f'\n{"=" * 50}')
print(f'Epoch {epoch + 1}/{epochs}')
# 训练
train_loss, train_acc = train_one_epoch(model, train_loader, criterion,
optimizer, device, epoch)
# 验证
val_loss, val_acc = validate(model, val_loader, criterion, device)
# 更新学习率
scheduler.step()
# 记录
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
# 打印结果
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print(f'✓ 保存最佳模型 (Acc: {val_acc:.2f}%)')
# 4. 绘制训练曲线
plot_training_history(history)
print(f'\n{"=" * 50}')
print(f'训练完成!最佳验证准确率: {best_val_acc:.2f}%')
return model, history
def plot_training_history(history):
"""绘制训练曲线"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# 损失曲线
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)
# 准确率曲线
ax2.plot(history['train_acc'], label='Train Acc')
ax2.plot(history['val_acc'], label='Val Acc')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()
# ========== 使用示例 ==========
if __name__ == '__main__':
# 假设你的 dataloader 已经写好了
# train_loader = ...
# val_loader = ...
# 1. 创建模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().get_network() # 根据你的 Net 类调整
model = model.to(device)
# 打印模型信息
print(f'Device: {device}')
print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')
# 2. 开始训练
trained_model, history = train(
model=model,
train_loader=train_loader,
val_loader=val_loader,
epochs=50,
lr=0.001,
device=device
)
# 3. 加载最佳模型用于预测
model.load_state_dict(torch.load('best_model.pth'))
print('训练完成,最佳模型已加载')