doc22_014.md 41.8 KB
Newer Older
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
1
# 扩展 PyTorch
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
2

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
3 4 5 6 7 8
> 原文:[`pytorch.org/docs/stable/notes/extending.html`](https://pytorch.org/docs/stable/notes/extending.html)> 
>
> 译者:[飞龙](https://github.com/wizardforcel)
>
> 协议:[CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/)

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
9

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
10
在本说明中,我们将介绍扩展`torch.nn``torch.autograd``torch`以及编写自定义 C++扩展的方法。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
11

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
12
## 扩展`torch.autograd`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
13

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
14
`autograd`添加操作需要为每个操作实现一个新的`Function`子类。请记住,`autograd`使用`Function`来编码操作历史并计算梯度。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
15 16 17 18 19

本文档的第一部分侧重于反向模式自动微分,因为它是最广泛使用的功能。文末的一节讨论了前向模式自动微分的扩展。

### 何时使用

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
20
一般来说,如果您想在模型中执行不可微分的计算或依赖于非 PyTorch 库(例如 NumPy),但仍希望您的操作与其他操作链接并与 autograd 引擎一起工作,则实现自定义函数。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
21

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
22
在某些情况下,自定义函数也可以用于提高性能和内存使用:如果您使用[C++扩展](https://pytorch.org/tutorials/advanced/cpp_extension.html)实现了前向和反向传递,您可以将它们包装在`Function`中,以与 autograd 引擎进行交互。如果您想减少为反向传播保存的缓冲区数量,可以使用自定义函数将操作组合在一起。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
23 24 25

### 何时不使用

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
26
如果您已经可以使用 PyTorch 内置操作编写函数,则其反向图(很可能)已经能够被 autograd 记录。在这种情况下,您不需要自己实现反向函数。考虑使用普通的 Python 函数。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
27

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
28
如果您需要保持状态,即可训练参数,您应该(也)使用自定义模块。有关在`torch.nn`上扩展的更多信息,请参阅下面的部分。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
29 30 31 32 33

如果您想在反向传播过程中更改梯度或执行副作用,请考虑注册一个[tensor](https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html#torch.Tensor.register_hook)[Module](https://pytorch.org/docs/stable/notes/modules.html#module-hooks) hook。

### 如何使用

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
34
按照以下步骤进行:1. 子类化`Function`并实现`forward()`,(可选)`setup_context()``backward()`方法。2. 在 ctx 参数上调用适当的方法。3. 声明您的函数是否支持[双向传播](https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html)。4. 使用 gradcheck 验证您的梯度是否正确。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
35

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
36
**步骤 1:**在子类化`Function`后,您需要定义 3 个方法:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
37

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
38
+   `forward()` 是执行操作的代码。它可以接受任意数量的参数,其中一些是可选的,如果您指定了默认值。这里接受所有类型的 Python 对象。跟踪历史记录的 `Tensor` 参数(即具有 `requires_grad=True` 的参数)在调用之前将被转换为不跟踪历史记录的参数,并且它们的使用将在图中注册。请注意,此逻辑不会遍历列表/字典/任何其他数据结构,只会考虑直接作为调用参数的张量。您可以返回单个 `Tensor` 输出,或者如果有多个输出,则可以返回张量的 [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple "(in Python v3.12)")。此外,请参考 `Function` 的文档,以查找只能从 `forward()` 中调用的有用方法的描述。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
39

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
40
+   `setup_context()`(可选)。可以编写一个接受 `ctx` 对象的“组合” `forward()` 或(从 PyTorch 2.0 开始)一个不接受 `ctx` 的单独 `forward()` 和一个 `setup_context()` 方法,在其中进行 `ctx` 修改。`forward()` 应该包含计算,而 `setup_context()` 应该只负责 `ctx` 的修改(不包含任何计算)。一般来说,单独的 `forward()``setup_context()` 更接近于 PyTorch 原生操作的工作方式,因此更具有与各种 PyTorch 子系统的可组合性。有关更多详细信息,请参见组合或分离的 forward() 和 setup_context()。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
41

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
42
+   `backward()`(或 `vjp()`)定义了梯度公式。它将作为输出的数量给出与之对应的梯度的 `Tensor` 参数,每个参数表示相对于该输出的梯度。重要的是绝对不要就地修改这些参数。它应该返回与输入的数量相同的张量,每个张量包含相对于其对应输入的梯度。如果您的输入不需要梯度(`needs_input_grad` 是一个布尔值元组,指示每个输入是否需要梯度计算),或者是非 `Tensor` 对象,您可以返回 `python:None`。此外,如果您在 `forward()` 中有可选参数,您可以返回比输入数量更多的梯度,只要它们都是 [`None`](https://docs.python.org/3/library/constants.html#None "(in Python v3.12)")
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
43

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
44
**步骤 2:** 您有责任正确使用 `ctx` 中的函数,以确保新的 `Function` 与自动求导引擎正常工作。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
45

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
46
+   `save_for_backward()`必须用于保存在反向传播中使用的任何张量。非张量应直接存储在 ctx 上。如果保存了既不是输入也不是输出的张量以进行反向传播,您的`Function`可能不支持双向传播(参见步骤 3)。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
47

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
48
+   `mark_dirty()`必须用于标记由前向函数就地修改的任何输入。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
49

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
50
+   `mark_non_differentiable()`必须用于告诉引擎输出是否不可微分。默认情况下,所有可微分类型的输出张量都将被设置为需要梯度。不可微分类型的张量(即整数类型)永远不会被标记为需要梯度。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
51

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
52
+   `set_materialize_grads()`可用于告诉自动求导引擎在输出不依赖于输入的情况下优化梯度计算,方法是不将传递给反向函数的梯度张量实例化。也就是说,如果设置为 False,Python 中的 None 对象或 C++中的“未定义张量”(对于其 defined()为 False 的张量 x)将不会在调用反向传播之前转换为填充了零的张量,因此您的代码将需要处理这些对象,就好像它们是填充了零的张量一样。此设置的默认值为 True。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
53

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
54
**步骤 3:** 如果您的`Function`不支持双向传播,您应该通过在反向传播中使用`once_differentiable()`来显式声明这一点。使用这个装饰器,尝试通过您的函数执行双向传播将产生错误。有关双向传播的更多信息,请参阅我们的双向传播教程。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
55

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
56
**步骤 4:** 建议您使用`torch.autograd.gradcheck()`来检查您的反向函数是否正确计算了前向梯度,方法是使用您的反向函数计算雅可比矩阵,并将其值逐个元素与使用有限差分数值计算的雅可比矩阵进行比较。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

### 示例

下面您可以找到一个`Linear`函数的代码,附加了注释:

```py
# Inherit from Function
class LinearFunction(Function):

    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(input, weight, bias):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias 
```

现在,为了更容易使用这些自定义操作,我们建议要么给它们取别名,要么将它们包装在一个函数中。将其包装在一个函数中可以让我们支持默认参数和关键字参数:

```py
# Option 1: alias
linear = LinearFunction.apply

# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
    return LinearFunction.apply(input, weight, bias) 
```

在这里,我们给出了一个通过非张量参数进行参数化的函数的额外示例:

```py
class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        # ctx is a context object that can be used to stash information
        # for backward computation
        tensor, constant = inputs
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
139
在这里,我们通过调用 set_materialize_grads(False)来优化上面的示例:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164

```py
class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        tensor, constant = inputs
        ctx.set_materialize_grads(False)
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # Here we must handle None grad_output tensor. In this case we
        # can skip unnecessary computations and just return None.
        if grad_output is None:
            return None, None

        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
165
如果您需要在`forward()`中计算的“中间”张量被保存,要么它们必须作为输出返回,要么结合`forward``setup_context()`(参见合并或分开 forward()和 setup_context())。请注意,这意味着如果您希望梯度通过这些中间值流动,您需要为它们定义梯度公式(也请参阅[双向传播教程](https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html)):
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199

```py
class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        # We wish to save dx for backward. In order to do so, it must
        # be returned as an output.
        dx = 3 * x ** 2
        result = x ** 3
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`,
        # which is grad_dx * 6 * x.
        result = grad_output * dx + grad_dx * 6 * x
        return result

# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result 
```

注意

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
200
传递给`backward`的输入,即`grad_output`,也可以是跟踪历史记录的张量。因此,如果`backward`是通过可微操作实现的(例如,调用另一个自定义`Function`),高阶导数将起作用。在这种情况下,使用`save_for_backward`保存的张量也可以在反向中使用,并且梯度会回流,但在`ctx`中保存的张量不会有梯度回流。如果您需要梯度回流到在`ctx`中保存的张量,您应该将其作为自定义`Function`的输出并使用`save_for_backward`保存它。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214

您可能希望检查您实现的反向方法是否实际计算了函数的导数。可以通过使用小的有限差分进行数值近似来进行比较:

```py
from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
215
有关有限差分梯度比较的更多详细信息,请参见数值梯度检查。如果您的函数用于高阶导数(对反向传递进行微分),则可以使用同一软件包中的`gradgradcheck`函数来检查高阶导数。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
216

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
217
### 合并或单独的`forward()`和`setup_context()`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
218

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
219
有两种主要方法来定义`Function`。要么:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
220

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
221
+   定义一个结合了前向计算逻辑和`setup_context()``forward()`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
222

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
223
+   (截至 PyTorch 2.0)定义一个单独的`forward()``setup_context()`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
224

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
225
我们推荐第二种选项(单独的`forward()``setup_context()`),因为这更接近 PyTorch 原生操作的实现方式,并且与`torch.func`转换组合。但是,我们计划支持两种方法;结合`forward()``setup_context()`:会更加灵活,因为您可以保存中间结果而无需将它们作为输出返回。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
226

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
227
有关如何定义具有单独`forward()``setup_context()``Function`的详细信息,请参见前一节。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
228

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
229
以下是如何定义一个带有合并`forward()``setup_context()``Function`的示例:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257

```py
class LinearFunction(Function):
    @staticmethod
    # ctx is the first argument to forward
    def forward(ctx, input, weight, bias=None):
        # The forward pass can use ctx.
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias 
```  ### 正向模式 AD

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
258
覆盖正向模式 AD 公式具有非常相似的 API,但有一些不同的微妙之处。您可以实现`jvp()`函数。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
259

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
260
它将被给予与输入相同数量的`Tensor`参数,每个参数代表相对于该输入的梯度。它应返回与输出相同数量的张量,每个张量包含相对于其对应输出的梯度。`jvp()`将在`forward()`方法之后,在`apply()`返回之前调用。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
261

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
262
`jvp()``backward()`函数有一些微妙的区别:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
263

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
264
+   您可以使用 ctx 将任何数据从`forward()`传递到`jvp()`函数。如果该状态在`backward()`中不需要,您可以在`jvp()`函数末尾通过`del ctx.foo`显式释放它。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
265

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
266
+   `jvp()`的实现必须是反向可微的,或者明确检查给定的前向模式梯度中是否有`requires_grad`设置。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
267

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
268
+   `jvp()`函数必须匹配`forward()`的视图/原地行为。例如,如果第`i`个输入被原地修改,则第`i`个梯度必须被原地更新。同样,如果第`j`个输出是第`k`个输入的视图。那么返回的第`j`个输出梯度必须是给定第`k`个输入梯度的视图。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
269

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
270
+   因为用户无法指定需要计算哪个梯度,`jvp()`函数应始终计算所有输出的梯度。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
271

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
272
+   前向模式梯度确实遵守`set_materialize_grads()`设置的标志,当禁用时,您可以获得 None 输入梯度。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
273

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
274
### `torch.func`转换和/或`torch.vmap()`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
275

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
276
有关详细信息,请参阅使用 autograd.Function 扩展 torch.func。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
277

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
278
## 扩展`torch.nn`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
279

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
280
`nn`导出两种接口 - 模块及其功能版本。您可以以两种方式扩展它,但我们建议对所有包含任何参数或缓冲区的层使用模块,并建议对参数为空的操作(如激活函数、池化等)使用功能形式。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
281 282 283

在上面的部分中已经完全涵盖了添加操作的功能版本。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
284
### 添加一个`Module`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
285

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
286
由于`nn`大量使用`autograd`,添加一个新的`Module`需要实现一个执行操作并能计算梯度的`Function`。从现在开始,让我们假设我们想要实现一个`Linear`模块,并且我们已经按照上面的列表实现了该函数。添加这个需要非常少的代码。现在,需要实现两个函数:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
287 288 289

+   `__init__`*可选*)- 接受参数,如内核大小、特征数量等,并初始化参数和缓冲区。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
290
+   `forward()` - 实例化一个`Function`并使用它执行操作。它与上面显示的功能包装器非常相似。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332

这是如何实现`Linear`模块的方法:

```py
class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)

        # Not a very smart way to initialize weights
        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return LinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        ) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
333
## 扩展`torch` Python API
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
334

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
335
您可以通过定义一个具有与`Tensor`匹配的方法的自定义类来创建模拟`Tensor`的自定义类型。但是如果您希望能够将这些类型传递给像`torch.add()`这样的顶层`torch`命名空间中接受`Tensor`操作数的函数,该怎么办?
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
336

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
337
如果您的自定义 Python 类型定义了一个名为`__torch_function__`的方法,PyTorch 将在将您的自定义类的实例传递给`torch`命名空间中的函数时调用您的`__torch_function__`实现。这使得可以为`torch`命名空间中的任何函数定义自定义实现,您的`__torch_function__`实现可以调用这些函数,使您的用户能够利用已经为`Tensor`编写的现有 PyTorch 工作流程来使用您的自定义类型。这适用于与`Tensor`无关的“鸭子”类型以及`Tensor`的用户定义子类。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
338

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
339
### 使用类似`Tensor`的类型扩展`torch`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
340 341 342

注意

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
343
这个功能受到了 NumPy `__array_function__`协议的启发。有关更多详细信息,请参阅[NumPy 文档](https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch)[NEP-0018](https://numpy.org/neps/nep-0018-array-function-protocol.html)
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
344

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
345
为了具体化这一点,让我们从一个简单的示例开始,说明 API 分发机制。我们将创建一个自定义类型,表示一个二维标量张量,由对角线条目的顺序`N`和值`value`参数化:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373

```py
class ScalarTensor(object):
   def __init__(self, N, value):
       self._N = N
       self._value = value

   def __repr__(self):
       return "ScalarTensor(N={}, value={})".format(self._N, self._value)

   def tensor(self):
       return self._value * torch.eye(self._N) 
```

这个设计的第一个迭代并不是非常有用。`ScalarTensor`的主要功能是提供比基本张量类更紧凑的标量张量的字符串表示:

```py
>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
 [0., 2., 0., 0., 0.],
 [0., 0., 2., 0., 0.],
 [0., 0., 0., 2., 0.],
 [0., 0., 0., 0., 2.]]) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
374
如果我们尝试使用`torch` API 中的这个对象,我们将遇到问题:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405

```py
>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor 
```

`ScalarTensor`添加一个`__torch_function__`实现使得上述操作能够成功。让我们重新实现我们的代码,这次添加一个`__torch_function__`实现:

```py
HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
    def __init__(self, N, value):
        self._N = N
        self._value = value

    def __repr__(self):
        return "ScalarTensor(N={}, value={})".format(self._N, self._value)

    def tensor(self):
        return self._value * torch.eye(self._N)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
            return NotImplemented
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
406
        return HANDLED_FUNCTIONSfunc 
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
407 408
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
409
`__torch_function__`方法接受四个参数:`func`,被覆盖的 torch API 函数的引用,`types`,实现`__torch_function__`的 Tensor-like 类型的类型列表,`args`,传递给函数的参数元组,以及`kwargs`,传递给函数的关键字参数字典。它使用一个名为`HANDLED_FUNCTIONS`的全局分发表来存储自定义实现。这个字典的键是`torch`命名空间中的函数,值是`ScalarTensor`的实现。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
410 411 412

注意

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
413
使用全局分派表不是`__torch_function__` API 的强制部分,它只是一种有用的设计模式,用于构建您的覆盖实现。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443

这个类定义并不足以使`torch.mean`在我们传递`ScalarTensor`时执行正确的操作 - 我们还需要为`ScalarTensor`操作数定义一个`torch.mean`实现,并将该实现添加到`HANDLED_FUNCTIONS`分派表字典中。一种方法是定义一个装饰器:

```py
import functools
def implements(torch_function):
  """Register a torch function override for ScalarTensor"""
    def decorator(func):
        functools.update_wrapper(func, torch_function)
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator 
```

这可以应用于我们覆盖的实现:

```py
@implements(torch.mean)
def mean(input):
    return float(input._value) / input._N 
```

有了这个改变,我们现在可以使用`ScalarTensor`来使用`torch.mean`

```py
>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
444
当然,`torch.mean`是最简单的覆盖函数的一个例子,因为它只接受一个操作数。我们可以使用相同的机制来覆盖接受多个操作数的函数,其中任何一个可能是定义了`__torch_function__`的张量或类似张量,例如`torch.add()`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474

```py
def ensure_tensor(data):
    if isinstance(data, ScalarTensor):
        return data.tensor()
    return torch.as_tensor(data)

@implements(torch.add)
def add(input, other):
   try:
       if input._N == other._N:
           return ScalarTensor(input._N, input._value + other._value)
       else:
           raise ValueError("Shape mismatch!")
   except AttributeError:
       return torch.add(ensure_tensor(input), ensure_tensor(other)) 
```

这个版本对于两个操作数都是`ScalarTensor`实例时有一个快速路径,还有一个较慢的路径,当任一操作数不是`ScalarTensor`时会将数据转换为张量。这使得覆盖函数在任一操作数是`ScalarTensor`或常规`Tensor`时都能正确运行:

```py
>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
 [1., 3.]]) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
475
请注意,我们的`add`实现不像`torch.add()`那样将`alpha``out`作为关键字参数:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
476 477 478 479 480 481

```py
>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha' 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
482
为了速度和灵活性,`__torch_function__`分派机制不会检查覆盖函数的签名是否与在`torch` API 中被覆盖的函数的签名匹配。对于一些应用程序,忽略可选参数可能是可以的,但为了确保与`Tensor`的完全兼容性,torch API 函数的用户实现应该确保精确模拟被覆盖的函数的 API。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
483

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
484
`torch` API 中没有显式覆盖的函数将从`__torch_function__`返回`NotImplemented`。如果所有具有在其上定义了`__torch_function__`的操作数都返回`NotImplemented`,PyTorch 将引发`TypeError`。这意味着大多数情况下,对于没有特定类型的显式覆盖的操作,当传递这种类型的实例时将引发`TypeError`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
485 486 487 488 489 490 491

```py
>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor] 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
492
实际上,这意味着如果您想要使用类似这样的`__torch_function__`实现来实现您的覆盖,您将需要显式实现完整的`torch` API 或您关心的用例的整个 API 子集。这可能是一个很大的挑战,因为完整的`torch` API 非常广泛。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
493

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
494
另一个选择是对于未处理的操作不返回`NotImplemented`,而是在没有可用覆盖时将`Tensor`传递给原始`torch`函数。例如,如果我们将`ScalarTensor``__torch_function__`实现更改为以下内容:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
495 496 497 498 499 500 501 502 503 504 505 506

```py
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
        args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
        return func(*args, **kwargs)
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
507
    return HANDLED_FUNCTIONSfunc 
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
508 509
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
510
然后`torch.mul()`将正常工作,尽管返回类型始终是`Tensor`而不是`ScalarTensor`,即使两个操作数都是`ScalarTensor`实例:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
511 512 513 514 515 516 517 518

```py
>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
 [0., 4.]]) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
519
另请参见下面的`MetadataTensor`示例,以了解这种模式的另一种变体,但是始终返回`MetadataTensor`以通过`torch` API 中的操作传播元数据。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
520

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
521
`__torch_function__`协议旨在完全覆盖 API,部分覆盖可能会导致不良结果,特别是某些函数引发`TypeError`。这对于子类尤其重要,其中 torch.add、torch.Tensor.__add__ 和 torch.Tensor.add 这三个函数必须被覆盖,即使它们返回完全相同的结果。未能这样做也可能导致无限递归。如果需要从`torch.Tensor`子类实现一个函数,他们必须在实现中使用`super().__torch_function__`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
522 523 524

### 子类化`torch.Tensor`

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
525
从 1.7.0 版本开始,在`torch.Tensor`上的方法和公共`torch.*`命名空间中应用于`torch.Tensor`子类的函数将返回子类实例而不是`torch.Tensor`实例:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562

```py
>>> class SubTensor(torch.Tensor):
...     pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor' 
```

如果存在多个子类,则默认选择层次结构中最低的子类。如果没有唯一确定这种情况的方法,则会引发`TypeError`

```py
>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor] 
```

如果希望对所有张量方法进行全局覆盖,可以使用`__torch_function__`。以下是一个记录所有函数/方法调用的示例:

```py
class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
        if func is not torch.Tensor.__repr__:
            logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
563
但是,如果希望在 Tensor 子类上覆盖一个方法,可以直接覆盖该方法(为子类定义它),或者使用`__torch_function__`并与`func`匹配。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
564

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
565
在子类的`__torch_function__`中,应该始终调用`super().__torch_function__(func, ...)`而不是直接调用`func`,这是在 1.7.0 版本之前的情况。未能这样做可能导致`func`递归回到`__torch_function__`,从而导致无限递归。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
566

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
567
### 使用`Tensor`包装类型扩展`torch`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
568

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
569
另一个有用的情况是一个类型包装了一个`Tensor`,可以作为属性或通过子类化。下面我们实现了这种类型的一个特殊情况,一个`MetadataTensor`,它将元数据字典附加到通过`torch`操作传播的`Tensor`上。由于这是对完整`torch` API 的一种通用包装,我们不需要单独实现每个覆盖,因此可以使`__torch_function__`实现更宽松,允许进行更多操作:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590

```py
class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._metadata = metadata

    def __repr__(self):
        return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
        args = [getattr(a, '_t', a) for a in args]
        assert len(metadatas) > 0
        ret = func(*args, **kwargs)
        return MetadataTensor(ret, metadata=metadatas[0]) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
591
这个简单的实现不一定适用于`torch` API 中的每个函数,但足以捕捉大多数常见操作:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614

```py
>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[2, 4],
 [4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[1, 4],
 [3, 8]]) 
