未验证 提交 ed6f1f90 编写于 作者: H Haohongxiang 提交者: GitHub

fix warning infos of recompute (#43495)

上级 b1f77b4d
......@@ -73,8 +73,6 @@ class EagerRecomputeFunction(EagerPyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
# store for recomputing
ctx.run_function = run_function
......@@ -211,8 +209,6 @@ class RecomputeFunction(PyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
# store for recomputing
ctx.run_function = run_function
......@@ -466,6 +462,9 @@ def recompute(function, *args, **kwargs):
raise ValueError("Unexpected keyword arguments: " +
",".join(arg for arg in kwargs))
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
if in_dygraph_mode():
return EagerRecomputeFunction.apply(function, preserve, *args)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册