Update Train.py

添加模型中断后继续训练的功能
This commit is contained in:
yukun-hh 2026-04-16 20:43:08 +08:00
parent cb6bdc7eb8
commit cff98f70cc

View file

@ -13,6 +13,7 @@ from tqdm import tqdm # 进度条,可选
import matplotlib.pyplot as plt
from Model import Net
from Dataloader import create_dataloaders
import os
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
"""训练一个epoch"""
model.train() # 设置为训练模式
@ -150,6 +151,9 @@ if __name__ == '__main__':
# 1. 创建模型
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
model = Net(num_classes=4) # 根据你的 Net 类调整
#断点继续训练
if os.path.exists('best_model.pth'):
model.load_state_dict(torch.load('best_model.pth'))
model = model.to(device)
# 打印模型信息
@ -161,11 +165,9 @@ if __name__ == '__main__':
model=model,
train_loader=train_loader,
val_loader=val_loader,
epochs=50,
epochs=20,
lr=0.001,
device=device
)
# 3. 加载最佳模型用于预测
model.load_state_dict(torch.load('best_model.pth'))
print('训练完成,最佳模型已加载')
model.load_state_dict(torch.load('best_model.pth'))