整个模型的保存与加载
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')
只保存以及加载模型参数
torch.save(model_object.state_dict(), 'params.pth')
model_object.load_state_dict(torch.load('params.pth'))
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')
torch.save(model_object.state_dict(), 'params.pth')
model_object.load_state_dict(torch.load('params.pth'))