未验证 提交 f769f850 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] performance improvement for Sharding-DP hybrid parallelism (#46180)

* remove no need grad allreduce communication when sharding-dp

* remove no need grad allreduce communication when sharding-dp

* bugfix

* bugfix

* bugfix
上级 b1e82031
......@@ -13,7 +13,7 @@
# limitations under the License.
from functools import reduce
from collections import OrderedDict, defaultdict
from collections import OrderedDict
import numpy as np
import paddle
......@@ -22,12 +22,15 @@ from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.operators.common import is_parameter_related
from paddle.distributed.auto_parallel.operators.common import is_parameter_related, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read']
_skip_ops = [
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign', "send_v2"
]
# update here to support new optimizers
_supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
......@@ -393,7 +396,7 @@ class ShardingPass(PassBase):
dp_ring_ids = [group.id for group in self.dp_groups]
for idx, op in reversed(list(enumerate(main_block.ops))):
if _is_param_grad_allreduce_op(op, main_block, dp_ring_ids):
if is_data_parallel_reduce_op(op):
input_name = op.input_arg_names[0]
base_name = _get_base_name_from_grad_name(input_name)
sharding_info = self.varname_to_sharding_info[base_name]
......@@ -401,7 +404,8 @@ class ShardingPass(PassBase):
sharding_info.group.id,
sharding_info.get_var_rank(base_name),
self._dist_context)
if not self.partial_sharding:
if not self.partial_sharding or not sharding_info.is_in_local_shard(
base_name):
main_block._remove_op(idx + 1, sync=False)
else:
op._set_attr("ring_id", self.outer_dp_group.id)
......@@ -439,7 +443,10 @@ class ShardingPass(PassBase):
continue
for input_name in op.desc.input_arg_names():
if op.type == "cast":
# NOTE hack for embedding op when AMP 02-3
# paddle amp force embedding (lookup table) to be run on fp32
if _is_param_fp16_cast_op(main_block, op,
sharding_info.param_names):
continue
if input_name not in need_broadcast_vars:
continue
......@@ -646,24 +653,6 @@ def _get_base_name_from_grad_name(grad_name):
return base_name
def _is_param_grad_allreduce_op(op, block, dp_ring_ids):
if not is_backward_op(op):
return False
if op.type != "c_allreduce_sum":
return False
if op.attr('ring_id') not in dp_ring_ids:
return False
output_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(output_name)
if not block.has_var(base_name):
return False
return block.var(base_name).is_parameter
def _is_param_grad_sum_op(op, block):
if not is_backward_op(op):
......@@ -756,9 +745,14 @@ class ShardingInfo(object):
return self.param_to_rank[varname]
return -1
# determine fp32 and fp16 (cast) param
def is_in_local_shard(self, param_name):
return self.get_var_rank(param_name) == self.local_rank
# NOTE the follwo logic is designed for supporting AMP O1 when
# the param would be cast to fp16 before used for caculation.
# and sharding should only broadcast the casted fp16 param
# instead of the origin fp32 version param.
def get_broadcast_vars_and_param_usage(self, block):
broadcast_vars = set([])
fp16_params = set([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册