From 91a0acdb2390d196caf181db2c112a34e12bc037 Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 25 Jun 2021 16:02:09 +0800 Subject: [PATCH] static support mp_layers (#33700) --- python/paddle/distributed/collective.py | 26 ++- .../distributed/fleet/base/fleet_base.py | 34 ++++ .../paddle/distributed/fleet/base/topology.py | 28 +++ .../parallel_layers/mp_layers.py | 6 +- .../fluid/tests/unittests/CMakeLists.txt | 2 + .../unittests/test_fleet_static_mp_layers.py | 183 ++++++++++++++++++ 6 files changed, 273 insertions(+), 6 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 1a09cf5394f..3f0d97075c8 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -92,8 +92,6 @@ class Group(): return True def get_group_rank(self, rank): - if self.id == 0: - return rank if self.is_member() and rank in self.ranks: return self.ranks.index(rank) else: @@ -126,7 +124,8 @@ def _get_group_map(): global _group_map if not _group_map: genv = _get_global_env() - _group_map[0] = Group(genv.rank, genv.world_size, 0) + _group_map[0] = Group(genv.rank, genv.world_size, + list(range(genv.world_size))) return _group_map @@ -1014,6 +1013,27 @@ def _c_softmax_with_cross_entropy(logits, else: return loss, softmax + attrs = { + 'ring_id': ring_id, + 'rank': rank, + 'nranks': nranks, + } + helper = LayerHelper('c_softmax_with_cross_entropy', **locals()) + softmax = helper.create_variable_for_type_inference(dtype=logits.dtype) + loss = helper.create_variable_for_type_inference(dtype=logits.dtype) + helper.append_op( + type='c_softmax_with_cross_entropy', + inputs={'Logits': logits, + 'Label': label}, + outputs={'Softmax': softmax, + 'Loss': loss}, + attrs=attrs) + + if return_softmax: + return loss, softmax + + return loss + def _linear(x, weight, bias=None, name=None): """ diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 9e5a31d6899..3f67d8ab619 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -253,6 +253,40 @@ class Fleet(object): warnings.warn( "The dygraph hybrid parallel environment has been initialized." ) + elif self._is_collective: + use_sharding = self._user_defined_strategy.sharding + + # global group + global_rank = self.worker_index() + global_world_size = self.worker_num() + # NOTE(wangxi): see sharding_optimizer + global_ring_id = 3 if use_sharding else 0 + global_ranks = list(range(global_world_size)) + + if tp._HYBRID_PARALLEL_GROUP is None: tp._CommunicateGroup() + cg = tp._HYBRID_PARALLEL_GROUP + self._hcg = cg + cg.set_comm_group('global', global_rank, global_world_size, + global_ring_id, global_ranks) + + # hybrid group + if use_sharding is False: return + + sharding_configs = self._user_defined_strategy.sharding_configs + mp_degree = int(sharding_configs['mp_degree']) + + if mp_degree > 1: + assert global_world_size % mp_degree == 0 + # NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups + mp_ring_id = 0 + mp_rank = global_rank % mp_degree + mp_group_id = global_rank // mp_degree + mp_group_ranks = [ + idx for idx in global_ranks + if idx // mp_degree == mp_group_id + ] + cg.set_comm_group('model', mp_rank, mp_degree, mp_ring_id, + mp_group_ranks) def _init_hybrid_parallel_env(self): """initialize the hybrid environment diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 850f3581421..0eb840c08a2 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -262,3 +262,31 @@ class HybridCommunicateGroup(object): def get_rank_from_stage(self, stage_id, **kwargs): return self._topo.get_rank_from_stage( self.global_rank, pipe=stage_id, **kwargs) + + +class _CommunicateGroup(object): + """ tmp for static """ + + def __init__(self): + global _HYBRID_PARALLEL_GROUP + _HYBRID_PARALLEL_GROUP = self + self.groups = dict() + + def set_comm_group(self, group_name, group_rank, group_size, ring_id, + group_ranks): + group = paddle.distributed.collective.Group(group_rank, group_size, + ring_id, group_ranks) + self.groups[group_name] = group + + def get_group(self, group_name): + assert group_name in self.groups + return self.groups[group_name] + + def get_model_parallel_group(self): + return self.get_group('model') + + def get_model_parallel_world_size(self): + return self.get_group('model').nranks + + def get_model_parallel_rank(self): + return self.get_group('model').rank diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py index f091c890f68..2555d73462b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py @@ -56,7 +56,7 @@ class VocabParallelEmbedding(Layer): self._weight_attr = weight_attr self._name = name - if self.is_mp: + if self.is_mp and paddle.in_dynamic_mode(): with get_rng_state_tracker().rng_state(): self.weight = self.create_parameter( attr=self._weight_attr, @@ -121,7 +121,7 @@ class ColumnParallelLinear(Layer): self._weight_attr = weight_attr self._dtype = self._helper.get_default_dtype() - if self.is_mp: + if self.is_mp and paddle.in_dynamic_mode(): with get_rng_state_tracker().rng_state(): self.weight = self.create_parameter( shape=[in_features, self.output_size_per_partition], @@ -198,7 +198,7 @@ class RowParallelLinear(Layer): self.input_size_per_partition = in_features // self.world_size - if self.is_mp: + if self.is_mp and paddle.in_dynamic_mode(): with get_rng_state_tracker().rng_state(): self.weight = self.create_parameter( shape=[self.input_size_per_partition, self.out_features], diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 023b092b774..9bb88abcea9 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -70,6 +70,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor) list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base) list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy) list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -525,6 +526,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS}) py_test_modules(test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS ${dist_ENVS}) py_test_modules(test_fleet_distributed_strategy MODULES test_fleet_distributed_strategy) + py_test_modules(test_fleet_static_mp_layers MODULES test_fleet_static_mp_layers) #py_test_modules(test_fleet_auto MODULES test_fleet_auto ENVS ${dist_ENVS}) if(NOT WIN32) py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py b/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py new file mode 100644 index 00000000000..6c7fab25a30 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py @@ -0,0 +1,183 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +from paddle import framework +import os + +paddle.enable_static() + + +class ColumnLinearNet(fluid.dygraph.Layer): + def __init__(self, input_size, output_size): + super(ColumnLinearNet, self).__init__() + self.parallel_linear = fleet.meta_parallel.ColumnParallelLinear( + in_features=input_size, + out_features=output_size, + weight_attr=None, + has_bias=True, + gather_output=True, + name="test_column_linear") + + def forward(self, x): + output = self.parallel_linear(x) + return output + + +class RowLinearNet(fluid.dygraph.Layer): + def __init__(self, input_size, output_size): + super(RowLinearNet, self).__init__() + self.parallel_linear = fleet.meta_parallel.RowParallelLinear( + in_features=input_size, + out_features=output_size, + has_bias=True, + input_is_parallel=False, + name="test_row_linear") + + def forward(self, x): + output = self.parallel_linear(x) + return output + + +class EmbeddingNet(fluid.dygraph.Layer): + def __init__(self, vocab_size, hidden_size): + super(EmbeddingNet, self).__init__() + self.embedding = fleet.meta_parallel.VocabParallelEmbedding(vocab_size, + hidden_size) + + def forward(self, x): + output = self.embedding(x) + return output + + +class TestDistTraning(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ID"] = "2" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002,127.0.0.1:36003,127.0.0.1:36004" + + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.sharding = True + strategy.sharding_configs = { + "mp_degree": self.model_parallel_size, + "sharding_degree": 2, + } + fleet.init(is_collective=True, strategy=strategy) + + def get_program(self): + return paddle.static.Program(), paddle.static.Program() + + def test_column_parallel_layer(self): + main_program, startup_program = self.get_program() + with paddle.static.program_guard(main_program, startup_program): + input_size, output_size = 28, 64 + model_a = ColumnLinearNet(input_size, output_size) + + x = paddle.static.data(name='x', shape=[None, input_size]) + y = model_a(x) + + #print(main_program) + ops = main_program.global_block().ops + ops = [op.type for op in ops] + self.assertEqual( + ops, ['c_identity', 'matmul', 'elementwise_add', 'c_concat']) + + weight = model_a.parallel_linear.weight + bias = model_a.parallel_linear.bias + self.assertEqual(weight.shape, (input_size, output_size // + self.model_parallel_size)) + self.assertEqual(bias.shape, + (output_size // self.model_parallel_size, )) + + def test_row_parallel_layer(self): + main_program, startup_program = self.get_program() + with paddle.static.program_guard(main_program, startup_program): + input_size, output_size = 28, 64 + model_a = RowLinearNet(input_size, output_size) + + x = paddle.static.data(name='x', shape=[None, input_size]) + y = model_a(x) + + #print(main_program) + ops = main_program.global_block().ops + ops = [op.type for op in ops] + self.assertEqual( + ops, + ['c_split', 'matmul', 'c_allreduce_sum', 'elementwise_add']) + + weight = model_a.parallel_linear.weight + bias = model_a.parallel_linear.bias + self.assertEqual(weight.shape, ( + input_size // self.model_parallel_size, output_size)) + self.assertEqual(bias.shape, (output_size, )) + + def test_parallel_embedding(self): + main_program, startup_program = self.get_program() + with paddle.static.program_guard(main_program, startup_program): + vocab_size, hidden_size = 1000, 512 + seq_len = 128 + + # model_a + model_a = EmbeddingNet(vocab_size, hidden_size) + + x = paddle.static.data( + name='x', shape=[None, seq_len], dtype='int64') + y = model_a(x) + + #print(main_program) + ops = main_program.global_block().ops + ops = [op.type for op in ops] + self.assertEqual(ops, ['c_embedding', 'c_allreduce_sum']) + + weight = model_a.embedding.weight + self.assertEqual(weight.shape, ( + vocab_size // self.model_parallel_size, hidden_size)) + + def test_parallel_cross_entropy(self): + main_program, startup_program = self.get_program() + with paddle.static.program_guard(main_program, startup_program): + batch_size = 8 + seq_length = 16 + class_size = 1000 + class_size_per_card = class_size // self.model_parallel_size + + # model_a + model_a = fleet.meta_parallel.ParallelCrossEntropy() + + x = paddle.static.data( + name='x', shape=[batch_size, seq_length, class_size_per_card]) + label = paddle.static.data( + name='label', shape=[batch_size, seq_length], dtype='int64') + loss_a = model_a(x, label) + + #print(main_program) + ops = main_program.global_block().ops + ops = [op.type for op in ops] + self.assertEqual(ops, + ['unsqueeze2', 'c_softmax_with_cross_entropy']) + + +if __name__ == '__main__': + unittest.main() -- GitLab