From 4bacf2abd4ca58515288396dcf8fff910dff89d0 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 9 Mar 2023 15:13:35 +0800 Subject: [PATCH] Extra Sync for Tensor Parallel (#50637) --- .../distributed/fleet/utils/__init__.py | 16 +- .../fleet/utils/tensor_parallel_utils.py | 356 ++++++++++++++++++ .../test_fleet_tensor_parallel_extra_sync.py | 152 ++++++++ 3 files changed, 519 insertions(+), 5 deletions(-) create mode 100644 python/paddle/distributed/fleet/utils/tensor_parallel_utils.py create mode 100644 python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_tensor_parallel_extra_sync.py diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 30afae2b432..ef205bb8b5f 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -21,22 +21,28 @@ from paddle.distributed import fleet import paddle from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 +from . import tensor_parallel_utils # noqa: F401 + __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa def recompute(function, *args, **kwargs): """ - recompute intermediate activations to save then memory. + recompute intermediate activations to save the memory. + Parameters: function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs to the function. - **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to - indicate whether to save the forward rng. If it is True, then the last forward rng value will be - restored when the forward recalculation of backpropagation is performed. The default - preserve_rng_state is True. + **kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params, + and the other contains ``preserve_rng_state`` and ``use_reentrant``. the key-value pair of ``preserve_rng_state``, + which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value + will be restored when the forward recalculation of backpropagation is performed, its default value is True. + the key-value pair of ``use_reentrant`` is used to indicate which implementation of recompute you will be used. + ``use_reentrant=True`` means to use the PyLayer implementation of recompute, ``use_reentrant=False`` means to + use the Hook implementation of recompute, its default value is True. Returns: Output of function on args. diff --git a/python/paddle/distributed/fleet/utils/tensor_parallel_utils.py b/python/paddle/distributed/fleet/utils/tensor_parallel_utils.py new file mode 100644 index 00000000000..85eab5e26fb --- /dev/null +++ b/python/paddle/distributed/fleet/utils/tensor_parallel_utils.py @@ -0,0 +1,356 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +logger = logging.getLogger(__name__) +formatter = logging.Formatter( + fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S' +) +ch = logging.StreamHandler() +ch.setFormatter(formatter) +logger.addHandler(ch) + +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY +from paddle.fluid import core +from paddle.fluid.framework import Parameter + +_supported_optimizer_type = [ + "adam", + "adamax", + "adamw", + "decayed_adagrad", + "momentum", + "dgc_momentum", + "lars_momentum", + "merged_momentum", + "lamb", + "sgd", +] + + +def tensor_parallel_sync_filter_fn( + param, pos_emb=True, layer_norm=True, bias=True +): + """ + Layer fliter function for tensor parallelism transformer. + + In tensor parallelism of transformer like model, there is 4 kind of param + that are supposed to be the same in all tensor parallel peers: + * position embedding + * scale of layer norm + * bias of layer norm + * bias of row parallel linear + + set corresponding input args to select specific layers. + NOTE adopting the param name pattern for different transformer blocks. + """ + p_name = param.name + if pos_emb and p_name.startswith("pos_embedding"): + return True + + elif layer_norm and p_name.endswith("_layer_norm_bias"): + return True + + elif layer_norm and p_name.endswith("_layer_norm_scale"): + return True + + elif bias and ".b_" in p_name and (param.is_distributed is False): + return True + + else: + return False + + +def resolute_tensor_parallel_ring_id(program): + ops = program.global_block().ops + ring_id = None + + for op in ops: + if op.type == "c_identity": + if ring_id is None: + ring_id = int(op.attr("ring_id")) + else: + assert ring_id == int( + op.attr("ring_id") + ), "Found two different ring_id for Tensor Parallel: ring_id={} and ring_id={}.".format( + ring_id, int(op.attr("ring_id")) + ) + assert ring_id is not None, "Could NOT found ring_id for Tensor Parallel." + + return ring_id + + +def copy_parameters(block_, params): + for param in params: + new_p = Parameter( + block=block_, + shape=param.shape, + dtype=param.dtype, + type=param.type, + lod_level=param.lod_level + if param.type == core.VarDesc.VarType.LOD_TENSOR + else None, + stop_gradient=param.stop_gradient, + trainable=param.trainable, + optimize_attr=param.optimize_attr, + regularizer=param.regularizer, + error_clip=param.error_clip, + name=param.name, + ) + assert ( + param.is_distributed is False + ), "Try to sync Distribted Parameter: {}".format(param) + new_p.is_distributed = False + + block_.vars[new_p.name] = new_p + + +def insert_sync_op( + block, idx, tp_degree, sync_mode, sync_ring_id, src_rank, varname, op_role +): + + if sync_mode == "broadcast": + block._insert_op_without_sync( + idx, + type='c_broadcast', + inputs={'X': varname}, + outputs={'Out': varname}, + attrs={ + 'ring_id': sync_ring_id, + 'root': src_rank, + 'use_calc_stream': True, + OP_ROLE_KEY: op_role, + }, + ) + + elif sync_mode == "average": + block._insert_op_without_sync( + idx, + type='scale', + inputs={'X': varname}, + outputs={'Out': varname}, + attrs={'scale': 1.0 / tp_degree, OP_ROLE_KEY: op_role}, + ) + block._insert_op_without_sync( + idx, + type='c_allreduce_sum', + inputs={'X': varname}, + outputs={'Out': varname}, + attrs={ + 'ring_id': sync_ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: op_role, + }, + ) + else: + raise NotImplementedError( + 'Sync mode of [{}] is NOT supported.'.format(sync_mode) + ) + + +def insert_synchronization( + block, + params_to_sync, + tp_degree, + sync_ring_id, + sync_param, + sync_grad, + sync_moment, + sync_mode, + src_rank, +): + + unsync_param_names = [p.name for p in params_to_sync] + + for idx, op in reversed(list(enumerate(block.ops))): + + if op.type in _supported_optimizer_type: + assert "Param" in op.input_names + assert len(op.input("Param")) == 1 + param_name = op.input("Param")[0] + op_role = op.attr(OP_ROLE_KEY) + + if param_name in unsync_param_names: + + unsync_param_names.remove(param_name) + + # Param sync after opt + if sync_param: + assert ( + "ParamOut" in op.output_names + and op.output("ParamOut")[0] == param_name + ) + insert_sync_op( + block, + idx + 1, + tp_degree, + sync_mode, + sync_ring_id, + src_rank, + param_name, + op_role, + ) + + if ( + "MasterParamOut" in op.output_names + and len(op.output("MasterParamOut")) == 1 + ): + sync_var = op.output("MasterParamOut")[0] + insert_sync_op( + block, + idx + 1, + tp_degree, + sync_mode, + sync_ring_id, + src_rank, + sync_var, + op_role, + ) + + # Moment sync after opt + if sync_moment: + if ( + "Moment1Out" in op.output_names + and len(op.output("Moment1Out")) == 1 + ): + sync_var = op.output("Moment1Out")[0] + insert_sync_op( + block, + idx + 1, + tp_degree, + sync_mode, + sync_ring_id, + src_rank, + sync_var, + op_role, + ) + + if ( + "Moment2Out" in op.output_names + and len(op.output("Moment2Out")) == 1 + ): + sync_var = op.output("Moment2Out")[0] + insert_sync_op( + block, + idx + 1, + tp_degree, + sync_mode, + sync_ring_id, + src_rank, + sync_var, + op_role, + ) + + # Grad sync before opt + if sync_grad: + assert ( + "Grad" in op.input_names and len(op.input("Grad")) == 1 + ) + sync_var = op.input("Grad")[0] + insert_sync_op( + block, + idx, + tp_degree, + sync_mode, + sync_ring_id, + src_rank, + sync_var, + op_role, + ) + + assert ( + len(unsync_param_names) == 0 + ), "The following param is unsync by some error: {}".format( + unsync_param_names + ) + + +def add_extra_synchronization( + program, + params_filter_fn=tensor_parallel_sync_filter_fn, + tp_degree=8, + sync_mode="broadcast", + sync_param=True, + sync_grad=False, + sync_moment=False, + src_rank=0, + sync_ring_id=None, +): + """ + Inplace add extra synchronization for input program. + + program(Paddle.Program): distributed train program. + + params_filter_fn(callable): function to filter out parameter for synchronization. + + sync_mode(string): select from + "broadcast": parameter is sync by broadcasted from 'src_rank' to all other ranks. + "average": paramter is sync by average amonge all ranks + + src_rank(int): the src used in broadcast sync_mode. + + sync_param(bool): extra synchronize parameters. + + sync_grad(bool): extra synchronize gradients. + + sync_grad(bool): extra synchronize optimizer momentum. + + sync_ring_id(int): communicator id use for synchronization, if it is None, use the ring_id of tensor parallel. + """ + + logger.info("Constructing Extra Parameter Synchronization.") + logger.info( + "Tensor Parallel Degree: {}, Synchronization mode: {}".format( + tp_degree, sync_mode + ) + ) + + # adopt for pipeline opt + if program._pipeline_opt is not None: + assert ( + program._pipeline_opt['section_program'] is not None + ), "Pipeline is enable but section_program is None" + program = program._pipeline_opt['section_program'] + + # step1: collect the param that need to be sync + params_to_sync = [] + # TODO support multiple blocks with different parameter. + all_params = program.global_block().all_parameters() + for param in all_params: + if params_filter_fn(param): + params_to_sync.append(param) + logger.info( + "The following param are goning to be synchronization everytime the optimizer update phase of the program is runned: " + ) + logger.info([p.name for p in params_to_sync]) + + # step2: resolute synchronization communicator group (ring_id) + if sync_ring_id is None: + sync_ring_id = resolute_tensor_parallel_ring_id(program) + + # step3: insert synchronization + # TODO support gradient merge with different update block + block = program.global_block() + insert_synchronization( + block, + params_to_sync, + tp_degree, + sync_ring_id, + sync_param, + sync_grad, + sync_moment, + sync_mode, + src_rank, + ) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_tensor_parallel_extra_sync.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_tensor_parallel_extra_sync.py new file mode 100644 index 00000000000..ba5eac00eb8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_tensor_parallel_extra_sync.py @@ -0,0 +1,152 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import paddle +import paddle.distributed.fleet as fleet + +paddle.enable_static() + + +class TensorParallelNet(paddle.fluid.dygraph.Layer): + def __init__(self, hidden_size): + super().__init__() + self.embedding = paddle.nn.Embedding(hidden_size, hidden_size) + self.col_linear = fleet.meta_parallel.ColumnParallelLinear( + in_features=hidden_size, + out_features=hidden_size, + weight_attr=None, + has_bias=True, + gather_output=False, + # name="test_column_linear", + ) + self.row_linear = fleet.meta_parallel.RowParallelLinear( + in_features=hidden_size, + out_features=hidden_size, + has_bias=True, + input_is_parallel=True, + # name="test_row_linear", + ) + self.layer_norm = paddle.nn.LayerNorm(hidden_size) + + def forward(self, x): + out = self.embedding(x) + out = self.col_linear(out) + out = self.row_linear(out) + output = self.layer_norm(out) + return output + + +def filter_fn(param, pos_emb=True, layer_norm=True, bias=True): + """ + Layer fliter function for tensor parallelism transformer. + + In tensor parallelism of transformer like model, there is 4 kind of param + that are supposed to be the same in all tensor parallel peers: + * position embedding + * scale of layer norm + * bias of layer norm + * bias of row parallel linear + + set corresponding input args to select specific layers. + NOTE adopting the param name pattern for different transformer blocks. + """ + p_name = param.name + if pos_emb and p_name.startswith("embedding"): + return True + + elif layer_norm and p_name.startswith("layer_norm"): + return True + + elif bias and ".b_" in p_name and (param.is_distributed is False): + return True + + else: + return False + + +class TestFleetMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ID"] = "1" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS" + ] = "127.0.0.1:36001,127.0.0.1:36002" + + def test_tensor_parallel_extra_sync(self): + import paddle.distributed.fleet as fleet + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {"tensor_parallel_degree": 2} + fleet.init(is_collective=True, strategy=strategy) + + main_program, startup_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, startup_program): + hidden_size = 512 + input_x = paddle.static.data( + name="x", shape=[-1, hidden_size], dtype='int64' + ) + model_a = TensorParallelNet(hidden_size) + y = model_a(input_x) + loss = paddle.mean(y) + + optimizer = paddle.fluid.optimizer.Adam(0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(loss) + ref_ops = [ + 'lookup_table_v2', + 'c_identity', + 'matmul_v2', + 'elementwise_add', + 'matmul_v2', + 'c_allreduce_sum', + 'elementwise_add', + 'layer_norm', + 'reduce_mean', + 'fill_constant', + 'reduce_mean_grad', + 'layer_norm_grad', + 'elementwise_add_grad', + 'c_identity', + 'matmul_v2_grad', + 'elementwise_add_grad', + 'matmul_v2_grad', + 'c_allreduce_sum', + 'lookup_table_v2_grad', + 'adam', + 'adam', + 'adam', + 'c_broadcast', + 'adam', + 'c_broadcast', + 'adam', + 'c_broadcast', + 'adam', + 'c_broadcast', + 'adam', + ] + paddle.distributed.fleet.utils.tensor_parallel_utils.add_extra_synchronization( + main_program, params_filter_fn=filter_fn + ) + ops = [op.type for op in main_program.global_block().ops] + self.assertTrue(ops == ref_ops) + + +if __name__ == "__main__": + unittest.main() -- GitLab