未验证 提交 926c4bd2 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParalle] balancing the calculation of global_norm in data parallel (#49510)

* [AutoParalle] balancing the calculation of global_norm in data parallel

* fix unittest

* update cond pure_data_parallel
上级 c549c6b9
...@@ -21,8 +21,14 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole ...@@ -21,8 +21,14 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid.executor import _is_enable_standalone_executor from paddle.fluid.executor import _is_enable_standalone_executor
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr
from ..auto_parallel.operators.common import SyncMode from ..auto_parallel.operators.common import (
from ..auto_parallel.process_group import get_world_process_group SyncMode,
is_data_parallel_reduce_op,
)
from ..auto_parallel.process_group import (
get_all_process_groups,
get_world_process_group,
)
from ..auto_parallel.process_mesh import ProcessMesh from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.reshard import Resharder from ..auto_parallel.reshard import Resharder
from ..auto_parallel.utils import ( from ..auto_parallel.utils import (
...@@ -31,6 +37,7 @@ from ..auto_parallel.utils import ( ...@@ -31,6 +37,7 @@ from ..auto_parallel.utils import (
is_gradient_clip_op, is_gradient_clip_op,
is_optimize_op, is_optimize_op,
) )
from .auto_parallel_sharding import ShardingPass
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
...@@ -145,46 +152,65 @@ def _is_about_global_norm( ...@@ -145,46 +152,65 @@ def _is_about_global_norm(
class ClipHelper: class ClipHelper:
def __init__(self, params_grads, rank_id, block, dist_context): def __init__(
self, params_grads, rank_id, block, dist_context, pass_context
):
params, _ = zip(*params_grads) params, _ = zip(*params_grads)
self.params = list(params) self.params = list(params)
self.params_name = [p.name for p in self.params] self.params_name = [p.name for p in self.params]
self.rank_id = rank_id self.rank_id = rank_id
self.block = block self.block = block
self.dist_context = dist_context self.dist_context = dist_context
self.pass_context = pass_context
self.sharding_group = None self.sharding_group = None
self.world_ranks = get_world_process_group().ranks self.world_ranks = get_world_process_group().ranks
if hasattr(dist_context, '_sharding_group'): if hasattr(dist_context, '_sharding_group'):
self.sharding_group = dist_context._sharding_group self.sharding_group = dist_context._sharding_group
def _is_calcuate_norm(self, name): self.world_nranks = len(self.world_ranks)
if not self._is_local_param(name): self.pure_data_parallel = self._is_pure_data_parallel()
return False, [] self.rank_to_params = self._partition_parameters(params)
param = self.params[self.params_name.index(name)] def is_calcuate_norm(self, name):
dist_attr = self._get_dist_attr(name) """
topology = dist_attr.process_mesh.shape whether the param_name@GRAD paticipate in the calculation of global_norm
processes = dist_attr.process_mesh.process_ids """
dims_mapping = dist_attr.dims_mapping if not self.is_local_param(name):
return _is_about_global_norm( return False
self.rank_id,
param.shape,
topology,
processes,
dims_mapping,
self.sharding_group,
)
def _get_dist_attr(self, name): param = self.params[self.params_name.index(name)]
var = self.block.vars[name] if not self.pure_data_parallel:
return self.dist_context.get_tensor_dist_attr_for_program(var) dist_attr = self._get_dist_attr(name)
topology = dist_attr.process_mesh.shape
processes = dist_attr.process_mesh.process_ids
dims_mapping = dist_attr.dims_mapping
return _is_about_global_norm(
self.rank_id,
param.shape,
topology,
processes,
dims_mapping,
self.sharding_group,
)
else:
return param.name in self.rank_to_params[self.rank_id]
def _is_local_param(self, name): def is_local_param(self, name):
"""
whether the param_name is updated with opt in cur_rank
"""
if name not in self.params_name: if name not in self.params_name:
return False return False
return True return True
def _is_local_var(self, name): def _get_dist_attr(self, name):
var = self.block.vars[name]
return self.dist_context.get_tensor_dist_attr_for_program(var)
def is_local_var_with_dist_attr(self, name):
"""
whether the var_name is belong to cur_rank
"""
dist_attr = self._get_dist_attr(name) dist_attr = self._get_dist_attr(name)
assert dist_attr is not None assert dist_attr is not None
return self.rank_id in dist_attr.process_mesh.process_ids return self.rank_id in dist_attr.process_mesh.process_ids
...@@ -212,6 +238,50 @@ class ClipHelper: ...@@ -212,6 +238,50 @@ class ClipHelper:
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
self.dist_context.set_op_dist_attr_for_program(op, op_dist_attr) self.dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
def _is_pure_data_parallel(self):
for applied_pass in self.pass_context.passes:
if isinstance(applied_pass, ShardingPass):
return False
groups = get_all_process_groups()
for g in groups:
if g.nranks != self.world_nranks:
return False
for op in self.block.ops:
if op.type in [
"c_reduce_sum",
"c_allreduce_sum",
] and not is_data_parallel_reduce_op(op):
return False
return True
def _partition_parameters(self, params):
"""
build rank_id_to_params by the param's numel
to guarantee params in every rank of dp_group as even as possible.
"""
mapping = {}
if not self.pure_data_parallel:
for rank_ in range(self.world_nranks):
mapping[rank_] = [p.name for p in params]
else:
for rank_ in range(self.world_nranks):
mapping[rank_] = []
sizes = [0] * self.world_nranks
for param in params:
rank = sizes.index(min(sizes))
mapping[rank].append(param.name)
numel = reduce(lambda x, y: x * y, param.shape)
assert (
numel > 0
), "param [{}] should larger than 0, but it is [{}]".format(
param.name, numel
)
sizes[rank] += numel
return mapping
@register_pass("auto_parallel_grad_clip") @register_pass("auto_parallel_grad_clip")
class ClipGradByGloblNormPass(PassBase): class ClipGradByGloblNormPass(PassBase):
...@@ -248,14 +318,13 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -248,14 +318,13 @@ class ClipGradByGloblNormPass(PassBase):
# dist_params_grads = _get_params_grads(block) # dist_params_grads = _get_params_grads(block)
self.clip_helper = ClipHelper( self.clip_helper = ClipHelper(
dist_params_grads, rank_id, block, dist_context dist_params_grads, rank_id, block, dist_context, context
) )
self._remove_no_need_ops_vars(block) self._remove_no_need_ops_vars(block)
def _remove_no_need_ops_vars(self, block): def _remove_no_need_ops_vars(self, block):
removed_op_out_type = [ removed_op_out_type = [
'clip_by_norm',
'squared_l2_norm', 'squared_l2_norm',
'square', 'square',
'reduce_sum', 'reduce_sum',
...@@ -267,31 +336,40 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -267,31 +336,40 @@ class ClipGradByGloblNormPass(PassBase):
if not is_gradient_clip_op(op): if not is_gradient_clip_op(op):
continue continue
if op.type in removed_op_out_type: if op.type == 'clip_by_norm':
# remove 'clip_by_norm' op if the param is not updated with opt in current rank
input_name = op.input("X")[0] input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1: if input_name.find("@GRAD") != -1:
# 'clip_by_norm', 'squared_l2_norm', 'square'
param_name = input_name[: input_name.find("@GRAD")] param_name = input_name[: input_name.find("@GRAD")]
is_local = self.clip_helper._is_local_param(param_name) is_local = self.clip_helper.is_local_param(param_name)
is_calculate = self.clip_helper._is_calcuate_norm( if not is_local:
param_name removed_op_idx.add(idx)
) removed_tmp_var.update(set(op.output_arg_names))
if not is_local or (
not is_calculate and op.type != 'clip_by_norm' elif op.type in removed_op_out_type:
): input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1:
# remove 'squared_l2_norm' and 'square' ops,
# if the param@GRAD in cur_rank does not participate in the calculation of global_norm
param_name = input_name[: input_name.find("@GRAD")]
is_local = self.clip_helper.is_local_param(param_name)
is_calculate = self.clip_helper.is_calcuate_norm(param_name)
if not is_local or not is_calculate:
removed_op_idx.add(idx) removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names)) removed_tmp_var.update(set(op.output_arg_names))
else: else:
# 'reduce_sum' # 'reduce_sum' must be behind 'square'
if idx - 1 in removed_op_idx: if idx - 1 in removed_op_idx:
removed_op_idx.add(idx) removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names)) removed_tmp_var.update(set(op.output_arg_names))
elif op.type == 'elementwise_mul': elif op.type == 'elementwise_mul':
# 'elementwise_mul' scale the param@GRAD with global_norm
# remove 'elementwise_mul' op if the param is not updated with opt in current rank
input_name = op.input("X")[0] input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1: if input_name.find("@GRAD") != -1:
param_name = input_name[: input_name.find("@GRAD")] param_name = input_name[: input_name.find("@GRAD")]
is_local = self.clip_helper._is_local_param(param_name) is_local = self.clip_helper.is_local_param(param_name)
if not is_local: if not is_local:
removed_op_idx.add(idx) removed_op_idx.add(idx)
if block.ops[idx - 1].type == 'cast': if block.ops[idx - 1].type == 'cast':
...@@ -301,11 +379,14 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -301,11 +379,14 @@ class ClipGradByGloblNormPass(PassBase):
) )
elif op.type == 'sum': elif op.type == 'sum':
# 'sum' op is used to calculate global_norm, and need to filter inputs which is not in cur_rank
reserved_vars = [] reserved_vars = []
for input_name in op.input_arg_names: for input_name in op.input_arg_names:
if ( if (
input_name not in removed_tmp_var input_name not in removed_tmp_var
and self.clip_helper._is_local_var(input_name) and self.clip_helper.is_local_var_with_dist_attr(
input_name
)
): ):
reserved_vars.append(input_name) reserved_vars.append(input_name)
if not reserved_vars: if not reserved_vars:
......
...@@ -31,6 +31,7 @@ def apply_pass(use_sharding=False): ...@@ -31,6 +31,7 @@ def apply_pass(use_sharding=False):
strategy.reinit = True strategy.reinit = True
if use_sharding: if use_sharding:
sharding = strategy.sharding sharding = strategy.sharding
sharding.enable = True
sharding.degree = 2 sharding.degree = 2
sharding.stage = 2 sharding.stage = 2
return strategy return strategy
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册