添加模型断点继续加载训练功能

This commit is contained in:
yukun-hh 2026-04-21 15:57:13 +08:00
parent 968e108857
commit f8bb340a70

View file

@ -153,7 +153,7 @@ if __name__ == '__main__':
model = Net(num_classes=4) # 根据你的 Net 类调整
#断点继续训练
if os.path.exists('best_model.pth'):
model.load_state_dict(torch.load('best_model.pth'))
model.load_state_dict(torch.load('best_model.pth',map_location=torch.device('cpu')))
model = model.to(device)
# 打印模型信息