From 1e39b1ecc9b63db0cf358e6ce53d13981cd37063 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Tue, 5 Sep 2023 10:51:04 +0800 Subject: [PATCH] [incubate/new_frl] Support detach of EagerParamBase in recompute (#56767) --- .../distributed/fleet/recompute/recompute.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) mode change 100755 => 100644 python/paddle/distributed/fleet/recompute/recompute.py diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py old mode 100755 new mode 100644 index b0b9885c33e..2fc65fa3abd --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib +import copy import weakref import paddle @@ -21,6 +22,7 @@ from paddle.autograd import PyLayer from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( get_rng_state_tracker, ) +from paddle.fluid.framework import EagerParamBase from paddle.framework import core, in_dygraph_mode from ..utils.log_util import logger @@ -28,6 +30,17 @@ from ..utils.log_util import logger __all__ = [] +def _varbase_help(param): + state = copy.deepcopy(param.__dict__) + new_param = EagerParamBase( + shape=param.shape, dtype=param.dtype, name=param.name, **state + ) + + param._share_buffer_to(new_param) + + return new_param + + def detach_variable(inputs): out = [] for inp in inputs: @@ -35,6 +48,10 @@ def detach_variable(inputs): out.append(inp) continue + if isinstance(inp, EagerParamBase): + out.append(_varbase_help(inp)) + continue + x = inp.detach() x.stop_gradient = inp.stop_gradient out.append(x) -- GitLab