84.md 68.5 KB
Newer Older
片刻小哥哥's avatar
片刻小哥哥 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
# torch脚本

> 原文: [https://pytorch.org/docs/stable/jit.html](https://pytorch.org/docs/stable/jit.html)

*   [创建 TorchScript 代码](#creating-torchscript-code)

*   [混合跟踪和脚本编写](#mixing-tracing-and-scripting)

*   [迁移到 PyTorch 1.2 递归脚本 API](#migrating-to-pytorch-1-2-recursive-scripting-api)

    *   [模块](#modules)

    *   [功能](#functions)

    *   [TorchScript 类](#torchscript-classes)

    *   [属性](#attributes)

        *   [Python 2](#python-2)

    *   [常数](#constants)

    *   [变量](#variables)

*   [TorchScript 语言参考](#torchscript-language-reference)

    *   [类型](#supported-type)

        *   [默认类型](#default-types)

        *   [可选类型细化](#optional-type-refinement)

        *   [TorchScript 类](#id3)

        *   [命名为元组](#named-tuples)

    *   [表达式](#expressions)

        *   [文字](#literals)

            *   [列表结构](#list-construction)

            *   [元组结构](#tuple-construction)

            *   [字典结构](#dict-construction)

        *   [变量](#id5)

        *   [算术运算符](#arithmetic-operators)

        *   [比较运算符](#comparison-operators)

        *   [逻辑运算符](#logical-operators)

        *   [下标和切片](#subscripts-and-slicing)

        *   [函数调用](#function-calls)

        *   [方法调用](#method-calls)

        *   [三元表达式](#ternary-expressions)

        *   [演员表](#casts)

        *   [访问模块参数](#accessing-module-parameters)

    *   [语句](#statements)

        *   [简单分配](#simple-assignments)

        *   [模式匹配分配](#pattern-matching-assignments)

        *   [打印报表](#print-statements)

        *   [If 语句](#if-statements)

        *   [While 循环](#while-loops)

        *   [适用于范围为](#for-loops-with-range)的循环

        *   [用于遍历元组的循环](#for-loops-over-tuples)

        *   [用于在常量 nn.ModuleList](#for-loops-over-constant-nn-modulelist) 上循环

        *   [中断并继续](#break-and-continue)

        *   [返回](#return)

    *   [可变分辨率](#variable-resolution)

    *   [使用 Python 值](#use-of-python-values)

        *   [功能](#id6)

        *   [Python 模块上的属性查找](#attribute-lookup-on-python-modules)

        *   [Python 定义的常量](#python-defined-constants)

        *   [模块属性](#module-attributes)

    *   [调试](#debugging)

        *   [禁用用于调试的 JIT](#disable-jit-for-debugging)

        *   [检查码](#inspecting-code)

        *   [解释图](#interpreting-graphs)

        *   [追踪案例](#tracing-edge-cases)

        *   [自动跟踪检查](#automatic-trace-checking)

        *   [跟踪器警告](#tracer-warnings)

    *   [内置函数](#builtin-functions)

*   [常见问题解答](#frequently-asked-questions)

TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。 任何 TorchScript 程序都可以从 Python 进程中保存并加载到没有 Python 依赖项的进程中。

我们提供了将模型从纯 Python 程序逐步过渡到可以独立于 Python 运行的 TorchScript 程序的工具,例如在独立的 C ++程序中。 这样就可以使用 Python 中熟悉的工具在 PyTorch 中训练模型,然后通过 TorchScript 将模型导出到生产环境中,在该生产环境中 Python 程序可能由于性能和多线程原因而处于不利地位。

有关 TorchScript 的简要介绍,请参见 [TorchScript 简介](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html)教程。

有关将 PyTorch 模型转换为 TorchScript 并在 C ++中运行的端到端示例,请参见[在 C ++中加载 PyTorch 模型](https://pytorch.org/tutorials/advanced/cpp_export.html)教程。

## [创建 TorchScript 代码](#id10)

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
132
class torch.jit.ScriptModule
片刻小哥哥's avatar
片刻小哥哥 已提交
133 134 135 136 137
```

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
138
property code
片刻小哥哥's avatar
片刻小哥哥 已提交
139 140 141 142 143 144 145
```

返回`forward`方法的内部图的漂亮打印表示形式(作为有效的 Python 语法)。 有关详细信息,请参见[检查代码](#inspecting-code)

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
146
property graph
片刻小哥哥's avatar
片刻小哥哥 已提交
147 148 149 150 151 152 153
```

返回`forward`方法的内部图形的字符串表示形式。 有关详细信息,请参见[解释图](#interpreting-graphs)

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
154
save(f, _extra_files=ExtraFilesMap{})
片刻小哥哥's avatar
片刻小哥哥 已提交
155 156 157 158 159 160 161
```

有关详细信息,请参见 [`torch.jit.save`](#torch.jit.save "torch.jit.save")

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
162
class torch.jit.ScriptFunction
片刻小哥哥's avatar
片刻小哥哥 已提交
163 164 165 166 167 168 169
```

功能上与 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 等效,但是代表单个功能,没有任何属性或参数。

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
170
torch.jit.script(obj)
片刻小哥哥's avatar
片刻小哥哥 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
```

为函数或`nn.Module`编写脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码,然后返回 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule")[`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction") 。 TorchScript 本身是 Python 语言的子集,因此 Python 并非所有功能都可以使用,但是我们提供了足够的功能来在张量上进行计算并执行与控制有关的操作。 有关完整指南,请参见 [TorchScript 语言参考](#torchscript-language-reference)

`torch.jit.script`可用作模块和功能的函数,以及 [TorchScript 类](#torchscript-class)和功能的修饰器`@torch.jit.script`

```
Scripting a function
```

`@torch.jit.script`装饰器将通过编译函数的主体来构造 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction")

示例(编写函数):

```
import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFuncion

# See the compiled graph as Python code
print(foo.code)

# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))

```

```
Scripting an nn.Module
```

默认情况下,为`nn.Module`编写脚本将编译`forward`方法,并递归编译`forward`调用的任何方法,子模块和函数。 如果`nn.Module`仅使用 TorchScript 支持的功能,则无需更改原始模块代码。 `script`将构建 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") ,该副本具有原始模块的属性,参数和方法的副本。

示例(使用参数编写简单模块的脚本):

```
import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))

```

示例(使用跟踪的子模块编写模块脚本):

```
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
      input = F.relu(self.conv1(input))
      input = F.relu(self.conv2(input))
      return input

scripted_module = torch.jit.script(MyModule())

```

要编译除`forward`以外的方法(并递归编译其调用的任何内容),请将 [`@torch.jit.export`](#torch.jit.export "torch.jit.export") 装饰器添加到该方法。 要选择退出编译,请使用 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore")

示例(模块中的导出方法和忽略方法):

```
import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb
        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99

scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

```

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
298
torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)
片刻小哥哥's avatar
片刻小哥哥 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
```

跟踪一个函数并返回将使用即时编译进行优化的可执行文件或 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction") 。 对于仅在`Tensor``Tensor`的列表,字典和元组上运行的代码,跟踪是理想的选择。

使用`torch.jit.trace`[`torch.jit.trace_module`](#torch.jit.trace_module "torch.jit.trace_module") ,您可以将现有模块或 Python 函数转换为 TorchScript [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction")[`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 。 您必须提供示例输入,然后我们运行该函数,记录在所有张量上执行的操作。

*   独立功能的最终记录将产生 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction")

*   `nn.Module``nn.Module``forward`功能的所得记录产生 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule")

该模块还包含原始模块也具有的任何参数。

警告

跟踪仅正确记录不依赖数据的功能和模块(例如,对张量中的数据没有条件)并且不包含任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)。 跟踪仅记录在给定张量上运行给定函数时执行的操作。 因此,返回的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 将始终在任何输入上运行相同的跟踪图。 当期望模块根据输入和/或模块状态运行不同的操作集时,这具有重要意义。 例如,

*   跟踪将不会记录任何控制流,例如 if 语句或循环。 当整个模块的控制流恒定时,这很好,并且通常内联控制流决策。 但是有时控制流实际上是模型本身的一部分。 例如,循环网络是输入序列(可能是动态)长度上的循环。

*   在返回的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 中,在`training``eval`模式下具有不同行为的操作将始终像在跟踪过程中一样处于运行状态,无论[是哪种模式 ] `ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 已插入。

在这种情况下,跟踪是不合适的, [`scripting`](#torch.jit.script "torch.jit.script") 是更好的选择。 如果跟踪此类模型,则可能在随后的模型调用中静默地得到不正确的结果。 在执行可能会导致产生不正确跟踪的操作时,跟踪器将尝试发出警告。

参数

*   **函数**(可调用的_或_ [_torch.nn.Module_](nn.html#torch.nn.Module "torch.nn.Module"))– Python 函数或`torch.nn.Module``example_inputs`一起运行。 `func`的参数和返回值必须是张量或包含张量的(可能是嵌套的)元组。 将模块传递到 [`torch.jit.trace`](#torch.jit.trace "torch.jit.trace") 时,仅运行并跟踪`forward`方法(有关详细信息,请参见 [`torch.jit.trace`](#torch.jit.trace_module "torch.jit.trace_module"))。

*   **example_inputs**  (_tuple_ )–示例输入的元组,将在跟踪时传递给函数。 假设跟踪的操作支持这些类型和形状,则可以使用不同类型和形状的输入来运行结果跟踪。 `example_inputs`也可以是单个张量,在这种情况下,它会自动包装在元组中。

```
Keyword Arguments
```

*   **check_trace** (`bool`,可选)–检查通过跟踪代码运行的相同输入是否产生相同的输出。 默认值:`True`。 例如,如果您的网络包含不确定性操作,或者即使检查程序失败,但您确定网络正确,则可能要禁用此功能。

*   **check_inputs** (_元组列表_ _,_ _可选_)–输入参数的元组列表,应使用这些元组来检查跟踪内容 是期待。 每个元组等效于`example_inputs`中指定的一组输入参数。 为了获得最佳结果,请传递一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。 如果未指定,则使用原始的`example_inputs`进行检查

*   **check_tolerance**  (_python:float_ _,_ _可选_)–在检查程序中使用的浮点比较公差。 如果结果由于已知原因(例如操作员融合)而在数值上出现差异,则可以使用此方法来放松检查器的严格性。

退货

如果`callable``nn.Module``nn.Module``forward`,则`trace`将使用包含跟踪代码的单个`forward`方法返回 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 对象。 返回的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 将具有与原始`nn.Module`相同的子模块和参数集。 如果`callable`是独立功能,则`trace`返回 [`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction")

示例(跟踪函数):

```
import torch

def foo(x, y):
    return 2 * x + y

# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment

```

示例(跟踪现有模块):

```
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

```

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
388
torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)
片刻小哥哥's avatar
片刻小哥哥 已提交
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
```

跟踪模块并返回可执行文件 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") ,该文件将使用即时编译进行优化。 将模块传递到 [`torch.jit.trace`](#torch.jit.trace "torch.jit.trace") 时,仅运行并跟踪`forward`方法。 使用`trace_module`,您可以指定方法名称的字典作为示例输入,以跟踪下面的参数(请参见`example_inputs`)。

有关跟踪的更多信息,请参见 [`torch.jit.trace`](#torch.jit.trace "torch.jit.trace")

Parameters

*   **mod**  ([_Torch.nn.Module_](nn.html#torch.nn.Module "torch.nn.Module"))–一种`torch.nn.Module`,其中包含名称在`example_inputs`中指定的方法。 给定的方法将被编译为单个 <cite>ScriptModule</cite> 的一部分。

*   **example_inputs**  (_dict_ )–包含样本输入的字典,该样本输入由`mod`中的方法名称索引。 输入将在跟踪时传递给名称与输入键对应的方法。 `{ 'forward' : example_forward_input, 'method2': example_method2_input}`

```
Keyword Arguments
```

*   **check_trace** (`bool`, optional) – Check if the same inputs run through traced code produce the same outputs. Default: `True`. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure.

*   **check_inputs** (_字典列表_ _,_ _可选_)–输入参数的字典列表,用于检查跟踪内容 是期待。 每个元组等效于`example_inputs`中指定的一组输入参数。 为了获得最佳结果,请传递一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。 如果未指定,则使用原始的`example_inputs`进行检查

*   **check_tolerance** (_python:float__,_ _optional_) – Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.

Returns

具有单个`forward`方法的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 对象,其中包含跟踪的代码。 当`func``torch.nn.Module`时,返回的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 将具有与`func`相同的子模块和参数集。

示例(使用多种方法跟踪模块):

```
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

```

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
454
torch.jit.save(m, f, _extra_files=ExtraFilesMap{})
片刻小哥哥's avatar
片刻小哥哥 已提交
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
```

保存此模块的脱机版本以在单独的过程中使用。 保存的模块将序列化此模块的所有方法,子模块,参数和属性。 可以使用`torch::jit::load(filename)`将其加载到 C ++ API 中,或者使用 [`torch.jit.load`](#torch.jit.load "torch.jit.load") 加载到 Python API 中。

为了能够保存模块,它不得对本地 Python 函数进行任何调用。 这意味着所有子模块也必须是`torch.jit.ScriptModule`的子类。

危险

所有模块,无论使用哪种设备,都始终在加载期间加载到 CPU 中。 这与 [`load`](#torch.jit.load "torch.jit.load") 的语义不同,并且将来可能会发生变化。

Parameters

*   **m** –要保存的 ScriptModule。

*   **f** –类似于文件的对象(必须实现写入和刷新)或包含文件名的字符串。

*   **_extra_files** -从文件名映射到将作为“ f”的一部分存储的内容。

Warning

如果您使用的是 Python 2,`torch.jit.save`不支持`StringIO.StringIO`作为有效的类似文件的对象。 这是因为 write 方法应返回写入的字节数; `StringIO.write()`不这样做。

请改用`io.BytesIO`之类的东西。

例:

```
import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

m = torch.jit.script(MyModule())

# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)

# Save with extra files
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

```

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
510
torch.jit.load(f, map_location=None, _extra_files=ExtraFilesMap{})
片刻小哥哥's avatar
片刻小哥哥 已提交
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
```

加载先前用 [`torch.jit.save`](#torch.jit.save "torch.jit.save") 保存的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule")[`ScriptFunction`](#torch.jit.ScriptFunction "torch.jit.ScriptFunction")

之前保存的所有模块,无论使用何种设备,都首先加载到 CPU 中,然后再移动到保存它们的设备上。 如果失败(例如,因为运行时系统没有某些设备),则会引发异常。

Parameters

*   **f** –类似于文件的对象(必须实现读取,读取行,告诉和查找),或包含文件名的字符串

*   **map_location** (_字符串_ _或_ [_torch设备_](tensor_attributes.html#torch.torch.device "torch.torch.device"))– `torch.save``map_location`的简化版本 用于动态地将存储重新映射到另一组设备。

*   **_extra_files** (_文件名到内容的字典_)–映射中给定的多余文件名将被加载,其内容将存储在提供的映射中。

Returns

[`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 对象。

Example:

```
import torch
import io

torch.jit.load('scriptmodule.pt')

# Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:
    buffer = io.BytesIO(f.read())

# Load all tensors to the original device
torch.jit.load(buffer)

# Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))

# Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')

# Load with extra files.
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])

```

## [混合跟踪和脚本编写](#id11)

在许多情况下,将模型转换为 TorchScript 都可以使用跟踪或脚本编写。 可以组成跟踪和脚本以适合模型一部分的特定要求。

脚本函数可以调用跟踪函数。 当您需要在简单的前馈模型周围使用控制流时,这特别有用。 例如,序列到序列模型的波束搜索通常将用脚本编写,但是可以调用使用跟踪生成的编码器模块。

示例(在脚本中调用跟踪的函数):

```
import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

```

跟踪的函数可以调用脚本函数。 即使大部分模型只是前馈网络,当模型的一小部分需要一些控制流时,这也很有用。 跟踪函数调用的脚本函数内部的控制流已正确保留。

示例(在跟踪函数中调用脚本函数):

```
import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

```

此组合也适用于`nn.Module`,在这里它可用于通过跟踪来生成子模块,该跟踪可以从脚本模块的方法中调用。

示例(使用跟踪模块):

```
import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

```

## [迁移到 PyTorch 1.2 递归脚本 API](#id12)

本节详细介绍了 PyTorch 1.2 中对 TorchScript 的更改。 如果您不熟悉 TorchScript,则可以跳过本节。 PyTorch 1.2 对 TorchScript API 进行了两个主要更改。

1\. [`torch.jit.script`](#torch.jit.script "torch.jit.script") 现在将尝试递归编译遇到的函数,方法和类。 调用`torch.jit.script`后,编译将是“选择退出”,而不是“选择加入”。

2.现在`torch.jit.script(nn_module_instance)`是创建 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 的首选方法,而不是从`torch.jit.ScriptModule`继承。 这些更改组合在一起,提供了一个更简单易用的 API,可将您的`nn.Module`转换为 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") ,可以在非 Python 环境中进行优化和执行。

新用法如下所示:

```
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

my_model = Model()
my_scripted_model = torch.jit.script(my_model)

```

*   该模块的`forward`是默认编译的。 从`forward`调用的方法将按照在`forward`中使用的顺序进行延迟编译。

*   要编译未从`forward`调用的`forward`以外的方法,请添加`@torch.jit.export`

*   要停止编译器编译方法,请添加 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore")[`@torch.jit.unused`](#torch.jit.unused "torch.jit.unused")`@ignore`离开

*   方法作为对 python 的调用,并且`@unused`将其替换为异常。 `@ignored`无法导出; `@unused`可以。

*   可以推断大多数属性类型,因此不需要`torch.jit.Attribute`。 对于空容器类型,请使用 [PEP 526 样式](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations)类注释对其类型进行注释。

*   可以使用`Final`类注释来标记常量,而不是将成员的名称添加到`__constants__`中。

*   可以使用 Python 3 类型提示代替`torch.jit.annotate`

```
As a result of these changes, the following items are considered deprecated and should not appear in new code:
```

*   `@torch.jit.script_method`装饰器

*   继承自`torch.jit.ScriptModule`的类

*   `torch.jit.Attribute`包装器类

*   `__constants__`数组

*   `torch.jit.annotate`功能

### [模块](#id13)

Warning

[`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore") 注释的行为在 PyTorch 1.2 中发生了变化。 在 PyTorch 1.2 之前,@ ignore 装饰器用于使函数或方法可从导出的代码中调用。 要恢复此功能,请使用`@torch.jit.unused()``@torch.jit.ignore`现在等同于`@torch.jit.ignore(drop=False)`。 有关详细信息,请参见 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore")[`@torch.jit.unused`](#torch.jit.unused "torch.jit.unused")

当传递给 [`torch.jit.script`](#torch.jit.script "torch.jit.script") 函数时,`torch.nn.Module`的数据将复制到 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") ,然后 TorchScript 编译器将编译该模块。 该模块的`forward`默认为编译状态。 从`forward`调用的方法以及它们在`forward`中使用的顺序都是按延迟顺序编译的。

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
696
torch.jit.export(fn)
片刻小哥哥's avatar
片刻小哥哥 已提交
697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557
```

此修饰符指示`nn.Module`上的方法用作 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 的入口点,应进行编译。

`forward`隐式地假定为入口点,因此不需要此装饰器。 从`forward`调用的函数和方法在编译器看到的情况下进行编译,因此它们也不需要此装饰器。

示例(在方法上使用`@torch.jit.export`):

```
import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

```

### [功能](#id14)

功能没有太大变化,可以根据需要用 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore")[`torch.jit.unused`](#torch.jit.unused "torch.jit.unused") 装饰。

```
# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2

# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4

# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2

```

### [TorchScript 类](#id15)

默认情况下,将导出用户定义的 [TorchScript 类](#torchscript-class)中的所有内容,可以根据需要用 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore") 修饰功能。

### [属性](#id16)

TorchScript 编译器需要知道[模块属性](#module-attributes)的类型。 大多数类型可以从成员的值推断出来。 空列表和字典不能推断其类型,而必须使用 [PEP 526 样式](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations)类注释来注释其类型。 如果无法推断类型并且未对显式类型进行注释,则不会将其作为属性添加到结果 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule")

旧 API:

```
from typing import Dict
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super(MyModule, self).__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)

m = MyModule()

```

新 API:

```
from typing import Dict

class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]

    def __init__(self):
        super(MyModule, self).__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}

        # The attribute type here is inferred to be `int`
        self.my_int = 20

    def forward(self):
        pass

m = torch.jit.script(MyModule())

```

#### [Python 2](#id17)

如果您受制于 Python 2 并且无法使用类注释语法,则可以使用`__annotations__`类成员直接应用类型注释。

```
from typing import Dict

class MyModule(torch.jit.ScriptModule):
    __annotations__ = {'my_dict': Dict[str, int]}

    def __init__(self):
        super(MyModule, self).__init__()
        self.my_dict = {}
        self.my_int = 20

```

### [常数](#id18)

`Final`类型的构造函数可用于将成员标记为[常量](#constant)。 如果成员未标记为常量,则将其复制为结果 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 作为属性。 如果已知该值是固定的,则使用`Final`可以进行优化,并提供附加的类型安全性。

Old API:

```
class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']

    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2

    def forward(self):
        pass
m = MyModule()

```

New API:

```
try:
    from typing_extensions import Final
except:
    # If you don't have `typing_extensions` installed, you can use a
    # polyfill from `torch.jit`.
    from torch.jit import Final

class MyModule(torch.nn.Module):

    my_constant: Final[int]

    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2

    def forward(self):
        pass

m = torch.jit.script(MyModule())

```

### [变量](#id19)

假定容器的类型为`Tensor`,并且是非可选的(有关更多信息,请参见[默认类型](#default-types))。 以前,`torch.jit.annotate`用来告诉 TorchScript 编译器类型是什么。 现在支持 Python 3 样式类型提示。

```
import torch
from typing import Dict, Optional

@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

```

## [TorchScript 语言参考](#id20)

TorchScript 是 Python 的静态类型子集,可以直接编写(使用 [`@torch.jit.script`](#torch.jit.script "torch.jit.script") 装饰器),也可以通过跟踪从 Python 代码自动生成。 使用跟踪时,通过仅在张量上记录实际的运算符并简单地执行和丢弃其他周围的 Python 代码,代码会自动转换为 Python 的此子集。

使用`@torch.jit.script`装饰器直接编写 TorchScript 时,程序员只能使用 TorchScript 支持的 Python 子集。 本节记录了 TorchScript 支持的功能,就像它是独立语言的语言参考一样。 本参考中未提及的 Python 的任何功能都不属于 TorchScript。 有关可用的 Pytorch 张量方法,模块和功能的完整参考,请参见[内置函数](#builtin-functions)

作为 Python 的子集,任何有效的 TorchScript 函数也是有效的 Python 函数。 这样就可以[禁用 TorchScript](#disable-torchscript) 并使用`pdb`之类的标准 Python 工具调试该功能。 反之则不成立:有许多有效的 Python 程序不是有效的 TorchScript 程序。 相反,TorchScript 特别专注于表示 PyTorch 中的神经网络模型所需的 Python 功能。

### [类型](#id21)

TorchScript 与完整的 Python 语言之间的最大区别是 TorchScript 仅支持表达神经网络模型所需的一小部分类型。 特别是,TorchScript 支持:

| 

类型

 | 

描述

 |
| --- | --- |
| `Tensor` | 任何 dtype,尺寸或后端的 PyTorch 张量 |
| `Tuple[T0, T1, ...]` | 包含子类型`T0``T1`等(例如`Tuple[Tensor, Tensor]`)的元组 |
| `bool` | 布尔值 |
| `int` | 标量整数 |
| `float` | 标量浮点数 |
| `str` | 一串 |
| `List[T]` | 所有成员均为`T`类型的列表 |
| `Optional[T]` | 无或输入`T`的值 |
| `Dict[K, V]` | 键类型为`K`而值类型为`V`的字典。 只能将`str``int``float`作为密​​钥类型。 |
| `T` | 一个 [TorchScript 类](#torchscript-class) |
| `NamedTuple[T0, T1, ...]` | `collections.namedtuple`元组类型 |

与 Python 不同,TorchScript 函数中的每个变量都必须具有一个静态类型。 这使优化 TorchScript 函数变得更加容易。

示例(类型不匹配)

```
import torch

@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r

```

```
Traceback (most recent call last):
  ...
RuntimeError: ...

Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~...  <--- HERE
        r = torch.rand(1)
    else:
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE
...

```

#### [默认类型](#id22)

默认情况下,TorchScript 函数的所有参数均假定为 Tensor。 要指定 TorchScript 函数的参数是其他类型,可以使用上面列出的类型使用 MyPy 样式的类型注释。

```
import torch

@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

```

注意

也可以使用`typing`模块中的 Python 3 类型提示来注释类型。

```
import torch
from typing import Tuple

@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

```

在我们的示例中,我们使用基于注释的类型提示来确保 Python 2 的兼容性。

假定空列表为`List[Tensor]`,空字典为`Dict[str, Tensor]`。 要实例化其他类型的空列表或字典,请使用 [Python 3 类型提示](#python-3-type-hints)。 如果您使用的是 Python 2,则可以使用`torch.jit.annotate`

示例(Python 3 的类型注释):

```
import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super(EmptyDataStructures, self).__init__()

    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))

        my_dict: Dict[str, int] = {}
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

```

示例(适用于 Python 2 的`torch.jit.annotate`):

```
import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super(EmptyDataStructures, self).__init__()

    def forward(self, x):
        # type: (Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]

        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list = torch.jit.annotate(List[Tuple[int, float]], [])
        for i in range(10):
            my_list.append((i, float(x.item())))

        my_dict = torch.jit.annotate(Dict[str, int], {})
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

```

#### [可选类型细化](#id23)

在 if 语句的条件内或在`assert`中检查与`None`的比较时,TorchScript 将优化`Optional[T]`类型的变量的类型。 编译器可以推理与`and``or``not`结合的多个`None`检查。 对于未明确编写的 if 语句的 else 块,也会进行优化。

`None`检查必须在 if 语句的条件内; 将`None`检查分配给变量,并在 if 语句的条件下使用它,将不会优化检查中的变量类型。 仅局部变量将被细化,`self.x`之类的属性将不会且必须分配给要细化的局部变量。

示例(优化参数和局部变量的类型):

```
import torch
import torch.nn as nn
from typing import Optional

class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super(M, self).__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z

    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

```

#### [TorchScript 类](#id24)

如果 Python 类使用 [`@torch.jit.script`](#torch.jit.script "torch.jit.script") 注释,则可以在 TorchScript 中使用,类似于声明 TorchScript 函数的方式:

```
@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x

  def aug_add_x(self, inc):
    self.x += inc

```

此子集受限制:

*   所有函数必须是有效的 TorchScript 函数(包括`__init__()`)。

*   这些类必须是新型类,因为我们使用`__new__()`和 pybind11 来构造它们。

*   TorchScript 类是静态类型的。 只能通过在`__init__()`方法中分配给 self 来声明成员。

    &gt; 例如,在`__init__()`方法之外分配给`self`:
    &gt; 
    &gt; ```
    &gt; @torch.jit.script
    &gt; class Foo:
    &gt;   def assign_x(self):
    &gt;     self.x = torch.rand(2, 3)
    &gt; 
    &gt; ```
    &gt; 
    &gt; 将导致:
    &gt; 
    &gt; ```
    &gt; RuntimeError:
    &gt; Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
    &gt; def assign_x(self):
    &gt;   self.x = torch.rand(2, 3)
    &gt;   ~~~~~~~~~~~~~~~~~~~~~~~~ &lt;--- HERE
    &gt; 
    &gt; ```

*   类的主体中不允许使用除方法定义之外的任何表达式。

*   除了从`object`继承以指定新样式类外,不支持继承或任何其他多态策略。

定义了一个类之后,就可以像其他任何 TorchScript 类型一样在 TorchScript 和 Python 中互换使用该类:

```
# Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second

@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second

p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

```

#### [命名为元组](#id25)

`collections.namedtuple`产生的类型可以在 TorchScript 中使用。

```
import torch
import collections

Point = collections.namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

```

### [表达式](#id26)

支持以下 Python 表达式。

#### [文字](#id27)

```
True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

```

##### [列表结构](#id28)

假定一个空列表具有`List[Tensor]`类型。 其他列表文字的类型是从成员的类型派生的。 有关更多详细信息,请参见[默认类型](#default-types)

```
[3, 4]
[]
[torch.rand(3), torch.rand(4)]

```

##### [元组结构](#id29)

```
(3, 4)
(3,)

```

##### [字典结构](#id30)

假定一个空字典为`Dict[str, Tensor]`类型。 其他 dict 文字的类型是从成员的类型派生的。 有关更多详细信息,请参见[默认类型](#default-types)

```
{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

```

#### [变量](#id31)

有关如何解析变量的信息,请参见[变量分辨率](#variable-resolution)

```
my_variable_name

```

#### [算术运算符](#id32)

```
a + b
a - b
a * b
a / b
a ^ b
a @ b

```

#### [比较运算符](#id33)

```
a == b
a != b
a < b
a > b
a <= b
a >= b

```

#### [逻辑运算符](#id34)

```
a and b
a or b
not b

```

#### [下标和切片](#id35)

```
t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

```

#### [函数调用](#id36)

调用[内置函数](#builtin-functions)

```
torch.rand(3, dtype=torch.int)

```

调用其他脚本函数:

```
import torch

@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

```

#### [方法调用](#id37)

调用诸如张量之类的内置类型的方法:`x.mm(y)`

在模块上,必须先编译方法才能调用它们。 TorchScript 编译器以递归方式编译在编译其他方法时看到的方法。 默认情况下,编译从`forward`方法开始。 将编译`forward`调用的任何方法,以及这些方法调用的任何方法,依此类推。 要以`forward`以外的方法开始编译,请使用 [`@torch.jit.export`](#torch.jit.export "torch.jit.export") 装饰器(`forward`隐式标记为`@torch.jit.export`)。

直接调用子模块(例如`self.resnet(input)`)等效于调用其`forward`方法(例如`self.resnet.forward(input)`)。

```
import torch
import torch.nn as nn
import torchvision

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

    def helper(self, input):
        return self.resnet(input - self.means)

    def forward(self, input):
        return self.helper(input)

    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)

    def other_helper(self, input):
        return input + 10

# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

```

#### [三元表达式](#id38)

```
x if x > y else y

```

#### [演员表](#id39)

```
float(ten)
int(3.5)
bool(ten)
str(2)``

```

#### [访问模块参数](#id40)

```
self.my_parameter
self.my_submodule.my_parameter

```

### [语句](#id41)

TorchScript 支持以下类型的语句:

#### [简单分配](#id42)

```
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

```

#### [模式匹配分配](#id43)

```
a, b = tuple_or_list
a, b, *c = a_tuple

```

多项分配

```
a = b, c = tup

```

#### [打印报表](#id44)

```
print("the result of an add:", a + b)

```

#### [If 语句](#id45)

```
if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

```

除布尔值外,浮点数,整数和张量还可以在条件中使用,并将隐式转换为布尔值。

#### [While 循环](#id46)

```
a = 0
while a < 4:
    print(a)
    a += 1

```

#### [适用于范围为](#id47)的循环

```
x = 0
for i in range(10):
    x *= i

```

#### [用于遍历元组的循环](#id48)

这些展开循环,为元组的每个成员生成一个主体。 主体必须对每个成员进行正确的类型检查。

```
tup = (3, torch.rand(4))
for x in tup:
    print(x)

```

#### [用于在常量 nn.ModuleList](#id49) 上循环

要在已编译方法中使用`nn.ModuleList`,必须通过将属性名称添加到`__constants__`列表中的类型来将其标记为常量。 `nn.ModuleList`上的 for 循环将在编译时展开循环的主体,并使用常量模块列表的每个成员。

```
class SubModule(torch.nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.weight = nn.Parameter(torch.randn(2))

    def forward(self, input):
        return self.weight + input

class MyModule(torch.nn.Module):
    __constants__ = ['mods']

    def __init__(self):
        super(MyModule, self).__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v

m = torch.jit.script(MyModule())

```

#### [中断并继续](#id50)

```
for i in range(5):
    if i == 1:
    continue
    if i == 3:
    break
    print(i)

```

#### [返回](#id51)

```
return a, b

```

### [可变分辨率](#id52)

TorchScript 支持 Python 的可变分辨率(即作用域)规则的子集。 局部变量的行为与 Python 中的相同,不同之处在于,在通过函数的所有路径上,变量必须具有相同的类型。 如果变量在 if 语句的不同分支上具有不同的类型,则在 if 语句结束后使用它是错误的。

同样,如果沿函数的某些路径仅将_定义为_,则不允许使用该变量。

Example:

```
@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)

```

```
Traceback (most recent call last):
  ...
RuntimeError: ...

y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~...  <--- HERE
        y = 4
    print(y)
...

```

定义函数时,会在编译时将非局部变量解析为 Python 值。 然后使用 [Python 值使用](#use-of-python-values)中描述的规则将这些值转换为 TorchScript 值。

### [使用 Python 值](#id53)

为了使编写 TorchScript 更加方便,我们允许脚本代码引用周围范围中的 Python 值。 例如,任何时候只要引用`torch`,当声明函数时,TorchScript 编译器实际上就会将其解析为`torch` Python 模块。 这些 Python 值不是 TorchScript 的一流部分。 而是在编译时将它们分解为 TorchScript 支持的原始类型。 这取决于编译发生时引用的 Python 值的动态类型。 本节介绍在 TorchScript 中访问 Python 值时使用的规则。

#### [功能](#id54)

TorchScript 可以调用 Python 函数。 当将模型逐步转换为 TorchScript 时,此功能非常有用。 可以将模型逐函数移至 TorchScript,而对 Python 函数的调用保留在原处。 这样,您可以在进行过程中逐步检查模型的正确性。

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
1558
torch.jit.ignore(drop=False, **kwargs)
片刻小哥哥's avatar
片刻小哥哥 已提交
1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617
```

该装饰器向编译器指示应忽略函数或方法,而将其保留为 Python 函数。 这使您可以将代码保留在尚未与 TorchScript 兼容的模型中。 具有忽略功能的模型无法导出; 请改用 torch.jit.unused。

示例(在方法上使用`@torch.jit.ignore`):

```
import torch
import torch.nn as nn

class MyModule(nn.Module):
    @torch.jit.ignore
    def debugger(self, x):
        import pdb
        pdb.set_trace()

    def forward(self, x):
        x += 10
        # The compiler would normally try to compile `debugger`,
        # but since it is `@ignore`d, it will be left as a call
        # to Python
        self.debugger(x)
        return x

m = torch.jit.script(MyModule())

# Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")

```

示例(在方法上使用`@torch.jit.ignore(drop=True)`):

```
import torch
import torch.nn as nn

class MyModule(nn.Module):
    @torch.jit.ignore(drop=True)
    def training_method(self, x):
        import pdb
        pdb.set_trace()

    def forward(self, x):
        if self.training:
            self.training_method(x)
        return x

m = torch.jit.script(MyModule())

# This is OK since `training_method` is not saved, the call is replaced
# with a `raise`.
m.save("m.pt")

```

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
1618
torch.jit.unused(fn)
片刻小哥哥's avatar
片刻小哥哥 已提交
1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658
```

此装饰器向编译器指示应忽略函数或方法,并用引发异常的方法代替。 这样,您就可以在尚不兼容 TorchScript 的模型中保留代码,并仍然可以导出模型。

> 示例(在方法上使用`@torch.jit.unused`):
> 
> ```
> import torch
> import torch.nn as nn
> 
> class MyModule(nn.Module):
>     def __init__(self, use_memory_efficent):
>         super(MyModule, self).__init__()
>         self.use_memory_efficent = use_memory_efficent
> 
>     @torch.jit.unused
>     def memory_efficient(self, x):
>         import pdb
>         pdb.set_trace()
>         return x + 10
> 
>     def forward(self, x):
>         # Use not-yet-scriptable memory efficient mode
>         if self.use_memory_efficient:
>             return self.memory_efficient(x)
>         else:
>             return x + 10
> 
> m = torch.jit.script(MyModule(use_memory_efficent=False))
> m.save("m.pt")
> 
> m = torch.jit.script(MyModule(use_memory_efficient=True))
> # exception raised
> m(torch.rand(100))
> 
> ```

* * *

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
1659
torch.jit.is_scripting()
片刻小哥哥's avatar
片刻小哥哥 已提交
1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826
```

在编译时返回 True 的函数,否则返回 False 的函数。 这对于使用@unused 装饰器尤其有用,可以将尚不兼容 TorchScript 的代码保留在模型中。 .. testcode:

```
import torch

@torch.jit.unused
def unsupported_linear_op(x):
    return x

def linear(x):
   if not torch.jit.is_scripting():
      return torch.linear(x)
   else:
      return unsupported_linear_op(x)

```

#### [Python 模块上的属性查找](#id55)

TorchScript 可以在模块上查找属性。 [像`torch.add`这样的内置功能](#builtin-functions)可以通过这种方式访问​​。 这使 TorchScript 可以调用其他模块中定义的函数。

#### [Python 定义的常量](#id56)

TorchScript 还提供了一种使用 Python 中定义的常量的方法。 这些可用于将超参数硬编码到函数中,或定义通用常量。 有两种指定 Python 值应视为常量的方式。

1.  查找为模块属性的值假定为常量:

```
import math
import torch

@torch.jit.script
def fn():
    return math.pi

```

1.  可以通过使用`Final[T]`注释 ScriptModule 的属性来将其标记为常量。

```
import torch
import torch.nn as nn

class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]

    def __init__(self):
        super(Foo, self).__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

```

支持的常量 Python 类型是

*   `int`

*   `float`

*   `bool`

*   `torch.device`

*   `torch.layout`

*   `torch.dtype`

*   包含受支持类型的元组

*   `torch.nn.ModuleList`可以在 TorchScript for 循环中使用

Note

如果您使用的是 Python 2,则可以通过将属性名称添加到类的`__constants__`属性中来将其标记为常量:

```
import torch
import torch.nn as nn

class Foo(nn.Module):
    __constants__ = ['a']

    def __init__(self):
        super(Foo, self).__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

```

#### [模块属性](#id57)

`torch.nn.Parameter`包装器和`register_buffer`可用于将张量分配给模块。 如果可以推断出其他类型的值,则分配给已编译模块的其他值将添加到已编译模块中。 TorchScript 中可用的所有[类型](#types)都可以用作模块属性。 张量属性在语义上与缓冲区相同。 空列表和字典的类型以及`None`值无法推断,必须通过 [PEP 526 样式](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations)类注释指定。 如果无法推断出类型并且未对其进行显式注释,则不会将其作为属性添加到结果 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 中。

Example:

```
from typing import List, Dict

class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]

    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]

    def __init__(self, a_dict):
        super(Foo, self).__init__()
        self.words = []
        self.some_dict = a_dict

        # `int`s can be inferred
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))

```

Note

如果您使用的是 Python 2,则可以通过将属性的类型添加到`__annotations__`类属性中作为属性名字典来标记属性的类型

```
from typing import List, Dict

class Foo(nn.Module):
    __annotations__ = {'words': List[str], 'some_dict': Dict[str, int]}

    def __init__(self, a_dict):
        super(Foo, self).__init__()
        self.words = []
        self.some_dict = a_dict

        # `int`s can be inferred
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))

```

### [调试](#id58)

#### [禁用用于调试的 JIT](#id59)

```
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
1827
PYTORCH_JIT
片刻小哥哥's avatar
片刻小哥哥 已提交
1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194
```

设置环境变量`PYTORCH_JIT=0`将禁用所有脚本和跟踪注释。 如果您的 TorchScript 模型之一存在难以调试的错误,则可以使用此标志来强制一切都使用本机 Python 运行。 由于此标志禁用了 TorchScript(脚本编写和跟踪),因此可以使用`pdb`之类的工具来调试模型代码。

给定一个示例脚本:

```
@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

```

除调用[,`@torch.jit.script`,](#torch.jit.script "torch.jit.script")函数外,使用`pdb`调试此脚本是可行的。 我们可以全局禁用 JIT,以便我们可以将 [`@torch.jit.script`](#torch.jit.script "torch.jit.script") 函数作为普通的 Python 函数调用,而不进行编译。 如果上述脚本称为`disable_jit_example.py`,我们可以这样调用它:

```
$ PYTORCH_JIT=0 python disable_jit_example.py

```

并且我们将能够像普通的 Python 函数一样进入 [`@torch.jit.script`](#torch.jit.script "torch.jit.script") 函数。 要为特定功能禁用 TorchScript 编译器,请参见 [`@torch.jit.ignore`](#torch.jit.ignore "torch.jit.ignore")

#### [检查码](#id60)

TorchScript 为所有 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 实例提供了代码漂亮的打印机。 这个漂亮的打印机可以将脚本方法的代码解释为有效的 Python 语法。 例如:

```
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

```

具有单个`forward`方法的 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 将具有属性`code`,您可以使用该属性检查 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 的代码。 如果 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 具有多个方法,则需要在方法本身而非模块上访问`.code`。 我们可以通过访问`.foo.code`在 ScriptModule 上检查名为`foo`的方法的代码。 上面的示例产生以下输出:

```
def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

```

这是 TorchScript 对`forward`方法的代码的编译。 您可以使用它来确保 TorchScript(跟踪或脚本)正确捕获了模型代码。

#### [解释图](#id61)

TorchScript 还以 IR 图的形式在比代码漂亮打印机更低的层次上进行表示。

TorchScript 使用静态单分配(SSA)中间表示(IR)表示计算。 这种格式的指令由 ATen(PyTorch 的 C ++后端)运算符和其他原始运算符组成,包括用于循环和条件的控制流运算符。 举个例子:

```
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

```

`graph`遵循[检查代码](#inspecting-code)部分中关于`forward`方法查找所述的相同规则。

上面的示例脚本生成图形:

```
graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

```

以指令`%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10`为例。

*   `%rv.1 : Tensor`表示我们将输出分配给一个名为`rv.1`的(唯一)值,该值是`Tensor`类型,并且我们不知道其具体形状。

*   `aten::zeros`是运算符(与`torch.zeros`等效),输入列表`(%4, %6, %6, %10, %12)`指定范围中的哪些值应作为输入传递。 可以在[内置函数](#builtin-functions)中找到`aten::zeros`等内置函数的模式。

*   `# test.py:9:10`是生成此指令的原始源文件中的位置。 在这种情况下,它是第 9 行和字符 10 处名为 &lt;cite&gt;test.py&lt;/cite&gt; 的文件。

请注意,运算符也可以具有关联的`blocks`,即`prim::Loop``prim::If`运算符。 在图形打印输出中,这些运算符被格式化以反映其等效的源代码形式,以方便进行调试。

如下图所示,可以检查图表以确认 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 所描述的计算是正确的,无论是自动方式还是手动方式。

#### [追踪案例](#id62)

在某些极端情况下,给定 Python 函数/模块的跟踪不会代表基础代码。 这些情况可以包括:

*   跟踪取决于输入的控制流(例如张量形状)

*   跟踪张量视图的就地操作(例如,分配左侧的索引)

请注意,这些情况实际上将来可能是可追溯的。

#### [自动跟踪检查](#id63)

自动捕获跟踪中许多错误的一种方法是使用`torch.jit.trace()` API 上的`check_inputs``check_inputs`提取输入元组的列表,这些列表将用于重新追踪计算并验证结果。 例如:

```
def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

```

为我们提供以下诊断信息:

```
ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

```

此消息向我们表明,在我们第一次追踪它和使用`check_inputs`追踪它之间,计算有所不同。 实际上,`loop_in_traced_fn`主体内的循环取决于输入`x`的形状,因此,当我们尝试另一种形状不同的`x`时,迹线会有所不同。

在这种情况下,可以使用 [`torch.jit.script()`](#torch.jit.script "torch.jit.script") 来捕获类似于数据的控制流:

```
def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())

for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple))

```

产生:

```
graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

```

#### [跟踪器警告](#id64)

跟踪器会针对跟踪计算中的几种有问题的模式生成警告。 举个例子,追踪一个在 Tensor 的切片(视图)上包含就地分配的函数:

```
def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

```

产生几个警告和一个仅返回输入的图形:

```
fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1\. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

```

我们可以通过修改代码来解决此问题,使其不使用就地更新,而是使用`torch.cat`来错位构建结果张量:

```
def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

```

### [内置函数](#id65)

TorchScript 支持 PyTorch 提供的内置张量和神经网络功能的子集。 Tensor 上的大多数方法以及`torch`名称空间中的函数,`torch.nn.functional`中的所有函数以及`torch.nn`中的所有模块在 TorchScript 中均受支持,下表中没有列出。 对于不支持的模块,建议使用 [`torch.jit.trace()`](#torch.jit.trace "torch.jit.trace")

不支持的`torch.nn`模块

```
torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss
torch.nn.modules.normalization.CrossMapLRN2d
torch.nn.modules.rnn.RNN

```

有关支持的功能的完整参考,请参见 [TorchScript 内置函数](jit_builtin_functions.html#builtin-functions)

## [常见问题解答](#id66)

问:我想在 GPU 上训练模型并在 CPU 上进行推理。 最佳做法是什么?

> 首先将模型从 GPU 转换为 CPU,然后将其保存,如下所示:
> 
> ```
> cpu_model = gpu_model.cpu()
> sample_input_cpu = sample_input_gpu.cpu()
> traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu)
> torch.jit.save(traced_cpu, "cpu.pth")
> 
> traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu)
> torch.jit.save(traced_gpu, "gpu.pth")
> 
> # ... later, when using the model:
> 
> if use_gpu:
>   model = torch.jit.load("gpu.pth")
> else:
>   model = torch.jit.load("cpu.pth")
> 
> model(input)
> 
> ```
> 
> 推荐这样做是因为跟踪器可能会在特定设备上见证张量的创建,因此强制转换已加载的模型可能会产生意想不到的效果。 在保存之前对模型_进行转换可确保跟踪器具有正确的设备信息。_

问:如何在 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 上存储属性?

> 说我们有一个像这样的模型:
> 
> ```
> class Model(nn.Module):
>     def __init__(self):
>         super(Model, self).__init__()
>         self.x = 2
> 
>     def forward(self):
>         return self.x
> 
> m = torch.jit.script(Model())
> 
> ```
> 
> 如果实例化`Model`,则将导致编译错误,因为编译器不了解`x`。 有四种方法可以通知编译器 [`ScriptModule`](#torch.jit.ScriptModule "torch.jit.ScriptModule") 的属性:
> 
> 1\. `nn.Parameter`-包装在`nn.Parameter`中的值将像在`nn.Module`上一样工作
> 
> 2\. `register_buffer`-包装在`register_buffer`中的值将像在`nn.Module`上一样工作。 这等效于`Tensor`类型的属性(请参见 4)。
> 
> 3.常量-将类成员注释为`Final`(或在类定义级别将其添加到名为`__constants__`的列表中)会将包含的名称标记为常量。 常数直接保存在模型代码中。 有关详细信息,请参见 [Python 定义的常量](#python-defined-constants)。
> 
> 4.属性-可以将[支持的类型](#supported-type)的值添加为可变属性。 可以推断大多数类型,但可能需要指定一些类型,有关详细信息,请参见[模块属性](#module-attributes)。

问:我想跟踪模块的方法,但一直出现此错误:

`RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient`

> 此错误通常表示您要跟踪的方法使用模块的参数,并且您正在传递模块的方法而不是模块实例(例如`my_module_instance.forward`与`my_module_instance`)。
> 
> &gt; *   使用模块的方法调用`trace`会将模块参数(可能需要渐变)捕获为**常量**。
> &gt;     
> &gt;     
> &gt; *   另一方面,使用模块实例(例如`my_module`)调用`trace`会创建一个新模块,并将参数正确复制到新模块中,以便在需要时可以累积梯度。
> 
> 要跟踪模块上的特定方法,请参见 [`torch.jit.trace_module`](#torch.jit.trace_module "torch.jit.trace_module")