未验证 提交 b518fa2a 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Add distributed mul op for paddle.fluid.layers.fc (#40207)

* [Auto Parallel] Add distributed mul for the old version
上级 9121115b
......@@ -1482,3 +1482,512 @@ register_distributed_operator_impl("matmul_v2",
DistributedMatmulV2Impl1("row_parallel"))
register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl2("replicate_parallel"))
class DistributedMul(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedMul, self).__init__(op_type)
register_distributed_operator_impl_container(DistributedMul("mul"))
# ColumnParallel
class DistributedMulImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMulImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(y_dims_mapping[
-1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_replicate(out_dims_mapping[-1]):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
if not _is_auto_compatible_for_matmul(dist_op):
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed:
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
"""
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0])
# TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group = new_process_group(group_ranks)
# infer new var shape with op dist attr
x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
assert x_tensor_dist_attr is not None
identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
assert identity_var_dist_attr is not None
ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr,
identity_var_dist_attr)
# infer out var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])),
dtype=X_var.dtype,
shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
identity_var_dist_attr)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
c_identity_op = main_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})
if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
# attrs = {'trans_x': False, 'trans_y': False}
attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims")
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
mul_op = main_block.append_op(
type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs)
if Out_var.shape != ref_shape_out:
Out_var.desc.set_shape(ref_shape_out)
# set dist op's dist_attr with serial op's dist_attr
# c_identity
identity_op_dist_attr = OperatorDistributedAttribute()
identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
identity_op_dist_attr.impl_type = op_dist_attr.impl_type
identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
# input
input_varname = c_identity_op.desc.input_arg_names()[0]
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
identity_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
# output
output_varname = c_identity_op.desc.output_arg_names()[0]
identity_op_dist_attr.set_output_dist_attr(output_varname,
input_dist_attr)
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
# matmulv2
matmulv2_op_dist_attr = OperatorDistributedAttribute()
matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in mul_op.desc.input_arg_names():
if input_varname in src_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(
input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
else:
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var)
matmulv2_op_dist_attr.set_input_dist_attr(input_varname,
tensor_dist_attr)
for output_varname in mul_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)
# init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
@staticmethod
def backward(ctx, *args, **kwargs):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
# RowParallel
class DistributedMulImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMulImpl1, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
if not _is_auto_compatible_for_matmul(dist_op):
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed:
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
"""
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0])
# TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group = new_process_group(group_ranks)
check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'],
'linear')
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
'linear')
# attrs = {'trans_x': False, 'trans_y': False}
attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims")
}
inputs = {'X': X_var, 'Y': Weight_var}
# infer out var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
intermediate_var_0 = main_block.create_var(
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
lod_level=Out_var.lod_level,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_var_dist_attr)
mul_op = main_block.append_op(
type='mul',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape)
c_allreduce_sum_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': intermediate_var_0},
outputs={'Out': Out_var},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True
})
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
# set dist op's dist_attr with serial op's dist_attr
# matmulv2
matmulv2_op_dist_attr = OperatorDistributedAttribute()
matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in mul_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
output_varname = mul_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)
# allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname,
tensor_dist_attr)
for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
allreduce_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op,
allreduce_op_dist_attr)
# init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
@staticmethod
def backward(ctx, *args, **kwargs):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
# ReplicateParallel
class DistributedMulImpl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMulImpl2, self).__init__(name)
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
if not _is_auto_compatible_for_matmul(dist_op):
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed:
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
register_distributed_operator_impl("mul",
DistributedMulImpl0("column_parallel"))
register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel"))
register_distributed_operator_impl("mul",
DistributedMulImpl2("replicate_parallel"))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册