# 序列化语义 > 原文:[`pytorch.org/docs/stable/notes/serialization.html`](https://pytorch.org/docs/stable/notes/serialization.html) 本说明描述了如何在 Python 中保存和加载 PyTorch 张量和模块状态,以及如何序列化 Python 模块,以便它们可以在 C++中加载。 目录 + 序列化语义 + 保存和加载张量 + 保存和加载张量保留视图 + 保存和加载 torch.nn.Modules + 序列化 torch.nn.Modules 并在 C++中加载它们 + 在不同 PyTorch 版本中保存和加载 ScriptModules + torch.div 执行整数除法 + torch.full 总是推断为浮点 dtype + 实用函数 ## 保存和加载张量[](#saving-and-loading-tensors "跳转到此标题") `torch.save()` 和 `torch.load()` 让您轻松保存和加载张量: ```py >>> t = torch.tensor([1., 2.]) >>> torch.save(t, 'tensor.pt') >>> torch.load('tensor.pt') tensor([1., 2.]) ``` 按照惯例,PyTorch 文件通常使用‘.pt’或‘.pth’扩展名编写。 `torch.save()` 和 `torch.load()` 默认使用 Python 的 pickle,因此您也可以将多个张量保存为 Python 对象的一部分,如元组、列表和字典: ```py >>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])} >>> torch.save(d, 'tensor_dict.pt') >>> torch.load('tensor_dict.pt') {'a': tensor([1., 2.]), 'b': tensor([3., 4.])} ``` 如果数据结构是可 pickle 的,那么包含 PyTorch 张量的自定义数据结构也可以保存。## 保存和加载张量保留视图[](#saving-and-loading-tensors-preserves-views "跳转到此标题") 保存张量保留它们的视图关系: ```py >>> numbers = torch.arange(1, 10) >>> evens = numbers[1::2] >>> torch.save([numbers, evens], 'tensors.pt') >>> loaded_numbers, loaded_evens = torch.load('tensors.pt') >>> loaded_evens *= 2 >>> loaded_numbers tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9]) ``` 在幕后,这些张量共享相同的“存储”。查看[Tensor Views](https://pytorch.org/docs/main/tensor_view.html)了解更多关于视图和存储的信息。 当 PyTorch 保存张量时,它会分别保存它们的存储对象和张量元数据。这是一个实现细节,可能会在将来发生变化,但通常可以节省空间,并让 PyTorch 轻松重建加载的张量之间的视图关系。例如,在上面的代码片段中,只有一个存储被写入到‘tensors.pt’中。 然而,在某些情况下,保存当前存储对象可能是不必要的,并且会创建过大的文件。在下面的代码片段中,一个比保存的张量大得多的存储被写入到文件中: ```py >>> large = torch.arange(1, 1000) >>> small = large[0:5] >>> torch.save(small, 'small.pt') >>> loaded_small = torch.load('small.pt') >>> loaded_small.storage().size() 999 ``` 与仅将小张量中的五个值保存到‘small.pt’不同,它与 large 共享的存储中的 999 个值被保存和加载。 当保存具有比其存储对象更少元素的张量时,可以通过首先克隆张量来减小保存文件的大小。克隆张量会产生一个新的张量,其中包含张量中的值的新存储对象: ```py >>> large = torch.arange(1, 1000) >>> small = large[0:5] >>> torch.save(small.clone(), 'small.pt') # saves a clone of small >>> loaded_small = torch.load('small.pt') >>> loaded_small.storage().size() 5 ``` 然而,由于克隆的张量彼此独立,它们没有原始张量的视图关系。如果在保存比其存储对象小的张量时,文件大小和视图关系都很重要,则必须小心构建新张量,以最小化其存储对象的大小,但仍具有所需的视图关系后再保存。## 保存和加载 torch.nn.Modules[](#saving-and-loading-torch-nn-modules "跳转到此标题") 参见:[教程:保存和加载模块](https://pytorch.org/tutorials/beginner/saving_loading_models.html) 在 PyTorch 中,模块的状态经常使用‘state dict’进行序列化。模块的状态字典包含所有参数和持久缓冲区: ```py >>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True) >>> list(bn.named_parameters()) [('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)), ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))] >>> list(bn.named_buffers()) [('running_mean', tensor([0., 0., 0.])), ('running_var', tensor([1., 1., 1.])), ('num_batches_tracked', tensor(0))] >>> bn.state_dict() OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([0., 0., 0.])), ('running_mean', tensor([0., 0., 0.])), ('running_var', tensor([1., 1., 1.])), ('num_batches_tracked', tensor(0))]) ``` 为了兼容性的原因,建议不直接保存模块,而是只保存其状态字典。Python 模块甚至有一个函数`load_state_dict()`,可以从状态字典中恢复它们的状态: ```py >>> torch.save(bn.state_dict(), 'bn.pt') >>> bn_state_dict = torch.load('bn.pt') >>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True) >>> new_bn.load_state_dict(bn_state_dict) ``` 请注意,状态字典首先使用`torch.load()`从文件中加载,然后使用`load_state_dict()`恢复状态。 即使是自定义模块和包含其他模块的模块也有状态字典,并且可以使用这种模式: ```py # A module with two linear layers >>> class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.l0 = torch.nn.Linear(4, 2) self.l1 = torch.nn.Linear(2, 1) def forward(self, input): out0 = self.l0(input) out0_relu = torch.nn.functional.relu(out0) return self.l1(out0_relu) >>> m = MyModule() >>> m.state_dict() OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406], [-0.3289, 0.2827, 0.4588, 0.2031]])), ('l0.bias', tensor([ 0.0300, -0.1316])), ('l1.weight', tensor([[0.6533, 0.3413]])), ('l1.bias', tensor([-0.1112]))]) >>> torch.save(m.state_dict(), 'mymodule.pt') >>> m_state_dict = torch.load('mymodule.pt') >>> new_m = MyModule() >>> new_m.load_state_dict(m_state_dict) ``` ## 序列化 torch.nn.Modules 并在 C++中加载它们[](#serializing-torch-nn-modules-and-loading-them-in-c "跳转到此标题") 另请参阅:[教程:在 C++中加载 TorchScript 模型](https://pytorch.org/tutorials/advanced/cpp_export.html) ScriptModules 可以被序列化为 TorchScript 程序,并使用`torch.jit.load()`加载。这种序列化编码了所有模块的方法、子模块、参数和属性,并允许在 C++中加载序列化的程序(即不需要 Python)。 `torch.jit.save()`和`torch.save()`之间的区别可能不是立即清楚的。`torch.save()`使用 pickle 保存 Python 对象。这对于原型设计、研究和训练特别有用。另一方面,`torch.jit.save()`将 ScriptModules 序列化为可以在 Python 或 C++中加载的格式。这在保存和加载 C++模块或在 C++中运行在 Python 中训练的模块时非常有用,这是部署 PyTorch 模型时的常见做法。 要在 Python 中脚本化、序列化和加载模块: ```py >>> scripted_module = torch.jit.script(MyModule()) >>> torch.jit.save(scripted_module, 'mymodule.pt') >>> torch.jit.load('mymodule.pt') RecursiveScriptModule( original_name=MyModule (l0): RecursiveScriptModule(original_name=Linear) (l1): RecursiveScriptModule(original_name=Linear) ) ``` 跟踪模块也可以使用`torch.jit.save()`保存,但要注意只有跟踪的代码路径被序列化。以下示例演示了这一点: ```py # A module with control flow >>> class ControlFlowModule(torch.nn.Module): def __init__(self): super().__init__() self.l0 = torch.nn.Linear(4, 2) self.l1 = torch.nn.Linear(2, 1) def forward(self, input): if input.dim() > 1: return torch.tensor(0) out0 = self.l0(input) out0_relu = torch.nn.functional.relu(out0) return self.l1(out0_relu) >>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4)) >>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt') >>> loaded = torch.jit.load('controlflowmodule_traced.pt') >>> loaded(torch.randn(2, 4))) tensor([[-0.1571], [-0.3793]], grad_fn=) >>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4)) >>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt') >>> loaded = torch.jit.load('controlflowmodule_scripted.pt') >> loaded(torch.randn(2, 4)) tensor(0) ``` 上述模块有一个 if 语句,不会被跟踪的输入触发,因此不是跟踪的模块的一部分,也不会与之一起序列化。然而,脚本化模块包含 if 语句,并与之一起序列化。有关脚本化和跟踪的更多信息,请参阅[TorchScript 文档](https://pytorch.org/docs/stable/jit.html)。 最后,在 C++中加载模块: ```py >>> torch::jit::script::Module module; >>> module = torch::jit::load('controlflowmodule_scripted.pt'); ``` 有关如何在 C++中使用 PyTorch 模块的详细信息,请参阅[PyTorch C++ API 文档](https://pytorch.org/cppdocs/)。## 在 PyTorch 版本间保存和加载 ScriptModules[](#saving-and-loading-scriptmodules-across-pytorch-versions "跳转到此标题") PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。较旧版本的 PyTorch 可能不支持较新的模块,而较新版本可能已删除或修改了较旧的行为。这些更改在 PyTorch 的[发布说明](https://github.com/pytorch/pytorch/releases)中有明确描述,依赖已更改功能的模块可能需要更新才能继续正常工作。在有限的情况下,如下所述,PyTorch 将保留序列化 ScriptModules 的历史行为,因此它们不需要更新。 ### torch.div 执行整数除法[](#torch-div-performing-integer-division "跳转到此标题") 在 PyTorch 1.5 及更早版本中,当给定两个整数输入时,`torch.div()`将执行地板除法: ```py # PyTorch 1.5 (and earlier) >>> a = torch.tensor(5) >>> b = torch.tensor(3) >>> a / b tensor(1) ``` 然而,在 PyTorch 1.7 中,`torch.div()`将始终执行其输入的真除法,就像 Python 3 中的除法一样: ```py # PyTorch 1.7 >>> a = torch.tensor(5) >>> b = torch.tensor(3) >>> a / b tensor(1.6667) ``` `torch.div()`的行为在序列化的 ScriptModules 中得到保留。也就是说,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 将继续看到当给定两个整数输入时,`torch.div()`执行地板除法,即使在较新版本的 PyTorch 中加载时也是如此。然而,使用`torch.div()`并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在较早版本的 PyTorch 中加载,因为这些较早版本不理解新的行为。 ### torch.full 总是推断浮点数据类型[](#torch-full-always-inferring-a-float-dtype "跳转到此标题的永久链接") 在 PyTorch 1.5 及更早版本中,`torch.full()`始终返回一个浮点张量,而不管给定的填充值是什么: ```py # PyTorch 1.5 and earlier >>> torch.full((3,), 1) # Note the integer fill value... tensor([1., 1., 1.]) # ...but float tensor! ``` 然而,在 PyTorch 1.7 中,`torch.full()`将从填充值推断返回的张量的数据类型: ```py # PyTorch 1.7 >>> torch.full((3,), 1) tensor([1, 1, 1]) >>> torch.full((3,), True) tensor([True, True, True]) >>> torch.full((3,), 1.) tensor([1., 1., 1.]) >>> torch.full((3,), 1 + 1j) tensor([1.+1.j, 1.+1.j, 1.+1.j]) ``` `torch.full()`的行为在序列化的 ScriptModules 中得到保留。也就是说,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 将继续看到 torch.full 默认返回浮点张量,即使给定布尔或整数填充值。然而,使用`torch.full()`并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在较早版本的 PyTorch 中加载,因为这些较早版本不理解新的行为。## 实用函数[](#utility-functions "跳转到此标题的永久链接") 以下实用函数与序列化相关: ```py torch.serialization.register_package(priority, tagger, deserializer)¶ ``` 为标记和反序列化存储对象注册可调用对象,并附带优先级。标记在保存时将设备与存储对象关联,而在加载时反序列化将存储对象移动到适当的设备上。`tagger`和`deserializer`按照它们的`priority`给出的顺序运行,直到一个标记器/反序列化器返回一个不是 None 的值。 要覆盖全局注册表中设备的反序列化行为,可以注册一个优先级高于现有标记器的标记器。 此函数还可用于为新设备注册标记器和反序列化器。 参数 + **priority**([*int*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12)"))– 指示与标记器和反序列化器相关联的优先级,其中较低的值表示较高的优先级。 + **tagger**([*Callable*](https://docs.python.org/3/library/typing.html#typing.Callable "(在 Python v3.12)")*[**[*[*Union*](https://docs.python.org/3/library/typing.html#typing.Union "(在 Python v3.12)")***Storage**,* [*TypedStorage**,* *UntypedStorage**]**]**,* [*Optional*](https://docs.python.org/3/library/typing.html#typing.Optional "(在 Python v3.12)")*[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)")*]**]*) – 接受存储对象并返回其标记设备的可调用对象,返回字符串或 None。 + deserializer(Callable)[Union[Storage, TypedStorage, UntypedStorage], str] [Optional[Union[Storage, TypedStorage, UntypedStorage]]] - 接受存储对象和设备字符串并返回适当设备上的存储对象或 None 的可调用函数。 返回 无 示例 ```py >>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) ``` ```py torch.serialization.get_default_load_endianness()¶ ``` 获取用于加载文件的回退字节顺序 如果保存的检查点中不存在字节顺序标记,则将使用此字节顺序作为回退。默认情况下是“本机”字节顺序。 返回 Optional[LoadEndianness] 返回类型 default_load_endian ```py torch.serialization.set_default_load_endianness(endianness)¶ ``` 设置用于加载文件的回退字节顺序 如果保存的检查点中不存在字节顺序标记,则将使用此字节顺序作为回退。默认情况下是“本机”字节顺序。 参数 endianness - 新的回退字节顺序