# 使用自定义函数进行双向传播 > 原文:[`pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html`](https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html) > > 译者:[飞龙](https://github.com/wizardforcel) > > 协议:[CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/) 有时候需要通过向后图两次运行反向传播,例如计算高阶梯度。然而,要支持双向传播需要对 autograd 有一定的理解和谨慎。支持单次向后传播的函数不一定能够支持双向传播。在本教程中,我们展示了如何编写支持双向传播的自定义 autograd 函数,并指出一些需要注意的事项。 当编写自定义 autograd 函数以进行两次向后传播时,重要的是要知道自定义函数中的操作何时被 autograd 记录,何时不被记录,以及最重要的是,save_for_backward 如何与所有这些操作一起使用。 自定义函数隐式影响梯度模式的两种方式: + 在向前传播期间,autograd 不会记录任何在前向函数内执行的操作的图形。当前向完成时,自定义函数的向后函数将成为每个前向输出的 grad_fn + 在向后传播期间,如果指定了 create_graph 参数,autograd 会记录用于计算向后传播的计算图 接下来,为了了解 save_for_backward 如何与上述交互,我们可以探索一些示例: ## 保存输入 考虑这个简单的平方函数。它保存一个输入张量以备向后传播使用。当 autograd 能够记录向后传播中的操作时,双向传播会自动工作,因此当我们保存一个输入以备向后传播时,通常不需要担心,因为如果输入是任何需要梯度的张量的函数,它应该有 grad_fn。这样可以正确传播梯度。 ```py import torch class Square(torch.autograd.Function): @staticmethod def forward(ctx, x): # Because we are saving one of the inputs use `save_for_backward` # Save non-tensors and non-inputs/non-outputs directly on ctx ctx.save_for_backward(x) return x**2 @staticmethod def backward(ctx, grad_out): # A function support double backward automatically if autograd # is able to record the computations performed in backward x, = ctx.saved_tensors return grad_out * 2 * x # Use double precision because finite differencing method magnifies errors x = torch.rand(3, 3, requires_grad=True, dtype=torch.double) torch.autograd.gradcheck(Square.apply, x) # Use gradcheck to verify second-order derivatives torch.autograd.gradgradcheck(Square.apply, x) ``` 我们可以使用 torchviz 来可视化图形以查看为什么这样可以工作 ```py import torchviz x = torch.tensor(1., requires_grad=True).clone() out = Square.apply(x) grad_x, = torch.autograd.grad(out, x, create_graph=True) torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out}) ``` 我们可以看到对于 x 的梯度本身是 x 的函数(dout/dx = 2x),并且这个函数的图形已经正确构建 ![`user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png`](img/664c9393ebdb32f044c3ab5f5780b3f7.png) ## 保存输出 在前一个示例的轻微变化是保存输出而不是输入。机制类似,因为输出也与 grad_fn 相关联。 ```py class Exp(torch.autograd.Function): # Simple case where everything goes well @staticmethod def forward(ctx, x): # This time we save the output result = torch.exp(x) # Note that we should use `save_for_backward` here when # the tensor saved is an ouptut (or an input). ctx.save_for_backward(result) return result @staticmethod def backward(ctx, grad_out): result, = ctx.saved_tensors return result * grad_out x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone() # Validate our gradients using gradcheck torch.autograd.gradcheck(Exp.apply, x) torch.autograd.gradgradcheck(Exp.apply, x) ``` 使用 torchviz 来可视化图形: ```py out = Exp.apply(x) grad_x, = torch.autograd.grad(out, x, create_graph=True) torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out}) ``` ![`user-images.githubusercontent.com/13428986/126559780-d141f2ba-1ee8-4c33-b4eb-c9877b27a954.png`](img/7ab379f6d65d456373fdf6a3cdb35b1a.png) ## 保存中间结果 更棘手的情况是当我们需要保存一个中间结果时。我们通过实现以下情况来演示这种情况: $$sinh(x) := \frac{e^x - e^{-x}}{2} $$ 由于 sinh 的导数是 cosh,因此在向后计算中重复使用 exp(x)和 exp(-x)这两个中间结果可能很有用。 尽管如此,中间结果不应直接保存并在向后传播中使用。因为前向是在无梯度模式下执行的,如果前向传递的中间结果用于计算向后传递中的梯度,则梯度的向后图将不包括计算中间结果的操作。这会导致梯度不正确。 ```py class Sinh(torch.autograd.Function): @staticmethod def forward(ctx, x): expx = torch.exp(x) expnegx = torch.exp(-x) ctx.save_for_backward(expx, expnegx) # In order to be able to save the intermediate results, a trick is to # include them as our outputs, so that the backward graph is constructed return (expx - expnegx) / 2, expx, expnegx @staticmethod def backward(ctx, grad_out, _grad_out_exp, _grad_out_negexp): expx, expnegx = ctx.saved_tensors grad_input = grad_out * (expx + expnegx) / 2 # We cannot skip accumulating these even though we won't use the outputs # directly. They will be used later in the second backward. grad_input += _grad_out_exp * expx grad_input -= _grad_out_negexp * expnegx return grad_input def sinh(x): # Create a wrapper that only returns the first output return Sinh.apply(x)[0] x = torch.rand(3, 3, requires_grad=True, dtype=torch.double) torch.autograd.gradcheck(sinh, x) torch.autograd.gradgradcheck(sinh, x) ``` 使用 torchviz 来可视化图形: ```py out = sinh(x) grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True) torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out}) ``` ![`user-images.githubusercontent.com/13428986/126560494-e48eba62-be84-4b29-8c90-a7f6f40b1438.png`](img/66f87d1f09778a82307fefa72409569c.png) ## 保存中间结果:不要这样做 现在我们展示当我们不返回中间结果作为输出时会发生什么:grad_x 甚至不会有一个反向图,因为它纯粹是一个函数 exp 和 expnegx,它们不需要 grad。 ```py class SinhBad(torch.autograd.Function): # This is an example of what NOT to do! @staticmethod def forward(ctx, x): expx = torch.exp(x) expnegx = torch.exp(-x) ctx.expx = expx ctx.expnegx = expnegx return (expx - expnegx) / 2 @staticmethod def backward(ctx, grad_out): expx = ctx.expx expnegx = ctx.expnegx grad_input = grad_out * (expx + expnegx) / 2 return grad_input ``` 使用 torchviz 来可视化图形。请注意,grad_x 不是图形的一部分! ```py out = SinhBad.apply(x) grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True) torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out}) ``` ![`user-images.githubusercontent.com/13428986/126565889-13992f01-55bc-411a-8aee-05b721fe064a.png`](img/c57a22a13ed99e177d45732c5bcc36ff.png) ## 当不跟踪反向传播时 最后,让我们考虑一个例子,即 autograd 可能根本无法跟踪函数的反向梯度。我们可以想象 cube_backward 是一个可能需要非 PyTorch 库(如 SciPy 或 NumPy)或编写为 C++扩展的函数。这里演示的解决方法是创建另一个自定义函数 CubeBackward,在其中手动指定 cube_backward 的反向传播! ```py def cube_forward(x): return x**3 def cube_backward(grad_out, x): return grad_out * 3 * x**2 def cube_backward_backward(grad_out, sav_grad_out, x): return grad_out * sav_grad_out * 6 * x def cube_backward_backward_grad_out(grad_out, x): return grad_out * 3 * x**2 class Cube(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return cube_forward(x) @staticmethod def backward(ctx, grad_out): x, = ctx.saved_tensors return CubeBackward.apply(grad_out, x) class CubeBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_out, x): ctx.save_for_backward(x, grad_out) return cube_backward(grad_out, x) @staticmethod def backward(ctx, grad_out): x, sav_grad_out = ctx.saved_tensors dx = cube_backward_backward(grad_out, sav_grad_out, x) dgrad_out = cube_backward_backward_grad_out(grad_out, x) return dgrad_out, dx x = torch.tensor(2., requires_grad=True, dtype=torch.double) torch.autograd.gradcheck(Cube.apply, x) torch.autograd.gradgradcheck(Cube.apply, x) ``` 使用 torchviz 来可视化图形: ```py out = Cube.apply(x) grad_x, = torch.autograd.grad(out, x, create_graph=True) torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out}) ``` ![`user-images.githubusercontent.com/13428986/126559935-74526b4d-d419-4983-b1f0-a6ee99428531.png`](img/44368555f30978a287e8a47e0cfff9ee.png) 总之,双向传播是否适用于您的自定义函数取决于反向传播是否可以被 autograd 跟踪。通过前两个示例,我们展示了双向传播可以直接使用的情况。通过第三和第四个示例,我们展示了使反向函数可以被跟踪的技术,否则它们将无法被跟踪。