PyTorch模型保存与加载

pytorch notes

Posted by viewsetting on September 24, 2019

整个模型的保存与加载

torch.save(model_object, 'model.pth')  
model = torch.load('model.pth')  

只保存以及加载模型参数

torch.save(model_object.state_dict(), 'params.pth')  
model_object.load_state_dict(torch.load('params.pth'))