未验证 提交 70770d0d 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Unify gradient synchronization procedure of data parallel (#44815)

上级 0e6bf744
......@@ -223,8 +223,8 @@ class Engine:
assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset."
batch_size = self._user_tuning_config["batch_size"]
dataset = self._user_tuning_config["dataset"]
dataset.dp_world_size = self._dp_world_size
dataset.dp_rank = self._dp_rank
dataset.dp_world_size = self._input_split_size
dataset.dp_rank = self._input_split_rank
from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner(self._user_tuning_config,
......@@ -262,7 +262,7 @@ class Engine:
if var.name in block.vars:
feed_list.append(block.vars[var.name])
self._dp_world_size, self._dp_rank = self._get_data_parallel_info(
self._input_split_size, self._input_split_rank = self._get_input_split_info(
feed_list[0], self._dist_contexts[mode])
def _parallel(self, mode, all_ranks):
......@@ -554,8 +554,8 @@ class Engine:
batch_size,
epochs,
steps_per_epoch,
data_parallel_world_size=self._dp_world_size,
data_parallel_rank=self._dp_rank)
data_parallel_world_size=self._input_split_size,
data_parallel_rank=self._input_split_rank)
# move read op from the end of program to the start of program
new_op_size = len(dist_main_block.ops)
......@@ -615,8 +615,8 @@ class Engine:
fetches = dict(inner_fetch, **usr_fetch)
return list(fetches.keys()), fetches
def _get_data_parallel_info(self, var, dist_context):
# get data parallel world size and current data parallel rank
def _get_input_split_info(self, var, dist_context):
# deduce how the input data is split among the cluster
from .utils import _get_comm_group, _get_corresponding_rank
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
......
......@@ -13,7 +13,11 @@
# limitations under the License
import abc
import paddle
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..dist_attribute import OperatorDistributedAttribute
from ..utils import _get_comm_group, _get_corresponding_rank
from ..process_group import new_process_group
_g_distributed_operator_impl_containers = {}
......@@ -24,6 +28,16 @@ _g_elementwise_ops = [
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
class ParallelMode():
"""
the parallel mode for communication or auxiliary operator
"""
DataParallel = "auto_parallel/data_parallel"
ModelParallel = "auto_parallel/model_parallel"
PipelineParalel = "auto_parallel/pipeline_paralel"
MoEParallel = "auto_parallel/moe_parallel"
def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops:
return True
......@@ -303,3 +317,121 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
new_op.output(output_name)[0], ref_tensor_dist_attr)
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
"""
deduce the data parallel communication group for current operator.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
act_grad_names (list): list of input activation grads variable name to the current operator.
out_grad_names (list): list of the output parameter's grads variable name of the current operator.
rank (int): global ranks index for current process.
"""
dp_group = None
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
process_mesh = op_dist_attr.process_mesh
mesh_shape = process_mesh.topology
# FIXME Hack for Pipeline Parallelism where the current operator
# not belong to the mesh the current rank belong to.
if rank not in process_mesh.processes:
rank = _get_corresponding_rank(dist_ctx, process_mesh, rank)
for var_name in act_grad_names:
var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name)
# consider that the variable's shape is None
# TODO utilize the batch_dim attr instead of "0" in future
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank)
dp_group = new_process_group(group_ranks)
break
return dp_group
def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
"""
insert the allreudce and scale ops for gradients of model
parameters for operator in data parallelism.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
allreduce_var_names (list): list of the parameter's grads variable name in the current operator output.
"""
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
process_mesh = op_dist_attr.process_mesh
dist_op_context = dist_ctx.dist_op_context
main_block = dist_op_context.work_block
dp_degree = len(dp_group.ranks)
for var_name in allreduce_var_names:
added_ops = []
grad_var = main_block.var(var_name)
allreduce_op = main_block.append_op(type='c_allreduce_sum',
inputs={'X': [grad_var]},
outputs={'Out': [grad_var]},
attrs={
'ring_id': dp_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
allreduce_op._set_attr('op_namescope',
str('/') + ParallelMode.DataParallel)
added_ops.append(allreduce_op)
if dist_ctx.gradient_scale:
scale_op = main_block.append_op(type='scale',
inputs={'X': grad_var},
outputs={'Out': grad_var},
attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
scale_op._set_attr('op_namescope',
str('/') + ParallelMode.DataParallel)
added_ops.append(scale_op)
dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name)
assert dims_mapping is not None, "Unexception: dims_mapping of output [{}] of op [{}] is None".format(
grad_var.name, op_dist_attr.op_type)
# NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
for new_op in added_ops:
new_op_attr = OperatorDistributedAttribute()
new_op_attr.process_mesh = process_mesh
new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr)
def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names,
rank):
"""
conduct the allreudce and scaling(dp size)for gradients of model
parameters for operator in data parallelism.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
act_grad_names (list): list of input activation grads variable name to the current operator.
out_grad_names (list): list of the output parameter's grads variable name of the current operator.
rank (int): global ranks index for current process.
"""
if len(act_grad_names) == 0 or len(out_grad_names) == 0:
return
dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank)
if not dp_group:
return
sync_and_scale_gradients(dist_ctx, op, dp_group, out_grad_names)
......@@ -15,6 +15,7 @@
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import gradient_synchronization
from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -537,87 +538,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
# check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
# we need insert gradient allreduce for the gradient of parameter in its output
need_gradient_allreduce = False
# data parallel gradient synchronization
act_grad_names = []
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):
act_grad_names.append(varname)
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(
ctx, process_mesh, rank_id)
# NOTE: consider that the variable's shape is None
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] if len(
var_dim_mapping) > 0 else -1
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)
break
if need_gradient_allreduce:
allreduce_vars = []
for output_name in backward_op.desc.output_names():
for varname in backward_op.desc.output(output_name):
if varname in kwargs["grad_var_to_var"]:
fwd_name = kwargs["grad_var_to_var"][varname]
if fwd_name not in main_block.vars:
continue
if is_parameter_related(fwd_name, main_block):
allreduce_vars.append(varname)
if len(allreduce_vars) > 0:
for varname in allreduce_vars:
added_ops = []
grad_var = main_block.var(varname)
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [grad_var]},
outputs={'Out': [grad_var]},
attrs={
'ring_id': dp_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(allreduce_op)
if ctx.gradient_scale:
scale_op = main_block.append_op(
type='scale',
inputs={'X': grad_var},
outputs={'Out': grad_var},
attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(scale_op)
dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in added_ops:
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name,
dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name,
dims_mapping)
ctx.set_op_dist_attr_for_program(op, op_attr)
out_grad_names = []
for output_name in backward_op.desc.output_names():
for varname in backward_op.desc.output(output_name):
if varname in kwargs["grad_var_to_var"]:
fwd_name = kwargs["grad_var_to_var"][varname]
if fwd_name not in main_block.vars:
continue
if is_parameter_related(fwd_name, main_block):
out_grad_names.append(varname)
gradient_synchronization(ctx, backward_op, act_grad_names,
out_grad_names, rank_id)
register_distributed_operator_impl(
......
......@@ -16,6 +16,7 @@ from .common import infer_shape
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import gradient_synchronization
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -518,56 +519,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
naive_copy_op_dist_attr_for_program(c_embedding_grad_op, backward_op,
ctx)
# check if need gradient allreduce
need_gradient_allreduce = False
# data parallel gradient synchronization
act_grad_names = [Ids_var.name]
out_grad_names = [kwargs['W@GRAD'][0]]
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)
if need_gradient_allreduce:
added_ops = []
W_Grad_var = main_block.var(kwargs['W@GRAD'][0])
allreduce_op = main_block.append_op(type='c_allreduce_sum',
inputs={'X': [W_Grad_var]},
outputs={'Out': [W_Grad_var]},
attrs={
'ring_id': dp_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(allreduce_op)
if ctx.gradient_scale:
scale_op = main_block.append_op(type='scale',
inputs={'X': W_Grad_var},
outputs={'Out': W_Grad_var},
attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(scale_op)
main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_dist_attr_for_program(
W_Grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in added_ops:
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping)
ctx.set_op_dist_attr_for_program(op, op_attr)
gradient_synchronization(ctx, backward_op, act_grad_names,
out_grad_names, rank_id)
register_distributed_operator_impl("lookup_table_v2",
......
......@@ -19,6 +19,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import gradient_synchronization
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -422,55 +423,15 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
matmul_op_desc = copy_op_with_new_input_output(ctx, main_block,
backward_op, **kwargs)
# check if need gradient allreduce
need_gradient_allreduce = False
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, batch_size_axis,
rank_id)
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)
if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block):
added_ops = []
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op(type='c_allreduce_sum',
inputs={'X': [Y_Grad_var]},
outputs={'Out': [Y_Grad_var]},
attrs={
'ring_id': dp_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(allreduce_op)
if ctx.gradient_scale:
scale_op = main_block.append_op(type='scale',
inputs={'X': Y_Grad_var},
outputs={'Out': Y_Grad_var},
attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(scale_op)
main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_dist_attr_for_program(
Y_Grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in added_ops:
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(Y_Grad_var.name, dims_mapping)
ctx.set_op_dist_attr_for_program(op, op_attr)
# data parallel gradient synchronization
act_grad_names = [X_var.name]
out_grad_names = []
if is_parameter_related(Y_var.name, main_block):
out_grad_names = [kwargs['Y@GRAD'][0]]
gradient_synchronization(ctx, backward_op, act_grad_names, out_grad_names,
rank_id)
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册