From 66d46221b7265fca2ca3da101cf3f550ea11df53 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Sat, 17 Apr 2021 01:07:40 +0800 Subject: [PATCH] [Hybrid Parallel] Add model parallel support in dygraph (#32248) * add model parallel support in dygraph --- .../framework/distributed_strategy.proto | 7 + python/paddle/distributed/fleet/__init__.py | 2 + .../fleet/base/distributed_strategy.py | 34 +++ .../distributed/fleet/base/fleet_base.py | 43 +++ .../paddle/distributed/fleet/base/topology.py | 27 +- .../fleet/meta_parallel/__init__.py | 15 + .../fleet/meta_parallel/mp_utils/__init__.py | 16 + .../fleet/meta_parallel/mp_utils/layers.py | 190 ++++++++++++ .../meta_parallel/mp_utils/layers_help.py | 116 ++++++++ .../fleet/meta_parallel/mp_utils/random.py | 79 +++++ .../fluid/tests/unittests/CMakeLists.txt | 3 + ...y => hybrid_parallel_communicate_group.py} | 0 .../unittests/hybrid_parallel_mp_layers.py | 273 ++++++++++++++++++ .../unittests/hybrid_parallel_mp_random.py | 74 +++++ .../test_fleet_distributed_strategy.py | 11 + .../fluid/tests/unittests/test_new_group.sh | 2 +- .../test_parallel_dygraph_hybrid_parallel.py | 33 +++ python/setup.py.in | 2 + 18 files changed, 917 insertions(+), 10 deletions(-) create mode 100644 python/paddle/distributed/fleet/meta_parallel/__init__.py create mode 100644 python/paddle/distributed/fleet/meta_parallel/mp_utils/__init__.py create mode 100644 python/paddle/distributed/fleet/meta_parallel/mp_utils/layers.py create mode 100644 python/paddle/distributed/fleet/meta_parallel/mp_utils/layers_help.py create mode 100644 python/paddle/distributed/fleet/meta_parallel/mp_utils/random.py rename python/paddle/fluid/tests/unittests/{hybrid_communicate_group.py => hybrid_parallel_communicate_group.py} (100%) create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_mp_random.py create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 59af35465a4..e6a7d74cc43 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -43,6 +43,12 @@ message ShardingConfig { optional int32 pp_degree = 11 [ default = 1 ]; } +message HybridConfig { + optional int32 dp_degree = 1 [ default = -1 ]; + optional int32 mp_degree = 2 [ default = 1 ]; + optional int32 pp_degree = 3 [ default = 1 ]; +} + message AMPConfig { optional float init_loss_scaling = 1 [ default = 32768.0 ]; optional int32 incr_every_n_steps = 2 [ default = 1000 ]; @@ -175,6 +181,7 @@ message DistributedStrategy { optional LambConfig lamb_configs = 109; optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; optional ShardingConfig sharding_configs = 111; + optional HybridConfig hybrid_configs = 112; optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; } diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 6d4aedddba6..784004269d7 100644 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -21,6 +21,7 @@ from .dataset import * from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator from . import metrics from .base.topology import CommunicateTopology, HybridCommunicateGroup +from .meta_parallel import random, layers __all__ = [ "DistributedStrategy", "UtilBase", "UserDefinedRoleMaker", @@ -72,3 +73,4 @@ get_lr = fleet.get_lr state_dict = fleet.state_dict set_state_dict = fleet.set_state_dict shrink = fleet.shrink +get_hybrid_communicate_group = fleet.get_hybrid_communicate_group diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 626f6a37a98..04cb7447e36 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -867,6 +867,40 @@ class DistributedStrategy(object): "pipeline_configs") assign_configs_value(self.strategy.pipeline_configs, configs) + @property + def hybrid_configs(self): + """ + Dynamic graph hybrid parallel strategy configuration. Three-way hybrid parallelism + needs to meet the following relationships + + total_number_GPUs = dp_degree * mp_degree * pp_degree + + **Note**: + dp_degree(int): set number of GPUs in a data parallel group. Default -1. + This value should be an integer greater than 0. + If it is not set, or set to -1, its value will be inferred + based on the total number of cards. + mp_degree(int): set number of GPUs in a model parallel group. Default 1 + pp_degree(int): set number of GPUs in a pipeline parallel group. Default 1 + + + Examples: + .. code-block:: python + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 2, + "pp_degree": 1} + """ + return get_msg_dict(self.strategy.hybrid_configs) + + @hybrid_configs.setter + def hybrid_configs(self, configs): + check_configs_key(self.strategy.hybrid_configs, configs, + "hybrid_configs") + assign_configs_value(self.strategy.hybrid_configs, configs) + @property def localsgd(self): """ diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 0a60cbf78d5..7ed5017b815 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -26,6 +26,7 @@ from .meta_optimizer_factory import MetaOptimizerFactory from .runtime_factory import RuntimeFactory from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.dygraph import parallel_helper +from . import topology as tp def _inited_runtime_handler_(func): @@ -234,6 +235,48 @@ class Fleet(object): self._user_defined_strategy.nccl_comm_num) paddle.distributed.init_parallel_env() + # init hybrid parallel environment in dygraph + if tp._HYBRID_PARALLEL_GROUP is None: + self._init_hybrid_parallel_env() + else: + warnings.warn( + "The dygraph hybrid parallel environment has been initialized." + ) + + def _init_hybrid_parallel_env(self): + """initialize the hybrid environment + """ + self.hybrid_configs = self._user_defined_strategy.hybrid_configs + self.dp_degree = self.hybrid_configs["dp_degree"] + self.mp_degree = self.hybrid_configs["mp_degree"] + self.pp_degree = self.hybrid_configs["pp_degree"] + + assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0" + assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0" + + self.mp_degree = max(self.mp_degree, 1) + self.pp_degree = max(self.pp_degree, 1) + + if self.dp_degree < 0: + nranks = paddle.distributed.get_world_size() + self.dp_degree = nranks // (self.mp_degree * self.pp_degree) + + self.dp_degree = max(self.dp_degree, 1) + + self._topology = tp.CommunicateTopology( + hybrid_group_names=["data", "pipe", "model"], + dims=[self.dp_degree, self.pp_degree, self.mp_degree]) + + self._hcg = tp.HybridCommunicateGroup(self._topology) + + def get_hybrid_communicate_group(self): + assert self._hcg is not None + return self._hcg + + def get_hybrid_parallel_topology(self): + assert self._topology is not None + return self._topology + def is_first_worker(self): """ Check whether the node is the first instance of worker. diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 4e20ad50611..4dca860212c 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function +import sys import paddle import collections import numpy as np @@ -19,6 +21,8 @@ from itertools import product from functools import reduce __all__ = ['CommunicateTopology', 'HybridCommunicateGroup'] +_HYBRID_PARALLEL_GROUP = None + class CommunicateTopology(object): def __init__(self, hybrid_group_names, dims): @@ -100,26 +104,31 @@ class HybridCommunicateGroup(object): self.global_rank = paddle.distributed.get_rank() self._topo = topology - self._num_data_parallel = self._topo.get_dim('data') - self._num_model_parallel = self._topo.get_dim('model') - self._num_pipe_parallel = self._topo.get_dim('pipe') + self._dp_degree = self._topo.get_dim('data') + self._mp_degree = self._topo.get_dim('model') + self._pp_degree = self._topo.get_dim('pipe') self._data_parallel_id = self._get_data_parallel_id() self._model_parallel_id = self._get_model_parallel_id() assert self._check_vaild_topo( - ), "Here is an unreasonable topogy setting" + ), "Here is an unreasonable topogy setting. world_size: {}, but" \ + "dp_num: {}, mp_num: {}, pp_num: {}".format(self.nranks, self._dp_degree, + self._mp_degree, self._pp_degree) # create comm group for data parallel self._dp_group, self._dp_comm_group = self._set_comm_group("data") - print("data parallel group", self._dp_group) + print("data parallel group", self._dp_group, file=sys.stderr) # create comm group for model parallel self._mp_group, self._mp_comm_group = self._set_comm_group("model") - print("model parallel group", self._mp_group) + print("data parallel group", self._mp_group, file=sys.stderr) + + global _HYBRID_PARALLEL_GROUP + _HYBRID_PARALLEL_GROUP = self def _check_vaild_topo(self): - return self._num_data_parallel * self._num_model_parallel * self._num_pipe_parallel == self.nranks + return self._dp_degree * self._mp_degree * self._pp_degree == self.nranks def _set_comm_group(self, parallel_method="data"): parallel_group = [] @@ -151,7 +160,7 @@ class HybridCommunicateGroup(object): return self._data_parallel_id def get_data_parallel_world_size(self): - return self._num_data_parallel + return self._dp_degree def get_data_parallel_group(self): return self._dp_comm_group @@ -167,7 +176,7 @@ class HybridCommunicateGroup(object): return self._model_parallel_id def get_model_parallel_world_size(self): - return self._num_model_parallel + return self._mp_degree def get_model_parallel_group(self): return self._mp_comm_group diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py new file mode 100644 index 00000000000..977954fc257 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mp_utils import * diff --git a/python/paddle/distributed/fleet/meta_parallel/mp_utils/__init__.py b/python/paddle/distributed/fleet/meta_parallel/mp_utils/__init__.py new file mode 100644 index 00000000000..a7da28700bc --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/mp_utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .layers import * +from .random import * diff --git a/python/paddle/distributed/fleet/meta_parallel/mp_utils/layers.py b/python/paddle/distributed/fleet/meta_parallel/mp_utils/layers.py new file mode 100644 index 00000000000..b7512afd9a6 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/mp_utils/layers.py @@ -0,0 +1,190 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.dygraph.layers import Layer +from .random import get_rng_state_tracker +from paddle.nn import functional as F +from paddle import framework +from ...base import topology as tp +from .layers_help import identity_in_model_parallel, gather_in_model_parallel, reduce_in_model_parallel, scatter_in_model_parallel + +__all__ = [ + 'VocabParallelEmbedding', 'ColumnParallelLinear', 'RowParallelLinear' +] + +# Follow this paper to achieve the file: +# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter +# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053) + + +class VocabParallelEmbedding(Layer): + def __init__(self, + num_embeddings, + embedding_dim, + weight_attr=None, + name=None): + super(VocabParallelEmbedding, self).__init__() + + self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( + ) + self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( + ) + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + + self.origin_num_embeddings = num_embeddings + + per_part_size = ( + num_embeddings + self.world_size - 1) // self.world_size + last_part_size = num_embeddings - per_part_size * (self.world_size - 1) + if self.rank == self.world_size - 1: + per_part_size = last_part_size + per_part_size += 1 # make the last row as the padding index + self.per_part_size = per_part_size + + self.embedding = paddle.nn.Embedding( + per_part_size, + embedding_dim, + padding_idx=per_part_size - 1, + sparse=False, + weight_attr=weight_attr, + name=name) + self.embedding.weight.is_distributed = True + + def forward(self, x): + origin_input_shape = x.shape + if len(origin_input_shape) == 2: + x = paddle.unsqueeze(x, axis=-1) + else: + assert origin_input_shape[-1] == 1, ( + "The last dimension size of x must be 1.") + x_shard = paddle.shard_index(x, self.origin_num_embeddings, + self.world_size, self.rank, + self.per_part_size - 1) + if len(origin_input_shape) == 2: + x_shard = paddle.squeeze(x_shard, axis=-1) + + emb_out_ = self.embedding(x_shard) + emb_out = reduce_in_model_parallel(emb_out_) + return emb_out + + +class ColumnParallelLinear(Layer): + def __init__(self, + in_features, + out_features, + weight_attr=None, + has_bias=None, + gather_output=True, + name=None): + super(ColumnParallelLinear, self).__init__() + + self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( + ) + self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( + ) + + self.name = name + self.gather_output = gather_output + assert out_features % self.world_size == 0, ( + "Number of column of the weight for linear ({}) must be" + " divisible by model parallel size ({})".format(out_features, + self.world_size)) + self.output_size_per_partition = out_features // self.world_size + + self._weight_attr = weight_attr + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, self.output_size_per_partition], + attr=self._weight_attr, + dtype=self._dtype) + self.weight.is_distributed = True + + if has_bias: + # initialize bias to zero like Megatron + self.bias = self.create_parameter( + shape=[self.output_size_per_partition], + attr=paddle.nn.initializer.Constant(value=0.0), + dtype=self._dtype) + self.bias.is_distributed = True + else: + self.bias = None + + def forward(self, x): + input_parallel = identity_in_model_parallel(x) + output_parallel = F.linear( + input_parallel, self.weight, self.bias, name=self.name) + if self.gather_output: + output = gather_in_model_parallel(output_parallel) + else: + output = output_parallel + return output + + +class RowParallelLinear(Layer): + def __init__(self, + in_features, + out_features, + weight_attr=None, + has_bias=True, + input_is_parallel=False, + name=None): + super(RowParallelLinear, self).__init__() + + self.in_features = in_features + self.out_features = out_features + self.input_is_parallel = input_is_parallel + self._weight_attr = weight_attr + self._dtype = self._helper.get_default_dtype() + self.name = name + + self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( + ) + self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( + ) + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + + assert in_features % self.world_size == 0, ( + "Number of row of the weight for linear ({}) must be" + " divisible by model parallel size ({})".format(in_features, + self.world_size)) + + self.input_size_per_partition = in_features // self.world_size + + self.weight = self.create_parameter( + shape=[self.input_size_per_partition, self.out_features], + attr=self._weight_attr, + dtype=self._dtype) + self.weight.is_distributed = True + + if has_bias: + self.bias = self.create_parameter( + shape=[self.out_features], + attr=paddle.nn.initializer.Constant(value=0.0), + dtype=self._dtype) + else: + self.bias = None + + def forward(self, x): + if self.input_is_parallel: + input_parallel = x + else: + # split last dim + input_parallel = scatter_in_model_parallel(x) + + output_parallel = F.linear(input_parallel, self.weight, name=self.name) + output_ = reduce_in_model_parallel(output_parallel) + output = output_ + self.bias if self.bias is not None else output_ + return output diff --git a/python/paddle/distributed/fleet/meta_parallel/mp_utils/layers_help.py b/python/paddle/distributed/fleet/meta_parallel/mp_utils/layers_help.py new file mode 100644 index 00000000000..e32db686efd --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/mp_utils/layers_help.py @@ -0,0 +1,116 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.autograd import PyLayer +from ...base import topology as tp +import paddle + +# Follow this paper to achieve the file: +# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter +# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053) + + +def mp_reduce(x): + if tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() == 1: + return x + + paddle.distributed.all_reduce( + x, group=tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()) + + return x + + +def mp_split(x): + world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() + + if world_size == 1: + return x + + rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + last_dim = len(x.shape) - 1 + input_list = paddle.split(x, num_or_sections=world_size, axis=last_dim) + output = input_list[rank] + + return output + + +def mp_gather(x): + world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() + + if world_size == 1: + return x + + output = [] + paddle.distributed.all_gather( + output, x, group=tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()) + + output = paddle.concat(output, axis=len(x.shape) - 1) + + return output + + +class _IdentityInModelParallel(PyLayer): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, dx): + return mp_reduce(dx) + + +class _ReduceInModelParallel(PyLayer): + @staticmethod + def forward(ctx, x): + return mp_reduce(x) + + @staticmethod + def backward(ctx, dx): + return dx + + +class _ScatterInModelParallel(PyLayer): + @staticmethod + def forward(ctx, x): + return mp_split(x) + + @staticmethod + def backward(ctx, dx): + return mp_gather(dx) + + +class _GatherInModelParallel(PyLayer): + @staticmethod + def forward(ctx, x): + return mp_gather(x) + + @staticmethod + def backward(ctx, dx): + return mp_split(dx) + + +def identity_in_model_parallel(x): + return _IdentityInModelParallel.apply(x) + + +def reduce_in_model_parallel(x): + return _ReduceInModelParallel.apply(x) + + +def scatter_in_model_parallel(x): + return _ScatterInModelParallel.apply(x) + + +def gather_in_model_parallel(x): + return _GatherInModelParallel.apply(x) diff --git a/python/paddle/distributed/fleet/meta_parallel/mp_utils/random.py b/python/paddle/distributed/fleet/meta_parallel/mp_utils/random.py new file mode 100644 index 00000000000..56c741dbd3c --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/mp_utils/random.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import contextlib +__all__ = [ + 'RNGStatesTracker', 'model_parallel_random_seed', 'get_rng_state_tracker' +] + +MODEL_PARALLEL_RNG = 'model_parallel_rng' + + +class RNGStatesTracker: + """ + Tracker the RNG states. + """ + + def __init__(self): + # Map from name to the rng state. + self.states_ = {} + self.seeds_ = set() + + def reset(self): + self.states_ = {} + self.seeds_ = set() + + def add(self, name, seed): + if seed in self.seeds_: + raise ValueError('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + if name in self.states_: + raise ValueError('state {} already exists'.format(name)) + orig_rng_state = paddle.get_cuda_rng_state() + paddle.seed(seed) + self.states_[name] = paddle.get_cuda_rng_state() + paddle.set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def rng_state(self, name=MODEL_PARALLEL_RNG): + if name not in self.states_: + raise ValueError('state {} does not exist'.format(name)) + orig_cuda_rng_state = paddle.get_cuda_rng_state() + paddle.set_cuda_rng_state(self.states_[name]) + try: + yield + finally: + self.states_[name] = paddle.get_cuda_rng_state() + paddle.set_cuda_rng_state(orig_cuda_rng_state) + + +RNG_STATE_TRACKER = RNGStatesTracker() + + +def get_rng_state_tracker(): + return RNG_STATE_TRACKER + + +def model_parallel_random_seed(seed=2048): + import paddle.distributed.fleet as fleet + hcg = fleet.get_hybrid_communicate_group() + rank = hcg.get_model_parallel_rank() + + local_seed = seed + 1024 + rank + global_seed = seed + + RNG_STATE_TRACKER.reset() + paddle.seed(global_seed) + RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 679f7651646..486ad38ae29 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -21,6 +21,7 @@ list(APPEND DIST_TEST_OPS test_gen_nccl_id_op) list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables) list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow) list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_hybrid_parallel) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) @@ -166,6 +167,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel) + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_hybrid_parallel) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single) elseif(WITH_GPU) @@ -843,6 +845,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_hybrid_parallel PROPERTIES TIMEOUT 120 LABELS "RUN_TYPE=DIST") if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/hybrid_communicate_group.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_communicate_group.py similarity index 100% rename from python/paddle/fluid/tests/unittests/hybrid_communicate_group.py rename to python/paddle/fluid/tests/unittests/hybrid_parallel_communicate_group.py diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py new file mode 100644 index 00000000000..ed5b9060e5e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py @@ -0,0 +1,273 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +from paddle import framework + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + fleet.meta_parallel.model_parallel_random_seed(seed) + + +class ColumnLinearNet(fluid.dygraph.Layer): + def __init__(self, input_size, output_size, global_dtype): + super(ColumnLinearNet, self).__init__() + self.parallel_linear = fleet.meta_parallel.ColumnParallelLinear( + in_features=input_size, + out_features=output_size, + weight_attr=None, + has_bias=True, + gather_output=True, + name="test_column_linear") + + def forward(self, x): + output = self.parallel_linear(x) + return output + + +class RowLinearNet(fluid.dygraph.Layer): + def __init__(self, input_size, output_size): + super(RowLinearNet, self).__init__() + self.parallel_linear = fleet.meta_parallel.RowParallelLinear( + in_features=input_size, + out_features=output_size, + has_bias=True, + input_is_parallel=False, + name="test_row_linear") + + def forward(self, x): + output = self.parallel_linear(x) + return output + + +class EmbeddingNet(fluid.dygraph.Layer): + def __init__(self, vocab_size, hidden_size): + super(EmbeddingNet, self).__init__() + self.embedding = fleet.meta_parallel.VocabParallelEmbedding(vocab_size, + hidden_size) + + def forward(self, x): + output = self.embedding(x) + return output + + +class SimpleMatmul(fluid.dygraph.Layer): + def __init__(self, weight, output_size, global_dtype): + super(SimpleMatmul, self).__init__() + self.weight = paddle.create_parameter( + shape=weight.shape, + dtype=global_dtype, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Assign(weight))) + self.bias = self.create_parameter( + shape=[output_size], + dtype=global_dtype, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0))) + + def forward(self, x): + output = paddle.matmul(x, self.weight) + self.bias + return output + + +class SimpleEmbedding(fluid.dygraph.Layer): + def __init__(self, vocab_size, hidden_size, weight): + super(SimpleEmbedding, self).__init__() + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + name="origin_embedding", + initializer=paddle.nn.initializer.Assign(weight))) + + def forward(self, x): + output = self.embedding(x) + return output + + +class TestDistTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1 + } + fleet.init(is_collective=True, strategy=strategy) + + def test_column_parallel_layer(self): + set_random_seed(1024) + global_dtype = "float32" + + input_size_per_card = 17 + input_size = input_size_per_card * self.model_parallel_size + output_size_per_card = 13 + output_size = output_size_per_card * self.model_parallel_size + batch_size = 4 + + model_a = ColumnLinearNet(input_size, output_size, global_dtype) + + # get w + check_group = dist.new_group(list(range(self.model_parallel_size))) + integral_w = [] + partial_w = model_a.parallel_linear.weight.clone().detach() + paddle.distributed.all_gather(integral_w, partial_w, group=check_group) + integral_w = paddle.concat(integral_w, axis=1) + + model_b = SimpleMatmul(integral_w, output_size, global_dtype) + + optimizer_a = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model_a.parameters()) + optimizer_b = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model_b.parameters()) + for idx in range(5): + input = paddle.randn([batch_size, input_size], global_dtype) + input.stop_gradient = True + + output_a = model_a(input) + loss_a = output_a.mean() + loss_a.backward() + + output_b = model_b(input) + loss_b = output_b.mean() + loss_b.backward() + + optimizer_a.step() + optimizer_b.step() + + np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy()) + + def test_row_parallel_layer(self): + global_dtype = "float32" + paddle.set_default_dtype(global_dtype) + set_random_seed(1024) + + self.hcg = fleet.get_hybrid_communicate_group() + + self.word_size = self.hcg.get_model_parallel_world_size() + self.rank_id = self.hcg.get_model_parallel_rank() + + input_size_per_card = 17 + input_size = input_size_per_card * self.model_parallel_size + output_size_per_card = 13 + output_size = output_size_per_card * self.model_parallel_size + batch_size = 4 + + model_a = RowLinearNet(input_size, output_size) + + # get w + check_group = dist.new_group(list(range(self.model_parallel_size))) + integral_w = [] + partial_w = model_a.parallel_linear.weight.clone().detach() + paddle.distributed.all_gather(integral_w, partial_w, group=check_group) + integral_w = paddle.concat(integral_w, axis=0) + + model_b = SimpleMatmul(integral_w, output_size, global_dtype) + + optimizer_a = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model_a.parameters()) + + optimizer_b = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model_b.parameters()) + + for idx in range(5): + input = paddle.randn([batch_size, input_size], global_dtype) + input.stop_gradient = True + + output_a = model_a(input) + loss_a = output_a.mean() + loss_a.backward() + + output_b = model_b(input) + loss_b = output_b.mean() + loss_b.backward() + + optimizer_a.step() + optimizer_b.step() + + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=1e-5) + + def test_parallel_embedding(self): + batch_size = 17 + seq_length = 23 + vocab_size_per_card = 2 + vocab_size = vocab_size_per_card * self.model_parallel_size + hidden_size = 2 + seed = 1236 + + set_random_seed(seed) + rank_id = dist.get_rank() + + # model_a + model_a = EmbeddingNet(vocab_size, hidden_size) + + # model_b + check_group = dist.new_group(list(range(self.model_parallel_size))) + integral_w = [] + partial_w = model_a.embedding.embedding.weight.clone().detach() + paddle.distributed.all_gather(integral_w, partial_w, group=check_group) + result_w = [] + for idx in range(len(integral_w)): + tmp = paddle.gather( + integral_w[idx], + paddle.to_tensor(list(range(vocab_size_per_card)))) + result_w.append(tmp) + integral_w = paddle.concat(result_w, axis=0) + + model_b = SimpleEmbedding(vocab_size, hidden_size, integral_w) + + optimizer_a = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model_a.parameters()) + + optimizer_b = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model_b.parameters()) + + for _ in range(5): + np_input_data = np.random.randint(0, vocab_size, + (batch_size, seq_length)) + input_data = paddle.to_tensor(np_input_data, dtype="int32") + + output_a = model_a(input_data) + loss_a = output_a.mean() + + output_b = model_b(input_data) + loss_b = output_b.mean() + + loss_a.backward() + loss_b.backward() + + optimizer_a.step() + optimizer_b.step() + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=1e-6) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_random.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_random.py new file mode 100644 index 00000000000..59d24066946 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_random.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import paddle +import numpy as np +import paddle.distributed as dist +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +import random + + +class TestDistTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1 + } + fleet.init(is_collective=True, strategy=strategy) + + def test_cuda_rng_tracker(self): + seed_1 = 2021 + seed_2 = 1024 + + size = [20, 15] + + paddle.seed(seed_1) + target_11 = paddle.randn(size, "float32") + target_12 = paddle.randn(size, "float32") + + paddle.seed(seed_2) + target_21 = paddle.randn(size, "float32") + target_22 = paddle.randn(size, "float32") + + paddle.seed(seed_1) + + fleet.meta_parallel.get_rng_state_tracker().add("test", seed_2) + + result_11 = paddle.randn(size, "float32") + + with fleet.meta_parallel.get_rng_state_tracker().rng_state("test"): + result_21 = paddle.randn(size, "float32") + + result_12 = paddle.randn(size, "float32") + + with fleet.meta_parallel.get_rng_state_tracker().rng_state("test"): + result_22 = paddle.randn(size, "float32") + + np.testing.assert_allclose(result_11.numpy(), target_11.numpy()) + np.testing.assert_allclose(result_12.numpy(), target_12.numpy()) + np.testing.assert_allclose(result_21.numpy(), target_21.numpy()) + np.testing.assert_allclose(result_22.numpy(), target_22.numpy()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index d843e172763..52895217d3f 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -73,6 +73,17 @@ class TestStrategyConfig(unittest.TestCase): strategy.pipeline_configs = configs self.assertEqual(strategy.pipeline_configs["accumulate_steps"], 2) + def test_hybrid_parallel_configs(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 2, + "pp_degree": 4 + } + self.assertEqual(strategy.hybrid_configs["dp_degree"], 1) + self.assertEqual(strategy.hybrid_configs["mp_degree"], 2) + self.assertEqual(strategy.hybrid_configs["pp_degree"], 4) + def test_localsgd(self): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.localsgd = True diff --git a/python/paddle/fluid/tests/unittests/test_new_group.sh b/python/paddle/fluid/tests/unittests/test_new_group.sh index d0b29a64145..4914183fb46 100755 --- a/python/paddle/fluid/tests/unittests/test_new_group.sh +++ b/python/paddle/fluid/tests/unittests/test_new_group.sh @@ -17,4 +17,4 @@ set -e CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 new_group.py -CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 hybrid_communicate_group.py +CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 hybrid_parallel_communicate_group.py diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py new file mode 100644 index 00000000000..6454b3918ef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import time +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestHybridParallel(TestMultipleGpus): + def test_hybrid_parallel_mp_layers(self): + self.run_mnist_2gpu('hybrid_parallel_mp_layers.py') + + def test_hybrid_parallel_mp_random(self): + self.run_mnist_2gpu('hybrid_parallel_mp_random.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 601e6e48703..c366415ebb2 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -156,6 +156,8 @@ packages=['paddle', 'paddle.distributed.fleet.metrics', 'paddle.distributed.fleet.proto', 'paddle.distributed.fleet.utils', + 'paddle.distributed.fleet.meta_parallel', + 'paddle.distributed.fleet.meta_parallel.mp_utils', 'paddle.framework', 'paddle.jit', 'paddle.jit.dy2static', -- GitLab