```

### 在定义了`__torch_function__`的多个类型上进行操作

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
615
可以使用具有各自`__torch_function__`实现的多个不同类型的 torch API,但必须特别小心。在这种情况下,规则是:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
616 617 618 619 620

+   调度操作会收集每个操作数的所有不同的`__torch_function__`实现,并按顺序调用它们:子类优先于超类,否则按照操作符表达式中的从左到右顺序。

+   如果返回的值不是`NotImplemented`,则将该值作为结果返回。实现可以通过返回`NotImplemented`来注册他们不实现的操作。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
621
+   如果所有的`__torch_function__`实现都返回`NotImplemented`,PyTorch 会引发`TypeError`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
622

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
623
### 测试 PyTorch API 的覆盖范围
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
624

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
625
实现`__torch_function__`的一个麻烦之处在于,如果某些操作有覆盖而其他操作没有覆盖,用户最多会看到不一致的体验,或者在运行时使用没有覆盖的函数时会看到错误。为了简化这个过程,PyTorch 提供了一个面向开发者的 API,用于确保对`__torch_function__`覆盖的全面支持。这个 API 是私有的,可能在未来会发生变化而没有警告。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
626

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
627
首先,要获取所有可重写函数的列表,请使用`torch.overrides._get_overridable_functions`。这将返回一个字典,其键是`PyTorch` Python API 中的命名空间,其值是该命名空间中可以被覆盖的函数列表。例如,让我们打印`torch.nn.functional`中可以被覆盖的前 5 个函数的名称:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
628 629 630 631 632 633 634 635 636 637

```py
>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
 'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices'] 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
