未验证 提交 d9fa8fde 编写于 作者: Y Yuang Liu 提交者: GitHub

fix vpp mp init prob (#54903)

* fix vpp mp init prob

* triger ci
上级 8670df60
......@@ -637,6 +637,13 @@ class PipelineLayer(nn.Layer):
logger.info(f"loss: {self._loss_fn.__class__.__name__}")
def _build_layer_with_interleave(self):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
for i in range(len(self._start_poss)):
start = self._start_poss[i]
end = self._end_poss[i]
......@@ -647,11 +654,24 @@ class PipelineLayer(nn.Layer):
self._model_chunks.append(chunk)
self.add_sublayer(str(start), chunk)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
def _build_layer(self):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
start = self._start_pos
end = self._end_pos
self.run_function = self._build_layer_impl(start, end)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
def _build_layer_impl(self, start, end):
if self._num_virtual_pipeline_stages > 1:
# For interleave scheduler, all layers relating with one model chunk will be saved in PipelineLayerChunk
......@@ -660,13 +680,6 @@ class PipelineLayer(nn.Layer):
# For 1f1b scheduler, just use run_function list
run_function = self.run_function
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
for index, layer in enumerate(self._layers_desc[start:end]):
layer_index = start + index
......@@ -720,9 +733,6 @@ class PipelineLayer(nn.Layer):
self.add_sublayer(str(layer_index), model)
else:
run_function.append(layer)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
return run_function
def forward_function(self, start, end):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册