diff --git a/python/paddle/distributed/fleet/layers/mpu/random.py b/python/paddle/distributed/fleet/layers/mpu/random.py index b5b2010ba97a596f66b280c5795ad903780a00a8..1447fccb66eb56586872dd4add56a28b753d4ce5 100644 --- a/python/paddle/distributed/fleet/layers/mpu/random.py +++ b/python/paddle/distributed/fleet/layers/mpu/random.py @@ -90,14 +90,20 @@ def model_parallel_random_seed(seed=None): from paddle.distributed import fleet hcg = fleet.get_hybrid_communicate_group() - rank = hcg.get_model_parallel_rank() + + mp_rank = hcg.get_model_parallel_rank() + mp_size = hcg.get_model_parallel_world_size() + + pp_rank = hcg.get_stage_id() + pp_size = hcg.get_pipe_parallel_world_size() if seed: global_seed = seed - local_seed = seed * 1024 + rank * 100 + # dp/sharding seed is same + local_seed = seed + 1 + mp_rank * pp_size + pp_rank else: - global_seed = np.random.randint(0, 655350) - local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1) + global_seed = np.random.randint(0, 10000) + local_seed = global_seed + 1 + mp_rank * pp_size + pp_rank RNG_STATE_TRACKER.reset() RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 30485d8e633359aca2c2d9da13dc97dad47afeac..2f5c42a69e362fdcbf063aa336257a10bc0af1ab 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -330,6 +330,9 @@ class PipelineLayer(nn.Layer): self._recompute_interval = recompute_interval self.recompute_ctx = recompute_ctx + # Defaults to 1234 to initialize layer parameters + self._base_seed = 1234 + if recompute_interval > 0: assert ( recompute_ctx is not None @@ -626,8 +629,21 @@ class PipelineLayer(nn.Layer): # For 1f1b scheduler, just use run_function list run_function = self.run_function + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) + + orig_rng_state = paddle.get_rng_state() + orig_rng_tracker = get_rng_state_tracker().get_states_tracker() + for index, layer in enumerate(self._layers_desc[start:end]): layer_index = start + index + + # NOTE(shenliang03): need set different seeds for pipeline parameters initialization. + # Since the parameters of model_parallel are controlled by its own RNG_STATE_TRACKER, + # only non-mp parameters in pp are controlled here. + paddle.seed(self._base_seed + layer_index) + if isinstance(layer, nn.Layer): run_function.append(layer) if self._num_virtual_pipeline_stages == 1: @@ -666,6 +682,8 @@ class PipelineLayer(nn.Layer): else: run_function.append(layer) + paddle.set_rng_state(orig_rng_state) + get_rng_state_tracker().set_states_tracker(orig_rng_tracker) return run_function def forward_function(self, start, end):