From e5eb3f55f645b55456fb8c197f1d6eba79e511e6 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 8 Nov 2022 20:13:11 +0800 Subject: [PATCH] =?UTF-8?q?[Auto=20Parallel]=20Sharding=20Optimization?= =?UTF-8?q?=EF=BC=9APartition=20Algorithm=20&=20Stage2=20Parameter=20Bucke?= =?UTF-8?q?t=20communication=20=20(#47180)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * partition param by order * add logging * reorder opt * config * stage2 bucket * update unitest --- .../distributed/auto_parallel/constants.py | 4 +- .../paddle/distributed/auto_parallel/utils.py | 13 + ...uto_parallel_data_parallel_optimization.py | 14 +- .../passes/auto_parallel_sharding.py | 370 ++++++++++++++++-- .../unittests/auto_parallel/test_strategy.py | 4 +- 5 files changed, 370 insertions(+), 35 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 82c5011faf0..51afad94c53 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -82,7 +82,9 @@ SHARDING = "sharding" set_field_default_config(SHARDING, "enable", False) set_field_default_config(SHARDING, "stage", 1) set_field_default_config(SHARDING, "degree", 8) -set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0) +set_field_default_config(SHARDING, "overlap_grad_comm", False) +set_field_default_config(SHARDING, "bucket_size_numel", -1) +set_field_default_config(SHARDING, "partition_algor", "greedy_even") set_field_default_config(SHARDING, "enable_tuning", False) set_field_default_config(SHARDING, "tuning_range", []) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index d9357db3714..bf4aa34303a 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -22,6 +22,7 @@ import logging from functools import reduce import paddle.fluid.core as core +from paddle.fluid.framework import Variable from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.auto_parallel.process_group import ( get_all_process_groups, @@ -1790,6 +1791,18 @@ def find_higher_order_backward_op(program): return False +def get_var_numel(var): + """ + input: + - var: variable + return: + number of elemnet in var + """ + assert isinstance(var, Variable) + assert -1 not in var.shape + return reduce(lambda x, y: x * y, var.shape) + + def get_lr(optimizer): if isinstance(optimizer, paddle.optimizer.Optimizer): return optimizer.get_lr() diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index ec3d799ee84..cbc9170a1e4 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -13,12 +13,12 @@ # limitations under the License. from collections import OrderedDict -import numpy as np import paddle from paddle.fluid import unique_name from paddle.fluid.framework import default_main_program from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole +from .pass_base import PassBase, PassType, register_pass from paddle.distributed.auto_parallel.operators.common import ( is_data_parallel_scale_op, is_data_parallel_reduce_op, @@ -28,8 +28,8 @@ from paddle.distributed.auto_parallel.utils import ( is_loss_grad_op, is_optimize_op, ring_id_to_process_group, + get_var_numel, ) -from .pass_base import PassBase, PassType, register_pass # add new optimizers supporting rescale_grad here __rescale_grad_supported_opts__ = [ @@ -44,10 +44,6 @@ __rescale_grad_supported_opts__ = [ __max_stream_num_allow__ = 16 -def numel(var): - return np.prod(list(var.shape)) - - @register_pass("auto_parallel_data_parallel_optimization") class DataParallelOptimizationPass(PassBase): """ @@ -430,7 +426,7 @@ class DataParallelOptimizationPass(PassBase): ring_id = op.attr("ring_id") grad_name = op.output_arg_names[0] grad_var = block.var(grad_name) - grad_numel = numel(grad_var) + grad_numel = get_var_numel(grad_var) if cur_group.acceptable(grad_var, ring_id): assert grad_name not in grouped_grad_names @@ -594,7 +590,7 @@ class GradientsGroup: return True if ring_id != self.ring_id: return False - if numel(grad_var) + self.numel > self.max_group_size: + if get_var_numel(grad_var) + self.numel > self.max_group_size: return False if grad_var.dtype != self.dtype: return False @@ -605,7 +601,7 @@ class GradientsGroup: self.gradients.append(grad_var) self.ring_id = ring_id self.dtype = grad_var.dtype - self.numel += numel(grad_var) + self.numel += get_var_numel(grad_var) # remove auxiliary ops in non-fuse dp allreduce self.remove_allreduce_op_indices.append(i) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index a80af73c2bc..49583e3ae66 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -13,15 +13,20 @@ # limitations under the License. from functools import reduce +import logging + +import paddle from paddle.framework import core +from paddle.fluid.framework import default_main_program, default_startup_program from paddle.fluid import unique_name from .pass_base import PassBase, register_pass +from paddle.distributed.auto_parallel.process_group import new_process_group +from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size from paddle.distributed.fleet.meta_optimizers.common import ( is_backward_op, is_optimizer_op, ) -from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.operators.common import ( is_parameter_related, is_data_parallel_reduce_op, @@ -30,6 +35,8 @@ from paddle.distributed.auto_parallel.utils import ( _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, + get_var_numel, + get_logger, ) OpRole = core.op_proto_and_checker_maker.OpRole @@ -57,6 +64,8 @@ _supported_optimizer_type = [ "sgd", ] +_logger = get_logger(logging.INFO) + def _is_reshard_op(op): return op.desc.has_attr( @@ -76,6 +85,9 @@ class ShardingPass(PassBase): self.set_attr("stage", None) self.set_attr("sharding_degree", None) # for parallelizer self.set_attr("degree", None) # for parallelizer_v2 + self.set_attr("overlap_grad_comm", None) + self.set_attr("bucket_size_numel", None) + self.set_attr("partition_algor", None) self.set_attr("params_grads", []) self.set_attr("global_rank", -1) self.dp_groups = set() @@ -109,6 +121,12 @@ class ShardingPass(PassBase): "global_rank" ) < 0: return False + if self.get_attr("overlap_grad_comm") is None: + return False + if self.get_attr("bucket_size_numel") is None: + return False + if self.get_attr("partition_algor") is None: + return False return True @@ -122,22 +140,35 @@ class ShardingPass(PassBase): ) self.stage = int(self.get_attr("stage")) self.global_rank = int(self.get_attr("global_rank")) + self.overlap_grad_comm = self.get_attr("overlap_grad_comm") + self.bucket_size_numel = int(self.get_attr("bucket_size_numel")) + self.partition_algor = self.get_attr("partition_algor") params_grads = self.get_attr("params_grads") main_block, startup_block = ( main_program.global_block(), startup_program.global_block(), ) + # NOTE Multi / Sub-Block Support + # we assume that only parameter are present and partitioned in main_block, + # there is NO new param in sub_block, and all params in sub_block follows the same + # partition as main_block. the above contraint fullfill the 3 most common use-cases in Paddle sub_block: + # 1. subblock for lr scheduler + # 2. sub-block uses the same or partial network of main-block, e.g. GPT3 generation model + # 3. sub-block used for double backward + self._build_sharding_groups(main_block, params_grads) - self._shard_optimizer(main_block, startup_block, params_grads, context) - self._shard_gradient_synchronization(main_block) - self._shard_parameter(main_block, startup_block) + for block in main_program.blocks: + self._shard_optimizer(block, startup_block, params_grads, context) + self._shard_gradient_synchronization(block) + self._shard_parameter(block, startup_block) context.set_attr("params_grads", self.shared_params_grads) + self._optimization_pass(main_program, startup_program) def _build_sharding_groups(self, main_block, params_grads): self._collective_data_parallel_groups(main_block) - self._build_sharding_infos(params_grads) + self._build_sharding_infos(main_block, params_grads) def _collective_data_parallel_groups(self, main_block): for op in main_block.ops: @@ -162,8 +193,14 @@ class ShardingPass(PassBase): ) ) - def _build_sharding_infos(self, params_grads): + def _build_sharding_infos(self, main_block, params_grads): + + # order params + params_grads = re_order_program( + main_block, params_grads, self._dist_context + ) + # partition for dp_group in self.dp_groups: assert ( @@ -204,7 +241,10 @@ class ShardingPass(PassBase): self._dist_context._sharding_group = sharding_group # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group sharding_info = ShardingInfo( - sharding_group, self.global_rank, params_grads + sharding_group, + self.global_rank, + params_grads, + self.partition_algor, ) self.sharding_infos.append(sharding_info) for param in sharding_info.params: @@ -317,7 +357,7 @@ class ShardingPass(PassBase): reserved_vars.append(input_name) op.desc.set_input("X", reserved_vars) - sum_op_output = op.desc.output_arg_names()[0] + sum_op_output = op.output_arg_names[0] for i, sharding_info in enumerate(self.sharding_infos): new_op = main_block._insert_op( idx + i + 1, @@ -401,7 +441,7 @@ class ShardingPass(PassBase): def _insert_optimizer_broadcasts(self, main_block, startup_block): - if self.stage > 2: + if self.stage > 2 or self.bucket_size_numel > 1: return for sharding_info in self.sharding_infos: @@ -508,7 +548,7 @@ class ShardingPass(PassBase): if is_optimizer_op(op): continue - for input_name in op.desc.input_arg_names(): + for input_name in op.input_arg_names: # NOTE hack for embedding op when AMP 02-3 # paddle amp force embedding (lookup table) to be run on fp32 if _is_param_fp16_cast_op( @@ -601,6 +641,24 @@ class ShardingPass(PassBase): main_block._sync_with_cpp() startup_block._sync_with_cpp() + def _optimization_pass(self, main_program, startup_program): + + with paddle.static.program_guard(main_program, startup_program): + if self.overlap_grad_comm: + _fuse_overlap_gradient_comm() + # TODO support multiple sub_blocks + if self.bucket_size_numel > 1: + if self.stage == 2: + _fuse_overlap_parameter_comm_stage_two( + self.sharding_infos, + self._dist_context, + fuse_size=self.bucket_size_numel, + ) + elif self.stage == 3: + _fuse_overlap_parameter_comm_stage_three( + self.sharding_infos, fuse_size=self.bucket_size_numel + ) + def _insert_init_and_broadcast_op( block, @@ -723,7 +781,7 @@ def _is_param_grad_fp32_cast_op(block, op): block, op, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32 ): return False - output_name = op.desc.output_arg_names()[0] + output_name = op.output_arg_names[0] base_name = output_name[: output_name.find("@")] if not block.has_var(base_name): return False @@ -736,7 +794,7 @@ def _is_param_fp16_cast_op(block, op, params): return False if not _is_desired_cast_op(block, op): return False - input_name = op.desc.input_arg_names()[0] + input_name = op.input_arg_names[0] if input_name not in params: return False return True @@ -750,10 +808,10 @@ def _is_desired_cast_op( ): if op.type != "cast": return False - assert len(op.desc.input_arg_names()) == 1 - assert len(op.desc.output_arg_names()) == 1 - input_var = block.var(op.desc.input_arg_names()[0]) - output_var = block.var(op.desc.output_arg_names()[0]) + assert len(op.input_arg_names) == 1 + assert len(op.output_arg_names) == 1 + input_var = block.var(op.input_arg_names[0]) + output_var = block.var(op.output_arg_names[0]) if input_var.dtype != src_var_type or output_var.dtype != dst_var_type: return False @@ -828,10 +886,36 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): return dp_group -def shard_parameters(params, group_size): - # TODO(JZ-LIANG) support multiple partition methods - # method1: greedy even but unorder - # method2: roughly even with oreder +def partition_by_use_order(params, group_size): + """ + shard the continouse param into same rank and divide the forward&backward computation into segement, + which will favor the fuse pass in later. + + we assume that the params is already sorted by utilization order. + """ + mapping = {} + total_param_mem = 0.0 + param2mem = [] + for param in params: + mem = get_var_size(param) + total_param_mem += mem + param2mem.append((param, mem)) + mapping = {x: [] for x in range(group_size)} + cur_rank = 0 + mem_accu = 0.0 + for param, mem in param2mem: + if mem_accu > total_param_mem * 1.0 * (cur_rank + 1) / group_size: + cur_rank += 1 + mapping[cur_rank].append(param) + mem_accu += mem + + return mapping + + +def partition_by_greedy_even(params, group_size): + """ + use greedy alogrithm to partition parameter as even as possible. + """ mapping = {} for rank_ in range(group_size): mapping[rank_] = [] @@ -850,8 +934,212 @@ def shard_parameters(params, group_size): return mapping -class ShardingInfo: - def __init__(self, group, rank, params_grads): +def partition_parameters(params, group_size, algor="greedy_even"): + if algor == "greedy_even": + rank_to_params = partition_by_greedy_even(params, group_size) + else: + rank_to_params = partition_by_use_order(params, group_size) + + _logger.info("Sharding Parameter Partition:") + for k, v in rank_to_params.items(): + _logger.info( + "Rank:{}, Parameter Size:{} MB.".format( + k, sum([get_var_size(var) for var in v]) + ) + ) + _logger.info("Params in this rank: {}.".format([var.name for var in v])) + + return rank_to_params + + +def re_order_program(block, param_grads, dist_context): + + # record order + pname_to_pg_pairs = {} + for p, g in param_grads: + pname_to_pg_pairs[p.name] = (p, g) + + use_order = [] + for op in block.ops: + for input_name in op.input_arg_names: + if (input_name in pname_to_pg_pairs) and ( + input_name not in use_order + ): + use_order.append(input_name) + if len(use_order) == len(pname_to_pg_pairs): + break + + # reorder optimzier + last_op = block.ops[-1] + pname_to_op = {} + num_ops = len(block.ops) + remove_op_indices = [] + # TODO support case when optimizer is not the last op + if is_optimizer_op(last_op) and last_op.type in _supported_optimizer_type: + # record optimizer + for idx, op in reversed(list(enumerate(block.ops))): + if op.type not in _supported_optimizer_type: + break + assert len(op.input("Param")) == 1 + pname_to_op[op.input("Param")[0]] = op + remove_op_indices.append(idx) + assert len(use_order) == len(pname_to_op) + + # append new opts + for pname in use_order: + new_op = block.append_op(type='nop') + new_op.desc.copy_from(pname_to_op[pname].desc) + dist_context.set_op_dist_attr_for_program( + new_op, + dist_context.get_op_dist_attr_for_program(pname_to_op[pname]), + ) + + # remove old opts + for idx in remove_op_indices: + block._remove_op(idx, sync=False) + + block._sync_with_cpp() + assert len(block.ops) == num_ops + + # TODO reorder gradient clip order + _logger.info( + "Sharding the Order of param being used: {}.".format(use_order) + ) + return [pname_to_pg_pairs[p] for p in use_order] + + +def group_param(sharding_info, fuse_size): + """ + param are group by: + rank id + fuse_size + dtype + """ + group_to_param_map = {} + param_to_group_map = {} + bucket = [] + cur_group = ParameterGroup(fuse_size) + for param in sharding_info.params: + rank = sharding_info.get_var_rank(param.name) + + if cur_group.acceptable(param, rank): + cur_group.collect(param, rank) + else: + cur_group = ParameterGroup(fuse_size) + cur_group.collect(param, rank) + + if cur_group in group_to_param_map: + group_to_param_map[cur_group].append(param.name) + else: + group_to_param_map[cur_group] = [param.name] + + param_to_group_map[param.name] = cur_group + + return group_to_param_map, param_to_group_map + + +def _fuse_overlap_gradient_comm(): + pass + + +def _fuse_overlap_parameter_comm_stage_two( + sharding_infos, dist_context, fuse_size +): + + assert ( + len(sharding_infos) == 1 + ), "fuse overlap optimization only support one sharding group right now, but got [{}].".format( + len(sharding_infos) + ) + sharding_info = sharding_infos[0] + + main_block = default_main_program().global_block() + startup_block = default_startup_program().global_block() + + group_to_param_map, param_to_group_map = group_param( + sharding_info, fuse_size + ) + _logger.info("Sharding Stage2 Optimization:") + _logger.info( + "Bucket size is [{}], [{}] Parameters are fused into [{}] Buckets".format( + fuse_size, + len(param_to_group_map.keys()), + len(group_to_param_map.keys()), + ) + ) + for i, group in enumerate(group_to_param_map.keys()): + + assert len(group) >= 1 + if len(group) > 1: + coalesce_var_name = unique_name.generate( + 'coalecse_param_{}'.format(i) + ) + startup_block.create_var( + name=coalesce_var_name, + dtype=group.dtype, + persistable=True, + stop_gradient=True, + ) + group.coalesce_var = main_block.create_var( + name=coalesce_var_name, + dtype=group.dtype, + persistable=True, + stop_gradient=True, + ) + startup_block.append_op( + type="coalesce_tensor", + inputs={"Input": group.params}, + outputs={ + "Output": group.params, + "FusedOutput": group.coalesce_var, + }, + attrs={ + "copy_data": True, + "use_align": True, + "dtype": group.dtype, + OP_ROLE_KEY: OpRole.Forward, + }, + ) + else: + group.coalesce_var = group.params[0] + _logger.info( + "Bucket[{}] size [{}]MB : {}".format( + i, + sum([get_var_size(p) for p in group.params]), + [p.name for p in group.params], + ) + ) + + # TODO Overlap broadcast with opt and next forward + new_op = main_block.append_op( + type='c_broadcast', + inputs={'X': group.coalesce_var}, + outputs={'Out': group.coalesce_var}, + attrs={ + 'ring_id': sharding_info.group.id, + 'root': group.rank, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, + }, + ) + + # NOTE the current dist context lack the presentation for bucket tensor which + # composes many tensor with different dims_mapping. we assign a fake dist attr + # for it currently. + + +def _fuse_overlap_parameter_comm_stage_three(sharding_infos, fuse_size): + + assert ( + len(sharding_infos) == 1 + ), "fuse overlap optimization only support one sharding group right now, but got [{}].".format( + len(sharding_infos) + ) + sharding_info = sharding_infos[0] + + +class ShardingInfo(object): + def __init__(self, group, rank, params_grads, partition_algor): self.group = group self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads]) assert len(self.params_grads) == len( @@ -863,8 +1151,11 @@ class ShardingInfo: self.group_size = group.nranks self.global_rank = rank self.local_rank = group.ranks.index(self.global_rank) + self.partition_algor = partition_algor # rank in below mapping are local rank in this sharding group - self.rank_to_params = shard_parameters(self.params, self.group_size) + self.rank_to_params = partition_parameters( + self.params, self.group_size, self.partition_algor + ) # include fp32 and fp16 param self.param_to_rank = dict() self._map_param_to_rank() @@ -899,7 +1190,7 @@ class ShardingInfo: for op in block.ops: if is_optimizer_op(op): continue - for input_name in op.desc.input_arg_names(): + for input_name in op.input_arg_names: if input_name in self.param_names: param_usage[input_name] += 1 @@ -927,3 +1218,34 @@ class ShardingInfo: if param_name not in self.params_grads: raise ValueError('param[{}] not in params_grads'.format(param_name)) return self.params_grads.get(param_name, None) + + +class ParameterGroup(object): + def __init__(self, max_size): + self.max_siez = max_size + self.dtype = None + self.rank = -1 + self.numel = 0 + self.params = [] + self.coalesce_var = None + + def acceptable(self, param, rank): + if self.numel == 0: + return True + else: + if param.dtype != self.dtype: + return False + if rank != self.rank: + return False + if self.numel + get_var_numel(param) > self.max_siez: + return False + return True + + def collect(self, param, rank): + self.dtype = param.dtype + self.rank = rank + self.numel += get_var_numel(param) + self.params.append(param) + + def __len__(self): + return len(self.params) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index cbe899a7e6e..58641a1ec3a 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -44,7 +44,9 @@ class TestStrategy(unittest.TestCase): self.assertEqual(sharding.enable, False) self.assertEqual(sharding.stage, 1) self.assertEqual(sharding.degree, 8) - self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0) + self.assertAlmostEqual(sharding.overlap_grad_comm, False) + self.assertAlmostEqual(sharding.bucket_size_numel, -1) + self.assertAlmostEqual(sharding.partition_algor, "greedy_even") self.assertEqual(sharding.enable_tuning, False) self.assertEqual(sharding.tuning_range, []) -- GitLab