未验证 提交 a9cc0274 编写于 作者: Y Yuang Liu 提交者: GitHub

[dygraph hybrid pp for interleave] Save/Load for interleaved pipeline. (#45797)

上级 960109af
......@@ -657,21 +657,42 @@ class PipelineLayer(Layer):
if self._topo.get_coord(self.global_rank).data != 0:
return
def _offset_dirname(ckpt_dir, local_layer_idx):
idx = local_layer_idx + self._start_pos
def _offset_dirname(ckpt_dir, local_layer_idx, local_chunk_id=None):
if self._num_virtual_pipeline_stages == 1:
pos_offset = self._start_pos
else:
assert hasattr(self, '_start_poss')
assert local_chunk_id < len(self._start_poss)
pos_offset = self._start_poss[local_chunk_id]
idx = local_layer_idx + pos_offset
model_rank = self._topo.get_coord(self.global_rank).model
rank_message = "-tensor_" + "{:0>2d}".format(model_rank)
virtual_pipeline_stage_message = ""
if self._num_virtual_pipeline_stages > 1:
# add virtual pipeline info to the save path
assert local_chunk_id is not None
virtual_pipeline_stage_message = "-virtual_pp_stage_{:0>2d}".format(
local_chunk_id)
layer_save_path = os.path.join(ckpt_dir,
'layer_{:0>2d}'.format(idx))
layer_save_path = layer_save_path + rank_message + '-model_states.pdparams'
layer_save_path = layer_save_path + virtual_pipeline_stage_message + rank_message + '-model_states.pdparams'
return layer_save_path
def _save_model(run_functions, local_chunk_id=None):
for idx, layer in enumerate(run_functions):
model_save_path = _offset_dirname(path, idx, local_chunk_id)
if not hasattr(layer, 'state_dict'):
continue
paddle.save(layer.state_dict(), model_save_path)
os.makedirs(path, exist_ok=True)
for idx, layer in enumerate(self.run_function):
model_save_path = _offset_dirname(path, idx)
if not hasattr(layer, 'state_dict'):
continue
paddle.save(layer.state_dict(), model_save_path)
if self._num_virtual_pipeline_stages > 1:
logger.info("save model state for virtual pipeline stage...")
for chunk_id in range(len(self._model_chunks)):
run_function = self._model_chunks[chunk_id].get_run_function()
_save_model(run_function, chunk_id)
else:
_save_model(self.run_function)
logger.info("save model state successfully...")
......@@ -679,21 +700,43 @@ class PipelineLayer(Layer):
assert os.path.exists(
path), "{} not found, please check the path".format(path)
for idx, layer in enumerate(self.run_function):
if not hasattr(layer, 'set_state_dict'):
continue
layer_idx = idx + self._start_pos
layer_save_path = os.path.join(path,
'layer_{0:0>2d}'.format(layer_idx))
model_files = glob.glob(layer_save_path + "*model_states.pdparams")
model_files.sort()
mp_rank = self._topo.get_coord(self.global_rank).model
mp_world_size = self._topo.get_dim('model')
num_files = len(model_files)
load_param_path = model_files[mp_rank * num_files // mp_world_size]
model_state_dict = paddle.load(load_param_path)
layer.set_state_dict(model_state_dict)
def _load_model(run_functions, local_chunk_id=None):
for idx, layer in enumerate(run_functions):
if not hasattr(layer, 'set_state_dict'):
continue
if self._num_virtual_pipeline_stages == 1:
pos_offset = self._start_pos
else:
assert hasattr(self, '_start_poss')
assert local_chunk_id < len(self._start_poss)
pos_offset = self._start_poss[local_chunk_id]
layer_idx = idx + pos_offset
layer_save_path = os.path.join(
path, 'layer_{0:0>2d}'.format(layer_idx))
if self._num_virtual_pipeline_stages > 1:
# add virtual pipeline info to the path
assert local_chunk_id is not None
layer_save_path = layer_save_path + "-virtual_pp_stage_{:0>2d}".format(
local_chunk_id)
model_files = glob.glob(layer_save_path +
"*model_states.pdparams")
model_files.sort()
mp_rank = self._topo.get_coord(self.global_rank).model
mp_world_size = self._topo.get_dim('model')
num_files = len(model_files)
load_param_path = model_files[mp_rank * num_files //
mp_world_size]
model_state_dict = paddle.load(load_param_path)
layer.set_state_dict(model_state_dict)
if self._num_virtual_pipeline_stages > 1:
logger.info("load model state for virtual pipeline stage...")
for chunk_id in range(len(self._model_chunks)):
run_function = self._model_chunks[chunk_id].get_run_function()
_load_model(run_function, chunk_id)
else:
_load_model(self.run_function)
self._synchronize_shared_weights()
logger.info("load model state successfully...")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import os
import shutil
import tempfile
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from hybrid_parallel_pp_transformer_with_virtual_stage import ModelPipe, set_random_seed
batch_size = 8
length = 8
micro_batch_size = 2
vocab_size = 128
class TestDistPPSaveLoadTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2],
values=[0.001, 0.002],
verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
output_dir = tempfile.mkdtemp()
# warmup step
for step_id in range(2):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
model._layers.save_state_dict(output_dir)
paddle.save(optimizer.state_dict(),
os.path.join(output_dir, "model_state.pdopt"))
# construct data
test_steps = 5
np_data = np.random.randint(0,
vocab_size,
size=[test_steps, batch_size, length])
origin_loss = []
for step_id in range(5):
x_data = np_data[step_id, :]
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
origin_loss.append(loss.numpy())
# test step
model._layers.set_state_dir(output_dir)
opt_dict = paddle.load(os.path.join(output_dir, "model_state.pdopt"))
optimizer.set_state_dict(opt_dict)
for step_id in range(5):
x_data = np_data[step_id, :]
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
print("origin loss: ", origin_loss[step_id], "current loss: ",
loss.numpy())
np.testing.assert_allclose(loss.numpy(), origin_loss[step_id])
# finally, remove the model/optimizer path
shutil.rmtree(output_dir)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -30,6 +30,10 @@ class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus):
self.run_mnist_2gpu(
'hybrid_parallel_pp_transformer_with_virtual_stage.py')
def test_hybrid_parallel_save_load_with_virtual_stage(self):
self.run_mnist_2gpu(
'hybrid_parallel_pp_save_load_with_virtual_stage.py')
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册