diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py old mode 100755 new mode 100644 index b0b9885c33efdb648d213918cc7cc1423943362a..2fc65fa3abd6beba0aa71f659802a429254460d4 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib +import copy import weakref import paddle @@ -21,6 +22,7 @@ from paddle.autograd import PyLayer from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( get_rng_state_tracker, ) +from paddle.fluid.framework import EagerParamBase from paddle.framework import core, in_dygraph_mode from ..utils.log_util import logger @@ -28,6 +30,17 @@ from ..utils.log_util import logger __all__ = [] +def _varbase_help(param): + state = copy.deepcopy(param.__dict__) + new_param = EagerParamBase( + shape=param.shape, dtype=param.dtype, name=param.name, **state + ) + + param._share_buffer_to(new_param) + + return new_param + + def detach_variable(inputs): out = [] for inp in inputs: @@ -35,6 +48,10 @@ def detach_variable(inputs): out.append(inp) continue + if isinstance(inp, EagerParamBase): + out.append(_varbase_help(inp)) + continue + x = inp.detach() x.stop_gradient = inp.stop_gradient out.append(x)