638
这些函数的列表使得可以迭代所有可重写的函数,然而在实践中,这并不足以为所有这些函数编写测试,而不是费力地手动复制每个测试的每个函数的签名。为了简化这个过程,`torch.overrides._get_testing_overrides`函数返回一个字典,将`PyTorch` API 中可重写的函数映射到具有与原始函数相同签名的虚拟 lambda 函数,但无条件返回-1。这些函数最有用的用法是与`inspect`一起使用,以分析原始`PyTorch`函数的函数签名:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
639 640 641 642 643 644 645 646 647 648

```py
>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)> 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
649
最后,`torch.overrides.get_ignored_functions`返回一个函数元组,这些函数明确不能被`__torch_function__`覆盖。这个列表可以用来确认通过`get_overridable_functions`返回的字典中不存在的函数不能被覆盖。## 扩展`torch`本机 API
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
650

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
651
虽然`__torch_function__`允许有效地扩展 PyTorch 的纯 Python 组件的行为,但它不允许扩展用 C++实现的 PyTorch 的部分。为此,`Tensor`子类还可以定义`__torch_dispatch__`,它将能够在 C++级别覆盖行为。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
652

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
653
要有效地使用这个功能,重要的是要了解 PyTorch 的本机部分是如何实现的。那里最重要的组件是我们称之为“调度程序”(最好的描述可以在这篇[博文](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/)中找到,尽管它略显过时)。正如其名称所示,它负责为特定函数调用调用正确的后端函数。例如,当调用`torch.add(a, b)`时,调度程序将检查两个参数,找出应该为这个特定调用使用哪个“特征”(自动微分、自动转换、功能化等)和哪个“后端”(CPU、CUDA、MPS 等),最后调用所有正确的内核。内核经常做的一件事是“重新调度”。例如,当在 GPU 上运行神经网络时,第一个调用将是处理任何潜在自动转换逻辑并重新调度的自动转换内核。接下来的特性将是自动微分,它将正确地创建自动微分图,然后重新调度。最后,我们到达 CUDA 的后端内核,它将启动正确的 CUDA 内核并返回最终结果。在退出时,自动微分将把图附加到输出上,最后,自动转换将有机会在退出时进行任何必要的更新。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
654

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
655
调度程序的一个配置是调用所有这些特征和后端键的顺序。最新的列表及其顺序可以在`DispatchKey.h`文件中的`DispatchKey`枚举中找到。为了扩展 torch,本讨论中排序的重要子集是:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
656 657 658 659 660 661 662 663 664

vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> Backends

对于本讨论的目的,最重要的键是`Python`,因为每个定义了`__torch_dispatch__`方法的`Tensor`子类将调用这个特性。用户定义的方法将从这里调用,行为可以任意重写。从那里,再次调用提供的`func`将执行“重新调度”。

这种实现的一些重要影响是:

+   这段代码“在所有功能之下”运行。因此,它只负责生成每个张量的输出值(可以,也应该,忽略所有高级功能,如自动求导、自动转换等),就像常规后端一样。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
665
+   如果任何高级功能实现了给定的函数而没有重新调度,它将永远不会到达`Python`键,因此`__torch_dispatch__`回调将永远不会被触发。这特别发生在 CompositeImplicitAutograd 函数中,这些函数在自动求导级别上进行评估而不进行重新调度。这是因为 CompositeImplicitAutograd 函数通过隐式调用其他本地操作来指定其自动求导公式,因此在自动求导级别上,函数被分解为其本地操作,而这些操作被评估。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
666

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
667
+   在回调到 Python 并包装结果时,与常规 PyTorch Python/C++绑定相同的转换被使用。特别是,一些对象无法在 Python 中表示,需要特殊处理(例如未定义的张量变为 None)。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
668

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
669
+   我们的本地函数是惰性填充的,作为可调用的 Python 对象,以便从 Python 轻松与它们交互,命名空间为`torch.ops.{namespace}.{func_name}.{overload_name}`。给定给`__torch_dispatch__``func`对象始终是来自此命名空间的条目。这个命名空间可以用于直接调用本地操作,绕过通常的 Python API 和绑定代码。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
670

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
671
类似于`__torch_function__`能够介入 torch 的所有 Python API 和 Tensor 方法,`__torch_dispatch__`能够拦截所有对 aten 本地 API 的调用。请注意,在进入调度程序之前,所有张量上的方法都会转换为函数调用,因此在这里会显示为函数调用:`torch.add(a, 2)``a + 2`将导致完全相同的 aten 调用。这些函数的大部分在`native_functions.yaml`中定义,该文件指定了这些函数的属性以及它们的后端实现。它们的实现以及指定的特性随后会通过 codegen 自动注册。一些更奇特的函数或特性也会在 C++代码库的其他位置或用户定义的 C++扩展中注册。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
672

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
673
还可以使用`torch.library`添加新的本地函数。这个 Python 特性允许定义和/或添加新的实现到本地函数。这可以用于添加缺失的内核、替换现有的内核或定义全新的本地函数。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
674 675 676

您可以在[subclass zoo](https://github.com/albanD/subclass_zoo)存储库中找到许多基于`__torch_dispatch__`的子类的示例。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
677
## 通过 Modes 扩展所有`torch` API。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
678

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
679
不幸的是,有些函数不接受张量输入。这意味着上面描述的子类方法无法用于覆盖 PyTorch 的所有函数的行为。此外,如果用例要求拦截每个函数调用,将每个张量更改为子类可能会过于侵入性。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745

为了解决这种用例,我们引入了“模式”概念。这些模式用于`__torch_function__``__torch_dispatch__`覆盖,分别通过继承`torch.overrides.TorchFunctionMode``torch.utils._python_dispatch.TorchDispatchMode`创建,并用作上下文管理器。

为了简化它与子类和其他模式的交互的描述,每当进入模式的上下文管理器时,每个函数的行为都会像在参数列表的开头有一个额外的张量参数,其中包含子类作为模式。这意味着特别是所有模式处理程序将在任何子类处理程序之前调用,并且与内部上下文管理器对应的模式将始终首先运行。

在给定的模式处理程序中,需要注意的是,该特定模式被禁用,可以通过`with self:`手动重新启用。

这里是一个示例,显示每种类型的日志记录模式:

```py
import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode

class FunctionLog(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

class DispatchLog(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

def f():
    a = torch.rand(10, requires_grad=True)
    b = a * 2
    b.sum().backward()

print("TorchFunctionMode logging:")
with FunctionLog():
    f()

print("TorchDispatchMode logging:")
with DispatchLog():
    f() 
```

以下打印内容,并附有额外的注释:

```py
TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
        0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
        1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})

TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
        0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
        1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{}) 
```

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
746
## 编写自定义 C++扩展
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
747

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
748
查看这个[PyTorch 教程](https://pytorch.org/tutorials/advanced/cpp_extension.html)以获取详细解释和示例。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
749

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
750
文档可在 torch.utils.cpp_extension 找到。