diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index a63dfd7b091f7ea16d07459cfe1f48c2bed219c1..bdb8b0a3ce2e449015007c3978d8abc38c09a68d 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -47,6 +47,7 @@ message HybridConfig { optional int32 dp_degree = 1 [ default = -1 ]; optional int32 mp_degree = 2 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ]; + optional int32 sharding_degree = 4 [ default = 1 ]; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 3f67d8ab619898b5e38299f440eef4f8c884b896..2a9b15c732541a22ff73b18b8f9aff0b6b3facc2 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -30,7 +30,7 @@ from paddle.fluid.dygraph import parallel_helper from . import topology as tp from .topology import ParallelMode from ..meta_parallel import TensorParallel, model_parallel_random_seed -from ..meta_parallel import PipelineParallel +from ..meta_parallel import PipelineParallel, ShardingParallel from ..meta_optimizers import HybridParallelOptimizer from ..meta_optimizers import HybridParallelGradScaler @@ -295,9 +295,11 @@ class Fleet(object): self.dp_degree = self.hybrid_configs["dp_degree"] self.mp_degree = self.hybrid_configs["mp_degree"] self.pp_degree = self.hybrid_configs["pp_degree"] + self.sharding_degree = self.hybrid_configs["sharding_degree"] assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0" assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0" + assert self.sharding_degree >= 0, "sharding_degree should be greater or equal to 0" self.mp_degree = max(self.mp_degree, 1) self.pp_degree = max(self.pp_degree, 1) @@ -309,8 +311,11 @@ class Fleet(object): self.dp_degree = max(self.dp_degree, 1) self._topology = tp.CommunicateTopology( - hybrid_group_names=["data", "pipe", "model"], - dims=[self.dp_degree, self.pp_degree, self.mp_degree]) + hybrid_group_names=["data", "pipe", "sharding", "model"], + dims=[ + self.dp_degree, self.pp_degree, self.sharding_degree, + self.mp_degree + ]) self._hcg = tp.HybridCommunicateGroup(self._topology) @@ -886,7 +891,11 @@ class Fleet(object): assert model is not None, "model should not be None" if self.worker_num() <= 1: return model - if self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: + + if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL: + distributed_model = ShardingParallel( + model, self._hcg, strategy=self._user_defined_strategy) + elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: distributed_model = paddle.DataParallel( model, comm_buffer_size=self._user_defined_strategy. @@ -901,6 +910,7 @@ class Fleet(object): elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: distributed_model = PipelineParallel( model, self._hcg, strategy=self._user_defined_strategy) + return distributed_model @dygraph_only diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 0eb840c08a2f8623dfc6a9ac1ea0b5f219866a29..3e89e9de181bc0684893adb079ea6dcfc8811c35 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -30,12 +30,13 @@ class ParallelMode(object): DATA_PARALLEL = 0 TENSOR_PARALLEL = 1 PIPELINE_PARALLEL = 2 + SHARDING_PARALLEL = 3 class CommunicateTopology(object): def __init__(self, - hybrid_group_names=["data", "pipe", "model"], - dims=[1, 1, 1]): + hybrid_group_names=["data", "pipe", "sharding", "model"], + dims=[1, 1, 1, 1]): self._parallel_names = hybrid_group_names self._dims = dims self.coordinate = collections.namedtuple('Coordinate', @@ -122,15 +123,17 @@ class HybridCommunicateGroup(object): self._dp_degree = self._topo.get_dim('data') self._mp_degree = self._topo.get_dim('model') self._pp_degree = self._topo.get_dim('pipe') + self._sharding_degree = self._topo.get_dim('sharding') self._data_parallel_id = self._get_data_parallel_id() self._model_parallel_id = self._get_model_parallel_id() + self._sharding_parallel_id = self._get_sharding_parallel_id() self.stage_id = self._get_pipe_parallel_id() assert self._check_vaild_topo( ), "Here is an unreasonable topogy setting. world_size: {}, but" \ - "dp_num: {}, mp_num: {}, pp_num: {}".format(self.nranks, self._dp_degree, - self._mp_degree, self._pp_degree) + "mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}".format(self.nranks, + self._mp_degree, self._sharding_degree, self._pp_degree, self._dp_degree) # create comm group for data parallel self._dp_group, self._dp_comm_group = self._set_comm_group("data") @@ -141,6 +144,10 @@ class HybridCommunicateGroup(object): # create comm group for pipe parallel self._pp_group, self._pp_comm_group = self._set_comm_group("pipe") + # create comm group for sharding parallel + self._sharding_group, self._sharding_comm_group = self._set_comm_group( + "sharding") + # create global group for check inf_nan / clip global norm self._check_group, self._check_comm_group = self._set_check_group( "data") @@ -149,19 +156,26 @@ class HybridCommunicateGroup(object): self.is_first_stage = (self.stage_id == 0) self.is_last_stage = (self.stage_id == (self._pp_degree - 1)) - debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \ - "mp_degree: %d, pp_degree: %d" % (self.global_rank, self._dp_degree, - self._mp_degree,self._pp_degree) - debug_str += ", dp_group: %s, mp_group: %s, pp_group: %s, check/clip group: %s" % ( - self._dp_group, self._mp_group, self._pp_group, self._check_group) + debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \ + "sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree, + self._sharding_degree, self._pp_degree, self._dp_degree) + debug_str += ", mp_group: %s, sharding_group: %s, pp_group: %s, dp_group: %s, check/clip group: %s" % ( + self._mp_group, self._sharding_group, self._pp_group, + self._dp_group, self._check_group) logger.info(debug_str) global _HYBRID_PARALLEL_GROUP _HYBRID_PARALLEL_GROUP = self def get_parallel_mode(self): - # there are three modes : DataParallel / TensorParallel / PipelineParallel - if self._mp_degree == 1 and self._pp_degree == 1: + # there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel + # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and + # adding its parallel logic within that parallelism + # when use sharding alone, it should have its own parallelism for its parallel logic + # TODO modify 3 others parallel to support sharding + if self._mp_degree == 1 and self._pp_degree == 1 and self._dp_degree == 1 and self._sharding_degree > 1: + return ParallelMode.SHARDING_PARALLEL + elif self._mp_degree == 1 and self._pp_degree == 1: return ParallelMode.DATA_PARALLEL elif self._mp_degree > 1 and self._pp_degree == 1: # initialize the seed @@ -170,7 +184,7 @@ class HybridCommunicateGroup(object): return ParallelMode.PIPELINE_PARALLEL def _check_vaild_topo(self): - return self._dp_degree * self._mp_degree * self._pp_degree == self.nranks + return self._dp_degree * self._mp_degree * self._pp_degree * self._sharding_degree == self.nranks def _set_comm_group(self, parallel_method="data"): parallel_group = [] @@ -255,6 +269,23 @@ class HybridCommunicateGroup(object): def get_pipe_parallel_group(self): return self._pp_comm_group + # sharding parallel message: + def _get_sharding_parallel_id(self): + return self._topo.get_coord(self.global_rank).sharding + + def get_sharding_parallel_rank(self): + return self._sharding_parallel_id + + def get_sharding_parallel_world_size(self): + return self._sharding_degree + + def get_sharding_parallel_group(self): + return self._sharding_comm_group + + def get_sharding_parallel_group_src_rank(self): + # TODO should the src rank related to the shard rank for each parameter ? + return self._sharding_comm_group.ranks[0] + # check parallel group def get_check_parallel_group(self): return self._check_comm_group diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py new file mode 100755 index 0000000000000000000000000000000000000000..4bddde6b5b62e6a09d6b64194f8cc5bbe4e976e7 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -0,0 +1,198 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +###### +from functools import reduce + +import paddle +from paddle import framework +from ...utils.log_util import logger + + +def _is_trainable(param: paddle.Tensor) -> bool: + return not param.stop_gradient + + +class DygraphShardingOptimizer(object): + """ + A wrapper for Sharding Optimizer in Dygraph. + + .. warning: DygraphShardingOptimizer is experimental and subject to change. + + .. ZeRO: https://arxiv.org/abs/1910.02054 + + """ + + # TODO (JZ-LIANG) + # TO support following featrues in future: + # 1. fused update parameter sync + # 2. parameters_groups + # 3. dynamic trainable params, which is the case bewteen pretraining and finetuning + # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm + + def __init__( + self, + hcg, + user_defined_strategy, + params, + inner_optimizer_class, + **inner_optimizer_kargs, ): + + if not isinstance(params, list): + raise TypeError( + "`parameters` argument given to the DygraphShardingOptimizer should be " + "an iterable of paddle Tensors, but got argument type is `{}`.". + format(type(params))) + self._parameter_list = params + self._reference_is_trainable_params = list( + map(_is_trainable, self._parameter_list)) + + self._inner_optimizer_class = inner_optimizer_class + self._inner_optimizer_kargs = inner_optimizer_kargs + + # sharding parallel information + # TODO better way to get the hcg & user_defined_strategy + self._hcg = hcg + self._user_defined_strategy = user_defined_strategy + self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() + self._sharding_rank = self._hcg.get_sharding_parallel_rank() + + # logic partitioning + self._build_sharding_mapping() + + # actually create opt ops + self._buid_inner_optimizer() + + def clear_grad(self): + """ + should clear grad for all parameters in model + """ + for p in self._parameter_list: + if not p.stop_gradient: + p.clear_gradient() + + def _build_sharding_mapping(self): + + self._rank2params = self._partition_parameters() + self._param2rank = self._map_param_to_rank() + + def _partition_parameters(self): + """ + Partitions parameters among sharding ranks. + + Return: + Dict[int, List] + """ + # TODO(JZ-LIANG) support multiple partition methods + # method1: greedy even but unorder + # method2: roughly even with oreder + + mapping = {} + for rank_ in range(self._sharding_world_size): + mapping[rank_] = [] + sizes = [0] * self._sharding_world_size + for param in self._parameter_list: + rank = sizes.index(min(sizes)) + mapping[rank].append(param) + numel = reduce(lambda x, y: x * y, param.shape) + assert numel > 0, "param [{}] should larger than 0, but it is [{}]".format( + param.name, numel) + sizes[rank] += numel + + return mapping + + def _map_param_to_rank(self): + """ + mapping parameters to the shard which holds it. + + Return: + Dict[str, int] + """ + mapping = {} + for rank, params in self._rank2params.items(): + for param in params: + mapping[param.name] = rank + return mapping + + def _buid_inner_optimizer(self): + # we rely on the inner opt to determine whether a parameter is stop_gradient or not: + # create moment + # update related ops: clip, regular, opt + self._inner_optimizer = self._inner_optimizer_class( + parameters=self._rank2params[self._sharding_rank], + **self._inner_optimizer_kargs) + + def _sharding_sync_parameters(self): + """ + sync parameter across sharding group + """ + # TODO speed up this functional + + logger.debug("sharding start sync parameters") + with framework.no_grad(): + # TODO detach not need (?) + for rank, params in self._rank2params.items(): + for param in params: + paddle.distributed.broadcast( + param, + # the collective API need src rank to be the global rank id + # instead of the relative logic rank id within group + src=self._hcg.get_sharding_parallel_group().ranks[rank], + group=self._hcg.get_sharding_parallel_group(), + use_calc_stream=True) + + def _update_trainable(self): + """ + allow user to update trainable parameters list during training + """ + raise NotImplementedError + + def minimize(self, + loss, + startup_program=None, + parameters=None, + no_grad_set=None): + + # NOTE in dygraph mode, the only different between step and minimize is that minimize + # allow user to customize the parameters for updating on each step + + input_param_names = set([param.name for param in parameters]) + parameters = list( + filter(lambda x: x.name in input_param_names, self._rank2params[ + self._sharding_rank])) + result = self._inner_optimizer.minimize(loss, startup_program, + parameters, no_grad_set) + + # sync parameters accross sharding ranks + self._sharding_sync_parameters() + + return result + + def step(self): + # TODO Check whether the model trainable param changed and update state accordingly + + # actually updating + self._inner_optimizer.step() + + # sync parameters accross sharding ranks + self._sharding_sync_parameters() + + # TODO is it a good way to make _grad_clip a property + @property + def _grad_clip(self): + assert self._inner_optimizer is not None, "inner opt of sharding is not initiliazed." + return self._inner_optimizer._grad_clip + + def __getattr__(self, item): + return getattr(self._inner_optimizer, item) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index bceabeee3c3dce9f355bb9a31a037a13cca4edd3..e3a5947bf60fc1aa152dd1ecfd89689cc204536e 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -17,7 +17,7 @@ import sys import paddle from paddle.optimizer import Optimizer from paddle.fluid.clip import ClipGradByGlobalNorm -from ...utils.hybrid_parallel_util import fused_allreduce_gradients +from ...utils.hybrid_parallel_util import fused_allreduce_gradients, sharding_reduce_gradients from ...base.topology import ParallelMode from paddle.fluid.dygraph import base as imperative_base from paddle.fluid import framework @@ -98,6 +98,9 @@ class HybridParallelOptimizer: self._need_dp = (self._hcg.get_data_parallel_world_size() > 1) + self._sharding_enable = ( + self._hcg.get_sharding_parallel_world_size() > 1) + if isinstance(self._inner_opt._grad_clip, ClipGradByGlobalNorm) and not self._use_dp_mode: logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \ @@ -108,6 +111,11 @@ class HybridParallelOptimizer: @imperative_base.no_grad @framework.dygraph_only def step(self): + # Here should use global parameter list + if self._sharding_enable: + sharding_reduce_gradients( + list(self._inner_opt._parameter_list), self._hcg) + if not self._use_dp_mode and self._need_dp: fused_allreduce_gradients( list(self._inner_opt._parameter_list), self._hcg) @@ -119,15 +127,19 @@ class HybridParallelOptimizer: startup_program=None, parameters=None, no_grad_set=None): - assert isinstance(loss, Variable), "The loss should be an Tensor." parameter_list = parameters if parameters \ - else self._parameter_list + else self._inner_opt._parameter_list + + # Here should use global parameter list + if self._sharding_enable: + sharding_reduce_gradients( + list(self._inner_opt._parameter_list), self._hcg) if not self._use_dp_mode and self._need_dp: fused_allreduce_gradients(list(parameter_list), self._hcg) - return self._inner_opt.minimize(loss, startup_program, parameters, + return self._inner_opt.minimize(loss, startup_program, parameter_list, no_grad_set) def __getattr__(self, item): diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py index 4e32ff5723c4181892620281e21232e7f2267f08..fe7f23f3d8cc33ef0c977c1ce9c07039eb67dc83 100644 --- a/python/paddle/distributed/fleet/meta_parallel/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -24,5 +24,6 @@ from .parallel_layers import model_parallel_random_seed # noqa: F401 from .parallel_layers import get_rng_state_tracker # noqa: F401 from .tensor_parallel import TensorParallel # noqa: F401 from .pipeline_parallel import PipelineParallel # noqa: F401 +from .sharding_parallel import ShardingParallel # noqa: F401 __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py b/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..953a76d874e558ccfce0188805f5fe961cf9426f --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/sharding_parallel.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.dygraph.layers import Layer +from .meta_parallel_base import MetaParallelBase +from ..utils.hybrid_parallel_util import broadcast_sharding_parameters +from ..utils.log_util import logger + +__all__ = [] + + +class ShardingParallel(MetaParallelBase): + def __init__(self, layers, hcg, **kwargs): + super(ShardingParallel, self).__init__(layers, hcg, **kwargs) + + def _prepare_for_model(self): + logger.info("start broadcast sharding parameters") + broadcast_sharding_parameters(self._layers, self._hcg) + + # TODO (JZ-LIANG) to support Sharding-DP + + logger.info("sharding's parameters is ready") diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index ddbd6111b460994f93adc90751c339db0c808234..81bed60050de2992d1ad35190477a743a39b8f9f 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -119,3 +119,46 @@ def fused_allreduce_gradients(parameter_list, hcg): logger.debug("dp start fuse allreduce gradients") with framework.no_grad(): _apply_collective_grads(parameter_list, data_parallel_group) + + +def sharding_reduce_gradients(parameter_list, hcg): + # TODO allreduce --> reduce + # TODO merge grad / nrank with dp + logger.debug("sharding start gradients sync") + with framework.no_grad(): + + sharding_nrank = hcg.get_sharding_parallel_group().nranks + for param in parameter_list: + if param.trainable and (param._grad_ivar() is not None): + + g_var = param._grad_ivar() + + # need use trace_op to allreduce + # paddle.distributed.all_reduce( + # g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True) + paddle.fluid.framework._dygraph_tracer().trace_op( + type="c_allreduce_sum", + inputs={'X': g_var}, + outputs={'Out': g_var}, + attrs={ + 'ring_id': hcg.get_sharding_parallel_group().id, + 'use_calc_stream': True + }) + + # grad / sharding_rank + div_factor = paddle.to_tensor(sharding_nrank, dtype=g_var.dtype) + paddle.fluid.framework._dygraph_tracer().trace_op( + type="elementwise_div", + inputs={'X': g_var, + 'Y': div_factor}, + outputs={'Out': g_var}, + attrs={'axis': -1}) + + +def broadcast_sharding_parameters(model, hcg): + # TODO TO save memory, use un-fused broadcast to avoid potentional OOM + logger.debug("sharding start init parameters sync") + sharding_parallel_group = hcg.get_sharding_parallel_group() + src_rank = hcg.get_sharding_parallel_group_src_rank() + sync_params_buffers( + model, sharding_parallel_group, src_rank, is_model_parallel=False) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 9bb88abcea9cb249be2d413952a2471273e1f23e..21d241224cab43b075d236716516af42794fc647 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -25,6 +25,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow) list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. @@ -185,6 +186,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel) + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single) @@ -882,6 +884,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200) + set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_communicate_group.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_communicate_group.py index 0a9785475b561a9f35ecd34532558f176cf77e03..53d0f95a2366720676f4d4f14e1aef55c56488a3 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_communicate_group.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_communicate_group.py @@ -21,7 +21,8 @@ from paddle.distributed import fleet class TestNewGroupAPI(object): def __init__(self): paddle.distributed.init_parallel_env() - topo = fleet.CommunicateTopology(["data", "model", "pipe"], [2, 1, 1]) + topo = fleet.CommunicateTopology(["data", "model", "sharding", "pipe"], + [2, 1, 1, 1]) self.hcg = fleet.HybridCommunicateGroup(topo) d1 = np.array([1, 2, 3]) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2995e4dbf84018fae3782b72325dec0ae81faada --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py @@ -0,0 +1,297 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +from paddle.io import DataLoader, Dataset +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import DygraphShardingOptimizer +import unittest + +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 2 +batch_size = 4 +STEPS = 10 + + +def parallel_matmul(lm_output, logit_weights, parallel_output): + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + world_size = hcg.get_model_parallel_world_size() + rank = hcg.get_model_parallel_rank() + + if world_size > 1: + input_parallel = paddle.distributed.collective._c_identity( + lm_output, group=model_parallel_group) + + logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True) + + if parallel_output: + return logits + + return paddle.distributed.collective._c_concat( + logits, group=model_parallel_group) + else: + logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) + return logits + + +class SimpleMPNet(fluid.dygraph.Layer): + def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1, + np_fc2, mp_id): + super(SimpleMPNet, self).__init__() + + if mp_id == 0: + init_fc1_data = np_fc1[:, :(inner_size // 2)] + init_fc2_data = np_fc2[:(inner_size // 2), :] + else: + init_fc1_data = np_fc1[:, (inner_size // 2):] + init_fc2_data = np_fc2[(inner_size // 2):, :] + + self.linear1 = fleet.meta_parallel.ColumnParallelLinear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc1_data)), + gather_output=False, + has_bias=True) + + self.linear2 = fleet.meta_parallel.RowParallelLinear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc2_data)), + input_is_parallel=True, + has_bias=True) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0))) + + self.embedding = fleet.meta_parallel.VocabParallelEmbedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5)) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = parallel_matmul(x, self.embedding.weight, False) + return x + + +class SimpleDPNet(fluid.dygraph.Layer): + def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1, + np_fc2): + + super(SimpleDPNet, self).__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0))) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0))) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0)), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0))) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5)) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class TestDistMPTraning(unittest.TestCase): + def setUp(self): + random.seed(2021) + np.random.seed(2021) + paddle.seed(2021) + + self.strategy = fleet.DistributedStrategy() + self.strategy.hybrid_configs = { + "sharding_degree": 2, + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=self.strategy) + self.data = [ + np.random.randint(0, vocab_size, ( + batch_size, + seq_length, )) for _ in range(STEPS) + ] + + def train_batch(self, batch, model, optimizer): + + output = model(batch) + loss = output.mean() + loss.backward() # do backward + optimizer.step() # update parameters + optimizer.clear_grad() + return loss + + def build_optimizer(self, + model, + strategy=None, + is_sharding=True, + Optimizer="adam"): + + if Optimizer == "adam": + if is_sharding: + optimizer = DygraphShardingOptimizer( + hcg=fleet.get_hybrid_communicate_group(), + user_defined_strategy=strategy, + params=model.parameters(), + inner_optimizer_class=paddle.optimizer.Adam, + learning_rate=0.001, + weight_decay=0.00001, ) + else: + optimizer = paddle.optimizer.Adam( + parameters=model.parameters(), + learning_rate=0.001, + weight_decay=0.00001, ) + else: + if is_sharding: + optimizer = DygraphShardingOptimizer( + hcg=fleet.get_hybrid_communicate_group(), + user_defined_strategy=strategy, + params=model.parameters(), + inner_optimizer_class=paddle.optimizer.Momentum, + learning_rate=0.001, ) + else: + optimizer = paddle.optimizer.Momentum( + learning_rate=0.001, parameters=model.parameters()) + return optimizer + + def build_model_optimizer(self, Optimizer="adam"): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + sharding_id = hcg.get_sharding_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size, + np_fc1, np_fc2) + optimizer_a = self.build_optimizer( + model_a, + strategy=self.strategy, + is_sharding=True, + Optimizer=Optimizer) + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size, + np_fc1, np_fc2) + optimizer_b = self.build_optimizer( + model_b, + strategy=self.strategy, + is_sharding=False, + Optimizer=Optimizer) + + return model_a, optimizer_a, model_b, optimizer_b + + def sharding_model(self, Optimizer, sharded_accumulators): + model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( + Optimizer=Optimizer) + + self.assertTrue( + isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer)) + + for idx in range(STEPS): + + if idx == 2 and paddle.distributed.get_rank() == 0: + self.assertTrue( + set(optimizer_a._inner_opt._inner_optimizer.state_dict() + .keys()) == sharded_accumulators) + + if paddle.distributed.get_rank() == 0: + batch_sharding = paddle.to_tensor(self.data[idx][:2]) + else: + batch_sharding = paddle.to_tensor(self.data[idx][2:]) + + batch_single = paddle.to_tensor(self.data[idx]) + loss_a = self.train_batch(batch_sharding, model_a, optimizer_a) + loss_b = self.train_batch(batch_single, model_b, optimizer_b) + + for j in range(len(model_a.parameters())): + np.testing.assert_allclose( + model_a.parameters()[j].numpy(), + model_b.parameters()[j].numpy(), + rtol=1e-6) + + def test_sharding_adam(self): + sharded_accumulators = set([ + 'linear_0.w_0_moment1_0', 'linear_1.b_0_moment1_0', + 'linear_2.b_0_moment1_0', 'embedding_0.w_0_moment1_0', + 'linear_0.w_0_moment2_0', 'linear_1.b_0_moment2_0', + 'linear_2.b_0_moment2_0', 'embedding_0.w_0_moment2_0', + 'linear_0.w_0_beta1_pow_acc_0', 'linear_1.b_0_beta1_pow_acc_0', + 'linear_2.b_0_beta1_pow_acc_0', 'embedding_0.w_0_beta1_pow_acc_0', + 'linear_0.w_0_beta2_pow_acc_0', 'linear_1.b_0_beta2_pow_acc_0', + 'linear_2.b_0_beta2_pow_acc_0', 'embedding_0.w_0_beta2_pow_acc_0' + ]) + self.sharding_model( + Optimizer="adam", sharded_accumulators=sharded_accumulators) + + def test_sharding_momentum(self): + sharded_accumulators = set([ + 'linear_6.w_0_velocity_0', 'linear_7.b_0_velocity_0', + 'linear_8.b_0_velocity_0', 'embedding_2.w_0_velocity_0' + ]) + self.sharding_model( + Optimizer="Momentum", sharded_accumulators=sharded_accumulators) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_hybrid_parallel_topology.py b/python/paddle/fluid/tests/unittests/test_hybrid_parallel_topology.py index e4c469599d72c05849b34440faa586fa5f66d7e4..e8300113ddc42e789b6b10c1ada01a779317ce06 100644 --- a/python/paddle/fluid/tests/unittests/test_hybrid_parallel_topology.py +++ b/python/paddle/fluid/tests/unittests/test_hybrid_parallel_topology.py @@ -79,6 +79,99 @@ class TestCommunicateTopology(unittest.TestCase): self.assertEqual(topo.get_dim_size("mp"), 2) self.assertEqual(topo.get_dim_size("pp"), 2) + def test_topology_4D(self): + topo = fleet.CommunicateTopology(["dp", "pp", "sharding", "mp"], + [2, 2, 2, 2]) + + # test get_comm_list + dp_comm_list = [[0, 8], [1, 9], [2, 10], [3, 11], [4, 12], [5, 13], + [6, 14], [7, 15]] + mp_comm_list = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], + [12, 13], [14, 15]] + pp_comm_list = [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13], + [10, 14], [11, 15]] + sharding_comm_list = [[0, 2], [1, 3], [4, 6], [5, 7], [8, 10], [9, 11], + [12, 14], [13, 15]] + + np.testing.assert_array_equal(dp_comm_list, topo.get_comm_list("dp")) + np.testing.assert_array_equal(mp_comm_list, topo.get_comm_list("mp")) + np.testing.assert_array_equal(pp_comm_list, topo.get_comm_list("pp")) + np.testing.assert_array_equal(sharding_comm_list, + topo.get_comm_list("sharding")) + + # test get_hybrid_group_names + parallel_names = ["dp", "pp", "sharding", "mp"] + np.testing.assert_array_equal(parallel_names, + topo.get_hybrid_group_names()) + + # test get_dims + np.testing.assert_array_equal(2, topo.get_dim("dp")) + np.testing.assert_array_equal(2, topo.get_dim("mp")) + np.testing.assert_array_equal(2, topo.get_dim("pp")) + np.testing.assert_array_equal(2, topo.get_dim("sharding")) + + # test world size + self.assertEqual(topo.world_size(), 16) + + # test get_rank + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, mp=0), 0) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, mp=1), 1) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, mp=0), 2) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, mp=1), 3) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, mp=0), 4) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, mp=1), 5) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, mp=0), 6) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, mp=1), 7) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, mp=0), 8) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, mp=1), 9) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, mp=0), 10) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, mp=1), 11) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, mp=0), 12) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, mp=1), 13) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, mp=0), 14) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, mp=1), 15) + + # test get_coord + self.assertEqual(topo.get_coord(0), topo.coordinate(0, 0, 0, 0)) + self.assertEqual(topo.get_coord(1), topo.coordinate(0, 0, 0, 1)) + self.assertEqual(topo.get_coord(2), topo.coordinate(0, 0, 1, 0)) + self.assertEqual(topo.get_coord(3), topo.coordinate(0, 0, 1, 1)) + self.assertEqual(topo.get_coord(4), topo.coordinate(0, 1, 0, 0)) + self.assertEqual(topo.get_coord(5), topo.coordinate(0, 1, 0, 1)) + self.assertEqual(topo.get_coord(6), topo.coordinate(0, 1, 1, 0)) + self.assertEqual(topo.get_coord(7), topo.coordinate(0, 1, 1, 1)) + self.assertEqual(topo.get_coord(8), topo.coordinate(1, 0, 0, 0)) + self.assertEqual(topo.get_coord(9), topo.coordinate(1, 0, 0, 1)) + self.assertEqual(topo.get_coord(10), topo.coordinate(1, 0, 1, 0)) + self.assertEqual(topo.get_coord(11), topo.coordinate(1, 0, 1, 1)) + self.assertEqual(topo.get_coord(12), topo.coordinate(1, 1, 0, 0)) + self.assertEqual(topo.get_coord(13), topo.coordinate(1, 1, 0, 1)) + self.assertEqual(topo.get_coord(14), topo.coordinate(1, 1, 1, 0)) + self.assertEqual(topo.get_coord(15), topo.coordinate(1, 1, 1, 1)) + + # test get_axis_list + self.assertEqual(topo.get_axis_list("dp", 0), [0, 1, 2, 3, 4, 5, 6, 7]) + self.assertEqual( + topo.get_axis_list("dp", 1), [8, 9, 10, 11, 12, 13, 14, 15]) + self.assertEqual( + topo.get_axis_list("mp", 0), [0, 2, 4, 6, 8, 10, 12, 14]) + self.assertEqual( + topo.get_axis_list("mp", 1), [1, 3, 5, 7, 9, 11, 13, 15]) + self.assertEqual( + topo.get_axis_list("pp", 0), [0, 1, 2, 3, 8, 9, 10, 11]) + self.assertEqual( + topo.get_axis_list("pp", 1), [4, 5, 6, 7, 12, 13, 14, 15]) + self.assertEqual( + topo.get_axis_list("sharding", 0), [0, 1, 4, 5, 8, 9, 12, 13]) + self.assertEqual( + topo.get_axis_list("sharding", 1), [2, 3, 6, 7, 10, 11, 14, 15]) + + # test get_dim_size + self.assertEqual(topo.get_dim_size("dp"), 2) + self.assertEqual(topo.get_dim_size("mp"), 2) + self.assertEqual(topo.get_dim_size("pp"), 2) + self.assertEqual(topo.get_dim_size("sharding"), 2) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py index f3cd97ee1ec86916ecebb8ddf0895443c6e14567..d15e55eb0fa1460b60b1b06582ad049875c7e54e 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -124,6 +124,8 @@ class TestMultipleGpus(unittest.TestCase): break time.sleep(3) + +class TestDataParallelGradientCheck(TestMultipleGpus): def test_multiple_gpus_dynamic(self): self.run_mnist_2gpu('parallel_dygraph_gradient_check.py') diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sharding_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sharding_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..b7e8e06029d937d091759a6a80f0b11d42ca7189 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sharding_parallel.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestHybridParallel(TestMultipleGpus): + + # check sharding logic as well as the accuracy with single mode + def test_hybrid_parallel_sharding_logic(self): + self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') + + +if __name__ == "__main__": + unittest.main()