添加模型断点继续加载训练功能
This commit is contained in:
parent
968e108857
commit
f8bb340a70
1 changed files with 1 additions and 1 deletions
2
Train.py
2
Train.py
|
|
@ -153,7 +153,7 @@ if __name__ == '__main__':
|
||||||
model = Net(num_classes=4) # 根据你的 Net 类调整
|
model = Net(num_classes=4) # 根据你的 Net 类调整
|
||||||
#断点继续训练
|
#断点继续训练
|
||||||
if os.path.exists('best_model.pth'):
|
if os.path.exists('best_model.pth'):
|
||||||
model.load_state_dict(torch.load('best_model.pth'))
|
model.load_state_dict(torch.load('best_model.pth',map_location=torch.device('cpu')))
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
# 打印模型信息
|
# 打印模型信息
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue