未验证 提交 2005b98b 编写于 作者: G Guoxia Wang 提交者: GitHub

fix recompute no grad warning (#38293)

上级 06cf314a
...@@ -63,7 +63,8 @@ def swith_rng_state(rng_state): ...@@ -63,7 +63,8 @@ def swith_rng_state(rng_state):
class RecomputeFunction(PyLayer): class RecomputeFunction(PyLayer):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args): def forward(ctx, run_function, preserve_rng_state, *args):
check_recompute_necessary(args) if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
# store for recomputing # store for recomputing
ctx.run_function = run_function ctx.run_function = run_function
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册