From ec1e0d5a6b73e7b4ca6aa077b879fbc328e4a322 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Fri, 29 Jul 2022 17:10:20 +0800 Subject: [PATCH] add dist op costs (#44701) --- .../auto_parallel/cost/__init__.py | 30 +- .../auto_parallel/cost/comp_op_cost.py | 19 ++ .../distributed/auto_parallel/dist_context.py | 6 +- .../auto_parallel/operators/dist_embedding.py | 92 ++++++ .../auto_parallel/operators/dist_matmul.py | 300 +++++++++++++++++- .../auto_parallel/operators/dist_reshape.py | 241 +++++++++++++- .../auto_parallel/operators/dist_softmax.py | 62 ++++ .../auto_parallel/operators/dist_transpose.py | 62 ++++ .../auto_parallel/test_dist_op_cost.py | 112 ++++++- 9 files changed, 902 insertions(+), 22 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/cost/__init__.py b/python/paddle/distributed/auto_parallel/cost/__init__.py index fa665bde1ec..e8ba0300d45 100644 --- a/python/paddle/distributed/auto_parallel/cost/__init__.py +++ b/python/paddle/distributed/auto_parallel/cost/__init__.py @@ -12,20 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License -from .base_cost import _g_op_cost_factory from .base_cost import Cost from .base_cost import CommContext +from .base_cost import _g_op_cost_factory from .base_cost import build_comm_desc -from .base_cost import build_comp_desc_from_op -from .base_cost import build_comp_desc_from_dist_op from .base_cost import build_dp_costs +from .base_cost import build_comp_desc_str_for_predict +from .base_cost import build_comp_desc_from_dist_op +from .base_cost import build_comm_desc_from_dist_op +from .base_cost import build_comm_costs_from_descs from .base_cost import build_comp_costs_from_descs -from .tensor_cost import TensorCost -from .estimate_cost import CostEstimator +from .comp_op_cost import EmbeddingOpCost +from .comp_op_cost import EmbeddingGradOpCost +from .comp_op_cost import ConcatOpCost +from .comp_op_cost import MatmulOpCost +from .comp_op_cost import MatmulGradOpCost from .comp_op_cost import MatmulV2OpCost +from .comp_op_cost import MatmulV2GradOpCost +from .comp_op_cost import MulOpCost +from .comp_op_cost import MulGradOpCost +from .comp_op_cost import Reshape2OpCost +from .comp_op_cost import Reshape2GradOpCost +from .comp_op_cost import SliceOpCost +from .comp_op_cost import SplitOpCost +from .comp_op_cost import SoftmaxOpCost +from .comp_op_cost import SoftmaxGradOpCost +from .comp_op_cost import Transpose2OpCost +from .comp_op_cost import Transpose2GradOpCost from .comp_op_cost import FillConstantBatchSizeLikeOpCost +from .tensor_cost import TensorCost + +from .estimate_cost import CostEstimator + from .comm_op_cost import SendOpCost from .comm_op_cost import RecvOpCost from .comm_op_cost import IdentityOpCost diff --git a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py index 8d8abe8d8e4..bdfcbfe06d3 100644 --- a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py @@ -15,6 +15,25 @@ from .base_cost import Cost, register_op_cost, CompOpCost, _g_op_cost_factory +@register_op_cost +class AdamOpCost(CompOpCost): + OP_TYPE = "adam" + + def __init__(self, op=None, op_desc=None, cluster=None): + super(AdamOpCost, self).__init__(op=op, + op_desc=op_desc, + cluster=cluster) + + # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided + def calc_flops(self): + # NOTE: The actual formula will be filled in the future + return 0 + + def calc_time(self): + # NOTE: The actual formula will be filled in the future + return 0 + + @register_op_cost class AssignOpCost(CompOpCost): OP_TYPE = "assign" diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 04b7f6aded7..b821d12a12f 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -831,8 +831,10 @@ class DistributedContext: if (dist_tensor is not None) and (not dist_tensor.validate_dist_attr()): assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format( - dist_tensor.serial_tensor.name, dist_tensor.desc.id(), - dist_tensor.desc.original_id(), dist_tensor.dist_attr) + dist_tensor.serial_tensor.name, + dist_tensor.serial_tensor.desc.id(), + dist_tensor.serial_tensor.desc.original_id(), + dist_tensor.dist_attr) for op in block.ops: dist_op = self.get_dist_op_for_program(op) assert dist_op is not None, \ diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index aa463398139..85b8c469aa4 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -31,6 +31,9 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..process_group import new_process_group from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank +from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op +from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs +from ..cost import EmbeddingOpCost, EmbeddingGradOpCost, AllreduceSumOpCost, IdentityOpCost class DistributedEmbedding(DistributedOperatorImplContainer): @@ -53,6 +56,95 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True + def calc_cost(self, op_role, dist_op, ctx, cluster): + """Calculate the cost by the op role.""" + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + # embedding need start_index + cost_mapping = build_comp_costs_from_descs(EmbeddingOpCost, ctx, + processes, desc_mapping, + cluster) + + serial_op = dist_op.serial_op + parallel_axis = dist_op.dist_attr.get_input_dims_mapping( + serial_op.input("W")[0])[0] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + var_names = serial_op.output("Out") + c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op( + "c_allreduce_sum", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + + comm_op_cost_list = build_comm_costs_from_descs( + AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, + cluster) + + res_cost = [cost_mapping, comm_op_cost_list] + + return res_cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # by now the backward function only insert the gradient allreduce for dist op itself + res = [] + backward_op = dist_op.serial_op + main_block = backward_op.block + dist_attr = dist_op.dist_attr + + embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("W")[0])[0] + parallel_axis = embedding_row_dim_mapping + attrs = {"use_calc_stream": True, "use_model_parallel": True} + var_names = [backward_op.input("Out@GRAD")[0]] + c_identity_desc_mapping = build_comm_desc_from_dist_op( + "c_identity", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + comm_op_cost_list = build_comm_costs_from_descs( + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + res.append(comm_op_cost_list) + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + cost_mapping = build_comp_costs_from_descs(EmbeddingGradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # need gradient allreduce + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("Ids")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('W@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + + return res + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 4e9aefd168c..5ca6366d6b5 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -13,6 +13,7 @@ # limitations under the License import copy + from .common import infer_shape from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl @@ -35,6 +36,10 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank from .dist_default import DistributedDefaultImpl0 +from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs +from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs +from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost, IdentityOpCost, AllreduceSumOpCost +from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): @@ -58,6 +63,14 @@ def _update_dims_mapping_for_matmul(dist_op): x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] out_name = op_desc.output('Out')[0] + trans_x = None + trans_y = None + if op_desc.type() == "matmul_v2": + trans_x = op_desc.attr('trans_x') + trans_y = op_desc.attr('trans_y') + elif op_desc.type() == "matmul": + trans_x = op_desc.attr('transpose_X') + trans_y = op_desc.attr('transpose_Y') x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) @@ -67,27 +80,34 @@ def _update_dims_mapping_for_matmul(dist_op): # Add dim mapping to Make sure the length dims_mapping be at least 2 if x_dims_mapping_len == 1: + assert trans_x is False x_dims_mapping.insert(0, -1) + out_dims_mapping.insert(out_dims_mapping_len - 1, 0) if y_dims_mapping_len == 1: + assert trans_y is False y_dims_mapping.insert(1, -1) + out_dims_mapping.insert(out_dims_mapping_len, 0) + new_x_dims_mapping_len = len(x_dims_mapping) + new_y_dims_mapping_len = len(y_dims_mapping) + new_out_dims_mapping_len = len(out_dims_mapping) # Deal with dim > 2 and take care of broadcasting - if out_dims_mapping_len > 2: + if new_out_dims_mapping_len > 2: broadcast_x_dims_mapping = [] broadcast_y_dims_mapping = [] broadcast_out_dims_mapping = [] - for i in range(out_dims_mapping_len - x_dims_mapping_len): + for i in range(new_out_dims_mapping_len - new_x_dims_mapping_len): broadcast_x_dims_mapping.append(out_dims_mapping[i]) - for i in range(x_dims_mapping_len - 2): + for i in range(new_x_dims_mapping_len - 2): broadcast_x_dims_mapping.append(x_dims_mapping[i]) - for i in range(out_dims_mapping_len - y_dims_mapping_len): + for i in range(new_out_dims_mapping_len - new_y_dims_mapping_len): broadcast_y_dims_mapping.append(out_dims_mapping[i]) - for i in range(y_dims_mapping_len - 2): + for i in range(new_y_dims_mapping_len - 2): broadcast_y_dims_mapping.append(y_dims_mapping[i]) - for i in range(out_dims_mapping_len - 2): + for i in range(new_out_dims_mapping_len - 2): broadcast_out_dims_mapping.append(out_dims_mapping[i]) compatible_dims_mapping = compute_compatible_dims_mapping([ @@ -97,23 +117,30 @@ def _update_dims_mapping_for_matmul(dist_op): if compatible_dims_mapping is None: return False - for i in range(x_dims_mapping_len - 2): - new_idx = i + (out_dims_mapping_len - x_dims_mapping_len) + for i in range(new_x_dims_mapping_len - 2): + new_idx = i + (out_dims_mapping_len - new_x_dims_mapping_len) if x_dims_mapping[i] != compatible_dims_mapping[new_idx]: x_dims_mapping[i] = compatible_dims_mapping[new_idx] changed = True - for i in range(y_dims_mapping_len - 2): - new_idx = i + (out_dims_mapping_len - y_dims_mapping_len) + for i in range(new_y_dims_mapping_len - 2): + new_idx = i + (out_dims_mapping_len - new_y_dims_mapping_len) if y_dims_mapping[i] != compatible_dims_mapping[new_idx]: y_dims_mapping[i] = compatible_dims_mapping[new_idx] changed = True - for i in range(out_dims_mapping_len - 2): + for i in range(new_out_dims_mapping_len - 2): if out_dims_mapping[i] != compatible_dims_mapping[i]: out_dims_mapping[i] = compatible_dims_mapping[i] changed = True + if trans_x: + x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[ + -2], x_dims_mapping[-1] + if trans_y: + y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[ + -2], y_dims_mapping[-1] + # The following which uses negative index can be work # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 dim_changed = compute_compatible_and_update_dim_mapping( @@ -131,11 +158,20 @@ def _update_dims_mapping_for_matmul(dist_op): if dim_changed: changed = True + if trans_x: + x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[ + -2], x_dims_mapping[-1] + if trans_y: + y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[ + -2], y_dims_mapping[-1] + # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor if x_dims_mapping_len == 1: x_dims_mapping.pop(0) + out_dims_mapping.pop(out_dims_mapping_len - 1) if y_dims_mapping_len == 1: y_dims_mapping.pop(1) + out_dims_mapping.pop(out_dims_mapping_len) assert len(x_dims_mapping) == x_dims_mapping_len assert len(y_dims_mapping) == y_dims_mapping_len @@ -484,6 +520,102 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # by now the backward function only insert the gradient allreduce for dist op itself + res = [] + backward_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + main_block = backward_op.block + vars = main_block.vars + Y_var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("Y")[0]) + # col parallel: matmul + allreduce + assert Y_var_dim_mapping[0] < 0 + parallel_axis = Y_var_dim_mapping[1] + + has_x_grad = len(backward_op.output("X@GRAD")) > 0 + if has_x_grad: + assert len(backward_op.output("X@GRAD")) == 1 + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # calc comm op cost + if has_x_grad: + attrs = {"use_calc_stream": True, "use_model_parallel": True} + var_names = backward_op.output("X@GRAD") + c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op( + "c_allreduce_sum", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + comm_op_cost_list = build_comm_costs_from_descs( + AllreduceSumOpCost, ctx, processes, + c_allreduce_sum_desc_mapping, cluster) + res.append(comm_op_cost_list) + + # need gradient allreduce + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("X")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[ + batch_size_axis] > 1 and is_parameter_related( + backward_op.input("Y")[0], main_block): + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('Y@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + return res + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes, + desc_mapping, cluster) + + # calc comm op cost + serial_op = dist_op.serial_op + vars = serial_op.block.vars + parallel_axis = dist_op.dist_attr.get_input_dims_mapping( + serial_op.input("Y")[0])[-1] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + var_names = serial_op.input("X") + c_identity_desc_mapping = build_comm_desc_from_dist_op( + "c_identity", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + + comm_op_cost_list = build_comm_costs_from_descs( + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + res_cost = [comm_op_cost_list, cost_mapping] + + return res_cost + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -710,6 +842,99 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # by now the backward function only insert the gradient allreduce for dist op itself + res = [] + backward_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + main_block = backward_op.block + vars = main_block.vars + Y_var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("Y")[0]) + assert Y_var_dim_mapping[1] < 0 + parallel_axis = Y_var_dim_mapping[0] + + # calc comm op cost + var_names = [backward_op.input("Out@GRAD")[0]] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + c_identity_desc_mapping = build_comm_desc_from_dist_op( + "c_identity", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + comm_op_cost_list = build_comm_costs_from_descs( + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + res.append(comm_op_cost_list) + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # need gradient allreduce + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("X")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[ + batch_size_axis] > 1 and is_parameter_related( + backward_op.input("Y")[0], main_block): + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('Y@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + return res + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes, + desc_mapping, cluster) + + # calc comm op cost + serial_op = dist_op.serial_op + vars = serial_op.block.vars + + parallel_axis = dist_op.dist_attr.get_input_dims_mapping( + serial_op.input("Y")[0])[-2] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + + var_names = serial_op.output("Out") + c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op( + "c_allreduce_sum", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + + comm_op_cost_list = build_comm_costs_from_descs( + AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, + cluster) + + res_cost = [cost_mapping, comm_op_cost_list] + + return res_cost + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -920,6 +1145,59 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): def __init__(self, name): super(DistributedMatmulImpl2, self).__init__(name) + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + res = [] + backward_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + main_block = backward_op.block + vars = main_block.vars + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # need gradient allreduce + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("X")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[ + batch_size_axis] > 1 and is_parameter_related( + backward_op.input("Y")[0], main_block): + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('Y@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + + return res + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes, + desc_mapping, cluster) + + res_cost = [cost_mapping] + return res_cost + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index 790e97cf4e1..d896667008c 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -15,7 +15,7 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container -from .common import register_distributed_operator_impl +from .common import register_distributed_operator_impl, is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -28,6 +28,11 @@ from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from .dist_default import DistributedDefaultImpl0 +from ..cost import build_comp_desc_from_dist_op, build_comp_costs_from_descs +from ..cost import build_comm_costs_from_descs +from ..cost import Reshape2OpCost +from ..cost import Reshape2GradOpCost +from paddle.distributed.fleet.meta_optimizers.common import OpRole class DistributedReshape2(DistributedOperatorImplContainer): @@ -46,6 +51,84 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = False + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + else: + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + res = [] + op = dist_op.serial_op + vars = op.block.vars + dist_attr = dist_op.dist_attr + + shape_list = op.desc.attr("shape") + # got dist attribute info + dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0]) + process_mesh_shape = dist_attr.process_mesh.topology + + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[ + idx] = shape_list[idx] // process_mesh_shape[axis] + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_attr.process_mesh.processes + for key in desc_mapping: + desc_mapping[key]["shape"] = shape_list + + cost_mapping = build_comp_costs_from_descs(Reshape2OpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + return res + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + res = [] + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + op_type = dist_op.serial_op.type + + cost_mapping = build_comp_costs_from_descs(Reshape2GradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + backward_op = dist_op.serial_op + main_block = backward_op.block + need_gradient_allreduce = False + vars = main_block.vars + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and is_parameter_related( + varname, main_block): + # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op + var_dim_mapping = dist_attr.get_input_dims_mapping(varname) + + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [varname + "@GRAD"] + build_dp_costs(res, dist_op, ctx, var_names, attrs, + parallel_axis, cluster) + + return res + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -199,6 +282,84 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = False + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + else: + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + res = [] + op = dist_op.serial_op + vars = op.block.vars + dist_attr = dist_op.dist_attr + + shape_list = op.desc.attr("shape") + # got dist attribute info + dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0]) + process_mesh_shape = dist_attr.process_mesh.topology + + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[ + idx] = shape_list[idx] // process_mesh_shape[axis] + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_attr.process_mesh.processes + for key in desc_mapping: + desc_mapping[key]["shape"] = shape_list + + cost_mapping = build_comp_costs_from_descs(Reshape2OpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + return res + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + res = [] + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + op_type = dist_op.serial_op.type + + cost_mapping = build_comp_costs_from_descs(Reshape2GradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + backward_op = dist_op.serial_op + main_block = backward_op.block + need_gradient_allreduce = False + vars = main_block.vars + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and not is_parameter_related( + varname, main_block): + # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op + var_dim_mapping = dist_attr.get_input_dims_mapping(varname) + + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [varname + "@GRAD"] + build_dp_costs(res, dist_op, ctx, var_names, attrs, + parallel_axis, cluster) + + return res + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -355,6 +516,84 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = False + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + else: + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + res = [] + op = dist_op.serial_op + vars = op.block.vars + dist_attr = dist_op.dist_attr + + shape_list = op.desc.attr("shape") + # got dist attribute info + dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0]) + process_mesh_shape = dist_attr.process_mesh.topology + + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[ + idx] = shape_list[idx] // process_mesh_shape[axis] + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_attr.process_mesh.processes + for key in desc_mapping: + desc_mapping[key]["shape"] = shape_list + + cost_mapping = build_comp_costs_from_descs(Reshape2OpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + return res + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + res = [] + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + op_type = dist_op.serial_op.type + + cost_mapping = build_comp_costs_from_descs(Reshape2GradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + backward_op = dist_op.serial_op + main_block = backward_op.block + need_gradient_allreduce = False + vars = main_block.vars + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and not is_parameter_related( + varname, main_block): + # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op + var_dim_mapping = dist_attr.get_input_dims_mapping(varname) + + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [varname + "@GRAD"] + build_dp_costs(res, dist_op, ctx, var_names, attrs, + parallel_axis, cluster) + + return res + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index afcdea4f045..bef18d1da8a 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl +from .common import is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -23,6 +24,11 @@ from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping from .dist_default import DistributedDefaultImpl0 +from ..cost import AllreduceSumOpCost, _g_op_cost_factory +from ..cost import build_comp_desc_from_dist_op, build_dp_costs +from ..cost import build_comp_costs_from_descs +from ..cost import SoftmaxOpCost, SoftmaxGradOpCost +from paddle.distributed.fleet.meta_optimizers.common import OpRole class DistributedSoftmax(DistributedOperatorImplContainer): @@ -41,6 +47,62 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): self._forward_implemented = False self._backward_implemented = False + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + else: + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + cost_mapping = build_comp_costs_from_descs(SoftmaxOpCost, ctx, + processes, desc_mapping, + cluster) + + res_cost = [cost_mapping] + return res_cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + res = [] + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + cost_mapping = build_comp_costs_from_descs(SoftmaxGradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + backward_op = dist_op.serial_op + main_block = backward_op.block + need_gradient_allreduce = False + vars = main_block.vars + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and is_parameter_related( + varname, main_block): + # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op + var_dim_mapping = dist_attr.get_input_dims_mapping(varname) + + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [varname + "@GRAD"] + build_dp_costs(res, dist_op, ctx, var_names, attrs, + parallel_axis, cluster) + + return res + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index 0dc4177399e..e5b4a51c4db 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl +from .common import is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -23,6 +24,10 @@ from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping from .dist_default import DistributedDefaultImpl0 +from ..cost import AllreduceSumOpCost, Transpose2OpCost, Transpose2GradOpCost +from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs +from ..cost import build_comp_costs_from_descs +from paddle.distributed.fleet.meta_optimizers.common import OpRole class DistributedTranspose2(DistributedOperatorImplContainer): @@ -116,6 +121,63 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): return changed + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + else: + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + op_type = dist_op.serial_op.type + cost_mapping = build_comp_costs_from_descs(Transpose2OpCost, ctx, + processes, desc_mapping, + cluster) + + res_cost = [cost_mapping] + return res_cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + res = [] + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + op_type = dist_op.serial_op.type + cost_mapping = build_comp_costs_from_descs(Transpose2GradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + backward_op = dist_op.serial_op + main_block = backward_op.block + need_gradient_allreduce = False + vars = main_block.vars + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and is_parameter_related( + varname, main_block): + # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op + var_dim_mapping = dist_attr.get_input_dims_mapping(varname) + + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [varname + "@GRAD"] + build_dp_costs(res, dist_op, ctx, var_names, attrs, + parallel_axis, cluster) + return res + @staticmethod def forward(ctx, *args, **kwargs): DistributedDefaultImpl0.forward(ctx, *args, **kwargs) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py index 0956c5bae61..2bf2f887e9d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py @@ -47,7 +47,7 @@ def parallelizer(program_func, rank): completer.complete_backward_annotation(main_program) dist_context.block_state.parse_backward_blocks(main_program) - optimizer = paddle.optimizer.SGD(learning_rate=0.001) + optimizer = paddle.optimizer.Adam(learning_rate=0.001) # generate opt and complete opt with program_guard(main_program, startup_program): optimize_ops = copy.deepcopy(optimizer).apply_gradients(params_grads) @@ -59,7 +59,7 @@ def parallelizer(program_func, rank): class TestDistOpCost(unittest.TestCase): - def test_dist_fill_constatnt_batch_size_like_op_cost(self): + def test_dist_op_cost_part1(self): def make_program(): main_program = paddle.static.Program() @@ -79,7 +79,7 @@ class TestDistOpCost(unittest.TestCase): tmp = paddle.fluid.layers.fill_constant_batch_size_like( input=x, shape=[2, 8], value=1, dtype='float32') weight_attr = paddle.ParamAttr() - linear = paddle.nn.Linear(8, 8, weight_attr=weight_attr) + linear = paddle.nn.Linear(8, 1, weight_attr=weight_attr) linear_out = linear(x) gelu_out = paddle.nn.functional.gelu(linear_out) # default op with dp @@ -109,6 +109,112 @@ class TestDistOpCost(unittest.TestCase): dist_context, cluster) self.assertTrue(dist_op_cost) + def test_dist_op_cost_part2(self): + + def make_program(): + main_program = paddle.static.Program() + start_program = paddle.static.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4], dtype='float32') + x.stop_gradient = True + label = paddle.static.data(name="label", + shape=[8, 1], + dtype='float32') + label.stop_gradient = True + auto.shard_tensor(x, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0] + }) + + auto.shard_tensor(label, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + # embedding + tmp = paddle.fluid.layers.fill_constant_batch_size_like( + input=x, shape=[4], value=1, dtype='int32') + embedding = paddle.nn.Embedding(10, 8) + out = embedding(tmp) + # row parallel embedding + for op in main_program.global_block().ops: + if op.type == "lookup_table_v2": + W = main_program.global_block().vars[op.input("W")[0]] + auto.shard_tensor(W, + dist_attr={ + "process_mesh": + auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + out = paddle.fluid.layers.transpose(out, + [1, 0]) # [8, 2] [-1, 0] + + # matmul + param1 = paddle.fluid.layers.create_parameter( + [4, 8], paddle.float32) # [2, 8] [0, -1] + auto.shard_tensor(param1, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + param2 = paddle.fluid.layers.create_parameter( + [8, 8], paddle.float32) # [8, 4] [-1, 0] + auto.shard_tensor(param2, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [-1, 0] + }) + out1 = paddle.fluid.layers.matmul(out, + param1) # [8, 8] [-1, -1] + tmp_param = paddle.fluid.layers.create_parameter( + [8, 8], paddle.float32) # [8, 8] [-1, -1] + auto.shard_tensor(param2, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [-1, -1] + }) + tmp_out = paddle.fluid.layers.matmul(out1, tmp_param) + out2 = paddle.fluid.layers.matmul(tmp_out, + param2) # [8, 4] [-1, 0] + + out8 = paddle.fluid.layers.transpose(out2, + [1, 0]) # [4, 8] [0, -1] + + # reshape + out9 = paddle.reshape(out8, [8, 2, 4]) # [4, 2, 4] [0, -1, -1] + tmp_reshape_out = paddle.reshape(out9, [8, 4, 2]) + out10 = paddle.reshape(tmp_reshape_out, + [8, 8]) # [4, 8] [0, -1] + + # softmax + softmax = paddle.nn.Softmax() + out11 = softmax(out10) + error_cost = paddle.nn.functional.square_error_cost( + out11, label) + loss = paddle.mean(error_cost) + return main_program, start_program, loss + + main_program, dist_context = parallelizer(make_program, 0) + ops = main_program.global_block().ops + cluster = Cluster() + cluster.gen_default_config_cluster(device_count=2) + for idx, op in enumerate(ops): + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + processes = op_dist_attr.process_mesh.processes + if is_elementwise_op(op.type): + container = get_distributed_operator_impl_container( + "elementwise") + else: + container = get_distributed_operator_impl_container( + op_dist_attr.impl_type) + + dist_impl = container.impls[op_dist_attr.impl_idx] + dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op, + dist_context, cluster) + self.assertTrue(dist_op_cost) + if __name__ == "__main__": unittest.main() -- GitLab