未验证 提交 5e97651e 编写于 作者: H Haohongxiang 提交者: GitHub

fix warning infos of recompute hybrid in eager mode (#47288)

上级 3f64a2c3
...@@ -86,7 +86,6 @@ class _HPRecomputeFunction(PyLayer): ...@@ -86,7 +86,6 @@ class _HPRecomputeFunction(PyLayer):
*args, *args,
**kwargs **kwargs
): ):
check_recompute_necessary(args)
# store for recomputing # store for recomputing
ctx.run_function = run_function ctx.run_function = run_function
...@@ -264,6 +263,9 @@ def recompute_hybrid(ctx, function, *args, **kwargs): ...@@ -264,6 +263,9 @@ def recompute_hybrid(ctx, function, *args, **kwargs):
offload = ctx.get('offload', False) offload = ctx.get('offload', False)
partition = ctx.get('partition', False) partition = ctx.get('partition', False)
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
all_outputs = [] all_outputs = []
_HPRecomputeFunction.apply( _HPRecomputeFunction.apply(
function, all_outputs, mp_group, offload, partition, *args, **kwargs function, all_outputs, mp_group, offload, partition, *args, **kwargs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册