未验证 提交 28795771 编写于 作者: L Leo Chen 提交者: GitHub

run recompute's real backward with amp disabled (#40042)

上级 1c4e3e5d
...@@ -182,9 +182,10 @@ class RecomputeFunction(PyLayer): ...@@ -182,9 +182,10 @@ class RecomputeFunction(PyLayer):
"none of output has requires_grad=True, this recompute() is not necessary" "none of output has requires_grad=True, this recompute() is not necessary"
) )
# actually backward # actually backward
paddle.autograd.backward(forward_outputs_with_grad, with paddle.amp.auto_cast(enable=False):
backward_inputs_with_grad) paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)
grads = list(inp._grad_ivar() for inp in detached_inputs grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase)) if isinstance(inp, core.VarBase))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册