未验证 提交 0dff82c2 编写于 作者: J JZ-LIANG 提交者: GitHub

Recompute: fix bug with transformer attention mask (#34664)

上级 b7355d8e
......@@ -145,23 +145,25 @@ class RecomputeFunction(PyLayer):
# run backward() with only tensor that requires grad
forward_outputs_with_grad = []
backward_inputs = list(args)
# NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
# pylayer will force the stop_gradient of attention mask to be False, which will make the number of
# tensor that need grad does not match.
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this recompute() is not necessary"
)
assert len(backward_inputs) == len(
forward_outputs_with_grad
), "number of forward outputs is [{}], but the backward got [{}] inputs".format(
len(forward_outputs_with_grad), len(backward_inputs))
# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)
grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册