diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 93e9b7d44c6fc46a562202a3a5a0e9cd90c4dd94..8773f9d82ca8b34d2160a14252dd518047fa57b1 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -637,6 +637,13 @@ class PipelineLayer(nn.Layer): logger.info(f"loss: {self._loss_fn.__class__.__name__}") def _build_layer_with_interleave(self): + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) + + orig_rng_state = paddle.get_rng_state() + orig_rng_tracker = get_rng_state_tracker().get_states_tracker() + for i in range(len(self._start_poss)): start = self._start_poss[i] end = self._end_poss[i] @@ -647,11 +654,24 @@ class PipelineLayer(nn.Layer): self._model_chunks.append(chunk) self.add_sublayer(str(start), chunk) + paddle.set_rng_state(orig_rng_state) + get_rng_state_tracker().set_states_tracker(orig_rng_tracker) + def _build_layer(self): + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) + + orig_rng_state = paddle.get_rng_state() + orig_rng_tracker = get_rng_state_tracker().get_states_tracker() + start = self._start_pos end = self._end_pos self.run_function = self._build_layer_impl(start, end) + paddle.set_rng_state(orig_rng_state) + get_rng_state_tracker().set_states_tracker(orig_rng_tracker) + def _build_layer_impl(self, start, end): if self._num_virtual_pipeline_stages > 1: # For interleave scheduler, all layers relating with one model chunk will be saved in PipelineLayerChunk @@ -660,13 +680,6 @@ class PipelineLayer(nn.Layer): # For 1f1b scheduler, just use run_function list run_function = self.run_function - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( - get_rng_state_tracker, - ) - - orig_rng_state = paddle.get_rng_state() - orig_rng_tracker = get_rng_state_tracker().get_states_tracker() - for index, layer in enumerate(self._layers_desc[start:end]): layer_index = start + index @@ -720,9 +733,6 @@ class PipelineLayer(nn.Layer): self.add_sublayer(str(layer_index), model) else: run_function.append(layer) - - paddle.set_rng_state(orig_rng_state) - get_rng_state_tracker().set_states_tracker(orig_rng_tracker) return run_function def forward_function(self, start, end):