From f769f85037a704c504732bdcc33ae668b6805b8c Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 20 Sep 2022 13:41:13 +0800 Subject: [PATCH] [Auto Parallel] performance improvement for Sharding-DP hybrid parallelism (#46180) * remove no need grad allreduce communication when sharding-dp * remove no need grad allreduce communication when sharding-dp * bugfix * bugfix * bugfix --- .../passes/auto_parallel_sharding.py | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index dcc786f8ff..5840c16fc0 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import reduce -from collections import OrderedDict, defaultdict +from collections import OrderedDict import numpy as np import paddle @@ -22,12 +22,15 @@ from paddle.fluid import unique_name from .pass_base import PassBase, register_pass from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op from paddle.distributed.auto_parallel.process_group import new_process_group -from paddle.distributed.auto_parallel.operators.common import is_parameter_related +from paddle.distributed.auto_parallel.operators.common import is_parameter_related, is_data_parallel_reduce_op from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() -_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read'] +_skip_ops = [ + 'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split', + 'assign', "send_v2" +] # update here to support new optimizers _supported_optimizer_type = [ "adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum", @@ -393,7 +396,7 @@ class ShardingPass(PassBase): dp_ring_ids = [group.id for group in self.dp_groups] for idx, op in reversed(list(enumerate(main_block.ops))): - if _is_param_grad_allreduce_op(op, main_block, dp_ring_ids): + if is_data_parallel_reduce_op(op): input_name = op.input_arg_names[0] base_name = _get_base_name_from_grad_name(input_name) sharding_info = self.varname_to_sharding_info[base_name] @@ -401,7 +404,8 @@ class ShardingPass(PassBase): sharding_info.group.id, sharding_info.get_var_rank(base_name), self._dist_context) - if not self.partial_sharding: + if not self.partial_sharding or not sharding_info.is_in_local_shard( + base_name): main_block._remove_op(idx + 1, sync=False) else: op._set_attr("ring_id", self.outer_dp_group.id) @@ -439,7 +443,10 @@ class ShardingPass(PassBase): continue for input_name in op.desc.input_arg_names(): - if op.type == "cast": + # NOTE hack for embedding op when AMP 02-3 + # paddle amp force embedding (lookup table) to be run on fp32 + if _is_param_fp16_cast_op(main_block, op, + sharding_info.param_names): continue if input_name not in need_broadcast_vars: continue @@ -646,24 +653,6 @@ def _get_base_name_from_grad_name(grad_name): return base_name -def _is_param_grad_allreduce_op(op, block, dp_ring_ids): - - if not is_backward_op(op): - return False - if op.type != "c_allreduce_sum": - return False - if op.attr('ring_id') not in dp_ring_ids: - return False - - output_name = op.output_arg_names[0] - base_name = _get_base_name_from_grad_name(output_name) - - if not block.has_var(base_name): - return False - - return block.var(base_name).is_parameter - - def _is_param_grad_sum_op(op, block): if not is_backward_op(op): @@ -756,9 +745,14 @@ class ShardingInfo(object): return self.param_to_rank[varname] return -1 + # determine fp32 and fp16 (cast) param def is_in_local_shard(self, param_name): return self.get_var_rank(param_name) == self.local_rank + # NOTE the follwo logic is designed for supporting AMP O1 when + # the param would be cast to fp16 before used for caculation. + # and sharding should only broadcast the casted fp16 param + # instead of the origin fp32 version param. def get_broadcast_vars_and_param_usage(self, block): broadcast_vars = set([]) fp16_params = set([]) -- GitLab