diff --git a/docs/1.0/saving_loading_models.md b/docs/1.0/saving_loading_models.md index caa158e0a559439015125590b829650ac48b3f01..13dbb39139a16c3b5eb9d6f96cbc0e6eb77a9be7 100644 --- a/docs/1.0/saving_loading_models.md +++ b/docs/1.0/saving_loading_models.md @@ -1,35 +1,37 @@ -# Saving and Loading Models +# 保存和加载模型 -**Author:** [Matthew Inkawhich](https://github.com/MatthewInkawhich) +> 译者 [bruce1408](https://github.com/bruce1408) -This document provides solutions to a variety of use cases regarding the saving and loading of PyTorch models. Feel free to read the whole document, or just skip to the code you need for a desired use case. +**作者:** [Matthew Inkawhich](https://github.com/MatthewInkawhich) -When it comes to saving and loading models, there are three core functions to be familiar with: +本文提供有关Pytorch模型保存和加载的各种用例的解决方案。您可以随意阅读整个文档,或者只是跳转到所需用例的代码部分。 -1. [torch.save](https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save): Saves a serialized object to disk. This function uses Python’s [pickle](https://docs.python.org/3/library/pickle.html) utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function. -2. [torch.load](https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load): Uses [pickle](https://docs.python.org/3/library/pickle.html)’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see [Saving & Loading Model Across Devices](#saving-loading-model-across-devices)). -3. [torch.nn.Module.load_state_dict](https://pytorch.org/docs/stable/nn.html?highlight=load_state_dict#torch.nn.Module.load_state_dict): Loads a model’s parameter dictionary using a deserialized _state_dict_. For more information on _state_dict_, see [What is a state_dict?](#what-is-a-state-dict). +当保存和加载模型时,有三个核心功能需要熟悉: -**Contents:** +1. [torch.save](https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save): 将序列化对象保存到磁盘。 此函数使用 Python 的[pickle](https://docs.python.org/3/library/pickle.html)模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。 +2. [torch.load](https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load): 使用 [pickle](https://docs.python.org/3/library/pickle.html)的 unpickling 功能将pickle对象文件反序列化到内存。 此功能还可以有助于设备加载数据(详见 [Saving & Loading Model Across Devices](#saving-loading-model-across-devices)). +3. [torch.nn.Module.load_state_dict](https://pytorch.org/docs/stable/nn.html?highlight=load_state_dict#torch.nn.Module.load_state_dict): 使用反序列化函数 _state_dict_ 来加载模型的参数字典。更多有关 _state_dict_ 的信息,请参考[What is a state_dict?](#what-is-a-state-dict). -* [What is a state_dict?](#what-is-a-state-dict) -* [Saving & Loading Model for Inference](#saving-loading-model-for-inference) -* [Saving & Loading a General Checkpoint](#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training) -* [Saving Multiple Models in One File](#saving-multiple-models-in-one-file) -* [Warmstarting Model Using Parameters from a Different Model](#warmstarting-model-using-parameters-from-a-different-model) +**内容:** + +* [什么是`状态字典`?](#what-is-a-state-dict) +* [保存和加载推断模型](#saving-loading-model-for-inference) +* [保存 和 加载 Checkpoint](#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training) +* [在一个文件中保存多个模型](#saving-multiple-models-in-one-file) +* [使用在不同模型参数下的热启动模式](#warmstarting-model-using-parameters-from-a-different-model) * [Saving & Loading Model Across Devices](#saving-loading-model-across-devices) -## What is a `state_dict`? +## 什么是 `状态字典`? -In PyTorch, the learnable parameters (i.e. weights and biases) of an `torch.nn.Module` model are contained in the model’s _parameters_ (accessed with `model.parameters()`). A _state_dict_ is simply a Python dictionary object that maps each layer to its parameter tensor. Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) have entries in the model’s _state_dict_. Optimizer objects (`torch.optim`) also have a _state_dict_, which contains information about the optimizer’s state, as well as the hyperparameters used. +在Pytorch中,`torch.nn.Module` 模型的可学习参数(即权重和偏差)包含在模型的 _parameters_ 中,(使用`model.parameters()`可以进行访问)。 _state_dict_ 仅仅是python字典对象,它将每一层映射到其参数张量。注意,只有具有可学习参数的层(如卷积层、线性层等)的模型才具有 _state_dict_ 这一项。优化目标 `torch.optim` 也有 _state_dict_ 属性,它包含有关优化器的状态信息,以及使用的超参数。 -Because _state_dict_ objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers. +因为 _state_dict_ 的对象是python字典,所以他们可以很容易的保存、更新、更改和恢复,为Pytorch模型和优化器添加了大量模块。 -### Example: +### 示例: -Let’s take a look at the _state_dict_ from the simple model used in the [Training a classifier](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py) tutorial. +让我们从 简单模型[训练一个分类器](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py)中了解一下 _state_dict_ 的使用。 ```py # Define model @@ -70,7 +72,7 @@ for var_name in optimizer.state_dict(): ``` -**Output:** +**输出:** ```py Model's state_dict: @@ -91,18 +93,18 @@ param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': ``` -## Saving & Loading Model for Inference +## 保存和加载推断模型 -### Save/Load `state_dict` (Recommended) +### 保存/加载 `state_dict` (推荐使用) -**Save:** +**保存:** ```py torch.save(model.state_dict(), PATH) ``` -**Load:** +**加载:** ```py model = TheModelClass(*args, **kwargs) @@ -111,26 +113,26 @@ model.eval() ``` -When saving a model for inference, it is only necessary to save the trained model’s learned parameters. Saving the model’s _state_dict_ with the `torch.save()` function will give you the most flexibility for restoring the model later, which is why it is the recommended method for saving models. - -A common PyTorch convention is to save models using either a `.pt` or `.pth` file extension. +当保存好模型用来推断的时候,只需要保存模型学习到的参数,使用 `torch.save()` 函数来保存模型 _state_dict_ ,它会给模型恢复提供最大的灵活性,这就是为什么要推荐它来保存的原因。 -Remember that you must call `model.eval()` to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. +在 Pytorch 中最常见的模型保存使用 ‘.pt’ 或者是 ‘.pth’ 作为模型文件扩展名。 + +请记住,在运行推理之前,务必调用 `model.eval()` 去设置 dropout 和 batch normalization 层为评估模式。如果不这么做,可能导致模型推断结果不一致。 -Note +注意 -Notice that the `load_state_dict()` function takes a dictionary object, NOT a path to a saved object. This means that you must deserialize the saved _state_dict_ before you pass it to the `load_state_dict()` function. For example, you CANNOT load using `model.load_state_dict(PATH)`. +请注意 `load_state_dict()` 函数只接受字典对象,而不是保存对象的路径。这就意味着在你传给 `load_state_dict()` 函数之前,你必须反序列化你保存的 _state_dict_。例如,你无法通过 `model.load_state_dict(PATH)`来加载模型。 -### Save/Load Entire Model +### 保存/加载完整模型 -**Save:** +**保存:** ```py torch.save(model, PATH) ``` -**Load:** +**加载:** ```py # Model class must be defined somewhere @@ -138,16 +140,15 @@ model = torch.load(PATH) model.eval() ``` +此部分保存/加载过程使用最直观的语法并涉及最少量的代码。以Python[pickle](https://docs.python.org/3/library/pickle.html)模块的方式来保存模型。这种方法的缺点是序列化数据受限于某种特殊的类而且需要确切的字典结构。这是因为pickle无法保存模型类本身。相反,它保存包含类的文件的路径,该文件在加载时使用。因此,当在其他项目使用或者重构之后,您的代码可能会以各种方式中断。 -This save/load process uses the most intuitive syntax and involves the least amount of code. Saving a model in this way will save the entire module using Python’s [pickle](https://docs.python.org/3/library/pickle.html) module. The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. The reason for this is because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code can break in various ways when used in other projects or after refactors. - -A common PyTorch convention is to save models using either a `.pt` or `.pth` file extension. +在 Pytorch 中最常见的模型保存使用 ‘.pt’ 或者是 ‘.pth’ 作为模型文件扩展名。 -Remember that you must call `model.eval()` to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. +请记住,在运行推理之前,务必调用 `model.eval()` 去设置 dropout 和 batch normalization 层为评估模式。如果不这么做,可能导致模型推断结果不一致。 -## Saving & Loading a General Checkpoint for Inference and/or Resuming Training +## 保存 和 加载 Checkpoint 用于推理/继续训练 -### Save: +### 保存: ```py torch.save({ @@ -160,7 +161,7 @@ torch.save({ ``` -### Load: +### 加载: ```py model = TheModelClass(*args, **kwargs) @@ -178,17 +179,17 @@ model.train() ``` -When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just the model’s _state_dict_. It is important to also save the optimizer’s _state_dict_, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the epoch you left off on, the latest recorded training loss, external `torch.nn.Embedding` layers, etc. +当保存成 checkpoint 的时候,可用于推理或者是恢复训练,您保存的不仅仅是模型的 _state_dict_ 。 保存优化器的 _state_dict_ 也很重要, 因为它包含作为模型训练更新的缓冲区和参数。你也许想保存其他项目,比如最新记录的训练损失,外部的 `torch.nn.Embedding` 层等等。 -To save multiple components, organize them in a dictionary and use `torch.save()` to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the `.tar` file extension. +要保存多个组件,请在字典中组织它们并使用 `torch.save()` 来序列化字典。 Pytorch 中常见的保存checkpoint 是使用 `.tar` 文件扩展名。 -To load the items, first initialize the model and optimizer, then load the dictionary locally using `torch.load()`. From here, you can easily access the saved items by simply querying the dictionary as you would expect. +要加载项目,首先需要初始化模型和优化器,然后使用 `torch.load()` 来加载本地字典。 这里,您可以非常容易的通过简单查询字典来访问您所保存的项目。 -Remember that you must call `model.eval()` to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call `model.train()` to ensure these layers are in training mode. +请记住在运行推理之前,务必调用 `model.eval()` 去设置 dropout 和 batch normalization 为评估。如果不这样做,有可能得到不一致的推断结果。如果你想要恢复训练,请调用 `model.train()` 以确保这些层处于训练模式。 -## Saving Multiple Models in One File +## 在一个文件中保存多个模型 -### Save: +### 保存: ```py torch.save({ @@ -201,7 +202,7 @@ torch.save({ ``` -### Load: +### 加载: ```py modelA = TheModelAClass(*args, **kwargs) @@ -223,24 +224,24 @@ modelB.train() ``` -When saving a model comprised of multiple `torch.nn.Modules`, such as a GAN, a sequence-to-sequence model, or an ensemble of models, you follow the same approach as when you are saving a general checkpoint. In other words, save a dictionary of each model’s _state_dict_ and corresponding optimizer. As mentioned before, you can save any other items that may aid you in resuming training by simply appending them to the dictionary. +当保存一个模型由多个 `torch.nn.Modules`组成时,例如GAN(对抗生成网络), sequence-to-sequence (序列到序列模型), 或者是多个模型融合, 您可以采用与保存常规检查点相同的方法。换句话说,保存每个模型的 _state_dict_ 的字典和相对应的优化器。如前所述,您可以通过简单地将它们附加到字典的方式来保存任何其他项目,这样有助于您恢复训练。 -A common PyTorch convention is to save these checkpoints using the `.tar` file extension. +Pytorch 中常见的保存checkpoint 是使用 `.tar` 文件扩展名。 -To load the models, first initialize the models and optimizers, then load the dictionary locally using `torch.load()`. From here, you can easily access the saved items by simply querying the dictionary as you would expect. +要加载项目,首先需要初始化模型和优化器,然后使用 `torch.load()` 来加载本地字典。 这里,您可以非常容易的通过简单查询字典来访问您所保存的项目。 -Remember that you must call `model.eval()` to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call `model.train()` to set these layers to training mode. +请记住在运行推理之前,务必调用 `model.eval()` 去设置 dropout 和 batch normalization 为评估。如果不这样做,有可能得到不一致的推断结果。如果你想要恢复训练,请调用 `model.train()` 以确保这些层处于训练模式。 -## Warmstarting Model Using Parameters from a Different Model +## 使用在不同模型参数下的热启动模式 -### Save: +### 保存: ```py torch.save(modelA.state_dict(), PATH) ``` -### Load: +### 加载: ```py modelB = TheModelBClass(*args, **kwargs) @@ -248,24 +249,24 @@ modelB.load_state_dict(torch.load(PATH), strict=False) ``` -Partially loading a model or loading a partial model are common scenarios when transfer learning or training a new complex model. Leveraging trained parameters, even if only a few are usable, will help to warmstart the training process and hopefully help your model converge much faster than training from scratch. +在迁移学习或训练新的复杂模型时, 部分加载模型或加载部分模型是常见的情况。利用训练好的参数,有助于热启动训练过程,并希望帮助您的模型比从头开始训练更快地收敛 -Whether you are loading from a partial _state_dict_, which is missing some keys, or loading a _state_dict_ with more keys than the model that you are loading into, you can set the `strict` argument to **False** in the `load_state_dict()` function to ignore non-matching keys. +无论是从缺少某些键的 _state_dict_ 加载还是从键数多于加载模型的 _state_dict_ , 您可以通过在`load_state_dict()`函数中将`strict`参数设置为 **False** 来忽略非匹配键的函数。 -If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the _state_dict_ that you are loading to match the keys in the model that you are loading into. +如果要将参数从一个层加载到另一个层,但是某些键不匹配,主要修改正在加载的 _state_dict_ 中的参数键的名称以匹配要在加载到模型中的键即可。 -## Saving & Loading Model Across Devices +## 通过设备保存/加载模型 -### Save on GPU, Load on CPU +### 保存到 GPU, 加载到 CPU -**Save:** +**保存:** ```py torch.save(model.state_dict(), PATH) ``` -**Load:** +**加载:** ```py device = torch.device('cpu') @@ -274,18 +275,18 @@ model.load_state_dict(torch.load(PATH, map_location=device)) ``` -When loading a model on a CPU that was trained with a GPU, pass `torch.device('cpu')` to the `map_location` argument in the `torch.load()` function. In this case, the storages underlying the tensors are dynamically remapped to the CPU device using the `map_location` argument. +当从CPU上加载模型在GPU上训练时, 将 `torch.device('cpu')` 传递给 `torch.load()` 函数中的 `map_location`参数.在这种情况下,使用`map_location` 参数将张量下的存储器动态的重新映射到CPU设备。 -### Save on GPU, Load on GPU +### 保存到 GPU, 加载到 GPU -**Save:** +**保存:** ```py torch.save(model.state_dict(), PATH) ``` -**Load:** +**加载:** ```py device = torch.device("cuda") @@ -296,18 +297,18 @@ model.to(device) ``` -When loading a model on a GPU that was trained and saved on GPU, simply convert the initialized `model` to a CUDA optimized model using `model.to(torch.device('cuda'))`. Also, be sure to use the `.to(torch.device('cuda'))` function on all model inputs to prepare the data for the model. Note that calling `my_tensor.to(device)` returns a new copy of `my_tensor` on GPU. It does NOT overwrite `my_tensor`. Therefore, remember to manually overwrite tensors: `my_tensor = my_tensor.to(torch.device('cuda'))`. +当在GPU上训练并把模型保存在GPU,只需要使用 `model.to(torch.device('cuda'))`,将初始化的 `model` 转换为CUDA优化模型。另外,请务必在所有模型输入上使用 `.to(torch.device('cuda'))` 函数来为模型准备数据。请注意,调用 `my_tensor.to(device)` 会在GPU上返回`my_tensor` 的副本。因此,请记住手动覆盖张量:`my_tensor= my_tensor.to(torch.device('cuda'))`。 -### Save on CPU, Load on GPU +### 保存到 CPU, 加载到 GPU -**Save:** +**保存:** ```py torch.save(model.state_dict(), PATH) ``` -**Load:** +**加载:** ```py device = torch.device("cuda") @@ -318,22 +319,22 @@ model.to(device) ``` -When loading a model on a GPU that was trained and saved on CPU, set the `map_location` argument in the `torch.load()` function to _cuda:device_id_. This loads the model to a given GPU device. Next, be sure to call `model.to(torch.device('cuda'))` to convert the model’s parameter tensors to CUDA tensors. Finally, be sure to use the `.to(torch.device('cuda'))` function on all model inputs to prepare the data for the CUDA optimized model. Note that calling `my_tensor.to(device)` returns a new copy of `my_tensor` on GPU. It does NOT overwrite `my_tensor`. Therefore, remember to manually overwrite tensors: `my_tensor = my_tensor.to(torch.device('cuda'))`. +在CPU上训练好并保存的模型加载到GPU时, 将`torch.load()` 函数中的 `map_location` 参数设置为 _cuda:device_id_。这会将模型加载到指定的GPU设备。接下来,请务必调用 `model.to(torch.device('cuda'))` 将模型的参数张量转换为 CUDA 张量。最后,确保在所有模型输入上使用 `.to(torch.device('cuda'))` 函数来为CUDA优化模型。请注意, 调用 `my_tensor.to(device)` 会在GPU上返回 `my_tensor` 的新副本。 它不会覆盖 `my_tensor`。因此, 请手动覆盖张量 `my_tensor = my_tensor.to(torch.device('cuda'))`。 -### Saving `torch.nn.DataParallel` Models +### 保存 `torch.nn.DataParallel` 模型 -**Save:** +**保存:** ```py torch.save(model.module.state_dict(), PATH) ``` -**Load:** +**加载:** ```py # Load to whatever device you want ``` -`torch.nn.DataParallel` is a model wrapper that enables parallel GPU utilization. To save a `DataParallel` model generically, save the `model.module.state_dict()`. This way, you have the flexibility to load the model any way you want to any device you want. +`torch.nn.DataParallel` 是一个模型封装,支持并行GPU使用。要一般性的保存 `DataParallel` 模型, 请保存 `model.module.state_dict()`。这样,您就可以非常灵活地以任何方式加载模型到您想要的设备中。