From cff98f70cc170b994e616f4d403b8c4a3730303d Mon Sep 17 00:00:00 2001 From: yukun-hh Date: Thu, 16 Apr 2026 20:43:08 +0800 Subject: [PATCH] Update Train.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加模型中断后继续训练的功能 --- Train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/Train.py b/Train.py index 8b75e7a..49172dc 100644 --- a/Train.py +++ b/Train.py @@ -13,6 +13,7 @@ from tqdm import tqdm # 进度条,可选 import matplotlib.pyplot as plt from Model import Net from Dataloader import create_dataloaders +import os def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch): """训练一个epoch""" model.train() # 设置为训练模式 @@ -150,6 +151,9 @@ if __name__ == '__main__': # 1. 创建模型 device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu') model = Net(num_classes=4) # 根据你的 Net 类调整 + #断点继续训练 + if os.path.exists('best_model.pth'): + model.load_state_dict(torch.load('best_model.pth')) model = model.to(device) # 打印模型信息 @@ -161,11 +165,9 @@ if __name__ == '__main__': model=model, train_loader=train_loader, val_loader=val_loader, - epochs=50, + epochs=20, lr=0.001, device=device ) - # 3. 加载最佳模型用于预测 - model.load_state_dict(torch.load('best_model.pth')) - print('训练完成,最佳模型已加载') \ No newline at end of file + model.load_state_dict(torch.load('best_model.pth')) \ No newline at end of file