未验证 提交 926c4bd2 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParalle] balancing the calculation of global_norm in data parallel (#49510)

* [AutoParalle] balancing the calculation of global_norm in data parallel

* fix unittest

* update cond pure_data_parallel
上级 c549c6b9
......@@ -21,8 +21,14 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid.executor import _is_enable_standalone_executor
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr
from ..auto_parallel.operators.common import SyncMode
from ..auto_parallel.process_group import get_world_process_group
from ..auto_parallel.operators.common import (
SyncMode,
is_data_parallel_reduce_op,
)
from ..auto_parallel.process_group import (
get_all_process_groups,
get_world_process_group,
)
from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.reshard import Resharder
from ..auto_parallel.utils import (
......@@ -31,6 +37,7 @@ from ..auto_parallel.utils import (
is_gradient_clip_op,
is_optimize_op,
)
from .auto_parallel_sharding import ShardingPass
from .pass_base import PassBase, register_pass
......@@ -145,23 +152,34 @@ def _is_about_global_norm(
class ClipHelper:
def __init__(self, params_grads, rank_id, block, dist_context):
def __init__(
self, params_grads, rank_id, block, dist_context, pass_context
):
params, _ = zip(*params_grads)
self.params = list(params)
self.params_name = [p.name for p in self.params]
self.rank_id = rank_id
self.block = block
self.dist_context = dist_context
self.pass_context = pass_context
self.sharding_group = None
self.world_ranks = get_world_process_group().ranks
if hasattr(dist_context, '_sharding_group'):
self.sharding_group = dist_context._sharding_group
def _is_calcuate_norm(self, name):
if not self._is_local_param(name):
return False, []
self.world_nranks = len(self.world_ranks)
self.pure_data_parallel = self._is_pure_data_parallel()
self.rank_to_params = self._partition_parameters(params)
def is_calcuate_norm(self, name):
"""
whether the param_name@GRAD paticipate in the calculation of global_norm
"""
if not self.is_local_param(name):
return False
param = self.params[self.params_name.index(name)]
if not self.pure_data_parallel:
dist_attr = self._get_dist_attr(name)
topology = dist_attr.process_mesh.shape
processes = dist_attr.process_mesh.process_ids
......@@ -174,17 +192,25 @@ class ClipHelper:
dims_mapping,
self.sharding_group,
)
else:
return param.name in self.rank_to_params[self.rank_id]
def _get_dist_attr(self, name):
var = self.block.vars[name]
return self.dist_context.get_tensor_dist_attr_for_program(var)
def _is_local_param(self, name):
def is_local_param(self, name):
"""
whether the param_name is updated with opt in cur_rank
"""
if name not in self.params_name:
return False
return True
def _is_local_var(self, name):
def _get_dist_attr(self, name):
var = self.block.vars[name]
return self.dist_context.get_tensor_dist_attr_for_program(var)
def is_local_var_with_dist_attr(self, name):
"""
whether the var_name is belong to cur_rank
"""
dist_attr = self._get_dist_attr(name)
assert dist_attr is not None
return self.rank_id in dist_attr.process_mesh.process_ids
......@@ -212,6 +238,50 @@ class ClipHelper:
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
self.dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
def _is_pure_data_parallel(self):
for applied_pass in self.pass_context.passes:
if isinstance(applied_pass, ShardingPass):
return False
groups = get_all_process_groups()
for g in groups:
if g.nranks != self.world_nranks:
return False
for op in self.block.ops:
if op.type in [
"c_reduce_sum",
"c_allreduce_sum",
] and not is_data_parallel_reduce_op(op):
return False
return True
def _partition_parameters(self, params):
"""
build rank_id_to_params by the param's numel
to guarantee params in every rank of dp_group as even as possible.
"""
mapping = {}
if not self.pure_data_parallel:
for rank_ in range(self.world_nranks):
mapping[rank_] = [p.name for p in params]
else:
for rank_ in range(self.world_nranks):
mapping[rank_] = []
sizes = [0] * self.world_nranks
for param in params:
rank = sizes.index(min(sizes))
mapping[rank].append(param.name)
numel = reduce(lambda x, y: x * y, param.shape)
assert (
numel > 0
), "param [{}] should larger than 0, but it is [{}]".format(
param.name, numel
)
sizes[rank] += numel
return mapping
@register_pass("auto_parallel_grad_clip")
class ClipGradByGloblNormPass(PassBase):
......@@ -248,14 +318,13 @@ class ClipGradByGloblNormPass(PassBase):
# dist_params_grads = _get_params_grads(block)
self.clip_helper = ClipHelper(
dist_params_grads, rank_id, block, dist_context
dist_params_grads, rank_id, block, dist_context, context
)
self._remove_no_need_ops_vars(block)
def _remove_no_need_ops_vars(self, block):
removed_op_out_type = [
'clip_by_norm',
'squared_l2_norm',
'square',
'reduce_sum',
......@@ -267,31 +336,40 @@ class ClipGradByGloblNormPass(PassBase):
if not is_gradient_clip_op(op):
continue
if op.type in removed_op_out_type:
if op.type == 'clip_by_norm':
# remove 'clip_by_norm' op if the param is not updated with opt in current rank
input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1:
# 'clip_by_norm', 'squared_l2_norm', 'square'
param_name = input_name[: input_name.find("@GRAD")]
is_local = self.clip_helper._is_local_param(param_name)
is_calculate = self.clip_helper._is_calcuate_norm(
param_name
)
if not is_local or (
not is_calculate and op.type != 'clip_by_norm'
):
is_local = self.clip_helper.is_local_param(param_name)
if not is_local:
removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names))
elif op.type in removed_op_out_type:
input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1:
# remove 'squared_l2_norm' and 'square' ops,
# if the param@GRAD in cur_rank does not participate in the calculation of global_norm
param_name = input_name[: input_name.find("@GRAD")]
is_local = self.clip_helper.is_local_param(param_name)
is_calculate = self.clip_helper.is_calcuate_norm(param_name)
if not is_local or not is_calculate:
removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names))
else:
# 'reduce_sum'
# 'reduce_sum' must be behind 'square'
if idx - 1 in removed_op_idx:
removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names))
elif op.type == 'elementwise_mul':
# 'elementwise_mul' scale the param@GRAD with global_norm
# remove 'elementwise_mul' op if the param is not updated with opt in current rank
input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1:
param_name = input_name[: input_name.find("@GRAD")]
is_local = self.clip_helper._is_local_param(param_name)
is_local = self.clip_helper.is_local_param(param_name)
if not is_local:
removed_op_idx.add(idx)
if block.ops[idx - 1].type == 'cast':
......@@ -301,11 +379,14 @@ class ClipGradByGloblNormPass(PassBase):
)
elif op.type == 'sum':
# 'sum' op is used to calculate global_norm, and need to filter inputs which is not in cur_rank
reserved_vars = []
for input_name in op.input_arg_names:
if (
input_name not in removed_tmp_var
and self.clip_helper._is_local_var(input_name)
and self.clip_helper.is_local_var_with_dist_attr(
input_name
)
):
reserved_vars.append(input_name)
if not reserved_vars:
......
......@@ -31,6 +31,7 @@ def apply_pass(use_sharding=False):
strategy.reinit = True
if use_sharding:
sharding = strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 2
return strategy
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册