Update Train.py
添加模型中断后继续训练的功能
This commit is contained in:
parent
cb6bdc7eb8
commit
cff98f70cc
1 changed files with 6 additions and 4 deletions
8
Train.py
8
Train.py
|
|
@ -13,6 +13,7 @@ from tqdm import tqdm # 进度条,可选
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from Model import Net
|
from Model import Net
|
||||||
from Dataloader import create_dataloaders
|
from Dataloader import create_dataloaders
|
||||||
|
import os
|
||||||
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
|
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
|
||||||
"""训练一个epoch"""
|
"""训练一个epoch"""
|
||||||
model.train() # 设置为训练模式
|
model.train() # 设置为训练模式
|
||||||
|
|
@ -150,6 +151,9 @@ if __name__ == '__main__':
|
||||||
# 1. 创建模型
|
# 1. 创建模型
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
|
||||||
model = Net(num_classes=4) # 根据你的 Net 类调整
|
model = Net(num_classes=4) # 根据你的 Net 类调整
|
||||||
|
#断点继续训练
|
||||||
|
if os.path.exists('best_model.pth'):
|
||||||
|
model.load_state_dict(torch.load('best_model.pth'))
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
# 打印模型信息
|
# 打印模型信息
|
||||||
|
|
@ -161,11 +165,9 @@ if __name__ == '__main__':
|
||||||
model=model,
|
model=model,
|
||||||
train_loader=train_loader,
|
train_loader=train_loader,
|
||||||
val_loader=val_loader,
|
val_loader=val_loader,
|
||||||
epochs=50,
|
epochs=20,
|
||||||
lr=0.001,
|
lr=0.001,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 加载最佳模型用于预测
|
# 3. 加载最佳模型用于预测
|
||||||
model.load_state_dict(torch.load('best_model.pth'))
|
model.load_state_dict(torch.load('best_model.pth'))
|
||||||
print('训练完成,最佳模型已加载')
|
|
||||||
Loading…
Reference in a new issue