# 使用 autograd.Function 扩展 torch.func > 原文:[`pytorch.org/docs/stable/notes/extending.func.html`](https://pytorch.org/docs/stable/notes/extending.func.html)> > > 译者:[飞龙](https://github.com/wizardforcel) > > 协议:[CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/) 因此,您希望使用`torch.autograd.Function`与`torch.func`转换,如`torch.vmap()`、`torch.func.grad()`等。 有两种主要用例: + 您希望调用不包含 PyTorch 操作的代码,并使其与函数转换一起工作。也就是说,`torch.autograd.Function`的 forward/backward 等调用其他系统(如 C++、CUDA、numpy)的函数。 + 您希望指定自定义梯度规则,类似于 JAX 的[custom_vjp/custom_jvp](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) PyTorch 将这两个概念结合到`torch.autograd.Function`中。 ## 基本用法 本指南假定您熟悉扩展 torch.autograd,该指南解释了如何使用`torch.autograd.Function`。 `torch.autograd.Function`可以有一个接受 ctx 对象的`forward()`,或者可以有单独的不接受`ctx`的`forward()`和一个修改`ctx`对象的`setup_context()`静态方法。 只有后者支持函数转换: + `forward()`是执行操作的代码,不应接受`ctx`对象。 + `setup_context(ctx, inputs, output)`是您可以在其中调用`ctx`上的方法的代码。在这里,您应该保存张量以进行反向传播(通过调用`ctx.save_for_backward(*tensors)`),或保存非张量(通过将它们分配给`ctx`对象)。 因为`setup_context()`只接受`inputs`和`output`,所以只能保存输入或输出中的对象(如张量)或从中派生的量(如`Tensor.shape`)。如果您希望保存来自`Function.forward()`的非输入中间激活以进行反向传播,则需要将其作为输出从`forward()`返回,以便它传递给`setup_context()`。 根据转换, + 要支持反向模式自动微分(`torch.func.grad()`、`torch.func.vjp()`),`torch.autograd.Function`需要一个`backward()`静态方法。 + 要支持`torch.vmap()`,`torch.autograd.Function`需要一个`vmap()`静态方法。 + 为了支持`torch.func.jvp()`,`torch.autograd.Function`需要一个`jvp()` staticmethod。 + 支持变换的组合(例如`torch.func.jacrev()`,`torch.func.jacfwd()`,`torch.func.hessian()`) - 您可能需要上述多个。 为了使`torch.autograd.Function`能够任意与函数变换组合,我们建议除了`forward()`和`setup_context()`之外的所有其他 staticmethod 必须是可转换的:也就是说,它们必须仅由 PyTorch 操作符组成或调用其他`torch.autograd.Function`(可能调用 C++/CUDA 等)。 让我们看一些常见用例的例子。 ### 示例 1:autograd.Function 调用另一个系统 一个常见情况是一个同时调用另一个系统(如 C++,CUDA,numpy,triton)的`torch.autograd.Function`,同时具有 forward()和 backward()。 ```py import torch import numpy as np def to_numpy(tensor): return tensor.cpu().numpy() class NumpySort(torch.autograd.Function): # Note that forward does not take ctx @staticmethod def forward(x, dim): device = x.device x = to_numpy(x) ind = np.argsort(x, axis=dim) ind_inv = np.argsort(ind, axis=dim) result = np.take_along_axis(x, ind, axis=dim) # Any intermediates to be saved in backward must be returned as # outputs. return ( # The desired output torch.tensor(result, device=device), # intermediate to save for backward torch.tensor(ind, device=device), # intermediate to save for backward torch.tensor(ind_inv, device=device), ) # setup_context is responsible for calling methods and/or assigning to # the ctx object. Please do not do additional compute (e.g. add # Tensors together) in setup_context. @staticmethod def setup_context(ctx, inputs, output): x, dim = inputs # Note that output is whatever you returned from forward. # If you returned multiple values, then output is a Tuple of multiple values. # If you returned a single Tensor, then output is a Tensor. # If you returned a Tuple with a single Tensor, then output is a # Tuple with a single Tensor. _, ind, ind_inv = output ctx.mark_non_differentiable(ind, ind_inv) # Tensors must be saved via ctx.save_for_backward. Please do not # assign them directly onto the ctx object. ctx.save_for_backward(ind, ind_inv) # Non-tensors may be saved by assigning them as attributes on the ctx object. ctx.dim = dim @staticmethod def backward(ctx, grad_output, _0, _1): # For the autograd.Function to be arbitrarily composable with function # transforms, all staticmethod other than forward and setup_context # must be implemented in a "transformable" way; that is, they must # only consist of PyTorch operations or autograd.Function. # # For example, this allows us to do double backwards and/or compute # second order gradients. # # We've written the backward pass of NumpySort in terms of another # autograd.Function, NumpyTake. ind, ind_inv = ctx.saved_tensors return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None class NumpyTake(torch.autograd.Function): @staticmethod def forward(x, ind, ind_inv, dim): device = x.device x = to_numpy(x) ind = to_numpy(ind) return torch.tensor(np.take_along_axis(x, ind, dim), device=device) @staticmethod def setup_context(ctx, inputs, output): x, ind, ind_inv, dim = inputs ctx.save_for_backward(ind, ind_inv) ctx.dim = dim @staticmethod def backward(ctx, grad_output): ind, ind_inv = ctx.saved_tensors result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim) return result, None, None, None ``` 现在,为了更容易使用`NumpySort`(隐藏我们返回的中间结果,以及允许默认参数和关键字参数),我们创建一个调用它的新函数: ```py def numpy_sort(x, dim=-1): result, _, _ = NumpySort.apply(x, dim) return result ``` 这里是一个健全性检查: ```py x = torch.randn(2, 3) grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x) assert torch.allclose(grad_x, torch.ones_like(x)) ``` ### 示例 2:autograd.Function 指定自定义梯度规则 另一个常见情况是一个使用 PyTorch 操作实现的`torch.autograd.Function`。PyTorch 能够自动为 PyTorch 操作计算梯度,但也许我们希望自定义梯度的计算方式。我们可能希望自定义反向传递与 PyTorch 给出的不同的原因有: + 提高数值稳定性 + 更改反向传递的性能特征 + 更改如何处理边缘情况(例如 nans,inf) + 修改梯度(例如梯度裁剪) 这里有一个关于函数`y = x ** 3`的`torch.autograd.Function`的示例,我们在其中改变了性能特征(一些在反向传递期间通常会发生的计算,计算 dx,现在发生在正向传递中)。 ```py class MyCube(torch.autograd.Function): @staticmethod def forward(x): result = x ** 3 # In regular PyTorch, if we had just run y = x ** 3, then the backward # pass computes dx = 3 * x ** 2\. In this autograd.Function, we've done # that computation here in the forward pass instead. dx = 3 * x ** 2 return result, dx @staticmethod def setup_context(ctx, inputs, output): x, = inputs result, dx = output ctx.save_for_backward(x, dx) @staticmethod def backward(ctx, grad_output, grad_dx): x, dx = ctx.saved_tensors # In order for the autograd.Function to work with higher-order # gradients, we must add the gradient contribution of `dx`. result = grad_output * dx + grad_dx * 6 * x return result ``` 现在,为了更容易使用`NumpySort`(并隐藏我们返回的中间结果),我们创建一个调用它的新函数: ```py def my_cube(x): result, _ = MyCube.apply(x) return result ``` 这里是一个检查计算二阶梯度的健全性检查: ```py x = torch.randn([]) ggx = torch.func.grad(torch.func.grad(my_cube))(x) assert torch.allclose(ggx, 6 * x) ``` ### 限制和注意事项 警告 请仔细阅读这些关于`torch.autograd.Function`的限制与 torch.func 变换。我们无法捕捉许多这些情况并优雅地报错,因此它们将导致未定义的行为。 请不要在`torch.autograd.Function`的方法中捕获正在转换的张量,这些张量具有 requires_grad=True 或是双重张量。完全安全的方法是确保`torch.autograd.Function`的任何方法中使用的唯一张量必须直接作为输入(或通过 ctx 对象)传递,而不是来自`torch.autograd.Function`外部的。 `torch.autograd.Function`不处理 pytrees 中的张量(可能包含或不包含张量的任意嵌套 Python 数据结构)。为了让这些张量被 autograd 跟踪,它们必须直接作为参数传递给`torch.autograd.Function`。这与 jax.{custom_vjp, custom_jvp}相反,后者接受 pytrees。 请只使用`save_for_backward()`或`save_for_forward()`来保存张量。请不要直接将张量或张量集合分配到 ctx 对象上 - 这些张量将不会被跟踪 ## `torch.vmap()`支持 要使用`torch.autograd.Function`与`torch.vmap()`,您必须: + 提供一个`vmap()`静态方法,告诉我们`torch.autograd.Function`在`torch.vmap()`下的行为 + 通过设置`generate_vmap_rule=True`来要求我们自动生成它。 ### 自动生成一个 vmap 规则 如果您的`torch.autograd.Function`满足以下额外约束条件,则我们可以为其生成一个 vmap 规则。如果不满足约束条件或者希望在 vmap 下自定义行为,请手动定义一个 vmap 静态方法(请参见下一节)。 警告 我们无法轻松检查以下约束条件并优雅地报错。违反约束条件可能导致未定义的行为。 + `torch.autograd.Function`的`forward()`、`backward()`(如果存在)和`jvp()`(如果存在)静态方法必须通过`torch.vmap()`进行转换。也就是说,它们必须仅包含 PyTorch 操作(而不是例如 NumPy 或自定义 CUDA 内核)。 示例: ```py class MyCube(torch.autograd.Function): # Set generate_vmap_rule to True to ask PyTorch to automatically generate # a vmap rule. generate_vmap_rule = True @staticmethod def forward(x): result = x ** 3 dx = 3 * x ** 2 return result, dx @staticmethod def setup_context(ctx, inputs, output): x, = inputs result, dx = output ctx.save_for_backward(x, dx) @staticmethod def backward(ctx, grad_output, grad_dx): x, dx = ctx.saved_tensors result = grad_output * dx + grad_dx * 6 * x return result def my_cube(x): result, dx = MyCube.apply(x) return result x = torch.randn(3) result = torch.vmap(my_cube)(x) assert torch.allclose(result, x ** 3) ``` ### 定义 vmap 静态方法 如果您的`torch.autograd.Function`调用另一个系统(如 NumPy、C++、CUDA、triton),那么为了使其与`torch.vmap()`或使用它的转换一起工作,您需要手动定义一个`vmap()`静态方法。 根据您想要使用的转换和用例,您可能不需要为所有的`torch.autograd.Function`添加一个`vmap()`静态方法: + 例如,`torch.func.jacrev()`在反向传播中执行`vmap()`。因此,如果您只对使用`torch.func.jacrev()`感兴趣,则只需要将`backward()`静态方法设置为可 vmapped。 我们建议确保所有的`torch.autograd.Function`都支持`torch.vmap()`,尤其是如果您正在编写第三方库,并且希望您的`torch.autograd.Function`能够与所有组合的`torch.func()`转换一起使用。 从概念上讲,vmap 静态方法负责定义`forward()`在`torch.vmap()`下应该如何行为。也就是说,它定义了如何将`forward()`转换为在具有额外维度(正在被 vmapped 覆盖的维度)的输入上运行。这类似于 PyTorch 操作上实现`torch.vmap()`的方式:对于每个操作,我们定义一个 vmap 规则(有时也称为“批处理规则”)。 这里是如何定义`vmap()`静态方法的: + 签名是`vmap(info, in_dims: Tuple[Optional[int]], *args)`,其中`*args`与`forward()`的参数相同。 + vmap 静态方法负责定义`forward()`在`torch.vmap()`下应该如何行为。也就是说,给定具有额外维度(由`in_dims`指定)的输入,我们如何计算`forward()`的批处理版本? + 对于`args`中的每个参数,`in_dims`都有一个相应的`Optional[int]`。如果参数不是张量或参数不是被 vmapped 覆盖的,则为`None`,否则,它是一个整数,指定正在被 vmapped 覆盖的张量的维度。 + `info`是一组额外的元数据,可能会有所帮助:`info.batch_size`指定了正在进行 vmapped 的维度的大小,而`info.randomness`是传递给`torch.vmap()`的`randomness`选项。 + vmap 静态方法的返回值是一个元组`(output, out_dims)`。与`in_dims`类似,`out_dims`应该与`output`的结构相同,并且包含一个`out_dim`,指定输出是否具有 vmapped 维度以及其索引。 示例: ```py def to_numpy(tensor): return tensor.cpu().numpy() class NumpySort(torch.autograd.Function): @staticmethod def forward(x, dim): device = x.device x = to_numpy(x) ind = np.argsort(x, axis=dim) ind_inv = np.argsort(ind, axis=dim) result = np.take_along_axis(x, ind, axis=dim) return ( torch.tensor(result, device=device), torch.tensor(ind, device=device), torch.tensor(ind_inv, device=device), ) @staticmethod def setup_context(ctx, inputs, output): x, dim = inputs _, ind, ind_inv = output ctx.mark_non_differentiable(ind, ind_inv) ctx.save_for_backward(ind, ind_inv) ctx.dim = dim @staticmethod def backward(ctx, grad_output, _0, _1): ind, ind_inv = ctx.saved_tensors return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None # The signature of the vmap staticmethod is: # vmap(info, in_dims: Tuple[Optional[int]], *args) # where *args is the same as the arguments to `forward`. @staticmethod def vmap(info, in_dims, x, dim): # For every input (x and dim), in_dims stores an Optional[int] # that is: # - None if the input is not being vmapped over or if the input # is not a Tensor # - an integer if the input is being vmapped over that represents # the index of the dimension being vmapped over. x_bdim, _ = in_dims # A "vmap rule" is the logic of how to perform the operation given # inputs with one additional dimension. In NumpySort, x has an # additional dimension (x_bdim). The vmap rule is simply # to call NumpySort again but pass it a different `dim`. x = x.movedim(x_bdim, 0) # Handle negative dims correctly dim = dim if dim >= 0 else dim + x.dim() - 1 result = NumpySort.apply(x, dim + 1) # The vmap rule must return a tuple of two things # 1\. the output. Should be the same amount of things # as returned by the forward(). # 2\. one Optional[int] for each output specifying if each output # is being vmapped over, and if so, the index of the # dimension being vmapped over. # # NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the # dimension being vmapped over to the front of `x`, that appears at # dimension 0 of all outputs. # The return is (output, out_dims) -- output is a tuple of 3 Tensors # and out_dims is a Tuple of 3 Optional[int] return NumpySort.apply(x, dim + 1), (0, 0, 0) class NumpyTake(torch.autograd.Function): @staticmethod def forward(x, ind, ind_inv, dim): device = x.device x = to_numpy(x) ind = to_numpy(ind) return torch.tensor(np.take_along_axis(x, ind, dim), device=device) @staticmethod def setup_context(ctx, inputs, output): x, ind, ind_inv, dim = inputs ctx.save_for_backward(ind, ind_inv) ctx.dim = dim @staticmethod def backward(ctx, grad_output): ind, ind_inv = ctx.saved_tensors result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim) return result, None, None, None @staticmethod def vmap(info, in_dims, x, ind, ind_inv, dim): x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims # The strategy is: expand {x, ind, ind_inv} to all have the dimension # being vmapped over. # Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim). # Handle negative dims by wrapping them to be positive logical_dim = x.dim() if x_bdim is None else x_bdim - 1 dim = dim if dim >= 0 else dim + logical_dim def maybe_expand_bdim_at_front(x, x_bdim): if x_bdim is None: return x.expand(info.batch_size, *x.shape) return x.movedim(x_bdim, 0) # If the Tensor doesn't have the dimension being vmapped over, # expand it out. Otherwise, move it to the front of the Tensor x = maybe_expand_bdim_at_front(x, x_bdim) ind = maybe_expand_bdim_at_front(ind, ind_bdim) ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim) # The return is a tuple (output, out_dims). Since output is a Tensor, # then out_dims is an Optional[int] (instead of being a Tuple). return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0 def numpy_sort(x, dim=-1): result, _, _ = NumpySort.apply(x, dim) return result x = torch.randn(2, 3) result = torch.vmap(numpy_sort)(x) assert torch.allclose(result, numpy_sort(result, 1)) ``` 注意 vmap 静态方法应该旨在保留整个`Function`的语义。也就是说,(伪代码)`grad(vmap(MyFunc))`应该可以替换为`grad(map(MyFunc))`。 如果您的 autograd.Function 在反向传播中具有任何自定义行为,请记住这一点。 注意 为一个 PyTorch 能够通过`generate_vmap_rule=True`生成 vmap 规则的`Function`编写自定义 vmap 静态方法是一个合法的用例。如果生成的 vmap 规则不符合您的预期语义,您可能希望这样做。 ## `torch.func.jvp()` 支持 为了支持正向模式自动微分,一个`torch.autograd.Function`必须有一个`jvp()`静态方法。请参阅正向模式自动微分获取详细信息。