diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 59bcf50ffb7988b7d505e78e259f97be1e338b3e..6c8badd64e161b11d1d44f061b311af0551b2ce9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -19,7 +19,7 @@ from paddle.fluid import core from paddle import _C_ops from paddle.autograd import PyLayer, EagerPyLayer from paddle.fluid import framework -from ...utils.recompute import check_recompute_necessary, detach_variable +from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker from ..parallel_layers.random import get_rng_state_tracker from paddle.fluid.framework import in_dygraph_mode @@ -151,20 +151,6 @@ def _merge_activation(tensor): return _all_gather(tensor, group=mp_group) -@contextlib.contextmanager -def _swith_rng_state_tracker(rng_state, tracker): - orig_cuda_rng_state = paddle.get_cuda_rng_state() - orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker() - - paddle.set_cuda_rng_state(rng_state) - get_rng_state_tracker().set_states_tracker(tracker) - try: - yield - finally: - paddle.set_cuda_rng_state(orig_cuda_rng_state) - get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker) - - class _HPEagerRecomputeFunction(EagerPyLayer): """ Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: @@ -261,8 +247,8 @@ class _HPEagerRecomputeFunction(EagerPyLayer): tracer._has_grad = True # need restore auto_cast state as well as w/b list - with _swith_rng_state_tracker(ctx.fwd_cuda_rng_state, - ctx.fwd_cuda_rng_state_tracker): + with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, @@ -393,8 +379,8 @@ class _HPRecomputeFunction(PyLayer): tracer._has_grad = True # need restore auto_cast state as well as w/b list - with _swith_rng_state_tracker(ctx.fwd_cuda_rng_state, - ctx.fwd_cuda_rng_state_tracker): + with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index c767be77d83841ae0c94ac3cb841325573b8329d..b8d1c881a08f93f4e38837516239efb37c53d8cf 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -53,18 +53,24 @@ def check_recompute_necessary(inputs): @contextlib.contextmanager -def swith_rng_state(rng_state): +def swith_rng_state_tracker(rng_state, tracker): + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker orig_cuda_rng_state = paddle.get_cuda_rng_state() + orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker() + paddle.set_cuda_rng_state(rng_state) + get_rng_state_tracker().set_states_tracker(tracker) try: yield finally: paddle.set_cuda_rng_state(orig_cuda_rng_state) + get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker) class EagerRecomputeFunction(EagerPyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) @@ -98,6 +104,8 @@ class EagerRecomputeFunction(EagerPyLayer): "Recompute with RNG perserve is not support current device: {}.". format(cur_device)) ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( + ).get_states_tracker() # TODO support AMP tracer = framework._dygraph_tracer() @@ -126,6 +134,7 @@ class EagerRecomputeFunction(EagerPyLayer): @staticmethod def backward(ctx, *args): + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker with paddle.fluid.dygraph.guard(): # TODO need to check the recompute calling is vaild or not @@ -143,7 +152,8 @@ class EagerRecomputeFunction(EagerPyLayer): # NOTE support AMP # need restore auto_cast state as well as w/b list if ctx.preserve_rng_state: - with swith_rng_state(ctx.fw_cuda_rng_state): + with swith_rng_state_tracker(ctx.fw_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, @@ -199,6 +209,7 @@ class EagerRecomputeFunction(EagerPyLayer): class RecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) @@ -232,6 +243,8 @@ class RecomputeFunction(PyLayer): "Recompute with RNG perserve is not support current device: {}.". format(cur_device)) ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( + ).get_states_tracker() # TODO support AMP tracer = framework._dygraph_tracer() @@ -260,6 +273,7 @@ class RecomputeFunction(PyLayer): @staticmethod def backward(ctx, *args): + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker with paddle.fluid.dygraph.guard(): # TODO need to check the recompute calling is vaild or not @@ -277,7 +291,8 @@ class RecomputeFunction(PyLayer): # NOTE support AMP # need restore auto_cast state as well as w/b list if ctx.preserve_rng_state: - with swith_rng_state(ctx.fw_cuda_rng_state): + with swith_rng_state_tracker(ctx.fw_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list,