doc22_060.md 28.2 KB
Newer Older
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
1 2
# TorchScript

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
3 4 5 6 7 8
> 原文:[`pytorch.org/docs/stable/jit.html`](https://pytorch.org/docs/stable/jit.html)> 
>
> 译者:[飞龙](https://github.com/wizardforcel)
>
> 协议:[CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/)

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
9

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
10
+   TorchScript 语言参考
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
11

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
12
+   创建 TorchScript 代码
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
13

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
14
+   混合追踪和脚本化
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
15

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
16
+   TorchScript 语言
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
17

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
18
+   内置函数和模块
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
19

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
20
    +   PyTorch 函数和模块
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
21

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
22
    +   Python 函数和模块
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
23

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
24
    +   Python 语言参考比较
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
25

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
26
+   调试
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
27

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
28
    +   用于调试的禁用 JIT
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
29

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
30
    +   检查代码
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
31

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
32
    +   解释图表
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
33

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
34
    +   追踪器
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
35

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
36
+   常见问题解答
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
37

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
38
+   已知问题
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
39

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
40
+   附录
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
41

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
42
    +   迁移到 PyTorch 1.2 递归脚本 API
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
43

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
44
    +   融合后端
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
45

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
46
    +   参考资料
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
47

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
48
TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方式。任何 TorchScript 程序都可以从 Python 进程中保存并加载到没有 Python 依赖项的进程中。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
49

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
50
我们提供工具,逐步将模型从纯 Python 程序转换为一个可以独立于 Python 运行的 TorchScript 程序,比如在一个独立的 C++程序中。这使得可以使用 Python 中熟悉的工具在 PyTorch 中训练模型,然后通过 TorchScript 将模型导出到一个生产环境中,在这个环境中,由于性能和多线程原因,Python 程序可能不利。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
51

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
52
想要了解 TorchScript 的初学者,可以参考[介绍 TorchScript](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html)教程。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
53

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
54
想要了解如何将 PyTorch 模型转换为 TorchScript 并在 C++中运行的端到端示例,可以参考[在 C++中加载 PyTorch 模型](https://pytorch.org/tutorials/advanced/cpp_export.html)教程。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
55

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
56
## 创建 TorchScript 代码
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
57

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
58
| `script` | 对函数进行脚本化。 |
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
59
| --- | --- |
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
| `trace` | 追踪一个函数并返回一个可执行的或`ScriptFunction`,该函数将使用即时编译进行优化。 |
| `script_if_tracing` | 当在追踪期间首次调用`fn`时,编译`fn`。 |
| `trace_module` | 追踪一个模块并返回一个可执行的`ScriptModule`,该模块将使用即时编译进行优化。 |
| `fork` | 创建一个执行 func 的异步任务,并引用此执行结果的值。 |
| `wait` | 强制完成一个 torch.jit.Future[T]异步任务,返回任务的结果。 |
| `ScriptModule` | 用于 C++ torch::jit::Module 的包装器,具有方法、属性和参数。 |
| `ScriptFunction` | 与`ScriptModule`在功能上等效,但表示单个函数,不具有任何属性或参数。 |
| `freeze` | 冻结 ScriptModule,内联子模块,并将属性作为常量。 |
| `optimize_for_inference` | 执行一系列优化传递,以优化模型以用于推断目的。 |
| `enable_onednn_fusion` | 根据启用的参数启用或禁用 onednn JIT 融合。 |
| `onednn_fusion_enabled` | 返回 onednn JIT 融合是否已启用。 |
| `set_fusion_strategy` | 设置在融合过程中可以发生的特化类型和数量。 |
| `strict_fusion` | 如果推断中没有融合所有节点,或者在训练中没有符号微分,则会出错。 |
| `save` | 保存此模块的离线版本,以供在单独的进程中使用。 |
| `load` | 加载之前使用 `torch.jit.save` 保存的 `ScriptModule``ScriptFunction`。 |
| `ignore` | 此装饰器指示编译器应忽略一个函数或方法,并将其保留为 Python 函数。 |
| `unused` | 此装饰器指示编译器应忽略一个函数或方法,并用引发异常替换。 |
| `interface` | 用于注释不同类型的类或模块的装饰器。 |
| `isinstance` | 在 TorchScript 中提供容器类型细化。 |
| `Attribute` | 此方法是一个传递函数,返回值,主要用于指示 TorchScript 编译器左侧表达式是具有类型的类实例属性。 |
| `annotate` | 用于在 TorchScript 编译器中给出 the_value 的类型。 |

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
82
## 混合跟踪和脚本化
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
83 84

在许多情况下,跟踪或脚本化是将模型转换为 TorchScript 的更简单方法。跟踪和脚本化可以组合以满足模型的特定要求。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145

脚本函数可以调用跟踪函数。当您需要在一个简单的前馈模型周围使用控制流时,这是特别有用的。例如,序列到序列模型的波束搜索通常会以脚本形式编写,但可以调用使用跟踪生成的编码器模块。

示例(在脚本中调用跟踪函数):

```py
import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x) 
```

跟踪函数可以调用脚本函数。当模型的一小部分需要一些控制流时,即使大部分模型只是一个前馈网络时,这是有用的。由跟踪函数调用的脚本函数内的控制流会被正确保留。

示例(在跟踪函数中调用脚本函数):

```py
import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3))) 
```

这种组合也适用于`nn.Module`,可以使用跟踪生成一个子模块,可以从脚本模块的方法中调用。

示例(使用跟踪模块):

```py
import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule()) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
146
## TorchScript 语言
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
147

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
148
TorchScript 是 Python 的静态类型子集,因此许多 Python 功能直接适用于 TorchScript。有关详细信息,请参阅完整的 TorchScript 语言参考。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
149

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
150
## 内置函数和模块
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
151

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
152
TorchScript 支持大多数 PyTorch 函数和许多 Python 内置函数。查看 TorchScript 内置函数以获取支持函数的完整参考。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
153

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
154
### PyTorch 函数和模块
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
155

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
156
TorchScript 支持 PyTorch 提供的张量和神经网络函数的子集。Tensor 上的大多数方法以及`torch`命名空间中的函数,`torch.nn.functional`中的所有函数以及`torch.nn`中的大多数模块都受 TorchScript 支持。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
157

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
158
请查看 TorchScript 不支持的 PyTorch 构造以获取不支持的 PyTorch 函数和模块列表。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
159

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
160
### Python 函数和模块
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
161

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
162
TorchScript 支持 Python 的许多[内置函数](https://docs.python.org/3/library/functions.html)[`math`](https://docs.python.org/3/library/math.html#module-math "(在 Python v3.12 中)")模块也受支持(有关详细信息,请参阅 math 模块),但不支持其他 Python 模块(内置或第三方)。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
163

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
164
### Python 语言参考比较
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
165

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
166
要查看支持的 Python 功能的完整列表,请参阅 Python 语言参考覆盖范围。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
167

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
168
## 调试
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
169

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
170
### 调试时禁用 JIT
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
171 172 173

PYTORCH_JIT

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
174
设置环境变量`PYTORCH_JIT=0`将禁用所有脚本和跟踪注释。如果您的 TorchScript 模型中有难以调试的错误,您可以使用此标志强制所有内容都使用本机 Python 运行。由于使用此标志禁用了 TorchScript(脚本和跟踪),您可以使用`pdb`等工具来调试模型代码。例如:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191

```py
@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4)) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
192
使用`pdb`调试此脚本可以正常工作,除非我们调用`@torch.jit.script`函数。我们可以全局禁用 JIT,这样我们可以像普通 Python 函数一样调用`@torch.jit.script`函数而不进行编译。如果上述脚本被称为`disable_jit_example.py`,我们可以这样调用它:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
193 194 195 196 197

```py
$ PYTORCH_JIT=0 python disable_jit_example.py 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
198
我们将能够像普通 Python 函数一样进入`@torch.jit.script`函数。要禁用特定函数的 TorchScript 编译器,请参阅`@torch.jit.ignore`。### 检查代码
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
199

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
200
TorchScript 为所有`ScriptModule`实例提供了代码漂亮打印机。这个漂亮打印机将脚本方法的代码解释为有效的 Python 语法。例如:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216

```py
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
217
具有单个`forward`方法的`ScriptModule`将具有一个名为`code`的属性,您可以使用它来检查`ScriptModule`的代码。如果`ScriptModule`有多个方法,则需要在方法本身上访问`.code`,而不是模块。我们可以通过访问`.foo.code`来检查名为`foo`的方法的代码,该方法位于`ScriptModule`上。上面的示例产生了这个输出:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231

```py
def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
232
这是 TorchScript 对`forward`方法的代码编译。您可以使用此功能来确保 TorchScript(跟踪或脚本化)正确捕获了您的模型代码。### 解释图形
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
233

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
234
TorchScript 还具有比代码漂亮打印机更低级别的表示形式,即 IR 图形。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
235

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
236
TorchScript 使用静态单赋值(SSA)中间表示(IR)来表示计算。此格式中的指令包括 ATen(PyTorch 的 C++后端)运算符和其他原始运算符,包括用于循环和条件的控制流运算符。例如:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252

```py
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
253
`graph`遵循与检查代码部分中关于`forward`方法查找的相同规则。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287

上面的示例脚本产生了以下图形:

```py
graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv) 
```

以指令`%rv.1:Tensor = aten::zeros(%4,%6,%6,%10,%12)#test.py:9:10`为例。

+   `%rv.1:Tensor`表示我们将输出分配给名为`rv.1`的(唯一)值,该值是`Tensor`类型的,我们不知道其具体形状。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
288
+   `aten::zeros`是运算符(等同于`torch.zeros`),输入列表`(%4,%6,%6,%10,%12)`指定应将哪些作用域内的值作为输入传递。内置函数(如`aten::zeros`)的模式可以在 Builtin Functions 中找到。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
289

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
290
+   `#test.py:9:10`是生成此指令的原始源文件中的位置。在这种情况下,它是一个名为 test.py 的文件,在第 9 行,第 10 个字符处。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
291 292 293

请注意,运算符也可以有关联的`blocks`,即`prim::Loop``prim::If`运算符。在图形打印输出中,这些运算符的格式化方式反映了它们等效的源代码形式,以便进行简单的调试。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
294
可以像下面描述的那样检查图形,以确认由`ScriptModule`描述的计算是正确的,无论是自动还是手动方式。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
295

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
296
### 跟踪器
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
297 298 299

#### 跟踪特殊情况

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
300
存在一些特殊情况,其中给定 Python 函数/模块的跟踪可能不代表底层代码。这些情况可能包括:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
301 302 303 304 305 306 307 308 309

+   跟踪依赖于输入(例如张量形状)的控制流

+   跟踪张量视图的原位操作(例如在赋值的左侧进行索引)

请注意,这些情况实际上可能在将来是可跟踪的。

#### 自动跟踪检查

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
310
通过在`torch.jit.trace()` API 上使用`check_inputs`来自动捕获跟踪中的许多错误是一种方法。`check_inputs`接受一个输入元组列表,用于重新跟踪计算并验证结果。例如:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364

```py
def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs) 
```

给我们提供以下诊断信息:

```py
ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            } 
```

这条消息告诉我们,当我们首次跟踪它和使用`check_inputs`重新跟踪它时,计算之间存在差异。实际上,在`loop_in_traced_fn`的主体内的循环取决于输入`x`的形状,因此当我们尝试另一个具有不同形状的`x`时,跟踪会有所不同。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
365
在这种情况下,像这样的数据相关控制流可以使用`torch.jit.script()`来捕获:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439

```py
def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())

for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple)) 
```

这将产生:

```py
graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
} 
```

#### 跟踪器警告

跟踪器会为跟踪计算中的几种问题模式产生警告。例如,考虑包含对张量切片(视图)进行就地赋值的函数的跟踪:

```py
def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph) 
```

产生几个警告和一个简单返回输入的图形:

```py
fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1\. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
} 
```

我们可以通过修改代码以不使用就地更新,而是使用`torch.cat`在原地构建结果张量来解决这个问题:

```py
def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
440
## 常见问题
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
441

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
442
问:我想在 GPU 上训练模型,然后在 CPU 上进行推断。有什么最佳实践吗?
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
443

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
444
> 首先将您的模型从 GPU 转换为 CPU,然后保存,就像这样:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
> 
> ```py
> cpu_model = gpu_model.cpu()
> sample_input_cpu = sample_input_gpu.cpu()
> traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
> torch.jit.save(traced_cpu, "cpu.pt")
> 
> traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
> torch.jit.save(traced_gpu, "gpu.pt")
> 
> # ... later, when using the model:
> 
> if use_gpu:
>   model = torch.jit.load("gpu.pt")
> else:
>   model = torch.jit.load("cpu.pt")
> 
> model(input) 
> ```
> 
> 这是推荐的做法,因为跟踪器可能会在特定设备上看到张量的创建,因此在保存模型之前对模型进行转换可能会产生意外效果。在保存模型之前对模型进行转换可以确保跟踪器具有正确的设备信息。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
467
问:如何在`ScriptModule`上存储属性?
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484

> 假设我们有一个模型:
> 
> ```py
> import torch
> 
> class Model(torch.nn.Module):
>     def __init__(self):
>         super().__init__()
>         self.x = 2
> 
>     def forward(self):
>         return self.x
> 
> m = torch.jit.script(Model()) 
> ```
> 
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
485
> 如果实例化`Model`,将导致编译错误,因为编译器不知道`x`。有 4 种方法可以通知编译器`ScriptModule`上的属性:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
486 487 488
> 
> 1. `nn.Parameter` - 包装在`nn.Parameter`中的值将像在`nn.Module`上一样工作
> 
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
489
> 2. `register_buffer` - 包装在`register_buffer`中的值将像在`nn.Module`上一样工作。这相当于类型为`Tensor`的属性(见 4)。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
> 
> 3. 常量 - 将类成员标记为`Final`(或将其添加到类定义级别的名为`__constants__`的列表中)将会将包含的名称标记为常量。常量直接保存在模型的代码中。详细信息请参见内置常量。
> 
> 4. 属性 - 可以将支持的类型添加为可变属性的值。大多数类型可以推断,但有些可能需要指定,详细信息请参见模块属性。

问:我想跟踪模块的方法,但我一直收到这个错误:

`RuntimeError: 无法将需要梯度的张量插入为常量。考虑将其作为参数或输入,或分离梯度`

> 这个错误通常意味着您正在跟踪的方法使用了模块的参数,并且您正在传递模块的方法而不是模块实例(例如`my_module_instance.forward` vs `my_module_instance`)。
> 
> > +   使用模块的方法调用`trace`会捕获模块参数(可能需要梯度)作为**常量**。
> > +   
> > +   另一方面,使用模块实例(例如`my_module`)调用`trace`会创建一个新模块,并将参数正确复制到新模块中,因此如果需要,它们可以累积梯度。
> > +   
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
505
> 要跟踪模块上的特定方法,请参见`torch.jit.trace_module`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
506

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
507
## 已知问题
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
508

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
509
如果您在 TorchScript 中使用`Sequential`,则某些`Sequential`子模块的输入可能会被错误地推断为`Tensor`,即使它们已经被注释。标准解决方案是子类化`nn.Sequential`并使用正确类型的输入重新声明`forward`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
510

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
511
## 附录
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
512

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
513
### 迁移到 PyTorch 1.2 递归脚本 API
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
514

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
515
本节详细介绍了 PyTorch 1.2 中 TorchScript 的更改。如果您是 TorchScript 的新手,可以跳过本节。PyTorch 1.2 对 TorchScript API 进行了两个主要更改。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
516

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
517
1. `torch.jit.script`现在将尝试递归编译遇到的函数、方法和类。一旦调用`torch.jit.script`,编译就是“opt-out”,而不是“opt-in”。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
518

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
519
2. `torch.jit.script(nn_module_instance)`现在是创建`ScriptModule`的首选方法,而不是继承自`torch.jit.ScriptModule`。这些更改结合在一起,为将您的`nn.Module`转换为`ScriptModule`提供了一个更简单、更易于使用的 API,准备在非 Python 环境中进行优化和执行。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545

新的用法如下:

```py
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

my_model = Model()
my_scripted_model = torch.jit.script(my_model) 
```

+   模块的`forward`默认被编译。从`forward`调用的方法会按照它们在`forward`中被使用的顺序进行延迟编译。

+   除了从`forward`调用的方法之外,要编译其他方法,请添加`@torch.jit.export`

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
546
+   要阻止编译器编译一个方法,请添加`@torch.jit.ignore``@torch.jit.unused``@ignore`保留了
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
547

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
548
+   将方法作为对 python 的调用,并用`@unused`替换它以引发异常。`@ignored`不能被导出;`@unused`可以。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
549

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
550
+   大多数属性类型可以被推断,因此不需要`torch.jit.Attribute`。对于空容器类型,使用[PEP 526 风格](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations)的类注释来注释它们的类型。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
551 552 553

+   常量可以用`Final`类注释标记,而不是将成员名称添加到`__constants__`中。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
554
+   Python 3 类型提示可以用来替代`torch.jit.annotate`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571

由于这些更改,以下项目被视为已弃用,不应出现在新代码中:

+   `@torch.jit.script_method`装饰器

+   继承自`torch.jit.ScriptModule`的类

+   `torch.jit.Attribute`包装类

+   `__constants__`数组

+   `torch.jit.annotate`函数

#### 模块

警告

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
572
在 PyTorch 1.2 中,`@torch.jit.ignore`注解的行为发生了变化。在 PyTorch 1.2 之前,@ignore 装饰器用于使一个函数或方法可以从导出的代码中调用。要恢复此功能,请使用`@torch.jit.unused()``@torch.jit.ignore`现在等同于`@torch.jit.ignore(drop=False)`。有关详细信息,请参阅`@torch.jit.ignore``@torch.jit.unused`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
573

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
574
当传递给`torch.jit.script`函数时,`torch.nn.Module`的数据会被复制到一个`ScriptModule`中,TorchScript 编译器会编译该模块。模块的`forward`默认被编译。从`forward`调用的方法会按照它们在`forward`中被使用的顺序进行延迟编译,以及任何`@torch.jit.export`方法。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
575 576

```py
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
577
torch.jit.export(fn)
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
578 579
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
580
这个装饰器表示`nn.Module`上的一个方法被用作进入`ScriptModule`的入口点,并且应该被编译。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618

`forward`隐式地被假定为入口点,因此不需要这个装饰器。从`forward`调用的函数和方法会被编译器按照它们在`forward`中被看到的顺序编译,因此它们也不需要这个装饰器。

示例(在方法上使用`@torch.jit.export`):

```py
import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule()) 
```

#### 函数

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
619
函数没有太大变化,如果需要,可以用`@torch.jit.ignore``torch.jit.unused`进行装饰。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646

```py
# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2

# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4

# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
647
#### TorchScript 类
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
648 649 650

警告

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
651
TorchScript 类支持是实验性的。目前最适合简单的类似记录的类型(考虑带有附加方法的`NamedTuple`)。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
652

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
653
用户定义的 TorchScript 类中的所有内容默认导出,如果需要,函数可以用`@torch.jit.ignore`进行装饰。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
654 655 656

#### 属性

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
657
TorchScript 编译器需要知道模块属性的类型。大多数类型可以从成员的值中推断出来。空列表和字典无法推断其类型,必须使用[PEP 526 风格](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations)的类注释来注释其类型。如果类型无法推断并且没有明确注释,则不会将其添加为结果`ScriptModule`的属性。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
658

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
659
旧 API:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
660 661 662 663 664 665 666 667 668 669 670 671 672 673

```py
from typing import Dict
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)

