# torch脚本 > 原文: [https://pytorch.org/docs/stable/jit.html](https://pytorch.org/docs/stable/jit.html) * [创建 TorchScript 代码](#creating-torchscript-code) * [混合跟踪和脚本编写](#mixing-tracing-and-scripting) * [迁移到 PyTorch 1.2 递归脚本 API](#migrating-to-pytorch-1-2-recursive-scripting-api) * [模块](#modules) * [功能](#functions) * [TorchScript 类](#torchscript-classes) * [属性](#attributes) * [Python 2](#python-2) * [常数](#constants) * [变量](#variables) * [TorchScript 语言参考](#torchscript-language-reference) * [类型](#supported-type) * [默认类型](#default-types) * [可选类型细化](#optional-type-refinement) * [TorchScript 类](#id3) * [命名为元组](#named-tuples) * [表达式](#expressions) * [文字](#literals) * [列表结构](#list-construction) * [元组结构](#tuple-construction) * [字典结构](#dict-construction) * [变量](#id5) * [算术运算符](#arithmetic-operators) * [比较运算符](#comparison-operators) * [逻辑运算符](#logical-operators) * [下标和切片](#subscripts-and-slicing) * [函数调用](#function-calls) * [方法调用](#method-calls) * [三元表达式](#ternary-expressions) * [演员表](#casts) * [访问模块参数](#accessing-module-parameters) * [语句](#statements) * [简单分配](#simple-assignments) * [模式匹配分配](#pattern-matching-assignments) * [打印报表](#print-statements) * [If 语句](#if-statements) * [While 循环](#while-loops) * [适用于范围为](#for-loops-with-range)的循环 * [用于遍历元组的循环](#for-loops-over-tuples) * [用于在常量 nn.ModuleList](#for-loops-over-constant-nn-modulelist) 上循环 * [中断并继续](#break-and-continue) * [返回](#return) * [可变分辨率](#variable-resolution) * [使用 Python 值](#use-of-python-values) * [功能](#id6) * [Python 模块上的属性查找](#attribute-lookup-on-python-modules) * [Python 定义的常量](#python-defined-constants) * [模块属性](#module-attributes) * [调试](#debugging) * [禁用用于调试的 JIT](#disable-jit-for-debugging) * [检查码](#inspecting-code) * [解释图](#interpreting-graphs) * [追踪案例](#tracing-edge-cases) * [自动跟踪检查](#automatic-trace-checking) * [跟踪器警告](#tracer-warnings) * [内置函数](#builtin-functions) * [常见问题解答](#frequently-asked-questions) TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。 任何 TorchScript 程序都可以从 Python 进程中保存并加载到没有 Python 依赖项的进程中。 我们提供了将模型从纯 Python 程序逐步过渡到可以独立于 Python 运行的 TorchScript 程序的工具,例如在独立的 C ++程序中。 这样就可以使用 Python 中熟悉的工具在 PyTorch 中训练模型,然后通过 TorchScript 将模型导出到生产环境中,在该生产环境中 Python 程序可能由于性能和多线程原因而处于不利地位。 有关 TorchScript 的简要介绍,请参见 [TorchScript 简介](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html)教程。 有关将 PyTorch 模型转换为 TorchScript 并在 C ++中运行的端到端示例,请参见[在 C ++中加载 PyTorch 模型](https://pytorch.org/tutorials/advanced/cpp_export.html)教程。 ## [创建 TorchScript 代码](#id10) * * * ``` class torch.jit.ScriptModule ``` * * * ``` property code ``` 返回`forward`方法的内部图的漂亮打印表示形式(作为有效的 Python 语法)。 有关详细信息,请参见[检查代码](#inspecting-code)。 * * * ``` property graph ``` 返回`forward`方法的内部图形的字符串表示形式。 有关详细信息,请参见[解释图](#interpreting-graphs)。 * * * ``` save(f, _extra_files=ExtraFilesMap{}) ``` 有关详细信息,请参见 [`torch.jit.save`](#torch.jit.save "torch.jit.save") 。 * * * ``` class torch.jit.ScriptFunction ``` 功能上与 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 等效,但是代表单个功能,没有任何属性或参数。 * * * ``` torch.jit.script(obj) ``` 为函数或`nn.Module`编写脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码,然后返回 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 或 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction") 。 TorchScript 本身是 Python 语言的子集,因此 Python 并非所有功能都可以使用,但是我们提供了足够的功能来在张量上进行计算并执行与控制有关的操作。 有关完整指南,请参见 [TorchScript 语言参考](#torchscript-language-reference)。 `torch.jit.script`可用作模块和功能的函数,以及 [TorchScript 类](#torchscript-class)和功能的修饰器`@torch.jit.script`。 ``` Scripting a function ``` `@torch.jit.script`装饰器将通过编译函数的主体来构造 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction") 。 示例(编写函数): ``` import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFuncion # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2)) ``` ``` Scripting an nn.Module ``` 默认情况下,为`nn.Module`编写脚本将编译`forward`方法,并递归编译`forward`调用的任何方法,子模块和函数。 如果`nn.Module`仅使用 TorchScript 支持的功能,则无需更改原始模块代码。 `script`将构建 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") ,该副本具有原始模块的属性,参数和方法的副本。 示例(使用参数编写简单模块的脚本): ``` import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3)) ``` 示例(使用跟踪的子模块编写模块脚本): ``` import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule()) ``` 要编译除`forward`以外的方法(并递归编译其调用的任何内容),请将 [`@torch.jit.export`](#torch.jit.export "torch.jit.export") 装饰器添加到该方法。 要选择退出编译,请使用 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore") 。 示例(模块中的导出方法和忽略方法): ``` import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2))) ``` * * * ``` torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5) ``` 跟踪一个函数并返回将使用即时编译进行优化的可执行文件或 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction") 。 对于仅在`Tensor`和`Tensor`的列表,字典和元组上运行的代码,跟踪是理想的选择。 使用`torch.jit.trace`和 [`torch.jit.trace_module`](#torch.jit.trace_module "torch.jit.trace_module") ,您可以将现有模块或 Python 函数转换为 TorchScript [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction") 或 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 。 您必须提供示例输入,然后我们运行该函数,记录在所有张量上执行的操作。 * 独立功能的最终记录将产生 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction") 。 * `nn.Module`或`nn.Module`的`forward`功能的所得记录产生 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 。 该模块还包含原始模块也具有的任何参数。 警告 跟踪仅正确记录不依赖数据的功能和模块(例如,对张量中的数据没有条件)并且不包含任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)。 跟踪仅记录在给定张量上运行给定函数时执行的操作。 因此,返回的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 将始终在任何输入上运行相同的跟踪图。 当期望模块根据输入和/或模块状态运行不同的操作集时,这具有重要意义。 例如, * 跟踪将不会记录任何控制流,例如 if 语句或循环。 当整个模块的控制流恒定时,这很好,并且通常内联控制流决策。 但是有时控制流实际上是模型本身的一部分。 例如,循环网络是输入序列(可能是动态)长度上的循环。 * 在返回的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 中,在`training`和`eval`模式下具有不同行为的操作将始终像在跟踪过程中一样处于运行状态,无论[是哪种模式 ] `ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 已插入。 在这种情况下,跟踪是不合适的, [`scripting`](#torch.jit.script "torch.jit.script") 是更好的选择。 如果跟踪此类模型,则可能在随后的模型调用中静默地得到不正确的结果。 在执行可能会导致产生不正确跟踪的操作时,跟踪器将尝试发出警告。 参数 * **函数**(可调用的_或_ [_torch.nn.Module_](nn.html#torch.nn.Module "torch.nn.Module"))– Python 函数或`torch.nn.Module` 与`example_inputs`一起运行。 `func`的参数和返回值必须是张量或包含张量的(可能是嵌套的)元组。 将模块传递到 [`torch.jit.trace`](#torch.jit.trace "torch.jit.trace") 时,仅运行并跟踪`forward`方法(有关详细信息,请参见 [`torch.jit.trace`](#torch.jit.trace_module "torch.jit.trace_module"))。 * **example_inputs** (_tuple_ )–示例输入的元组,将在跟踪时传递给函数。 假设跟踪的操作支持这些类型和形状,则可以使用不同类型和形状的输入来运行结果跟踪。 `example_inputs`也可以是单个张量,在这种情况下,它会自动包装在元组中。 ``` Keyword Arguments ``` * **check_trace** (`bool`,可选)–检查通过跟踪代码运行的相同输入是否产生相同的输出。 默认值:`True`。 例如,如果您的网络包含不确定性操作,或者即使检查程序失败,但您确定网络正确,则可能要禁用此功能。 * **check_inputs** (_元组列表_ _,_ _可选_)–输入参数的元组列表,应使用这些元组来检查跟踪内容 是期待。 每个元组等效于`example_inputs`中指定的一组输入参数。 为了获得最佳结果,请传递一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。 如果未指定,则使用原始的`example_inputs`进行检查 * **check_tolerance** (_python:float_ _,_ _可选_)–在检查程序中使用的浮点比较公差。 如果结果由于已知原因(例如操作员融合)而在数值上出现差异,则可以使用此方法来放松检查器的严格性。 退货 如果`callable`是`nn.Module`的`nn.Module`或`forward`,则`trace`将使用包含跟踪代码的单个`forward`方法返回 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 对象。 返回的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 将具有与原始`nn.Module`相同的子模块和参数集。 如果`callable`是独立功能,则`trace`返回 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction") 示例(跟踪函数): ``` import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment ``` 示例(跟踪现有模块): ``` import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) ``` * * * ``` torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5) ``` 跟踪模块并返回可执行文件 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") ,该文件将使用即时编译进行优化。 将模块传递到 [`torch.jit.trace`](#torch.jit.trace "torch.jit.trace") 时,仅运行并跟踪`forward`方法。 使用`trace_module`,您可以指定方法名称的字典作为示例输入,以跟踪下面的参数(请参见`example_inputs`)。 有关跟踪的更多信息,请参见 [`torch.jit.trace`](#torch.jit.trace "torch.jit.trace") 。 Parameters * **mod** ([_Torch.nn.Module_](nn.html#torch.nn.Module "torch.nn.Module"))–一种`torch.nn.Module`,其中包含名称在`example_inputs`中指定的方法。 给定的方法将被编译为单个 <cite>ScriptModule</cite> 的一部分。 * **example_inputs** (_dict_ )–包含样本输入的字典,该样本输入由`mod`中的方法名称索引。 输入将在跟踪时传递给名称与输入键对应的方法。 `{ 'forward' : example_forward_input, 'method2': example_method2_input}` ``` Keyword Arguments ``` * **check_trace** (`bool`, optional) – Check if the same inputs run through traced code produce the same outputs. Default: `True`. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure. * **check_inputs** (_字典列表_ _,_ _可选_)–输入参数的字典列表,用于检查跟踪内容 是期待。 每个元组等效于`example_inputs`中指定的一组输入参数。 为了获得最佳结果,请传递一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。 如果未指定,则使用原始的`example_inputs`进行检查 * **check_tolerance** (_python:float__,_ _optional_) – Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion. Returns 具有单个`forward`方法的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 对象,其中包含跟踪的代码。 当`func`是`torch.nn.Module`时,返回的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 将具有与`func`相同的子模块和参数集。 示例(使用多种方法跟踪模块): ``` import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) def weighted_kernel_sum(self, weight): return weight * self.conv.weight n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} module = torch.jit.trace_module(n, inputs) ``` * * * ``` torch.jit.save(m, f, _extra_files=ExtraFilesMap{}) ``` 保存此模块的脱机版本以在单独的过程中使用。 保存的模块将序列化此模块的所有方法,子模块,参数和属性。 可以使用`torch::jit::load(filename)`将其加载到 C ++ API 中,或者使用 [`torch.jit.load`](#torch.jit.load "torch.jit.load") 加载到 Python API 中。 为了能够保存模块,它不得对本地 Python 函数进行任何调用。 这意味着所有子模块也必须是`torch.jit.ScriptModule`的子类。 危险 所有模块,无论使用哪种设备,都始终在加载期间加载到 CPU 中。 这与 [`load`](#torch.jit.load "torch.jit.load") 的语义不同,并且将来可能会发生变化。 Parameters * **m** –要保存的 ScriptModule。 * **f** –类似于文件的对象(必须实现写入和刷新)或包含文件名的字符串。 * **_extra_files** -从文件名映射到将作为“ f”的一部分存储的内容。 Warning 如果您使用的是 Python 2,`torch.jit.save`不支持`StringIO.StringIO`作为有效的类似文件的对象。 这是因为 write 方法应返回写入的字节数; `StringIO.write()`不这样做。 请改用`io.BytesIO`之类的东西。 例: ``` import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 m = torch.jit.script(MyModule()) # Save to file torch.jit.save(m, 'scriptmodule.pt') # This line is equivalent to the previous m.save("scriptmodule.pt") # Save to io.BytesIO buffer buffer = io.BytesIO() torch.jit.save(m, buffer) # Save with extra files extra_files = torch._C.ExtraFilesMap() extra_files['foo.txt'] = 'bar' torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files) ``` * * * ``` torch.jit.load(f, map_location=None, _extra_files=ExtraFilesMap{}) ``` 加载先前用 [`torch.jit.save`](#torch.jit.save "torch.jit.save") 保存的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 或 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction") 之前保存的所有模块,无论使用何种设备,都首先加载到 CPU 中,然后再移动到保存它们的设备上。 如果失败(例如,因为运行时系统没有某些设备),则会引发异常。 Parameters * **f** –类似于文件的对象(必须实现读取,读取行,告诉和查找),或包含文件名的字符串 * **map_location** (_字符串_ _或_ [_torch设备_](tensor_attributes.html#torch.torch.device "torch.torch.device"))– `torch.save`中`map_location`的简化版本 用于动态地将存储重新映射到另一组设备。 * **_extra_files** (_文件名到内容的字典_)–映射中给定的多余文件名将被加载,其内容将存储在提供的映射中。 Returns [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 对象。 Example: ``` import torch import io torch.jit.load('scriptmodule.pt') # Load ScriptModule from io.BytesIO object with open('scriptmodule.pt', 'rb') as f: buffer = io.BytesIO(f.read()) # Load all tensors to the original device torch.jit.load(buffer) # Load all tensors onto CPU, using a device buffer.seek(0) torch.jit.load(buffer, map_location=torch.device('cpu')) # Load all tensors onto CPU, using a string buffer.seek(0) torch.jit.load(buffer, map_location='cpu') # Load with extra files. extra_files = torch._C.ExtraFilesMap() extra_files['foo.txt'] = 'bar' torch.jit.load('scriptmodule.pt', _extra_files=extra_files) print(extra_files['foo.txt']) ``` ## [混合跟踪和脚本编写](#id11) 在许多情况下,将模型转换为 TorchScript 都可以使用跟踪或脚本编写。 可以组成跟踪和脚本以适合模型一部分的特定要求。 脚本函数可以调用跟踪函数。 当您需要在简单的前馈模型周围使用控制流时,这特别有用。 例如,序列到序列模型的波束搜索通常将用脚本编写,但是可以调用使用跟踪生成的编码器模块。 示例(在脚本中调用跟踪的函数): ``` import torch def foo(x, y): return 2 * x + y traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) @torch.jit.script def bar(x): return traced_foo(x, x) ``` 跟踪的函数可以调用脚本函数。 即使大部分模型只是前馈网络,当模型的一小部分需要一些控制流时,这也很有用。 跟踪函数调用的脚本函数内部的控制流已正确保留。 示例(在跟踪函数中调用脚本函数): ``` import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r def bar(x, y, z): return foo(x, y) + z traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3))) ``` 此组合也适用于`nn.Module`,在这里它可用于通过跟踪来生成子模块,该跟踪可以从脚本模块的方法中调用。 示例(使用跟踪模块): ``` import torch import torchvision class MyScriptModule(torch.nn.Module): def __init__(self): super(MyScriptModule, self).__init__() self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) .resize_(1, 3, 1, 1)) self.resnet = torch.jit.trace(torchvision.models.resnet18(), torch.rand(1, 3, 224, 224)) def forward(self, input): return self.resnet(input - self.means) my_script_module = torch.jit.script(MyScriptModule()) ``` ## [迁移到 PyTorch 1.2 递归脚本 API](#id12) 本节详细介绍了 PyTorch 1.2 中对 TorchScript 的更改。 如果您不熟悉 TorchScript,则可以跳过本节。 PyTorch 1.2 对 TorchScript API 进行了两个主要更改。 1\. [`torch.jit.script`](#torch.jit.script "torch.jit.script") 现在将尝试递归编译遇到的函数,方法和类。 调用`torch.jit.script`后,编译将是“选择退出”,而不是“选择加入”。 2.现在`torch.jit.script(nn_module_instance)`是创建 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 的首选方法,而不是从`torch.jit.ScriptModule`继承。 这些更改组合在一起,提供了一个更简单易用的 API,可将您的`nn.Module`转换为 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") ,可以在非 Python 环境中进行优化和执行。 新用法如下所示: ``` import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) my_model = Model() my_scripted_model = torch.jit.script(my_model) ``` * 该模块的`forward`是默认编译的。 从`forward`调用的方法将按照在`forward`中使用的顺序进行延迟编译。 * 要编译未从`forward`调用的`forward`以外的方法,请添加`@torch.jit.export`。 * 要停止编译器编译方法,请添加 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore") 或 [`@torch.jit.unused`](#torch.jit.unused "torch.jit.unused") 。 `@ignore`离开 * 方法作为对 python 的调用,并且`@unused`将其替换为异常。 `@ignored`无法导出; `@unused`可以。 * 可以推断大多数属性类型,因此不需要`torch.jit.Attribute`。 对于空容器类型,请使用 [PEP 526 样式](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations)类注释对其类型进行注释。 * 可以使用`Final`类注释来标记常量,而不是将成员的名称添加到`__constants__`中。 * 可以使用 Python 3 类型提示代替`torch.jit.annotate` ``` As a result of these changes, the following items are considered deprecated and should not appear in new code: ``` * `@torch.jit.script_method`装饰器 * 继承自`torch.jit.ScriptModule`的类 * `torch.jit.Attribute`包装器类 * `__constants__`数组 * `torch.jit.annotate`功能 ### [模块](#id13) Warning [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore") 注释的行为在 PyTorch 1.2 中发生了变化。 在 PyTorch 1.2 之前,@ ignore 装饰器用于使函数或方法可从导出的代码中调用。 要恢复此功能,请使用`@torch.jit.unused()`。 `@torch.jit.ignore`现在等同于`@torch.jit.ignore(drop=False)`。 有关详细信息,请参见 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore") 和 [`@torch.jit.unused`](#torch.jit.unused "torch.jit.unused") 。 当传递给 [`torch.jit.script`](#torch.jit.script "torch.jit.script") 函数时,`torch.nn.Module`的数据将复制到 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") ,然后 TorchScript 编译器将编译该模块。 该模块的`forward`默认为编译状态。 从`forward`调用的方法以及它们在`forward`中使用的顺序都是按延迟顺序编译的。 * * * ``` torch.jit.export(fn) ``` 此修饰符指示`nn.Module`上的方法用作 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 的入口点,应进行编译。 `forward`隐式地假定为入口点,因此不需要此装饰器。 从`forward`调用的函数和方法在编译器看到的情况下进行编译,因此它们也不需要此装饰器。 示例(在方法上使用`@torch.jit.export`): ``` import torch import torch.nn as nn class MyModule(nn.Module): def implicitly_compiled_method(self, x): return x + 99 # `forward` is implicitly decorated with `@torch.jit.export`, # so adding it here would have no effect def forward(self, x): return x + 10 @torch.jit.export def another_forward(self, x): # When the compiler sees this call, it will compile # `implicitly_compiled_method` return self.implicitly_compiled_method(x) def unused_method(self, x): return x - 20 # `m` will contain compiled methods: # `forward` # `another_forward` # `implicitly_compiled_method` # `unused_method` will not be compiled since it was not called from # any compiled methods and wasn't decorated with `@torch.jit.export` m = torch.jit.script(MyModule()) ``` ### [功能](#id14) 功能没有太大变化,可以根据需要用 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore") 或 [`torch.jit.unused`](#torch.jit.unused "torch.jit.unused") 装饰。 ``` # Same behavior as pre-PyTorch 1.2 @torch.jit.script def some_fn(): return 2 # Marks a function as ignored, if nothing # ever calls it then this has no effect @torch.jit.ignore def some_fn2(): return 2 # As with ignore, if nothing calls it then it has no effect. # If it is called in script it is replaced with an exception. @torch.jit.unused def some_fn3(): import pdb; pdb.set_trace() return 4 # Doesn't do anything, this function is already # the main entry point @torch.jit.export def some_fn4(): return 2 ``` ### [TorchScript 类](#id15) 默认情况下,将导出用户定义的 [TorchScript 类](#torchscript-class)中的所有内容,可以根据需要用 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore") 修饰功能。 ### [属性](#id16) TorchScript 编译器需要知道[模块属性](#module-attributes)的类型。 大多数类型可以从成员的值推断出来。 空列表和字典不能推断其类型,而必须使用 [PEP 526 样式](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations)类注释来注释其类型。 如果无法推断类型并且未对显式类型进行注释,则不会将其作为属性添加到结果 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 旧 API: ``` from typing import Dict import torch class MyModule(torch.jit.ScriptModule): def __init__(self): super(MyModule, self).__init__() self.my_dict = torch.jit.Attribute({}, Dict[str, int]) self.my_int = torch.jit.Attribute(20, int) m = MyModule() ``` 新 API: ``` from typing import Dict class MyModule(torch.nn.Module): my_dict: Dict[str, int] def __init__(self): super(MyModule, self).__init__() # This type cannot be inferred and must be specified self.my_dict = {} # The attribute type here is inferred to be `int` self.my_int = 20 def forward(self): pass m = torch.jit.script(MyModule()) ``` #### [Python 2](#id17) 如果您受制于 Python 2 并且无法使用类注释语法,则可以使用`__annotations__`类成员直接应用类型注释。 ``` from typing import Dict class MyModule(torch.jit.ScriptModule): __annotations__ = {'my_dict': Dict[str, int]} def __init__(self): super(MyModule, self).__init__() self.my_dict = {} self.my_int = 20 ``` ### [常数](#id18) `Final`类型的构造函数可用于将成员标记为[常量](#constant)。 如果成员未标记为常量,则将其复制为结果 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 作为属性。 如果已知该值是固定的,则使用`Final`可以进行优化,并提供附加的类型安全性。 Old API: ``` class MyModule(torch.jit.ScriptModule): __constants__ = ['my_constant'] def __init__(self): super(MyModule, self).__init__() self.my_constant = 2 def forward(self): pass m = MyModule() ``` New API: ``` try: from typing_extensions import Final except: # If you don't have `typing_extensions` installed, you can use a # polyfill from `torch.jit`. from torch.jit import Final class MyModule(torch.nn.Module): my_constant: Final[int] def __init__(self): super(MyModule, self).__init__() self.my_constant = 2 def forward(self): pass m = torch.jit.script(MyModule()) ``` ### [变量](#id19) 假定容器的类型为`Tensor`,并且是非可选的(有关更多信息,请参见[默认类型](#default-types))。 以前,`torch.jit.annotate`用来告诉 TorchScript 编译器类型是什么。 现在支持 Python 3 样式类型提示。 ``` import torch from typing import Dict, Optional @torch.jit.script def make_dict(flag: bool): x: Dict[str, int] = {} x['hi'] = 2 b: Optional[int] = None if flag: b = 2 return x, b ``` ## [TorchScript 语言参考](#id20) TorchScript 是 Python 的静态类型子集,可以直接编写(使用 [`@torch.jit.script`](#torch.jit.script "torch.jit.script") 装饰器),也可以通过跟踪从 Python 代码自动生成。 使用跟踪时,通过仅在张量上记录实际的运算符并简单地执行和丢弃其他周围的 Python 代码,代码会自动转换为 Python 的此子集。 使用`@torch.jit.script`装饰器直接编写 TorchScript 时,程序员只能使用 TorchScript 支持的 Python 子集。 本节记录了 TorchScript 支持的功能,就像它是独立语言的语言参考一样。 本参考中未提及的 Python 的任何功能都不属于 TorchScript。 有关可用的 Pytorch 张量方法,模块和功能的完整参考,请参见[内置函数](#builtin-functions)。 作为 Python 的子集,任何有效的 TorchScript 函数也是有效的 Python 函数。 这样就可以[禁用 TorchScript](#disable-torchscript) 并使用`pdb`之类的标准 Python 工具调试该功能。 反之则不成立:有许多有效的 Python 程序不是有效的 TorchScript 程序。 相反,TorchScript 特别专注于表示 PyTorch 中的神经网络模型所需的 Python 功能。 ### [类型](#id21) TorchScript 与完整的 Python 语言之间的最大区别是 TorchScript 仅支持表达神经网络模型所需的一小部分类型。 特别是,TorchScript 支持: | 类型 | 描述 | | --- | --- | | `Tensor` | 任何 dtype,尺寸或后端的 PyTorch 张量 | | `Tuple[T0, T1, ...]` | 包含子类型`T0`,`T1`等(例如`Tuple[Tensor, Tensor]`)的元组 | | `bool` | 布尔值 | | `int` | 标量整数 | | `float` | 标量浮点数 | | `str` | 一串 | | `List[T]` | 所有成员均为`T`类型的列表 | | `Optional[T]` | 无或输入`T`的值 | | `Dict[K, V]` | 键类型为`K`而值类型为`V`的字典。 只能将`str`,`int`和`float`作为密​​钥类型。 | | `T` | 一个 [TorchScript 类](#torchscript-class) | | `NamedTuple[T0, T1, ...]` | `collections.namedtuple`元组类型 | 与 Python 不同,TorchScript 函数中的每个变量都必须具有一个静态类型。 这使优化 TorchScript 函数变得更加容易。 示例(类型不匹配) ``` import torch @torch.jit.script def an_error(x): if x: r = torch.rand(1) else: r = 4 return r ``` ``` Traceback (most recent call last): ... RuntimeError: ... Type mismatch: r is set to type Tensor in the true branch and type int in the false branch: @torch.jit.script def an_error(x): if x: ~~~~~... <--- HERE r = torch.rand(1) else: and was used here: else: r = 4 return r ~ <--- HERE ... ``` #### [默认类型](#id22) 默认情况下,TorchScript 函数的所有参数均假定为 Tensor。 要指定 TorchScript 函数的参数是其他类型,可以使用上面列出的类型使用 MyPy 样式的类型注释。 ``` import torch @torch.jit.script def foo(x, tup): # type: (int, Tuple[Tensor, Tensor]) -> Tensor t0, t1 = tup return t0 + t1 + x print(foo(3, (torch.rand(3), torch.rand(3)))) ``` 注意 也可以使用`typing`模块中的 Python 3 类型提示来注释类型。 ``` import torch from typing import Tuple @torch.jit.script def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: t0, t1 = tup return t0 + t1 + x print(foo(3, (torch.rand(3), torch.rand(3)))) ``` 在我们的示例中,我们使用基于注释的类型提示来确保 Python 2 的兼容性。 假定空列表为`List[Tensor]`,空字典为`Dict[str, Tensor]`。 要实例化其他类型的空列表或字典,请使用 [Python 3 类型提示](#python-3-type-hints)。 如果您使用的是 Python 2,则可以使用`torch.jit.annotate`。 示例(Python 3 的类型注释): ``` import torch import torch.nn as nn from typing import Dict, List, Tuple class EmptyDataStructures(torch.nn.Module): def __init__(self): super(EmptyDataStructures, self).__init__() def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]: # This annotates the list to be a `List[Tuple[int, float]]` my_list: List[Tuple[int, float]] = [] for i in range(10): my_list.append((i, x.item())) my_dict: Dict[str, int] = {} return my_list, my_dict x = torch.jit.script(EmptyDataStructures()) ``` 示例(适用于 Python 2 的`torch.jit.annotate`): ``` import torch import torch.nn as nn from typing import Dict, List, Tuple class EmptyDataStructures(torch.nn.Module): def __init__(self): super(EmptyDataStructures, self).__init__() def forward(self, x): # type: (Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]] # This annotates the list to be a `List[Tuple[int, float]]` my_list = torch.jit.annotate(List[Tuple[int, float]], []) for i in range(10): my_list.append((i, float(x.item()))) my_dict = torch.jit.annotate(Dict[str, int], {}) return my_list, my_dict x = torch.jit.script(EmptyDataStructures()) ``` #### [可选类型细化](#id23) 在 if 语句的条件内或在`assert`中检查与`None`的比较时,TorchScript 将优化`Optional[T]`类型的变量的类型。 编译器可以推理与`and`,`or`和`not`结合的多个`None`检查。 对于未明确编写的 if 语句的 else 块,也会进行优化。 `None`检查必须在 if 语句的条件内; 将`None`检查分配给变量,并在 if 语句的条件下使用它,将不会优化检查中的变量类型。 仅局部变量将被细化,`self.x`之类的属性将不会且必须分配给要细化的局部变量。 示例(优化参数和局部变量的类型): ``` import torch import torch.nn as nn from typing import Optional class M(nn.Module): z: Optional[int] def __init__(self, z): super(M, self).__init__() # If `z` is None, its type cannot be inferred, so it must # be specified (above) self.z = z def forward(self, x, y, z): # type: (Optional[int], Optional[int], Optional[int]) -> int if x is None: x = 1 x = x + 1 # Refinement for an attribute by assigning it to a local z = self.z if y is not None and z is not None: x = y + z # Refinement via an `assert` assert z is not None x += z return x module = torch.jit.script(M(2)) module = torch.jit.script(M(None)) ``` #### [TorchScript 类](#id24) 如果 Python 类使用 [`@torch.jit.script`](#torch.jit.script "torch.jit.script") 注释,则可以在 TorchScript 中使用,类似于声明 TorchScript 函数的方式: ``` @torch.jit.script class Foo: def __init__(self, x, y): self.x = x def aug_add_x(self, inc): self.x += inc ``` 此子集受限制: * 所有函数必须是有效的 TorchScript 函数(包括`__init__()`)。 * 这些类必须是新型类,因为我们使用`__new__()`和 pybind11 来构造它们。 * TorchScript 类是静态类型的。 只能通过在`__init__()`方法中分配给 self 来声明成员。 > 例如,在`__init__()`方法之外分配给`self`: > > ``` > @torch.jit.script > class Foo: > def assign_x(self): > self.x = torch.rand(2, 3) > > ``` > > 将导致: > > ``` > RuntimeError: > Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: > def assign_x(self): > self.x = torch.rand(2, 3) > ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE > > ``` * 类的主体中不允许使用除方法定义之外的任何表达式。 * 除了从`object`继承以指定新样式类外,不支持继承或任何其他多态策略。 定义了一个类之后,就可以像其他任何 TorchScript 类型一样在 TorchScript 和 Python 中互换使用该类: ``` # Declare a TorchScript class @torch.jit.script class Pair: def __init__(self, first, second): self.first = first self.second = second @torch.jit.script def sum_pair(p): # type: (Pair) -> Tensor return p.first + p.second p = Pair(torch.rand(2, 3), torch.rand(2, 3)) print(sum_pair(p)) ``` #### [命名为元组](#id25) `collections.namedtuple`产生的类型可以在 TorchScript 中使用。 ``` import torch import collections Point = collections.namedtuple('Point', ['x', 'y']) @torch.jit.script def total(point): # type: (Point) -> Tensor return point.x + point.y p = Point(x=torch.rand(3), y=torch.rand(3)) print(total(p)) ``` ### [表达式](#id26) 支持以下 Python 表达式。 #### [文字](#id27) ``` True False None 'string literals' "string literals" 3 # interpreted as int 3.4 # interpreted as a float ``` ##### [列表结构](#id28) 假定一个空列表具有`List[Tensor]`类型。 其他列表文字的类型是从成员的类型派生的。 有关更多详细信息,请参见[默认类型](#default-types)。 ``` [3, 4] [] [torch.rand(3), torch.rand(4)] ``` ##### [元组结构](#id29) ``` (3, 4) (3,) ``` ##### [字典结构](#id30) 假定一个空字典为`Dict[str, Tensor]`类型。 其他 dict 文字的类型是从成员的类型派生的。 有关更多详细信息,请参见[默认类型](#default-types)。 ``` {'hello': 3} {} {'a': torch.rand(3), 'b': torch.rand(4)} ``` #### [变量](#id31) 有关如何解析变量的信息,请参见[变量分辨率](#variable-resolution)。 ``` my_variable_name ``` #### [算术运算符](#id32) ``` a + b a - b a * b a / b a ^ b a @ b ``` #### [比较运算符](#id33) ``` a == b a != b a < b a > b a <= b a >= b ``` #### [逻辑运算符](#id34) ``` a and b a or b not b ``` #### [下标和切片](#id35) ``` t[0] t[-1] t[0:2] t[1:] t[:1] t[:] t[0, 1] t[0, 1:2] t[0, :1] t[-1, 1:, 0] t[1:, -1, 0] t[i:j, i] ``` #### [函数调用](#id36) 调用[内置函数](#builtin-functions) ``` torch.rand(3, dtype=torch.int) ``` 调用其他脚本函数: ``` import torch @torch.jit.script def foo(x): return x + 1 @torch.jit.script def bar(x): return foo(x) ``` #### [方法调用](#id37) 调用诸如张量之类的内置类型的方法:`x.mm(y)` 在模块上,必须先编译方法才能调用它们。 TorchScript 编译器以递归方式编译在编译其他方法时看到的方法。 默认情况下,编译从`forward`方法开始。 将编译`forward`调用的任何方法,以及这些方法调用的任何方法,依此类推。 要以`forward`以外的方法开始编译,请使用 [`@torch.jit.export`](#torch.jit.export "torch.jit.export") 装饰器(`forward`隐式标记为`@torch.jit.export`)。 直接调用子模块(例如`self.resnet(input)`)等效于调用其`forward`方法(例如`self.resnet.forward(input)`)。 ``` import torch import torch.nn as nn import torchvision class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() means = torch.tensor([103.939, 116.779, 123.68]) self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1)) resnet = torchvision.models.resnet18() self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224)) def helper(self, input): return self.resnet(input - self.means) def forward(self, input): return self.helper(input) # Since nothing in the model calls `top_level_method`, the compiler # must be explicitly told to compile this method @torch.jit.export def top_level_method(self, input): return self.other_helper(input) def other_helper(self, input): return input + 10 # `my_script_module` will have the compiled methods `forward`, `helper`, # `top_level_method`, and `other_helper` my_script_module = torch.jit.script(MyModule()) ``` #### [三元表达式](#id38) ``` x if x > y else y ``` #### [演员表](#id39) ``` float(ten) int(3.5) bool(ten) str(2)`` ``` #### [访问模块参数](#id40) ``` self.my_parameter self.my_submodule.my_parameter ``` ### [语句](#id41) TorchScript 支持以下类型的语句: #### [简单分配](#id42) ``` a = b a += b # short-hand for a = a + b, does not operate in-place on a a -= b ``` #### [模式匹配分配](#id43) ``` a, b = tuple_or_list a, b, *c = a_tuple ``` 多项分配 ``` a = b, c = tup ``` #### [打印报表](#id44) ``` print("the result of an add:", a + b) ``` #### [If 语句](#id45) ``` if a < 4: r = -a elif a < 3: r = a + a else: r = 3 * a ``` 除布尔值外,浮点数,整数和张量还可以在条件中使用,并将隐式转换为布尔值。 #### [While 循环](#id46) ``` a = 0 while a < 4: print(a) a += 1 ``` #### [适用于范围为](#id47)的循环 ``` x = 0 for i in range(10): x *= i ``` #### [用于遍历元组的循环](#id48) 这些展开循环,为元组的每个成员生成一个主体。 主体必须对每个成员进行正确的类型检查。 ``` tup = (3, torch.rand(4)) for x in tup: print(x) ``` #### [用于在常量 nn.ModuleList](#id49) 上循环 要在已编译方法中使用`nn.ModuleList`,必须通过将属性名称添加到`__constants__`列表中的类型来将其标记为常量。 `nn.ModuleList`上的 for 循环将在编译时展开循环的主体,并使用常量模块列表的每个成员。 ``` class SubModule(torch.nn.Module): def __init__(self): super(SubModule, self).__init__() self.weight = nn.Parameter(torch.randn(2)) def forward(self, input): return self.weight + input class MyModule(torch.nn.Module): __constants__ = ['mods'] def __init__(self): super(MyModule, self).__init__() self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) def forward(self, v): for module in self.mods: v = module(v) return v m = torch.jit.script(MyModule()) ``` #### [中断并继续](#id50) ``` for i in range(5): if i == 1: continue if i == 3: break print(i) ``` #### [返回](#id51) ``` return a, b ``` ### [可变分辨率](#id52) TorchScript 支持 Python 的可变分辨率(即作用域)规则的子集。 局部变量的行为与 Python 中的相同,不同之处在于,在通过函数的所有路径上,变量必须具有相同的类型。 如果变量在 if 语句的不同分支上具有不同的类型,则在 if 语句结束后使用它是错误的。 同样,如果沿函数的某些路径仅将_定义为_,则不允许使用该变量。 Example: ``` @torch.jit.script def foo(x): if x < 0: y = 4 print(y) ``` ``` Traceback (most recent call last): ... RuntimeError: ... y is not defined in the false branch... @torch.jit.script... def foo(x): if x < 0: ~~~~~~~~~... <--- HERE y = 4 print(y) ... ``` 定义函数时,会在编译时将非局部变量解析为 Python 值。 然后使用 [Python 值使用](#use-of-python-values)中描述的规则将这些值转换为 TorchScript 值。 ### [使用 Python 值](#id53) 为了使编写 TorchScript 更加方便,我们允许脚本代码引用周围范围中的 Python 值。 例如,任何时候只要引用`torch`,当声明函数时,TorchScript 编译器实际上就会将其解析为`torch` Python 模块。 这些 Python 值不是 TorchScript 的一流部分。 而是在编译时将它们分解为 TorchScript 支持的原始类型。 这取决于编译发生时引用的 Python 值的动态类型。 本节介绍在 TorchScript 中访问 Python 值时使用的规则。 #### [功能](#id54) TorchScript 可以调用 Python 函数。 当将模型逐步转换为 TorchScript 时,此功能非常有用。 可以将模型逐函数移至 TorchScript,而对 Python 函数的调用保留在原处。 这样,您可以在进行过程中逐步检查模型的正确性。 * * * ``` torch.jit.ignore(drop=False, **kwargs) ``` 该装饰器向编译器指示应忽略函数或方法,而将其保留为 Python 函数。 这使您可以将代码保留在尚未与 TorchScript 兼容的模型中。 具有忽略功能的模型无法导出; 请改用 torch.jit.unused。 示例(在方法上使用`@torch.jit.ignore`): ``` import torch import torch.nn as nn class MyModule(nn.Module): @torch.jit.ignore def debugger(self, x): import pdb pdb.set_trace() def forward(self, x): x += 10 # The compiler would normally try to compile `debugger`, # but since it is `@ignore`d, it will be left as a call # to Python self.debugger(x) return x m = torch.jit.script(MyModule()) # Error! The call `debugger` cannot be saved since it calls into Python m.save("m.pt") ``` 示例(在方法上使用`@torch.jit.ignore(drop=True)`): ``` import torch import torch.nn as nn class MyModule(nn.Module): @torch.jit.ignore(drop=True) def training_method(self, x): import pdb pdb.set_trace() def forward(self, x): if self.training: self.training_method(x) return x m = torch.jit.script(MyModule()) # This is OK since `training_method` is not saved, the call is replaced # with a `raise`. m.save("m.pt") ``` * * * ``` torch.jit.unused(fn) ``` 此装饰器向编译器指示应忽略函数或方法,并用引发异常的方法代替。 这样,您就可以在尚不兼容 TorchScript 的模型中保留代码,并仍然可以导出模型。 > 示例(在方法上使用`@torch.jit.unused`): > > ``` > import torch > import torch.nn as nn > > class MyModule(nn.Module): > def __init__(self, use_memory_efficent): > super(MyModule, self).__init__() > self.use_memory_efficent = use_memory_efficent > > @torch.jit.unused > def memory_efficient(self, x): > import pdb > pdb.set_trace() > return x + 10 > > def forward(self, x): > # Use not-yet-scriptable memory efficient mode > if self.use_memory_efficient: > return self.memory_efficient(x) > else: > return x + 10 > > m = torch.jit.script(MyModule(use_memory_efficent=False)) > m.save("m.pt") > > m = torch.jit.script(MyModule(use_memory_efficient=True)) > # exception raised > m(torch.rand(100)) > > ``` * * * ``` torch.jit.is_scripting() ``` 在编译时返回 True 的函数,否则返回 False 的函数。 这对于使用@unused 装饰器尤其有用,可以将尚不兼容 TorchScript 的代码保留在模型中。 .. testcode: ``` import torch @torch.jit.unused def unsupported_linear_op(x): return x def linear(x): if not torch.jit.is_scripting(): return torch.linear(x) else: return unsupported_linear_op(x) ``` #### [Python 模块上的属性查找](#id55) TorchScript 可以在模块上查找属性。 [像`torch.add`这样的内置功能](#builtin-functions)可以通过这种方式访问​​。 这使 TorchScript 可以调用其他模块中定义的函数。 #### [Python 定义的常量](#id56) TorchScript 还提供了一种使用 Python 中定义的常量的方法。 这些可用于将超参数硬编码到函数中,或定义通用常量。 有两种指定 Python 值应视为常量的方式。 1. 查找为模块属性的值假定为常量: ``` import math import torch @torch.jit.script def fn(): return math.pi ``` 1. 可以通过使用`Final[T]`注释 ScriptModule 的属性来将其标记为常量。 ``` import torch import torch.nn as nn class Foo(nn.Module): # `Final` from the `typing_extensions` module can also be used a : torch.jit.Final[int] def __init__(self): super(Foo, self).__init__() self.a = 1 + 4 def forward(self, input): return self.a + input f = torch.jit.script(Foo()) ``` 支持的常量 Python 类型是 * `int` * `float` * `bool` * `torch.device` * `torch.layout` * `torch.dtype` * 包含受支持类型的元组 * `torch.nn.ModuleList`可以在 TorchScript for 循环中使用 Note 如果您使用的是 Python 2,则可以通过将属性名称添加到类的`__constants__`属性中来将其标记为常量: ``` import torch import torch.nn as nn class Foo(nn.Module): __constants__ = ['a'] def __init__(self): super(Foo, self).__init__() self.a = 1 + 4 def forward(self, input): return self.a + input f = torch.jit.script(Foo()) ``` #### [模块属性](#id57) `torch.nn.Parameter`包装器和`register_buffer`可用于将张量分配给模块。 如果可以推断出其他类型的值,则分配给已编译模块的其他值将添加到已编译模块中。 TorchScript 中可用的所有[类型](#types)都可以用作模块属性。 张量属性在语义上与缓冲区相同。 空列表和字典的类型以及`None`值无法推断,必须通过 [PEP 526 样式](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations)类注释指定。 如果无法推断出类型并且未对其进行显式注释,则不会将其作为属性添加到结果 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 中。 Example: ``` from typing import List, Dict class Foo(nn.Module): # `words` is initialized as an empty list, so its type must be specified words: List[str] # The type could potentially be inferred if `a_dict` (below) was not # empty, but this annotation ensures `some_dict` will be made into the # proper type some_dict: Dict[str, int] def __init__(self, a_dict): super(Foo, self).__init__() self.words = [] self.some_dict = a_dict # `int`s can be inferred self.my_int = 10 def forward(self, input): # type: (str) -> int self.words.append(input) return self.some_dict[input] + self.my_int f = torch.jit.script(Foo({'hi': 2})) ``` Note 如果您使用的是 Python 2,则可以通过将属性的类型添加到`__annotations__`类属性中作为属性名字典来标记属性的类型 ``` from typing import List, Dict class Foo(nn.Module): __annotations__ = {'words': List[str], 'some_dict': Dict[str, int]} def __init__(self, a_dict): super(Foo, self).__init__() self.words = [] self.some_dict = a_dict # `int`s can be inferred self.my_int = 10 def forward(self, input): # type: (str) -> int self.words.append(input) return self.some_dict[input] + self.my_int f = torch.jit.script(Foo({'hi': 2})) ``` ### [调试](#id58) #### [禁用用于调试的 JIT](#id59) ``` PYTORCH_JIT ``` 设置环境变量`PYTORCH_JIT=0`将禁用所有脚本和跟踪注释。 如果您的 TorchScript 模型之一存在难以调试的错误,则可以使用此标志来强制一切都使用本机 Python 运行。 由于此标志禁用了 TorchScript(脚本编写和跟踪),因此可以使用`pdb`之类的工具来调试模型代码。 给定一个示例脚本: ``` @torch.jit.script def scripted_fn(x : torch.Tensor): for i in range(12): x = x + x return x def fn(x): x = torch.neg(x) import pdb; pdb.set_trace() return scripted_fn(x) traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),)) traced_fn(torch.rand(3, 4)) ``` 除调用[,`@torch.jit.script`,](#torch.jit.script "torch.jit.script")函数外,使用`pdb`调试此脚本是可行的。 我们可以全局禁用 JIT,以便我们可以将 [`@torch.jit.script`](#torch.jit.script "torch.jit.script") 函数作为普通的 Python 函数调用,而不进行编译。 如果上述脚本称为`disable_jit_example.py`,我们可以这样调用它: ``` $ PYTORCH_JIT=0 python disable_jit_example.py ``` 并且我们将能够像普通的 Python 函数一样进入 [`@torch.jit.script`](#torch.jit.script "torch.jit.script") 函数。 要为特定功能禁用 TorchScript 编译器,请参见 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore") 。 #### [检查码](#id60) TorchScript 为所有 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 实例提供了代码漂亮的打印机。 这个漂亮的打印机可以将脚本方法的代码解释为有效的 Python 语法。 例如: ``` @torch.jit.script def foo(len): # type: (int) -> torch.Tensor rv = torch.zeros(3, 4) for i in range(len): if i < 10: rv = rv - 1.0 else: rv = rv + 1.0 return rv print(foo.code) ``` 具有单个`forward`方法的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 将具有属性`code`,您可以使用该属性检查 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 的代码。 如果 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 具有多个方法,则需要在方法本身而非模块上访问`.code`。 我们可以通过访问`.foo.code`在 ScriptModule 上检查名为`foo`的方法的代码。 上面的示例产生以下输出: ``` def foo(len: int) -> Tensor: rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None) rv0 = rv for i in range(len): if torch.lt(i, 10): rv1 = torch.sub(rv0, 1., 1) else: rv1 = torch.add(rv0, 1., 1) rv0 = rv1 return rv0 ``` 这是 TorchScript 对`forward`方法的代码的编译。 您可以使用它来确保 TorchScript(跟踪或脚本)正确捕获了模型代码。 #### [解释图](#id61) TorchScript 还以 IR 图的形式在比代码漂亮打印机更低的层次上进行表示。 TorchScript 使用静态单分配(SSA)中间表示(IR)表示计算。 这种格式的指令由 ATen(PyTorch 的 C ++后端)运算符和其他原始运算符组成,包括用于循环和条件的控制流运算符。 举个例子: ``` @torch.jit.script def foo(len): # type: (int) -> torch.Tensor rv = torch.zeros(3, 4) for i in range(len): if i < 10: rv = rv - 1.0 else: rv = rv + 1.0 return rv print(foo.graph) ``` `graph`遵循[检查代码](#inspecting-code)部分中关于`forward`方法查找所述的相同规则。 上面的示例脚本生成图形: ``` graph(%len.1 : int): %24 : int = prim::Constant[value=1]() %17 : bool = prim::Constant[value=1]() # test.py:10:5 %12 : bool? = prim::Constant() %10 : Device? = prim::Constant() %6 : int? = prim::Constant() %1 : int = prim::Constant[value=3]() # test.py:9:22 %2 : int = prim::Constant[value=4]() # test.py:9:25 %20 : int = prim::Constant[value=10]() # test.py:11:16 %23 : float = prim::Constant[value=1]() # test.py:12:23 %4 : int[] = prim::ListConstruct(%1, %2) %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5 block0(%i.1 : int, %rv.14 : Tensor): %21 : bool = aten::lt(%i.1, %20) # test.py:11:12 %rv.13 : Tensor = prim::If(%21) # test.py:11:9 block0(): %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18 -> (%rv.3) block1(): %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18 -> (%rv.6) -> (%17, %rv.13) return (%rv) ``` 以指令`%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10`为例。 * `%rv.1 : Tensor`表示我们将输出分配给一个名为`rv.1`的(唯一)值,该值是`Tensor`类型,并且我们不知道其具体形状。 * `aten::zeros`是运算符(与`torch.zeros`等效),输入列表`(%4, %6, %6, %10, %12)`指定范围中的哪些值应作为输入传递。 可以在[内置函数](#builtin-functions)中找到`aten::zeros`等内置函数的模式。 * `# test.py:9:10`是生成此指令的原始源文件中的位置。 在这种情况下,它是第 9 行和字符 10 处名为 <cite>test.py</cite> 的文件。 请注意,运算符也可以具有关联的`blocks`,即`prim::Loop`和`prim::If`运算符。 在图形打印输出中,这些运算符被格式化以反映其等效的源代码形式,以方便进行调试。 如下图所示,可以检查图表以确认 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 所描述的计算是正确的,无论是自动方式还是手动方式。 #### [追踪案例](#id62) 在某些极端情况下,给定 Python 函数/模块的跟踪不会代表基础代码。 这些情况可以包括: * 跟踪取决于输入的控制流(例如张量形状) * 跟踪张量视图的就地操作(例如,分配左侧的索引) 请注意,这些情况实际上将来可能是可追溯的。 #### [自动跟踪检查](#id63) 自动捕获跟踪中许多错误的一种方法是使用`torch.jit.trace()` API 上的`check_inputs`。 `check_inputs`提取输入元组的列表,这些列表将用于重新追踪计算并验证结果。 例如: ``` def loop_in_traced_fn(x): result = x[0] for i in range(x.size(0)): result = result * x[i] return result inputs = (torch.rand(3, 4, 5),) check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)] traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs) ``` 为我们提供以下诊断信息: ``` ERROR: Graphs differed across invocations! Graph diff: graph(%x : Tensor) { %1 : int = prim::Constant[value=0]() %2 : int = prim::Constant[value=0]() %result.1 : Tensor = aten::select(%x, %1, %2) %4 : int = prim::Constant[value=0]() %5 : int = prim::Constant[value=0]() %6 : Tensor = aten::select(%x, %4, %5) %result.2 : Tensor = aten::mul(%result.1, %6) %8 : int = prim::Constant[value=0]() %9 : int = prim::Constant[value=1]() %10 : Tensor = aten::select(%x, %8, %9) - %result : Tensor = aten::mul(%result.2, %10) + %result.3 : Tensor = aten::mul(%result.2, %10) ? ++ %12 : int = prim::Constant[value=0]() %13 : int = prim::Constant[value=2]() %14 : Tensor = aten::select(%x, %12, %13) + %result : Tensor = aten::mul(%result.3, %14) + %16 : int = prim::Constant[value=0]() + %17 : int = prim::Constant[value=3]() + %18 : Tensor = aten::select(%x, %16, %17) - %15 : Tensor = aten::mul(%result, %14) ? ^ ^ + %19 : Tensor = aten::mul(%result, %18) ? ^ ^ - return (%15); ? ^ + return (%19); ? ^ } ``` 此消息向我们表明,在我们第一次追踪它和使用`check_inputs`追踪它之间,计算有所不同。 实际上,`loop_in_traced_fn`主体内的循环取决于输入`x`的形状,因此,当我们尝试另一种形状不同的`x`时,迹线会有所不同。 在这种情况下,可以使用 [`torch.jit.script()`](#torch.jit.script "torch.jit.script") 来捕获类似于数据的控制流: ``` def fn(x): result = x[0] for i in range(x.size(0)): result = result * x[i] return result inputs = (torch.rand(3, 4, 5),) check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)] scripted_fn = torch.jit.script(fn) print(scripted_fn.graph) #print(str(scripted_fn.graph).strip()) for input_tuple in [inputs] + check_inputs: torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple)) ``` 产生: ``` graph(%x : Tensor) { %5 : bool = prim::Constant[value=1]() %1 : int = prim::Constant[value=0]() %result.1 : Tensor = aten::select(%x, %1, %1) %4 : int = aten::size(%x, %1) %result : Tensor = prim::Loop(%4, %5, %result.1) block0(%i : int, %7 : Tensor) { %10 : Tensor = aten::select(%x, %1, %i) %result.2 : Tensor = aten::mul(%7, %10) -> (%5, %result.2) } return (%result); } ``` #### [跟踪器警告](#id64) 跟踪器会针对跟踪计算中的几种有问题的模式生成警告。 举个例子,追踪一个在 Tensor 的切片(视图)上包含就地分配的函数: ``` def fill_row_zero(x): x[0] = torch.rand(*x.shape[1:2]) return x traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) print(traced.graph) ``` 产生几个警告和一个仅返回输入的图形: ``` fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe. x[0] = torch.rand(*x.shape[1:2]) fill_row_zero.py:6: TracerWarning: Output nr 1\. of the traced function does not match the corresponding output of the Python function. Detailed error: Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%) traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) graph(%0 : Float(3, 4)) { return (%0); } ``` 我们可以通过修改代码来解决此问题,使其不使用就地更新,而是使用`torch.cat`来错位构建结果张量: ``` def fill_row_zero(x): x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0) return x traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) print(traced.graph) ``` ### [内置函数](#id65) TorchScript 支持 PyTorch 提供的内置张量和神经网络功能的子集。 Tensor 上的大多数方法以及`torch`名称空间中的函数,`torch.nn.functional`中的所有函数以及`torch.nn`中的所有模块在 TorchScript 中均受支持,下表中没有列出。 对于不支持的模块,建议使用 [`torch.jit.trace()`](#torch.jit.trace "torch.jit.trace") 。 不支持的`torch.nn`模块 ``` torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss torch.nn.modules.normalization.CrossMapLRN2d torch.nn.modules.rnn.RNN ``` 有关支持的功能的完整参考,请参见 [TorchScript 内置函数](jit_builtin_functions.html#builtin-functions)。 ## [常见问题解答](#id66) 问:我想在 GPU 上训练模型并在 CPU 上进行推理。 最佳做法是什么? > 首先将模型从 GPU 转换为 CPU,然后将其保存,如下所示: > > ``` > cpu_model = gpu_model.cpu() > sample_input_cpu = sample_input_gpu.cpu() > traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu) > torch.jit.save(traced_cpu, "cpu.pth") > > traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu) > torch.jit.save(traced_gpu, "gpu.pth") > > # ... later, when using the model: > > if use_gpu: > model = torch.jit.load("gpu.pth") > else: > model = torch.jit.load("cpu.pth") > > model(input) > > ``` > > 推荐这样做是因为跟踪器可能会在特定设备上见证张量的创建,因此强制转换已加载的模型可能会产生意想不到的效果。 在保存之前对模型_进行转换可确保跟踪器具有正确的设备信息。_ 问:如何在 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 上存储属性? > 说我们有一个像这样的模型: > > ``` > class Model(nn.Module): > def __init__(self): > super(Model, self).__init__() > self.x = 2 > > def forward(self): > return self.x > > m = torch.jit.script(Model()) > > ``` > > 如果实例化`Model`,则将导致编译错误,因为编译器不了解`x`。 有四种方法可以通知编译器 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 的属性: > > 1\. `nn.Parameter`-包装在`nn.Parameter`中的值将像在`nn.Module`上一样工作 > > 2\. `register_buffer`-包装在`register_buffer`中的值将像在`nn.Module`上一样工作。 这等效于`Tensor`类型的属性(请参见 4)。 > > 3.常量-将类成员注释为`Final`(或在类定义级别将其添加到名为`__constants__`的列表中)会将包含的名称标记为常量。 常数直接保存在模型代码中。 有关详细信息,请参见 [Python 定义的常量](#python-defined-constants)。 > > 4.属性-可以将[支持的类型](#supported-type)的值添加为可变属性。 可以推断大多数类型,但可能需要指定一些类型,有关详细信息,请参见[模块属性](#module-attributes)。 问:我想跟踪模块的方法,但一直出现此错误: `RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient` > 此错误通常表示您要跟踪的方法使用模块的参数,并且您正在传递模块的方法而不是模块实例(例如`my_module_instance.forward`与`my_module_instance`)。 > > > * 使用模块的方法调用`trace`会将模块参数(可能需要渐变)捕获为**常量**。 > > > > > > * 另一方面,使用模块实例(例如`my_module`)调用`trace`会创建一个新模块,并将参数正确复制到新模块中,以便在需要时可以累积梯度。 > > 要跟踪模块上的特定方法,请参见 [`torch.jit.trace_module`](#torch.jit.trace_module "torch.jit.trace_module")