未验证 提交 61ca8b39 编写于 作者: S ShenLiang 提交者: GitHub

[BugFix] Fix random seed bug in hybridparallel (#52656)

* add seed control

* fix bug
上级 6cd095fc
......@@ -90,14 +90,20 @@ def model_parallel_random_seed(seed=None):
from paddle.distributed import fleet
hcg = fleet.get_hybrid_communicate_group()
rank = hcg.get_model_parallel_rank()
mp_rank = hcg.get_model_parallel_rank()
mp_size = hcg.get_model_parallel_world_size()
pp_rank = hcg.get_stage_id()
pp_size = hcg.get_pipe_parallel_world_size()
if seed:
global_seed = seed
local_seed = seed * 1024 + rank * 100
# dp/sharding seed is same
local_seed = seed + 1 + mp_rank * pp_size + pp_rank
else:
global_seed = np.random.randint(0, 655350)
local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)
global_seed = np.random.randint(0, 10000)
local_seed = global_seed + 1 + mp_rank * pp_size + pp_rank
RNG_STATE_TRACKER.reset()
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
......
......@@ -330,6 +330,9 @@ class PipelineLayer(nn.Layer):
self._recompute_interval = recompute_interval
self.recompute_ctx = recompute_ctx
# Defaults to 1234 to initialize layer parameters
self._base_seed = 1234
if recompute_interval > 0:
assert (
recompute_ctx is not None
......@@ -626,8 +629,21 @@ class PipelineLayer(nn.Layer):
# For 1f1b scheduler, just use run_function list
run_function = self.run_function
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
for index, layer in enumerate(self._layers_desc[start:end]):
layer_index = start + index
# NOTE(shenliang03): need set different seeds for pipeline parameters initialization.
# Since the parameters of model_parallel are controlled by its own RNG_STATE_TRACKER,
# only non-mp parameters in pp are controlled here.
paddle.seed(self._base_seed + layer_index)
if isinstance(layer, nn.Layer):
run_function.append(layer)
if self._num_virtual_pipeline_stages == 1:
......@@ -666,6 +682,8 @@ class PipelineLayer(nn.Layer):
else:
run_function.append(layer)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
return run_function
def forward_function(self, start, end):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册