未验证 提交 88f2f4a4 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel] Support save/load for PipeLineParallel (#34768)

* add save/load for pipelineparallel

* add save/load
上级 b5ec65e1
......@@ -11,12 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
import re
import glob
import os
import numpy as np
import random
from functools import partial
import paddle
from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str
from functools import partial
__all__ = []
......@@ -310,3 +316,48 @@ class PipelineLayer(Layer):
for layer in self.run_function:
input = layer(input)
return input
def save_state_dict(self, path):
if self._topo.get_coord(self.global_rank).data != 0:
return
def _offset_dirname(ckpt_dir, local_layer_idx):
idx = local_layer_idx + self._start_pos
model_rank = self._topo.get_coord(self.global_rank).model
rank_message = "-tensor_" + "{:0>2d}".format(model_rank)
layer_save_path = os.path.join(ckpt_dir,
'layer_{:0>2d}'.format(idx))
layer_save_path = layer_save_path + rank_message + '-model_states.pdparams'
return layer_save_path
os.makedirs(path, exist_ok=True)
for idx, layer in enumerate(self.run_function):
model_save_path = _offset_dirname(path, idx)
if not hasattr(layer, 'state_dict'):
continue
paddle.save(layer.state_dict(), model_save_path)
logger.info("save model state successfully...")
def set_state_dir(self, path):
assert os.path.exists(
path), "{} not found, please check the path".format(path)
for idx, layer in enumerate(self.run_function):
if not hasattr(layer, 'set_state_dict'):
continue
layer_idx = idx + self._start_pos
layer_save_path = os.path.join(path,
'layer_{0:0>2d}'.format(layer_idx))
model_files = glob.glob(layer_save_path + "*model_states.pdparams")
model_files.sort()
mp_rank = self._topo.get_coord(self.global_rank).model
mp_world_size = self._topo.get_dim('model')
num_files = len(model_files)
load_param_path = model_files[mp_rank * num_files // mp_world_size]
model_state_dict = paddle.load(load_param_path)
layer.set_state_dict(model_state_dict)
self._synchronize_shared_weights()
logger.info("load model state successfully...")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import os
import shutil
import tempfile
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from hybrid_parallel_pp_transformer import ModelPipe, set_random_seed
batch_size = 8
length = 8
micro_batch_size = 2
vocab_size = 128
class TestDistPPSaveLoadTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
output_dir = tempfile.mkdtemp()
# warmup step
for step_id in range(2):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
model._layers.save_state_dict(output_dir)
paddle.save(optimizer.state_dict(),
os.path.join(output_dir, "model_state.pdopt"))
# construct data
test_steps = 5
np_data = np.random.randint(
0, vocab_size, size=[test_steps, batch_size, length])
origin_loss = []
for step_id in range(5):
x_data = np_data[step_id, :]
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
origin_loss.append(loss.numpy())
# test step
model._layers.set_state_dir(output_dir)
opt_dict = paddle.load(os.path.join(output_dir, "model_state.pdopt"))
optimizer.set_state_dict(opt_dict)
for step_id in range(5):
x_data = np_data[step_id, :]
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
print("origin loss: ", origin_loss[step_id], "current loss: ",
loss.numpy())
np.testing.assert_allclose(loss.numpy(), origin_loss[step_id])
# finally, remove the model/optimizer path
shutil.rmtree(output_dir)
if __name__ == "__main__":
unittest.main()
......@@ -86,7 +86,8 @@ class TransformerNet(Layer):
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5)
weights = F.softmax(product + mask)
weights = F.dropout(weights, 0.2)
# TODO(shenliang03) For save/load in PipeLineParallel, can’t support dropout temporarily.
# weights = F.dropout(weights, 0.2)
tgt = layers.matmul(weights, v)
residual = tgt
tgt = self.norm1(tgt)
......
......@@ -36,6 +36,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py')
def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_save_load.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册