From 561dc719da979f25287cdf57f7d90572ad5eae81 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Sun, 25 Apr 2021 22:45:32 +0800 Subject: [PATCH] add pipeline for dynamic graph (#32511) * add pp dygraph, test=develop --- .../distributed/fleet/base/fleet_base.py | 4 + .../fleet/meta_parallel/__init__.py | 1 + .../fleet/meta_parallel/meta_parallel_base.py | 1 + .../fleet/meta_parallel/pipeline_parallel.py | 427 ++++++++++++++++++ .../fleet/meta_parallel/pp_utils/__init__.py | 15 + .../fleet/meta_parallel/pp_utils/utils.py | 111 +++++ .../fluid/tests/unittests/CMakeLists.txt | 3 + .../unittests/hybrid_parallel_pp_model.py | 93 ++++ .../tests/unittests/test_pipeline_parallel.py | 29 ++ python/setup.py.in | 1 + 10 files changed, 685 insertions(+) create mode 100644 python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py create mode 100644 python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py create mode 100644 python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_pp_model.py create mode 100644 python/paddle/fluid/tests/unittests/test_pipeline_parallel.py diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 178edc0fe88..9e200f4ee5f 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -29,6 +29,7 @@ from paddle.fluid.dygraph import parallel_helper from . import topology as tp from .topology import ParallelMode from ..meta_parallel import ModelParallel +from ..meta_parallel import PipelineParallel from ..meta_optimizers import HybridParallelOptimizer from ..meta_optimizers import HybridParallelGradScaler @@ -780,6 +781,9 @@ class Fleet(object): elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL: distributed_model = ModelParallel( model, self._hcg, strategy=self._user_defined_strategy) + elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: + distributed_model = PipelineParallel( + model, self._hcg, strategy=self._user_defined_strategy) return distributed_model @dygraph_only diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py index 81fb9a6ea6d..ed1add1f7ba 100644 --- a/python/paddle/distributed/fleet/meta_parallel/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -14,3 +14,4 @@ from .parallel_layers import * from .model_parallel import ModelParallel +from .pipeline_parallel import PipelineParallel diff --git a/python/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py b/python/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py index 6c8bf68fd1f..cdf947895b7 100644 --- a/python/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py +++ b/python/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py @@ -21,6 +21,7 @@ class MetaParallelBase(Layer): self).__init__(layers.full_name() + "_meta_parallel_base") self._layers = layers self._hcg = hcg + self._strategy = strategy self._prepare_for_model() def _prepare_for_model(self): diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py new file mode 100644 index 00000000000..98a82f2b798 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -0,0 +1,427 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import time +import copy +import os + +from types import MethodType + +from numpy import prod + +import paddle +import paddle.fluid as fluid +from .meta_parallel_base import MetaParallelBase +from .pp_utils.utils import get_tensor_bytes +from .pp_utils import utils +from .parallel_layers.pp_layers import PipelineLayer + +FLOAT_TYPES = [ + paddle.float16, + paddle.float32, + paddle.float64, +] + + +class PipelineParallel(MetaParallelBase): + def __init__(self, layers, hcg, strategy): + super(PipelineParallel, self).__init__(layers, hcg, strategy) + + self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1 + self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 + self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 + + self.num_caches = 0 + self.caches = { + 'inputs': [], + 'labels': [], + 'outputs': [], + 'backward_tensors': [], + } + self.recv_cache = None + self.grad_tensors = None + + self.meta_buffer = None + + self.send_meta = True + self.first_gradient_send = True + + self.current_loss = paddle.to_tensor(0.0) + self.total_loss = None + + def _prepare_for_model(self): + self.micro_batch_size = self._strategy.pipeline_configs[ + 'micro_batch_size'] + self.accumulate_steps = self._strategy.pipeline_configs[ + 'accumulate_steps'] + + self.num_stages = self._hcg.get_pipe_parallel_world_size() + self.stage_id = self._hcg.get_stage_id() + self.prev_stage_id = self.stage_id - 1 + self.next_stage_id = self.stage_id + 1 + self._layers = PipelineLayer( + layers=self._layers, num_stages=self.num_stages) + #TODO: init process group + + def _allocate_caches(self, num_caches): + if self.num_caches >= num_caches: + return + + num = num_caches - self.num_caches + self.num_caches = num_caches + for key in self.caches: + self.caches[key].extend([None] * num) + + def train_batch(self, data_iter, optimizer): + self.optimizer = optimizer + assert fluid.framework._dygraph_tracer()._has_grad, ( + 'Please enable the generation of gradients.') + + if self.stage_id == 0 or self.stage_id == self.num_stages - 1: + assert data_iter, ( + "For the first and the last stage, the data_iter must be set.") + else: + assert data_iter is None, ( + "For pipe stages other than the first and the last one, " + "the data_iter must be None.") + self.data_iter = data_iter + self._layers.train() + self.total_loss = None + + minibatch_cmds = utils.TrainGenerator(self.accumulate_steps, + self.num_stages, self.stage_id) + self._train(minibatch_cmds) + return self.total_loss + + def _train(self, minibatch_cmds): + self._allocate_caches(self.num_stages) + for microbatch_cmds in minibatch_cmds: + for cmd in microbatch_cmds: + if type(cmd) not in self._COMMAND_MAP: + #FIXME: + continue + + self._apply_cmd = MethodType(self._COMMAND_MAP[type(cmd)], self) + self._apply_cmd(**cmd.kwargs) + + def _allreduce_grads(self): + self._modifying_grad = True + assert self.use_data_parallel <= 1, ("Do not support data parallel " + "with pipeline parallel now.") + self._modifying_grad = False + + def _get_data(self): + if self.use_model_parallel: + mp_rank = self._hcg.get_model_parallel_rank() + else: + mp_rank = 0 + + data = None + + # mp rank 0 loads the data and broadcat it to others. + if mp_rank == 0: + data = next(self.data_iter) + if self.use_model_parallel: + data = paddle.distributed.broadcast( + data, group=self._hcg.get_model_parallel_group()) + return data + + def _forward(self, cache_id): + if isinstance(self.caches['inputs'][cache_id], tuple): + inputs = tuple(t.clone() for t in self.caches['inputs'][cache_id]) + else: + inputs = self.caches['inputs'][cache_id].clone() + + self._clear_grads(inputs) + outputs = self._layers.forward(inputs) + + self.caches['outputs'][cache_id] = outputs + + if self.stage_id == self.num_stages - 1: + self.current_loss = outputs + if isinstance(self.current_loss, paddle.Tensor): + if self.total_loss is None: + self.total_loss = paddle.zeros_like(self.current_loss) + self.total_loss += self.current_loss.detach() + else: + if self.total_loss is None: + self.total_loss = [ + paddle.zeros_like(v) for v in self.current_loss + ] + for idx, v in enumerate(self.current_loss): + self.total_loss[idx] += v.detach() + + def _backward(self, cache_id): + assert self.optimizer is not None + if self.stage_id == self.num_stages - 1: + paddle.autograd.backward(self.current_loss) + return + + outputs = self.caches['outputs'][cache_id] + + grad_tensors = self.grad_tensors + if isinstance(outputs, tuple): + out_tensors = [t for t in outputs if t.dtype in FLOAT_TYPES] + assert len(out_tensors) == len(grad_tensors) + paddle.autograd.backward( + tensors=out_tensors, grad_tensors=grad_tensors) + else: + paddle.autograd.backward( + tensors=[outputs], grad_tensors=[grad_tensors]) + + self.caches['outputs'][cache_id] = None + grad_tensors = None + + def _load_micro_batch(self, cache_id): + inputs = self._get_data() + + if self.stage_id == 0: + data = None + if isinstance(inputs[0], paddle.Tensor): + data = inputs[0].clone().detach() + data.stop_gradient = data.dtype == paddle.float32 + else: + assert isinstance(inputs[0], tuple) + # Assume list or tuple + data = [] + for d in inputs[0]: + assert isinstance(d, paddle.Tensor) + d = d.clone().detach() + d.stop_gradient = d.dtype == paddle.float32 + loaded.append(d) + data = tuple(data) + self.caches['inputs'][cache_id] = data + + if self.stage_id == self.num_stages - 1: + label = None + if isinstance(inputs[1], paddle.Tensor): + label = inputs[1] + elif isinstance(data[1], tuple): + label = [] + for l in inputs[1]: + assert isinstance(l, paddle.Tensor) + l = l.detach() + label.append(l) + label = tuple(label) + self.caches['labels'][cache_id] = label + + def _send_meta(self, data, peer): + """ + % type (0: tensor, 1: tuple) + % num_tensors if type=tuple + foreach tensor: + % ndims + % shape + """ + if isinstance(data, paddle.Tensor): + tensor_type = paddle.to_tensor([0]) + paddle.distributed.send(tensor_type, peer) + dims = paddle.to_tensor(len(data.shape)) + paddle.distributed.send(dims, peer) + shape = paddle.to_tensor(data.shape) + paddle.distributed.send(shape, peer) + elif isinstance(data, tuple): + tensor_type = paddle.to_tensor([1]) + paddle.distributed.send(tensor_type, peer) + nums = paddle.to_tensor(len(data)) + paddle.distributed.send(nums, peer) + for idx, d in enumerate(data): + assert isinstance(d, paddle.Tensor) + dims = paddle.to_tensor(len(d.shape)) + paddle.distributed.send(dims, peer) + shape = paddle.to_tensor(d.shape) + paddle.distributed.send(shape, peer) + + def _recv_meta(self, peer): + tensor_type = paddle.to_tensor([0]) + paddle.distributed.recv(tensor_type, peer) + tensor_type = tensor_type.numpy()[0] + + if tensor_type == 0: + dims = paddle.to_tensor([0]) + paddle.distributed.recv(dims, peer) + dims = dims.numpy()[0] + shape = paddle.to_tensor([0] * dims) + paddle.distributed.recv(shape, peer) + shape = shape.numpy().tolist() + return self._allocate_buffer( + shape, dtype="float32", num_caches=1)[0] + elif tensor_type == 1: + num = paddle.to_tensor([0]) + paddle.distributed.recv(num, peer) + num = num.numpy()[0] + shapes = [] + for i in range(num): + dims = paddle.to_tensor([0]) + paddle.distributed.recv(dims, peer) + dims = dims.numpy()[0] + shape = paddle.to_tensor([0] * dims) + paddle.distributed.recv(shape, peer) + shapes.append(shape.numpy().tolist()) + + dtypes = ["float32"] * len(shapes) + caches = self._allocate_buffers(shapes, dtypes, num_buffers=1)[0] + buffers = tuple(buffers) + return buffers + + def _send_activations(self, cache_id): + outputs = self.caches['outputs'][cache_id] + + if self.send_meta: + self.send_meta = False + self._send_meta(outputs, self.next_stage_id) + + if isinstance(outputs, paddle.Tensor): + paddle.distributed.send(outputs, self.next_stage_id) + elif isinstance(outputs, tuple): + for output in outputs: + paddle.distributed.send(output, self.next_stage_id) + + def _send_gradients(self, cache_id): + inputs = self.caches['inputs'][cache_id] + + if isinstance(inputs, paddle.Tensor): + assert inputs.grad is not None + paddle.distributed.send( + paddle.to_tensor(inputs.grad), self.prev_stage_id) + else: + for idx, d in enumerate(inputs): + # Skip tensors that will not produce a grad + if not d.dtype in FLOAT_TYPES: + assert d.grad is None + continue + assert d.grad is not None + paddle.distributed.send(d.grad, self.prev_stage_id) + self.caches['inputs'][cache_id] = None + + def _recv_activations(self, cache_id): + inputs = None + + # Allocate the buffer if necessary + if self.recv_cache is None: + self.recv_cache = self._recv_meta(self.prev_stage_id) + + if isinstance(self.recv_cache, paddle.Tensor): + paddle.distributed.recv(self.recv_cache, self.prev_stage_id) + inputs = self.recv_cache.clone().detach() + inputs.stop_gradient = inputs.dtype not in FLOAT_TYPES + else: + assert isinstance(self.recv_cache, tuple) + inputs = [None] * len(self.recv_cache) + for idx, d in enumerate(self.recv_cache): + assert isinstance(d, paddle.Tensor) + + paddle.distributed.recv(d, self.prev_stage_id) + inputs[idx] = d.clone().detach() + + inputs = tuple(inputs) + + for d in inputs: + d.stop_gradient = d.dtype not in FLOAT_TYPES + + self.caches['inputs'][cache_id] = inputs + + def _recv_gradients(self, cache_id): + outputs = self.caches['outputs'][cache_id] + if self.grad_tensors is None: + if isinstance(outputs, paddle.Tensor): + s = list(outputs.shape) + dtype = 'float32' + self.grad_tensors = self._allocate_buffer( + s, dtype, num_buffers=1)[0] + else: + sizes = [ + list(d.shape) for d in outputs if d.dtype in FLOAT_TYPES + ] + dtypes = ['float32'] * len(sizes) + self.grad_tensors = self._allocate_buffers( + sizes, dtypes, num_buffers=1)[0] + + if isinstance(self.grad_tensors, paddle.Tensor): + paddle.distributed.recv(self.grad_tensors, self.next_stage_id) + else: + assert isinstance(outputs, tuple) + for d in self.grad_tensors: + paddle.distributed.recv(d, self.next_stage_id) + + def _step(self, lr_kwargs=None): + self._modifying_grad = True + self.optimizer.step() + self.optimizer.clear_gradients() + self._modifying_grad = False + + def _clear_grads(self, inputs): + if isinstance(inputs, paddle.Tensor): + if inputs.grad is not None: + inputs.clear_gradient() + else: + for d in inputs: + if d.grad is not None: + d.clear_gradient() + + def _allocate_zeros(self, shape, dtype): + return paddle.zeros(shape, dtype) + + def _allocate_buffer(self, shape, dtype, num_buffers=-1, **kwargs): + buffers = [] + if num_buffers == -1: + num_buffers = self.num_caches + for count in range(num_buffers): + buffers.append(self._allocate_zeros(shape, dtype)) + return buffers + + def _allocate_buffers(self, shapes, dtypes, num_buffers=-1): + buffers = [] + if num_buffers == -1: + num_buffers = self.num_caches + for count in range(num_buffers): + buffer = [] + for shape, dtype in zip(shapes, dtypes): + buffer.append( + self._allocate_zeros( + shape, dtype, requires_grad=requires_grad)) + buffers.append(buffer) + return buffers + + def save_state_dict(self, model_path): + state_dict = self._layers.state_dict() + paddle.save(state_dict, model_path) + + def load_state_dict(self, model_path): + state_dict = paddle.load(self.model_path) + self._layers.set_state_dict(state_dict) + + _COMMAND_MAP = { + utils.Optimize: _step, + #utils.ReduceGrads: _allreduce_grads, + utils.Forward: _forward, + utils.Backward: _backward, + } + + def _pre_forward(self, *inputs, **kwargs): + pass + + def forward(self, *inputs, **kwargs): + raise RuntimeError("Call train_batch for pipeline instead of forward.") + + def _post_forward(self, output): + pass + + def _pre_backward(self, loss): + pass + + def backward_impl(self, loss, parameters): + pass + + def _post_backward(self, loss): + pass diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py new file mode 100644 index 00000000000..d39e6760a38 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils import * diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py new file mode 100644 index 00000000000..56eef8d7d21 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import paddle +from ...utils import hybrid_parallel_util as hp_util + +__all__ = ['get_tensor_bytes', ] + + +def get_tensor_bytes(tensor): + """Get the bytes a tensor occupied.""" + elem_size = None + if tensor.dtype == paddle.float32: + elem_size = 4 + elif tensor.dtype == paddle.float64: + elem_size = 8 + elif tensor.dtype == paddle.int64: + elem_size = 8 + elif tensor.dtype == paddle.int32: + elem_size = 4 + elif tensor.dtype == paddle.float16: + elem_size = 2 + elif tensor.dtype == paddle.int8: + elem_size = 1 + else: + raise ValueError("unknown data type: {}".format(tensor.dtype)) + return tensor.numel() * elem_size + + +class Generator(): + def __init__(self, micro_batches, stages, stage_id): + __metaclass__ = abc.ABCMeta + + self.micro_batches = micro_batches + self.stages = stages + self.stage_id = stage_id + self.prev_stage = self.stage_id - 1 + self.next_stage = self.stage_id + 1 + assert self.micro_batches >= self.stages, ( + "micro_batches {} " + "must be greater than or equal to {}".format(self.micro_batches, + self.stages)) + + @abc.abstractmethod + def generate(self): + pass + + def __iter__(self): + self.iter = None + return self + + def __next__(self): + if self.iter is None: + self.iter = self.generate() + return next(self.iter) + + +class TrainGenerator(Generator): + def generate(self): + startup_steps = self.stages - self.stage_id - 1 + cmds = [] + forward_steps = 0 + backward_steps = 0 + while (forward_steps < startup_steps): + cmds.append(Forward) + forward_steps += 1 + while (forward_steps < self.micro_batches): + cmds.append(Forward) + forward_steps += 1 + cmds.append(Backward) + backward_steps += 1 + while (backward_steps < self.micro_batches): + cmds.append(Backward) + backward_steps += 1 + cmds.append(Optimize) + yield cmds + + +class Command: + def __init__(self, **kwargs): + self.name = self.__class__.__name__ + self.kwargs = kwargs + for key, val in kwargs.items(): + setattr(self, key, val) + + def __repr__(self): + return hp_util.call_to_str(self.name, **self.kwargs) + + +class Optimize(Command): + pass + + +class Forward(Command): + pass + + +class Backward(Command): + pass diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ae3cf5f2858..ffa347a60c3 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -79,6 +79,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_allreduce) LIST(REMOVE_ITEM TEST_OPS test_broadcast) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce) + LIST(REMOVE_ITEM TEST_OPS test_pipeline_parallel) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter) LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv) LIST(REMOVE_ITEM TEST_OPS test_reducescatter) @@ -878,6 +879,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_broadcast PROPERTIES TIMEOUT 120) set_tests_properties(test_reducescatter PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_reduce_api PROPERTIES TIMEOUT 120) + set_tests_properties(test_pipeline_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_reduce PROPERTIES TIMEOUT 120) set_tests_properties(test_allreduce PROPERTIES TIMEOUT 120) set_tests_properties(test_c_concat PROPERTIES TIMEOUT 120) @@ -895,6 +897,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) test_collective_scatter_api test_collective_barrier_api test_collective_reduce_api + test_pipeline_parallel test_collective_allreduce_api test_new_group_api test_collective_broadcast_api diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_model.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_model.py new file mode 100644 index 00000000000..9b9283a1a9b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_model.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +from paddle.io import DataLoader, Dataset +import unittest + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + rank_id) + + +HIDDEN_DIM = 32 +LAYERS = 8 + + +def sequential_model(): + model = paddle.nn.Sequential( + paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), + paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), + paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), + paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), + paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), + paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), + paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), + paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), + paddle.nn.Linear(HIDDEN_DIM, 1), ) + return model + + +class TestDistPPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = {"accumulate_steps": 2} + paddle.distributed.init_parallel_env() + fleet.init(is_collective=True, strategy=strategy) + + def test_mp_model(self): + batch_input = paddle.randn(shape=(1, HIDDEN_DIM), dtype="float32") + pipe_model = sequential_model() + sgd = paddle.optimizer.SGD(learning_rate=0.0003, parameters=[]) + pipe_model = paddle.distributed.fleet.distributed_model(pipe_model) + + if pipe_model.stage_id == 0 or pipe_model.stage_id == 1: + pipe_input = batch_input.clone().detach() + pipe_input = paddle.cast(pipe_input, 'float32') + + def data_gen(): + gen = True + while gen: + yield [pipe_input, 0] + gen = False + + loader = paddle.io.DataLoader.from_generator(capacity=5) + loader.set_batch_generator(data_gen) + data_iter = iter(loader) + else: + data_iter = None + return True + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py new file mode 100644 index 00000000000..7f8294ad0ef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestPipelineParallel(TestMultipleGpus): + def test_pipeline_parallel(self): + self.run_mnist_2gpu('hybrid_parallel_pp_model.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 3458a42d2d9..0e94d02cd6f 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -159,6 +159,7 @@ packages=['paddle', 'paddle.distributed.fleet.proto', 'paddle.distributed.fleet.utils', 'paddle.distributed.fleet.meta_parallel', + 'paddle.distributed.fleet.meta_parallel.pp_utils', 'paddle.distributed.fleet.meta_parallel.parallel_layers', 'paddle.framework', 'paddle.jit', -- GitLab