From 61ca8b3937116798a25a5fd2f22280e994fec7c4 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Sat, 8 Apr 2023 21:31:56 -0500 Subject: [PATCH] [BugFix] Fix random seed bug in hybridparallel (#52656) * add seed control * fix bug --- .../distributed/fleet/layers/mpu/random.py | 14 ++++++++++---- .../meta_parallel/parallel_layers/pp_layers.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/layers/mpu/random.py b/python/paddle/distributed/fleet/layers/mpu/random.py index b5b2010ba97..1447fccb66e 100644 --- a/python/paddle/distributed/fleet/layers/mpu/random.py +++ b/python/paddle/distributed/fleet/layers/mpu/random.py @@ -90,14 +90,20 @@ def model_parallel_random_seed(seed=None): from paddle.distributed import fleet hcg = fleet.get_hybrid_communicate_group() - rank = hcg.get_model_parallel_rank() + + mp_rank = hcg.get_model_parallel_rank() + mp_size = hcg.get_model_parallel_world_size() + + pp_rank = hcg.get_stage_id() + pp_size = hcg.get_pipe_parallel_world_size() if seed: global_seed = seed - local_seed = seed * 1024 + rank * 100 + # dp/sharding seed is same + local_seed = seed + 1 + mp_rank * pp_size + pp_rank else: - global_seed = np.random.randint(0, 655350) - local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1) + global_seed = np.random.randint(0, 10000) + local_seed = global_seed + 1 + mp_rank * pp_size + pp_rank RNG_STATE_TRACKER.reset() RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 30485d8e633..2f5c42a69e3 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -330,6 +330,9 @@ class PipelineLayer(nn.Layer): self._recompute_interval = recompute_interval self.recompute_ctx = recompute_ctx + # Defaults to 1234 to initialize layer parameters + self._base_seed = 1234 + if recompute_interval > 0: assert ( recompute_ctx is not None @@ -626,8 +629,21 @@ class PipelineLayer(nn.Layer): # For 1f1b scheduler, just use run_function list run_function = self.run_function + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) + + orig_rng_state = paddle.get_rng_state() + orig_rng_tracker = get_rng_state_tracker().get_states_tracker() + for index, layer in enumerate(self._layers_desc[start:end]): layer_index = start + index + + # NOTE(shenliang03): need set different seeds for pipeline parameters initialization. + # Since the parameters of model_parallel are controlled by its own RNG_STATE_TRACKER, + # only non-mp parameters in pp are controlled here. + paddle.seed(self._base_seed + layer_index) + if isinstance(layer, nn.Layer): run_function.append(layer) if self._num_virtual_pipeline_stages == 1: @@ -666,6 +682,8 @@ class PipelineLayer(nn.Layer): else: run_function.append(layer) + paddle.set_rng_state(orig_rng_state) + get_rng_state_tracker().set_states_tracker(orig_rng_tracker) return run_function def forward_function(self, start, end): -- GitLab