Update Train.py
添加模型中断后继续训练的功能
This commit is contained in:
parent
cb6bdc7eb8
commit
cff98f70cc
1 changed files with 6 additions and 4 deletions
10
Train.py
10
Train.py
|
|
@ -13,6 +13,7 @@ from tqdm import tqdm # 进度条,可选
|
|||
import matplotlib.pyplot as plt
|
||||
from Model import Net
|
||||
from Dataloader import create_dataloaders
|
||||
import os
|
||||
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
|
||||
"""训练一个epoch"""
|
||||
model.train() # 设置为训练模式
|
||||
|
|
@ -150,6 +151,9 @@ if __name__ == '__main__':
|
|||
# 1. 创建模型
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
|
||||
model = Net(num_classes=4) # 根据你的 Net 类调整
|
||||
#断点继续训练
|
||||
if os.path.exists('best_model.pth'):
|
||||
model.load_state_dict(torch.load('best_model.pth'))
|
||||
model = model.to(device)
|
||||
|
||||
# 打印模型信息
|
||||
|
|
@ -161,11 +165,9 @@ if __name__ == '__main__':
|
|||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
epochs=50,
|
||||
epochs=20,
|
||||
lr=0.001,
|
||||
device=device
|
||||
)
|
||||
|
||||
# 3. 加载最佳模型用于预测
|
||||
model.load_state_dict(torch.load('best_model.pth'))
|
||||
print('训练完成,最佳模型已加载')
|
||||
model.load_state_dict(torch.load('best_model.pth'))
|
||||
Loading…
Reference in a new issue