未验证 提交 d9fa8fde 编写于 作者: Y Yuang Liu 提交者: GitHub

fix vpp mp init prob (#54903)

* fix vpp mp init prob

* triger ci
上级 8670df60
...@@ -637,6 +637,13 @@ class PipelineLayer(nn.Layer): ...@@ -637,6 +637,13 @@ class PipelineLayer(nn.Layer):
logger.info(f"loss: {self._loss_fn.__class__.__name__}") logger.info(f"loss: {self._loss_fn.__class__.__name__}")
def _build_layer_with_interleave(self): def _build_layer_with_interleave(self):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
for i in range(len(self._start_poss)): for i in range(len(self._start_poss)):
start = self._start_poss[i] start = self._start_poss[i]
end = self._end_poss[i] end = self._end_poss[i]
...@@ -647,11 +654,24 @@ class PipelineLayer(nn.Layer): ...@@ -647,11 +654,24 @@ class PipelineLayer(nn.Layer):
self._model_chunks.append(chunk) self._model_chunks.append(chunk)
self.add_sublayer(str(start), chunk) self.add_sublayer(str(start), chunk)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
def _build_layer(self): def _build_layer(self):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
start = self._start_pos start = self._start_pos
end = self._end_pos end = self._end_pos
self.run_function = self._build_layer_impl(start, end) self.run_function = self._build_layer_impl(start, end)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
def _build_layer_impl(self, start, end): def _build_layer_impl(self, start, end):
if self._num_virtual_pipeline_stages > 1: if self._num_virtual_pipeline_stages > 1:
# For interleave scheduler, all layers relating with one model chunk will be saved in PipelineLayerChunk # For interleave scheduler, all layers relating with one model chunk will be saved in PipelineLayerChunk
...@@ -660,13 +680,6 @@ class PipelineLayer(nn.Layer): ...@@ -660,13 +680,6 @@ class PipelineLayer(nn.Layer):
# For 1f1b scheduler, just use run_function list # For 1f1b scheduler, just use run_function list
run_function = self.run_function run_function = self.run_function
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
for index, layer in enumerate(self._layers_desc[start:end]): for index, layer in enumerate(self._layers_desc[start:end]):
layer_index = start + index layer_index = start + index
...@@ -720,9 +733,6 @@ class PipelineLayer(nn.Layer): ...@@ -720,9 +733,6 @@ class PipelineLayer(nn.Layer):
self.add_sublayer(str(layer_index), model) self.add_sublayer(str(layer_index), model)
else: else:
run_function.append(layer) run_function.append(layer)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
return run_function return run_function
def forward_function(self, start, end): def forward_function(self, start, end):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册