m = MyModule() 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
674
新 API:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697

```py
from typing import Dict

class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]

    def __init__(self):
        super().__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}

        # The attribute type here is inferred to be `int`
        self.my_int = 20

    def forward(self):
        pass

m = torch.jit.script(MyModule()) 
```

#### 常量

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
698
`Final`类型构造函数可用于将成员标记为常量。如果成员未标记为常量,则它们将被复制到结果`ScriptModule`作为属性。使用`Final`可以在值已知为固定时进行优化,并提供额外的类型安全性。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
699

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
700
旧 API:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
701 702 703 704 705 706 707 708 709 710 711 712 713 714

```py
class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass
m = MyModule() 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
715
新 API:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735

```py
from typing import Final

class MyModule(torch.nn.Module):

    my_constant: Final[int]

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass

m = torch.jit.script(MyModule()) 
```

#### 变量

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
736
假定容器的类型为`Tensor`且非可选(有关更多信息,请参见默认类型)。以前,使用`torch.jit.annotate`告诉 TorchScript 编译器应该是什么类型。现在支持 Python 3 风格的类型提示。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
737 738 739 740 741 742 743 744 745 746 747 748 749 750 751

```py
import torch
from typing import Dict, Optional

@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
752
### 融合后端
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
753

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
754
有几种融合后端可用于优化 TorchScript 执行。在 CPU 上的默认融合器是 NNC,可以为 CPU 和 GPU 执行融合。在 GPU 上的默认融合器是 NVFuser,支持更广泛的运算符,并且已经生成了具有改进吞吐量的内核。有关使用和调试的更多详细信息,请参阅[NVFuser 文档](https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/codegen/cuda/README.md)
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
755

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
756
### 参考资料
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
757

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
758
+   Python 语言参考覆盖
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
759

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
760
+   TorchScript 不支持的 PyTorch 构造