diff --git a/Train.py b/Train.py index 49172dc..157af01 100644 --- a/Train.py +++ b/Train.py @@ -153,7 +153,7 @@ if __name__ == '__main__': model = Net(num_classes=4) # 根据你的 Net 类调整 #断点继续训练 if os.path.exists('best_model.pth'): - model.load_state_dict(torch.load('best_model.pth')) + model.load_state_dict(torch.load('best_model.pth',map_location=torch.device('cpu'))) model = model.to(device) # 打印模型信息