diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index c92142cf7384d2b0c76c1a5cb3b4e6ac257303a2..684db52a28d83e49e53790b0abd4db278a247865 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -1482,3 +1482,512 @@ register_distributed_operator_impl("matmul_v2", DistributedMatmulV2Impl1("row_parallel")) register_distributed_operator_impl( "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")) + + +class DistributedMul(DistributedOperatorImplContainer): + def __init__(self, op_type): + super(DistributedMul, self).__init__(op_type) + + +register_distributed_operator_impl_container(DistributedMul("mul")) + + +# ColumnParallel +class DistributedMulImpl0(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedMulImpl0, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + if is_dim_shard(x_dims_mapping[-1]): + return False + if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(y_dims_mapping[ + -1]): + return False + for mapping in x_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if is_dim_replicate(out_dims_mapping[-1]): + return False + for mapping in out_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + + if not _is_auto_compatible_for_matmul(dist_op): + return False + + return True + + def update_dims_mapping(self, dist_op): + changed = False + dim_changed = _update_dims_mapping_for_matmul(dist_op) + if dim_changed: + changed = True + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(src_op)) + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, + rank_id) + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + X_var = main_block.var(kwargs['X'][0]) + Weight_var = main_block._var_recursive(kwargs['Y'][0]) + Out_var = main_block.var(kwargs['Out'][0]) + + # TODO infer logic comm presentation + matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[-1] + assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_col_dim_mapping) + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes + + parallel_axis = matmul_col_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + # infer new var shape with op dist attr + x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var) + assert x_tensor_dist_attr is not None + identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) + assert identity_var_dist_attr is not None + ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr, + identity_var_dist_attr) + # infer out var shape with op dist attr + out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) + assert out_tensor_dist_attr is not None + out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert out_var_dist_attr is not None + ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr, + out_var_dist_attr) + + intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_identity", 'tmp'])), + dtype=X_var.dtype, + shape=X_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=X_var.stop_gradient) + # set intermediate_var_0's dist_attr with X_var's dist_attr + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + identity_var_dist_attr) + + check_variable_and_dtype( + X_var, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') + c_identity_op = main_block.append_op( + type='c_identity', + inputs={'X': [X_var]}, + outputs={'Out': intermediate_var_0}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + if intermediate_var_0.shape != ref_shape_x: + intermediate_var_0.desc.set_shape(ref_shape_x) + + check_variable_and_dtype(intermediate_var_0, 'x', + ['float16', 'float32', 'float64'], 'linear') + check_dtype(intermediate_var_0.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + # attrs = {'trans_x': False, 'trans_y': False} + attrs = { + "x_num_col_dims": src_op.desc.attr("x_num_col_dims"), + "y_num_col_dims": src_op.desc.attr("y_num_col_dims") + } + inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} + mul_op = main_block.append_op( + type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs) + if Out_var.shape != ref_shape_out: + Out_var.desc.set_shape(ref_shape_out) + + # set dist op's dist_attr with serial op's dist_attr + # c_identity + identity_op_dist_attr = OperatorDistributedAttribute() + identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh + identity_op_dist_attr.impl_type = op_dist_attr.impl_type + identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx + # input + input_varname = c_identity_op.desc.input_arg_names()[0] + input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + identity_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + # output + output_varname = c_identity_op.desc.output_arg_names()[0] + identity_op_dist_attr.set_output_dist_attr(output_varname, + input_dist_attr) + ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) + + # matmulv2 + matmulv2_op_dist_attr = OperatorDistributedAttribute() + matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type + matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in mul_op.desc.input_arg_names(): + if input_varname in src_op.desc.input_arg_names(): + input_dist_attr = op_dist_attr.get_input_dist_attr( + input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmulv2_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + else: + input_var = main_block.var(input_varname) + tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( + input_var) + matmulv2_op_dist_attr.set_input_dist_attr(input_varname, + tensor_dist_attr) + for output_varname in mul_op.desc.output_arg_names(): + output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmulv2_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr) + + # init param sync + if Weight_var.is_parameter and not op_dist_attr.is_recompute: + _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, + rank_id) + + @staticmethod + def backward(ctx, *args, **kwargs): + _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) + + +# RowParallel +class DistributedMulImpl1(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedMulImpl1, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + if is_dim_replicate(x_dims_mapping[-1]): + return False + if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[ + -1]): + return False + # Other dimensions must be replicate except the batch dimension + for mapping in x_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if is_dim_shard(out_dims_mapping[-1]): + return False + # Other dimensions must be replicate except the batch dimension + for mapping in out_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + + if not _is_auto_compatible_for_matmul(dist_op): + return False + + return True + + def update_dims_mapping(self, dist_op): + changed = False + dim_changed = _update_dims_mapping_for_matmul(dist_op) + if dim_changed: + changed = True + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(src_op)) + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, + rank_id) + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + X_var = main_block.var(kwargs['X'][0]) + Weight_var = main_block._var_recursive(kwargs['Y'][0]) + Out_var = main_block.var(kwargs['Out'][0]) + + # TODO infer logic comm presentation + matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[-2] + assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_row_dim_mapping) + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes + + parallel_axis = matmul_row_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], + 'linear') + check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], + 'linear') + # attrs = {'trans_x': False, 'trans_y': False} + attrs = { + "x_num_col_dims": src_op.desc.attr("x_num_col_dims"), + "y_num_col_dims": src_op.desc.attr("y_num_col_dims") + } + inputs = {'X': X_var, 'Y': Weight_var} + + # infer out var shape with op dist attr + out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) + assert out_tensor_dist_attr is not None + out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert out_var_dist_attr is not None + ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, + out_var_dist_attr) + + intermediate_var_0 = main_block.create_var( + shape=Out_var.shape, + dtype=Out_var.dtype, + type=Out_var.type, + lod_level=Out_var.lod_level, + persistable=False, + is_data=False, + need_check_feed=Out_var.desc.need_check_feed()) + # set intermediate_var_0's dist_attr with Out_var's dist_attr + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + out_var_dist_attr) + + mul_op = main_block.append_op( + type='mul', + inputs=inputs, + outputs={'Out': intermediate_var_0}, + attrs=attrs) + if intermediate_var_0.shape != ref_shape: + intermediate_var_0.desc.set_shape(ref_shape) + + c_allreduce_sum_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': intermediate_var_0}, + outputs={'Out': Out_var}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + if Out_var.shape != ref_shape: + Out_var.desc.set_shape(ref_shape) + + # set dist op's dist_attr with serial op's dist_attr + # matmulv2 + matmulv2_op_dist_attr = OperatorDistributedAttribute() + matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type + matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in mul_op.desc.input_arg_names(): + input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmulv2_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + output_varname = mul_op.desc.output_arg_names()[0] + output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmulv2_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr) + + # allreduce + allreduce_op_dist_attr = OperatorDistributedAttribute() + allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type + allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in c_allreduce_sum_op.desc.input_arg_names(): + input_var = main_block.var(input_varname) + tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) + assert tensor_dist_attr is not None + allreduce_op_dist_attr.set_input_dist_attr(input_varname, + tensor_dist_attr) + for output_varname in c_allreduce_sum_op.desc.output_arg_names(): + output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + allreduce_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, + allreduce_op_dist_attr) + + # init param sync + if Weight_var.is_parameter and not op_dist_attr.is_recompute: + _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, + rank_id) + + @staticmethod + def backward(ctx, *args, **kwargs): + _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) + + +# ReplicateParallel +class DistributedMulImpl2(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedMulImpl2, self).__init__(name) + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + + if is_dim_shard(x_dims_mapping[-1]): + return False + if is_valid_list_index(x_dims_mapping, + -2) and is_dim_shard(x_dims_mapping[-2]): + return False + + if is_dim_shard(y_dims_mapping[-1]): + return False + if is_valid_list_index(y_dims_mapping, + -2) and is_dim_shard(y_dims_mapping[-2]): + return False + return True + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if is_dim_shard(out_dims_mapping[-1]): + return False + if is_valid_list_index(out_dims_mapping, + -2) and is_dim_shard(out_dims_mapping[-2]): + return False + + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + + if not _is_auto_compatible_for_matmul(dist_op): + return False + + return True + + def update_dims_mapping(self, dist_op): + changed = False + dim_changed = _update_dims_mapping_for_matmul(dist_op) + if dim_changed: + changed = True + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl("mul", + DistributedMulImpl0("column_parallel")) +register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel")) +register_distributed_operator_impl("mul", + DistributedMulImpl2("replicate_parallel"))