From 99f601885387a90bb9185f2d3d7e1b7b5ed859f5 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 2 Nov 2022 10:50:32 +0800 Subject: [PATCH] support unbalanced data for pipeline (#47199) * add unbalanced data * fix utest --- .../fleet/meta_parallel/pipeline_parallel.py | 78 ++++++++++--------- ...parallel_pp_transformer_unbalanced_data.py | 67 ++++++++++++++++ ...test_parallel_dygraph_pipeline_parallel.py | 7 ++ 3 files changed, 115 insertions(+), 37 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_unbalanced_data.py diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 89a3d61921..b7d1eb39c0 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -355,51 +355,55 @@ class PipelineParallel(MetaParallelBase): input_tensor_grad = input_tensor.grad return input_tensor_grad - def _load_micro_batch(self, cache_id): - inputs = self.data + def _check_data_vaild(self, data): + batch_size = data.shape[0] + assert self.micro_batch_size * self.accumulate_steps == batch_size, ( + "batch_size needs to be divisible by micro_batch_size. Currently, " + "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." + % (batch_size, self.micro_batch_size, self.accumulate_steps) + ) + + def _load_micro_batch_impl(self, inputs, cache_id): begin = cache_id * self.micro_batch_size end = begin + self.micro_batch_size - # The virtual first and last pipeline stage need data, all others don't need. + if isinstance(inputs, tuple): + output = [] + for data in inputs: + if isinstance(data, list): + assert ( + len(data) == self.accumulate_steps + ), "length of data should be %d, but it is %d" % ( + self.accumulate_steps, + len(data), + ) + output.append(data[cache_id].detach()) + else: + self._check_data_vaild(data) + output.append(data[begin:end, :].detach()) + return tuple(output) + + elif isinstance(inputs, list): + assert ( + len(inputs) == self.accumulate_steps + ), "length of data should be %d, but it is %d" % ( + self.accumulate_steps, + len(inputs), + ) + return inputs[cache_id].detach() + else: + self._check_data_vaild(inputs) + return inputs[begin:end, :].detach() + + def _load_micro_batch(self, cache_id): + inputs = self.data if self.is_pipeline_first_stage(): assert len(inputs) == 2, "length of input should be 2" - if isinstance(inputs[0], tuple): - assert ( - len(inputs[0]) > 1 - ), "If you use tuple for input data, it should have at least two inputs." - batch_size = inputs[0][0].shape[0] - assert ( - self.micro_batch_size * self.accumulate_steps == batch_size - ), ( - "batch_size needs to be divisible by micro_batch_size. Currently, " - "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." - % (batch_size, self.micro_batch_size, self.accumulate_steps) - ) - data = [input[begin:end, :].detach() for input in inputs[0]] - return tuple(data) - else: - batch_size = inputs[0].shape[0] - assert ( - self.micro_batch_size * self.accumulate_steps == batch_size - ) - return inputs[0][begin:end, :].detach() + return self._load_micro_batch_impl(inputs[0], cache_id) elif self.is_pipeline_last_stage(): assert len(inputs) == 2, "length of input should be 2" - if isinstance(inputs[1], tuple): - batch_size = inputs[1][0].shape[0] - assert ( - self.micro_batch_size * self.accumulate_steps == batch_size - ) - data = [input[begin:end, :].detach() for input in inputs[1]] - return tuple(data) - else: - batch_size = inputs[1].shape[0] - assert ( - self.micro_batch_size * self.accumulate_steps == batch_size - ) - return inputs[1][begin:end, :].detach() + return self._load_micro_batch_impl(inputs[1], cache_id) else: - # No data input is required for other stages inputs = None def _broadcast_final_loss(self): diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_unbalanced_data.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_unbalanced_data.py new file mode 100644 index 0000000000..1db15407a5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_unbalanced_data.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from hybrid_parallel_pp_transformer import ( + TestDistPPTraning, + set_random_seed, + ModelPipe, + batch_size, + length, + micro_batch_size, + vocab_size, +) + + +class TestDistPPTraningUnbalancedData(TestDistPPTraning): + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + topology = hcg.topology() + set_random_seed(1024, dp_id, rank_id) + + model = ModelPipe(topology) + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True + ) + optimizer = paddle.optimizer.SGD( + learning_rate=scheduler, parameters=model.parameters() + ) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + for step_id in range(5): + x = [] + for _ in range(batch_size // micro_batch_size): + size = micro_batch_size + x_data = np.random.randint(0, vocab_size, size=[size, length]) + x.append(paddle.to_tensor(x_data)) + e_loss = model.eval_batch([x, x], True) + loss = model.train_batch([x, x], optimizer, scheduler) + + # TODO(shenliang03) add utest for loss + if pp_id != 0: + np.testing.assert_allclose(loss.numpy(), e_loss.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py index f45104de32..275c3721d6 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel.py @@ -64,6 +64,13 @@ class TestHybridPipeParallel(TestMultipleGpus): self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py') self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py', eager_mode=False) + def test_hybrid_parallel_transformer_unbalanced_data(self): + self.run_mnist_2gpu('hybrid_parallel_pp_transformer_unbalanced_data.py') + self.run_mnist_2gpu( + 'hybrid_parallel_pp_transformer_unbalanced_data.py', + eager_mode=False, + ) + if __name__ == "__main__": os.environ["FLAGS_enable_eager_mode"] = "1" -- GitLab