未验证 提交 f769f850 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] performance improvement for Sharding-DP hybrid parallelism (#46180)

* remove no need grad allreduce communication when sharding-dp

* remove no need grad allreduce communication when sharding-dp

* bugfix

* bugfix

* bugfix
上级 b1e82031
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from functools import reduce from functools import reduce
from collections import OrderedDict, defaultdict from collections import OrderedDict
import numpy as np import numpy as np
import paddle import paddle
...@@ -22,12 +22,15 @@ from paddle.fluid import unique_name ...@@ -22,12 +22,15 @@ from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.operators.common import is_parameter_related from paddle.distributed.auto_parallel.operators.common import is_parameter_related, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read'] _skip_ops = [
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign', "send_v2"
]
# update here to support new optimizers # update here to support new optimizers
_supported_optimizer_type = [ _supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum", "adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
...@@ -393,7 +396,7 @@ class ShardingPass(PassBase): ...@@ -393,7 +396,7 @@ class ShardingPass(PassBase):
dp_ring_ids = [group.id for group in self.dp_groups] dp_ring_ids = [group.id for group in self.dp_groups]
for idx, op in reversed(list(enumerate(main_block.ops))): for idx, op in reversed(list(enumerate(main_block.ops))):
if _is_param_grad_allreduce_op(op, main_block, dp_ring_ids): if is_data_parallel_reduce_op(op):
input_name = op.input_arg_names[0] input_name = op.input_arg_names[0]
base_name = _get_base_name_from_grad_name(input_name) base_name = _get_base_name_from_grad_name(input_name)
sharding_info = self.varname_to_sharding_info[base_name] sharding_info = self.varname_to_sharding_info[base_name]
...@@ -401,7 +404,8 @@ class ShardingPass(PassBase): ...@@ -401,7 +404,8 @@ class ShardingPass(PassBase):
sharding_info.group.id, sharding_info.group.id,
sharding_info.get_var_rank(base_name), sharding_info.get_var_rank(base_name),
self._dist_context) self._dist_context)
if not self.partial_sharding: if not self.partial_sharding or not sharding_info.is_in_local_shard(
base_name):
main_block._remove_op(idx + 1, sync=False) main_block._remove_op(idx + 1, sync=False)
else: else:
op._set_attr("ring_id", self.outer_dp_group.id) op._set_attr("ring_id", self.outer_dp_group.id)
...@@ -439,7 +443,10 @@ class ShardingPass(PassBase): ...@@ -439,7 +443,10 @@ class ShardingPass(PassBase):
continue continue
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if op.type == "cast": # NOTE hack for embedding op when AMP 02-3
# paddle amp force embedding (lookup table) to be run on fp32
if _is_param_fp16_cast_op(main_block, op,
sharding_info.param_names):
continue continue
if input_name not in need_broadcast_vars: if input_name not in need_broadcast_vars:
continue continue
...@@ -646,24 +653,6 @@ def _get_base_name_from_grad_name(grad_name): ...@@ -646,24 +653,6 @@ def _get_base_name_from_grad_name(grad_name):
return base_name return base_name
def _is_param_grad_allreduce_op(op, block, dp_ring_ids):
if not is_backward_op(op):
return False
if op.type != "c_allreduce_sum":
return False
if op.attr('ring_id') not in dp_ring_ids:
return False
output_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(output_name)
if not block.has_var(base_name):
return False
return block.var(base_name).is_parameter
def _is_param_grad_sum_op(op, block): def _is_param_grad_sum_op(op, block):
if not is_backward_op(op): if not is_backward_op(op):
...@@ -756,9 +745,14 @@ class ShardingInfo(object): ...@@ -756,9 +745,14 @@ class ShardingInfo(object):
return self.param_to_rank[varname] return self.param_to_rank[varname]
return -1 return -1
# determine fp32 and fp16 (cast) param
def is_in_local_shard(self, param_name): def is_in_local_shard(self, param_name):
return self.get_var_rank(param_name) == self.local_rank return self.get_var_rank(param_name) == self.local_rank
# NOTE the follwo logic is designed for supporting AMP O1 when
# the param would be cast to fp16 before used for caculation.
# and sharding should only broadcast the casted fp16 param
# instead of the origin fp32 version param.
def get_broadcast_vars_and_param_usage(self, block): def get_broadcast_vars_and_param_usage(self, block):
broadcast_vars = set([]) broadcast_vars = set([])
fp16_params = set([]) fp16_params = set([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册