diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index a3c6a5b5fb665fad5b36541973222d6991c03f74..f546adc65ea7143c92b4483ccaacf9afa54cc333 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -11,12 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import math -import paddle import re +import glob +import os +import numpy as np +import random +from functools import partial + +import paddle from paddle.fluid.dygraph.layers import Layer from ...utils.log_util import logger, layer_to_str -from functools import partial __all__ = [] @@ -310,3 +316,48 @@ class PipelineLayer(Layer): for layer in self.run_function: input = layer(input) return input + + def save_state_dict(self, path): + if self._topo.get_coord(self.global_rank).data != 0: + return + + def _offset_dirname(ckpt_dir, local_layer_idx): + idx = local_layer_idx + self._start_pos + model_rank = self._topo.get_coord(self.global_rank).model + rank_message = "-tensor_" + "{:0>2d}".format(model_rank) + layer_save_path = os.path.join(ckpt_dir, + 'layer_{:0>2d}'.format(idx)) + layer_save_path = layer_save_path + rank_message + '-model_states.pdparams' + return layer_save_path + + os.makedirs(path, exist_ok=True) + for idx, layer in enumerate(self.run_function): + model_save_path = _offset_dirname(path, idx) + if not hasattr(layer, 'state_dict'): + continue + paddle.save(layer.state_dict(), model_save_path) + + logger.info("save model state successfully...") + + def set_state_dir(self, path): + assert os.path.exists( + path), "{} not found, please check the path".format(path) + + for idx, layer in enumerate(self.run_function): + if not hasattr(layer, 'set_state_dict'): + continue + layer_idx = idx + self._start_pos + layer_save_path = os.path.join(path, + 'layer_{0:0>2d}'.format(layer_idx)) + model_files = glob.glob(layer_save_path + "*model_states.pdparams") + model_files.sort() + mp_rank = self._topo.get_coord(self.global_rank).model + mp_world_size = self._topo.get_dim('model') + num_files = len(model_files) + + load_param_path = model_files[mp_rank * num_files // mp_world_size] + model_state_dict = paddle.load(load_param_path) + layer.set_state_dict(model_state_dict) + + self._synchronize_shared_weights() + logger.info("load model state successfully...") diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_save_load.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_save_load.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e27bbb41a8a4269806e4df1dd5d2ebc9bd3acb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_save_load.py @@ -0,0 +1,114 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import os +import shutil +import tempfile +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from hybrid_parallel_pp_transformer import ModelPipe, set_random_seed + +batch_size = 8 +length = 8 +micro_batch_size = 2 +vocab_size = 128 + + +class TestDistPPSaveLoadTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + topology = hcg.topology() + set_random_seed(1024, dp_id, rank_id) + + model = ModelPipe(topology) + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer = paddle.optimizer.SGD(learning_rate=scheduler, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + output_dir = tempfile.mkdtemp() + + # warmup step + for step_id in range(2): + x_data = np.random.randint(0, vocab_size, size=[batch_size, length]) + x = paddle.to_tensor(x_data) + x.stop_gradient = True + loss = model.train_batch([x, x], optimizer, scheduler) + + model._layers.save_state_dict(output_dir) + paddle.save(optimizer.state_dict(), + os.path.join(output_dir, "model_state.pdopt")) + + # construct data + test_steps = 5 + np_data = np.random.randint( + 0, vocab_size, size=[test_steps, batch_size, length]) + + origin_loss = [] + for step_id in range(5): + x_data = np_data[step_id, :] + x = paddle.to_tensor(x_data) + x.stop_gradient = True + loss = model.train_batch([x, x], optimizer, scheduler) + origin_loss.append(loss.numpy()) + + # test step + model._layers.set_state_dir(output_dir) + opt_dict = paddle.load(os.path.join(output_dir, "model_state.pdopt")) + optimizer.set_state_dict(opt_dict) + + for step_id in range(5): + x_data = np_data[step_id, :] + x = paddle.to_tensor(x_data) + x.stop_gradient = True + loss = model.train_batch([x, x], optimizer, scheduler) + print("origin loss: ", origin_loss[step_id], "current loss: ", + loss.numpy()) + np.testing.assert_allclose(loss.numpy(), origin_loss[step_id]) + + # finally, remove the model/optimizer path + shutil.rmtree(output_dir) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py index 62b1a8b1da6797c602667e58cdbfc11c3979a5dc..524099c6ab05e882f5815d75522a8ab005bc0a2c 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py @@ -86,7 +86,8 @@ class TransformerNet(Layer): product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5) weights = F.softmax(product + mask) - weights = F.dropout(weights, 0.2) + # TODO(shenliang03) For save/load in PipeLineParallel, can’t support dropout temporarily. + # weights = F.dropout(weights, 0.2) tgt = layers.matmul(weights, v) residual = tgt tgt = self.norm1(tgt) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py index 62e781678c9fc8df45544cb287f759bd26373e1f..003e0c1685cae7d5d63fad87654728f5f40eb8ab 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py @@ -36,6 +36,9 @@ class TestHybridPipeParallel(TestMultipleGpus): def test_hybrid_parallel_transformer(self): self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py') + def test_hybrid_parallel_transformer(self): + self.run_mnist_2gpu('hybrid_parallel_pp_save_load.py') + if __name__ == "__main__": unittest.main()