# Gradcheck 机制 > [`pytorch.org/docs/stable/notes/gradcheck.html`](https://pytorch.org/docs/stable/notes/gradcheck.html) 本说明概述了`gradcheck()`和`gradgradcheck()`函数的工作原理。 它将涵盖实数和复数值函数的前向和反向模式 AD,以及高阶导数。本说明还涵盖了 gradcheck 的默认行为以及传递`fast_mode=True`参数的情况(以下简称为快速 gradcheck)。 + [符号和背景信息] + [默认的反向模式 gradcheck 行为] + [实数到实数函数] + [复数到实数函数] + [具有复数输出的函数] + [快速反向模式 gradcheck] + [实数到实数函数的快速 gradcheck] + [复数到实数函数的快速 gradcheck] + [具有复数输出的函数的快速 gradcheck] + [Gradgradcheck 实现] ## [符号和背景信息] 在本说明中,我们将使用以下约定: 1. $x$、$y$、$a$、$b$、$v$、$u$、$ur$和$ui$是实值向量,$z$是一个复值向量,可以用两个实值向量重新表示为$z = a + i b$。 1. $N$和$M$是我们将用于输入和输出空间的维度的两个整数。 1. $f: \mathcal{R}^N \to \mathcal{R}^M$是我们的基本实数到实数函数,使得$y = f(x)$。 1. $g: \mathcal{C}^N \to \mathcal{R}^M$是我们的基本复数到实数函数,使得$y = g(z)$。 对于简单的实数到实数情况,我们将与$f$相关的雅可比矩阵记为$J_f$,大小为$M \times N$。这个矩阵包含所有偏导数,使得位置$(i, j)$处的条目包含$\frac{\partial y_i}{\partial x_j}$。然后,反向模式 AD 计算给定大小为$M$的向量$v$的数量$v^T J_f$。另一方面,前向模式 AD 计算给定大小为$N$的向量$u$的数量$J_f u$。 对于包含复数值的函数,情况要复杂得多。我们这里只提供概要,完整描述可以在复数数值的 Autograd 中找到。 满足复数可微性(柯西-黎曼方程)的约束对于所有实值损失函数来说太过严格,因此我们选择使用 Wirtinger 微积分。在 Wirtinger 微积分的基本设置中,链式法则要求同时访问 Wirtinger 导数(以下称为$W$)和共轭 Wirtinger 导数(以下称为$CW$)。由于一般情况下,尽管它们的名称如此,但$W$和$CW$都需要传播,因为它们不是彼此的复共轭。 为了避免传播两个值,对于反向模式 AD,我们总是假设正在计算导数的函数要么是实值函数,要么是更大的实值函数的一部分。这个假设意味着我们在反向传递过程中计算的所有中间梯度也与实值函数相关联。在实践中,当进行优化时,这个假设并不具限制性,因为这样的问题需要实值目标(复数之间没有自然的排序)。 在这种假设下,使用$W$和$CW$的定义,我们可以展示$W = CW^*$,因此只需要“通过图形向后传递”两个值中的一个,另一个可以轻松恢复。为了简化内部计算,PyTorch 使用$2 * CW$作为向后传递和用户请求梯度时返回的值。类似于实数情况,当输出实际上在$\mathcal{R}^M$时,反向模式 AD 不会计算$2 * CW$,而只会计算$v^T (2 * CW)$,其中$v \in \mathcal{R}^M$。 对于前向模式 AD,我们使用类似的逻辑,假设函数是更大函数的一部分,其输入在$\mathcal{R}$中。在这种假设下,我们可以做出类似的声明,即每个中间结果对应于一个输入在$\mathcal{R}$中的函数,并且在这种情况下,使用$W$和$CW$的定义,我们可以展示中间函数的$W = CW$。为了确保前向和后向模式在一维函数的基本情况下计算相同的量,前向模式还计算$2 * CW$。类似于实数情况,当输入实际上在$\mathcal{R}^N$时,前向模式 AD 不会计算$2 * CW$,而只会计算$(2 * CW) u$,其中$u \in \mathcal{R}^N$。 ## 默认反向模式 gradcheck 行为 ### 实数到实数函数 为了测试一个函数$f: \mathcal{R}^N \to \mathcal{R}^M, x \to y$,我们以两种方式重建大小为$M \times N$的完整雅可比矩阵$J_f$:分析和数值。分析版本使用我们的反向模式 AD,而数值版本使用有限差分。然后,逐个元素比较这两个重建的雅可比矩阵是否相等。 #### 默认实数输入数值评估 如果考虑一维函数的基本情况($N = M = 1$),那么我们可以使用维基百科文章中的基本有限差分公式。我们使用“中心差分”以获得更好的数值性质。 $\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}$ 这个公式很容易推广到具有多个输出($M \gt 1$)的情况,通过将$\frac{\partial y}{\partial x}$作为大小为$M \times 1$的列向量,就像$f(x + eps)$一样。在这种情况下,上述公式可以直接重复使用,并且只需对用户函数进行两次评估(即$f(x + eps)$和$f(x - eps)$)即可近似计算完整的雅可比矩阵。 处理具有多个输入($N \gt 1$)的情况更加昂贵。在这种情况下,我们依次循环遍历所有输入,并为$x$的每个元素依次应用$eps$的扰动。这使我们能够逐列重建$J_f$矩阵。 #### 默认实数输入分析评估 对于分析评估,我们使用如上所述的事实,即反向模式 AD 计算$v^T J_f$。对于具有单个输出的函数,我们简单地使用$v = 1$来通过单个反向传递恢复完整的雅可比矩阵。 对于具有多个输出的函数,我们使用一个 for 循环,迭代输出,其中每个$v$是一个依次对应于每个输出的 one-hot 向量。这样可以逐行重建$J_f$矩阵。 ### 复数到实数函数 为了测试一个函数 $g: \mathcal{C}^N \to \mathcal{R}^M, z \to y$,其中 $z = a + i b$,我们重建包含 $2 * CW$ 的(复数值)矩阵。 #### 默认复数输入数值评估 考虑首先 $N = M = 1$ 的基本情况。我们从[这篇研究论文](https://arxiv.org/pdf/1701.00392.pdf)中知道: $CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})$ 请注意,上述方程中的 $\frac{\partial y}{\partial a}$ 和 $\frac{\partial y}{\partial b}$ 是 $\mathcal{R} \to \mathcal{R}$ 的导数。为了对这些进行数值评估,我们使用了上述实数到实数情况的描述方法。这使我们能够计算 $CW$ 矩阵,然后乘以 $2$。 请注意,截至撰写时,代码以稍微复杂的方式计算这个值: ```py # Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105 # Notation changes in this code block: # s here is y above # x, y here are a, b above ds_dx = compute_gradient(eps) ds_dy = compute_gradient(eps * 1j) # conjugate wirtinger derivative conj_w_d = 0.5 * (ds_dx + ds_dy * 1j) # wirtinger derivative w_d = 0.5 * (ds_dx - ds_dy * 1j) d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj() # Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`. ``` #### 默认复数输入解析评估 由于反向模式 AD 已经精确计算了两倍的 $CW$ 导数,因此我们在这里与实数到实数情况一样使用相同的技巧,并在有多个实数输出时逐行重建矩阵。 ### 具有复数输出的函数 在这种情况下,用户提供的函数不符合自动微分的假设,即我们为其计算反向 AD 的函数是实值的。这意味着直接在这个函数上使用自动微分是不明确定义的。为了解决这个问题,我们将测试函数 $h: \mathcal{P}^N \to \mathcal{C}^M$(其中 $\mathcal{P}$ 可以是 $\mathcal{R}$ 或 $\mathcal{C}$)替换为两个函数:$hr$ 和 $hi$,使得: $\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q)) \end{aligned}$ 其中 $q \in \mathcal{P}$。然后我们对 $hr$ 和 $hi$ 进行基本的梯度检查,使用上述描述的实数到实数或复数到实数的情况,取决于 $\mathcal{P}$。 请注意,截至撰写时,代码并没有显式创建这些函数,而是通过将 $\text{grad\_out}$ 参数传递给不同的函数,手动使用 $real$ 或 $imag$ 函数执行链式规则。当 $\text{grad\_out} = 1$ 时,我们考虑 $hr$。当 $\text{grad\_out} = 1j$ 时,我们考虑 $hi$。 ## 快速反向模式梯度检查 尽管上述梯度检查的公式很好,为了确保正确性和可调试性,它非常慢,因为它重建了完整的雅可比矩阵。本节介绍了一种以更快的方式执行梯度检查的方法,而不影响其正确性。通过在检测到错误时添加特殊逻辑,可以恢复可调试性。在这种情况下,我们可以运行重建完整矩阵的默认版本,以向用户提供完整的细节。 这里的高级策略是找到一个标量量,可以通过数值和解析方法高效计算,并且能够很好地代表慢梯度检查计算的完整矩阵,以确保它能够捕捉雅可比矩阵中的任何差异。 ### 实数到实数函数的快速梯度检查 我们想要计算的标量量是给定随机向量 $v \in \mathcal{R}^M$ 和随机单位范数向量 $u \in \mathcal{R}^N$ 时的 $v^T J_f u$。 对于数值评估,我们可以高效地计算 $J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.$ 然后我们执行这个向量与 $v$ 的点积,得到感兴趣的标量值。 对于分析版本,我们可以使用反向模式自动微分来直接计算$v^T J_f$。然后执行与$u$的点积以获得期望值。 ### 快速复数输入分析评估 类似于实数到实数的情况,我们希望对完整矩阵进行简化。但是$2 * CW$矩阵是复数值的,因此在这种情况下,我们将与复数标量进行比较。 由于在数值情况下我们可以有效计算的一些约束以及为了将数值评估的数量保持最小,我们计算以下(尽管令人惊讶的)标量值: $s := 2 * v^T (real(CW) ur + i * imag(CW) ui)$ $\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}$ #### 快速复数输入数值评估 我们首先考虑如何用数值方法计算$s$。为此,牢记我们考虑的是$g: \mathcal{C}^N \to \mathcal{R}^M, z \to y$,其中$z = a + i b$,以及$CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})$,我们将其重写如下: $\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}$ 在这个公式中,我们可以看到$\frac{\partial y}{\partial a} ur$和$\frac{\partial y}{\partial b} ui$可以像实数到实数情况的快速版本一样进行评估。一旦计算出这些实值量,我们可以在右侧重建复向量,并与实值$v$向量进行点积。 #### 此时,您可能会想知道为什么我们没有选择一个复数$u$并只执行简化$2 * v^T CW u'$。为了深入探讨这一点,在本段中,我们将使用$u' = ur' + i ui'$的复数版本。使用这样的复数$u'$,问题在于在进行数值评估时,我们需要计算: 对于分析情况,事情更简单,我们将公式重写为: 复杂到实数函数的快速梯度检查 因此,我们可以利用反向模式自动微分提供的有效方法来计算$v^T (2 * CW)$,然后在重建最终复数标量$s$之前,将实部与$ur$和虚部与$ui$进行点积。 #### 为什么不使用复数$u$ 其中$v \in \mathcal{R}^M$,$ur \in \mathcal{R}^N$,$ui \in \mathcal{R}^N$。 $\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}$ 这将需要四次实数到实数有限差分的评估(与上述方法相比多两倍)。由于这种方法没有更多的自由度(相同数量的实值变量),我们尝试在这里获得最快的评估,因此使用上述的另一种公式。 ### 对于具有复杂输出的函数的快速 gradcheck[](#fast-gradcheck-for-functions-with-complex-outputs "跳转到本标题") 就像在慢速情况下一样,我们考虑两个实值函数,并对每个函数使用上面的适当规则。 ## Gradgradcheck 实现[](#gradgradcheck-implementation "跳转到本标题") PyTorch 还提供了一个工具来验证二阶梯度。这里的目标是确保反向实现也是正确可微的,并计算正确的结果。 这个特性是通过考虑函数 $F: x, v \to v^T J_f$F:x,v→vTJf​ 并在这个函数上使用上面定义的 gradcheck 来实现的。请注意,这种情况下的 $v$v 只是一个与 $f(x)$f(x) 相同类型的随机向量。 通过在相同函数 $F$F 上使用快速版本的 gradcheck 来实现 gradgradcheck 的快速版本。