# 雅可比矩阵、海森矩阵、hvp、vhp 等:组合函数转换 > 原文:[`pytorch.org/tutorials/intermediate/jacobians_hessians.html`](https://pytorch.org/tutorials/intermediate/jacobians_hessians.html) > > 译者:[飞龙](https://github.com/wizardforcel) > > 协议:[CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/) 注意 点击这里下载完整的示例代码 计算雅可比矩阵或海森矩阵在许多非传统的深度学习模型中是有用的。使用 PyTorch 的常规自动微分 API(`Tensor.backward()`,`torch.autograd.grad`)高效地计算这些量是困难的(或者烦人的)。PyTorch 的 [受 JAX 启发的](https://github.com/google/jax) [函数转换 API](https://pytorch.org/docs/master/func.html) 提供了高效计算各种高阶自动微分量的方法。 注意 本教程需要 PyTorch 2.0.0 或更高版本。 ## 计算雅可比矩阵 ```py import torch import torch.nn.functional as F from functools import partial _ = torch.manual_seed(0) ``` 让我们从一个我们想要计算雅可比矩阵的函数开始。这是一个带有非线性激活的简单线性函数。 ```py def predict(weight, bias, x): return F.linear(x, weight, bias).tanh() ``` 让我们添加一些虚拟数据:一个权重、一个偏置和一个特征向量 x。 ```py D = 16 weight = torch.randn(D, D) bias = torch.randn(D) x = torch.randn(D) # feature vector ``` 让我们将 `predict` 视为一个将输入 `x` 从 $R^D \to R^D$ 的函数。PyTorch Autograd 计算向量-雅可比乘积。为了计算这个 $R^D \to R^D$ 函数的完整雅可比矩阵,我们将不得不逐行计算,每次使用一个不同的单位向量。 ```py def compute_jac(xp): jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] for vec in unit_vectors] return torch.stack(jacobian_rows) xp = x.clone().requires_grad_() unit_vectors = torch.eye(D) jacobian = compute_jac(xp) print(jacobian.shape) print(jacobian[0]) # show first row ``` ```py torch.Size([16, 16]) tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190, 0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308]) ``` 我们可以使用 PyTorch 的 `torch.vmap` 函数转换来消除循环并向量化计算,而不是逐行计算雅可比矩阵。我们不能直接将 `vmap` 应用于 `torch.autograd.grad`;相反,PyTorch 提供了一个 `torch.func.vjp` 转换,与 `torch.vmap` 组合使用: ```py from torch.func import vmap, vjp _, vjp_fn = vjp(partial(predict, weight, bias), x) ft_jacobian, = vmap(vjp_fn)(unit_vectors) # let's confirm both methods compute the same result assert torch.allclose(ft_jacobian, jacobian) ``` 在后续教程中,反向模式自动微分和 `vmap` 的组合将给我们提供每个样本的梯度。在本教程中,组合反向模式自动微分和 `vmap` 将给我们提供雅可比矩阵的计算!`vmap` 和自动微分转换的各种组合可以给我们提供不同的有趣量。 PyTorch 提供了 `torch.func.jacrev` 作为一个方便的函数,执行 `vmap-vjp` 组合来计算雅可比矩阵。`jacrev` 接受一个 `argnums` 参数,指定我们想要相对于哪个参数计算雅可比矩阵。 ```py from torch.func import jacrev ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) # Confirm by running the following: assert torch.allclose(ft_jacobian, jacobian) ``` 让我们比较两种计算雅可比矩阵的方式的性能。函数转换版本要快得多(并且随着输出数量的增加而变得更快)。 一般来说,我们期望通过 `vmap` 的向量化可以帮助消除开销,并更好地利用硬件。 `vmap` 通过将外部循环下推到函数的原始操作中,以获得更好的性能。 让我们快速创建一个函数来评估性能,并处理微秒和毫秒的测量: ```py def get_perf(first, first_descriptor, second, second_descriptor): """takes torch.benchmark objects and compares delta of second vs first.""" faster = second.times[0] slower = first.times[0] gain = (slower-faster)/slower if gain < 0: gain *=-1 final_gain = gain*100 print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ") ``` 然后进行性能比较: ```py from torch.utils.benchmark import Timer without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) no_vmap_timer = without_vmap.timeit(500) with_vmap_timer = with_vmap.timeit(500) print(no_vmap_timer) print(with_vmap_timer) ``` ```py compute_jac(xp) 1.43 ms 1 measurement, 500 runs , 1 thread jacrev(predict, argnums=2)(weight, bias, x) 435.16 us 1 measurement, 500 runs , 1 thread ``` 让我们通过我们的 `get_perf` 函数进行上述的相对性能比较: ```py get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap") ``` ```py Performance delta: 69.4681 percent improvement with vmap ``` 此外,很容易将问题转换过来,说我们想要计算模型参数(权重、偏置)的雅可比矩阵,而不是输入的雅可比矩阵 ```py # note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x) ``` ## 反向模式雅可比矩阵(`jacrev`) vs 正向模式雅可比矩阵(`jacfwd`) 我们提供了两个 API 来计算雅可比矩阵:`jacrev` 和 `jacfwd`: + `jacrev` 使用反向模式自动微分。正如你在上面看到的,它是我们 `vjp` 和 `vmap` 转换的组合。 + `jacfwd` 使用正向模式自动微分。它是我们 `jvp` 和 `vmap` 转换的组合实现。 `jacfwd` 和 `jacrev` 可以互相替代,但它们具有不同的性能特征。 作为一个经验法则,如果你正在计算一个 $R^N \to R^M$ 函数的雅可比矩阵,并且输出比输入要多得多(例如,$M > N$),那么首选 `jacfwd`,否则使用 `jacrev`。当然,这个规则也有例外,但以下是一个非严格的论证: 在反向模式 AD 中,我们逐行计算雅可比矩阵,而在正向模式 AD(计算雅可比向量积)中,我们逐列计算。雅可比矩阵有 M 行和 N 列,因此如果它在某个方向上更高或更宽,我们可能更喜欢处理较少行或列的方法。 ```py from torch.func import jacrev, jacfwd ``` 首先,让我们使用更多的输入进行基准测试: ```py Din = 32 Dout = 2048 weight = torch.randn(Dout, Din) bias = torch.randn(Dout) x = torch.randn(Din) # remember the general rule about taller vs wider... here we have a taller matrix: print(weight.shape) using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) jacfwd_timing = using_fwd.timeit(500) jacrev_timing = using_bwd.timeit(500) print(f'jacfwd time: {jacfwd_timing}') print(f'jacrev time: {jacrev_timing}') ``` ```py torch.Size([2048, 32]) jacfwd time: jacfwd(predict, argnums=2)(weight, bias, x) 773.29 us 1 measurement, 500 runs , 1 thread jacrev time: jacrev(predict, argnums=2)(weight, bias, x) 8.54 ms 1 measurement, 500 runs , 1 thread ``` 然后进行相对基准测试: ```py get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", ); ``` ```py Performance delta: 1004.5112 percent improvement with jacrev ``` 现在反过来 - 输出(M)比输入(N)更多: ```py Din = 2048 Dout = 32 weight = torch.randn(Dout, Din) bias = torch.randn(Dout) x = torch.randn(Din) using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) jacfwd_timing = using_fwd.timeit(500) jacrev_timing = using_bwd.timeit(500) print(f'jacfwd time: {jacfwd_timing}') print(f'jacrev time: {jacrev_timing}') ``` ```py jacfwd time: jacfwd(predict, argnums=2)(weight, bias, x) 7.15 ms 1 measurement, 500 runs , 1 thread jacrev time: jacrev(predict, argnums=2)(weight, bias, x) 533.13 us 1 measurement, 500 runs , 1 thread ``` 以及相对性能比较: ```py get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd") ``` ```py Performance delta: 1241.8207 percent improvement with jacfwd ``` ## 使用 functorch.hessian 进行 Hessian 计算 我们提供了一个方便的 API 来计算 Hessian:`torch.func.hessiani`。Hessians 是雅可比矩阵的雅可比矩阵(或偏导数的偏导数,也称为二阶导数)。 这表明可以简单地组合 functorch 雅可比变换来计算 Hessian。实际上,在内部,`hessian(f)`就是`jacfwd(jacrev(f))`。 注意:为了提高性能:根据您的模型,您可能还希望使用`jacfwd(jacfwd(f))`或`jacrev(jacrev(f))`来计算 Hessian,利用上述关于更宽还是更高矩阵的经验法则。 ```py from torch.func import hessian # lets reduce the size in order not to overwhelm Colab. Hessians require # significant memory: Din = 512 Dout = 32 weight = torch.randn(Dout, Din) bias = torch.randn(Dout) x = torch.randn(Din) hess_api = hessian(predict, argnums=2)(weight, bias, x) hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x) hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x) ``` 让我们验证无论是使用 Hessian API 还是使用`jacfwd(jacfwd())`,我们都会得到相同的结果。 ```py torch.allclose(hess_api, hess_fwdfwd) ``` ```py True ``` ## 批处理雅可比矩阵和批处理 Hessian 在上面的例子中,我们一直在操作单个特征向量。在某些情况下,您可能希望对一批输出相对于一批输入进行雅可比矩阵的计算。也就是说,给定形状为`(B, N)`的输入批次和一个从$R^N \to R^M$的函数,我们希望得到形状为`(B, M, N)`的雅可比矩阵。 使用`vmap`是最简单的方法: ```py batch_size = 64 Din = 31 Dout = 33 weight = torch.randn(Dout, Din) print(f"weight shape = {weight.shape}") bias = torch.randn(Dout) x = torch.randn(batch_size, Din) compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0)) batch_jacobian0 = compute_batch_jacobian(weight, bias, x) ``` ```py weight shape = torch.Size([33, 31]) ``` 如果您有一个从(B, N) -> (B, M)的函数,而且确定每个输入产生独立的输出,那么有时也可以通过对输出求和,然后计算该函数的雅可比矩阵来实现,而无需使用`vmap`: ```py def predict_with_output_summed(weight, bias, x): return predict(weight, bias, x).sum(0) batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0) assert torch.allclose(batch_jacobian0, batch_jacobian1) ``` 如果您的函数是从$R^N \to R^M$,但输入是批处理的,您可以组合`vmap`和`jacrev`来计算批处理雅可比矩阵: 最后,批次 Hessian 矩阵的计算方式类似。最容易的方法是使用`vmap`批处理 Hessian 计算,但在某些情况下,求和技巧也适用。 ```py compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0)) batch_hess = compute_batch_hessian(weight, bias, x) batch_hess.shape ``` ```py torch.Size([64, 33, 31, 31]) ``` ## 计算 Hessian 向量积 计算 Hessian 向量积的朴素方法是将完整的 Hessian 材料化并与向量进行点积。我们可以做得更好:事实证明,我们不需要材料化完整的 Hessian 来做到这一点。我们将介绍两种(许多种)不同的策略来计算 Hessian 向量积:-将反向模式 AD 与反向模式 AD 组合-将反向模式 AD 与正向模式 AD 组合 将反向模式 AD 与正向模式 AD 组合(而不是反向模式与反向模式)通常是计算 HVP 的更节省内存的方式,因为正向模式 AD 不需要构建 Autograd 图并保存反向传播的中间结果: ```py from torch.func import jvp, grad, vjp def hvp(f, primals, tangents): return jvp(grad(f), primals, tangents)[1] ``` 以下是一些示例用法。 ```py def f(x): return x.sin().sum() x = torch.randn(2048) tangent = torch.randn(2048) result = hvp(f, (x,), (tangent,)) ``` 如果 PyTorch 正向 AD 没有覆盖您的操作,那么我们可以将反向模式 AD 与反向模式 AD 组合: ```py def hvp_revrev(f, primals, tangents): _, vjp_fn = vjp(grad(f), *primals) return vjp_fn(*tangents) result_hvp_revrev = hvp_revrev(f, (x,), (tangent,)) assert torch.allclose(result, result_hvp_revrev[0]) ``` **脚本的总运行时间:**(0 分钟 10.644 秒) `下载 Python 源代码:jacobians_hessians.py` `下载 Jupyter 笔记本:jacobians_hessians.ipynb` [Sphinx-Gallery 生成的图库](https://sphinx-gallery.github.io)