From 94c17a0f2150b89b9eac853430c76e998cdd685e Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Tue, 9 Aug 2022 14:29:04 +0800 Subject: [PATCH] [Auto Parallel] Add mul dist op cost (#44973) * add mul dist op cost * add mul unittest --- .../auto_parallel/operators/dist_matmul.py | 538 +++++++++++++++++- .../auto_parallel/test_dist_op_cost.py | 209 +++++++ .../test_auto_parallel_reshard_dpmppp.py | 8 + .../test_auto_parallel_reshard_mppp.py | 10 +- 4 files changed, 763 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index f9b5b9a532..18ceb79ea8 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -1246,6 +1246,108 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # by now the backward function only insert the gradient allreduce for dist op itself + res = [] + backward_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + main_block = backward_op.block + vars = main_block.vars + Y_var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("Y")[0]) + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + # col parallel: matmul + allreduce + assert Y_var_dim_mapping[0] < 0 + parallel_axis = Y_var_dim_mapping[1] + + has_x_grad = len(backward_op.output("X@GRAD")) > 0 + if has_x_grad: + assert len(backward_op.output("X@GRAD")) == 1 + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + + cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # calc comm op cost + if has_x_grad: + attrs = {"use_calc_stream": True, "use_model_parallel": True} + var_names = backward_op.output("X@GRAD") + c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op( + "c_allreduce_sum", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + comm_op_cost_list = build_comm_costs_from_descs( + AllreduceSumOpCost, ctx, processes, + c_allreduce_sum_desc_mapping, cluster) + res.append(comm_op_cost_list) + + # need gradient allreduce + process_mesh = dist_attr.process_mesh + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("X")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[ + batch_size_axis] > 1 and is_parameter_related( + backward_op.input("Y")[0], main_block): + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('Y@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + return res + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + # TODO: trans shape if trans_x or trans_y is True + comp_desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + comp_cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx, + processes, + comp_desc_mapping, + cluster) + + # calc comm op cost + serial_op = dist_op.serial_op + vars = serial_op.block.vars + + parallel_axis = dist_op.dist_attr.get_input_dims_mapping( + serial_op.input("Y")[0])[-1] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + + var_names = serial_op.input("X") + c_identity_desc_mapping = build_comm_desc_from_dist_op( + "c_identity", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + comm_op_cost_list = build_comm_costs_from_descs( + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + + res_cost = [comm_op_cost_list, comp_cost_mapping] + return res_cost + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -1468,6 +1570,100 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # by now the backward function only insert the gradient allreduce for dist op itself + res = [] + backward_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + main_block = backward_op.block + vars = main_block.vars + Y_var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("Y")[0]) + assert Y_var_dim_mapping[1] < 0 + parallel_axis = Y_var_dim_mapping[0] + + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + # calc comm op cost + var_names = [backward_op.input("Out@GRAD")[0]] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + c_identity_desc_mapping = build_comm_desc_from_dist_op( + "c_identity", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + comm_op_cost_list = build_comm_costs_from_descs( + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + res.append(comm_op_cost_list) + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # need gradient allreduce + process_mesh = dist_attr.process_mesh + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("X")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[ + batch_size_axis] > 1 and is_parameter_related( + backward_op.input("Y")[0], main_block): + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('Y@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + return res + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx, + processes, desc_mapping, + cluster) + + # calc comm op cost + serial_op = dist_op.serial_op + vars = serial_op.block.vars + + parallel_axis = dist_op.dist_attr.get_input_dims_mapping( + serial_op.input("Y")[0])[-2] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + + var_names = serial_op.output("Out") + c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op( + "c_allreduce_sum", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + + comm_op_cost_list = build_comm_costs_from_descs( + AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, + cluster) + res_cost = [cost_mapping, comm_op_cost_list] + + return res_cost + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -1677,6 +1873,61 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): def __init__(self, name): super(DistributedMatmulV2Impl2, self).__init__(name) + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + res = [] + backward_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + main_block = backward_op.block + vars = main_block.vars + process_mesh = dist_attr.process_mesh + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # need gradient allreduce + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("X")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[ + batch_size_axis] > 1 and is_parameter_related( + backward_op.input("Y")[0], main_block): + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('Y@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + + return res + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx, + processes, desc_mapping, + cluster) + + res_cost = [cost_mapping] + + return res_cost + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -1765,6 +2016,102 @@ class DistributedMulImpl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # by now the backward function only insert the gradient allreduce for dist op itself + res = [] + backward_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + main_block = backward_op.block + vars = main_block.vars + Y_var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("Y")[0]) + # col parallel: matmul + allreduce + assert Y_var_dim_mapping[0] < 0 + parallel_axis = Y_var_dim_mapping[1] + + has_x_grad = len(backward_op.output("X@GRAD")) > 0 + if has_x_grad: + assert len(backward_op.output("X@GRAD")) == 1 + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # calc comm op cost + if has_x_grad: + attrs = {"use_calc_stream": True, "use_model_parallel": True} + var_names = backward_op.output("X@GRAD") + c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op( + "c_allreduce_sum", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + comm_op_cost_list = build_comm_costs_from_descs( + AllreduceSumOpCost, ctx, processes, + c_allreduce_sum_desc_mapping, cluster) + res.append(comm_op_cost_list) + + # need gradient allreduce + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("X")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[ + batch_size_axis] > 1 and is_parameter_related( + backward_op.input("Y")[0], main_block): + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('Y@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + return res + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes, + desc_mapping, cluster) + + # calc comm op cost + serial_op = dist_op.serial_op + vars = serial_op.block.vars + parallel_axis = dist_op.dist_attr.get_input_dims_mapping( + serial_op.input("Y")[0])[-1] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + var_names = serial_op.input("X") + c_identity_desc_mapping = build_comm_desc_from_dist_op( + "c_identity", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + + comm_op_cost_list = build_comm_costs_from_descs( + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + res_cost = [comm_op_cost_list, cost_mapping] + + return res_cost + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -1916,7 +2263,24 @@ class DistributedMulImpl0(DistributedOperatorImpl): "y_num_col_dims": src_op.desc.attr("y_num_col_dims"), OP_ROLE_KEY: src_op.attr('op_role') } - inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} + inputs = {'X': intermediate_var_0, 'Y': Weight_var} + + inputs_ref_shape = {} + inputs_original_shape = {} + for var_name in inputs: + if var_name == "X": + var = X_var + else: + var = inputs[var_name] + inputs_original_shape[var_name] = var.shape + input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) + input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name) + input_ref_shape = infer_shape(main_block, var, + input_tensor_dist_attr, + input_var_dist_attr) + inputs_ref_shape[var_name] = input_ref_shape + var.desc.set_shape(input_ref_shape) + mul_op = main_block.append_op(type='mul', inputs=inputs, outputs={'Out': Out_var}, @@ -1924,6 +2288,11 @@ class DistributedMulImpl0(DistributedOperatorImpl): if Out_var.shape != ref_shape_out: Out_var.desc.set_shape(ref_shape_out) + for var_name in inputs: + var = inputs[var_name] + original_shape = inputs_original_shape[var_name] + var.desc.set_shape(original_shape) + # set dist op's dist_attr with serial op's dist_attr # c_identity identity_op_dist_attr = OperatorDistributedAttribute() @@ -1988,6 +2357,100 @@ class DistributedMulImpl1(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # by now the backward function only insert the gradient allreduce for dist op itself + res = [] + backward_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + main_block = backward_op.block + vars = main_block.vars + Y_var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("Y")[0]) + assert Y_var_dim_mapping[1] < 0 + parallel_axis = Y_var_dim_mapping[0] + + # calc comm op cost + var_names = [backward_op.input("Out@GRAD")[0]] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + c_identity_desc_mapping = build_comm_desc_from_dist_op( + "c_identity", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + processes = process_mesh.processes + comm_op_cost_list = build_comm_costs_from_descs( + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + res.append(comm_op_cost_list) + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # need gradient allreduce + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("X")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[ + batch_size_axis] > 1 and is_parameter_related( + backward_op.input("Y")[0], main_block): + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('Y@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + return res + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes, + desc_mapping, cluster) + + # calc comm op cost + serial_op = dist_op.serial_op + vars = serial_op.block.vars + + parallel_axis = dist_op.dist_attr.get_input_dims_mapping( + serial_op.input("Y")[0])[-2] + attrs = {"use_calc_stream": True, "use_model_parallel": True} + + var_names = serial_op.output("Out") + c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op( + "c_allreduce_sum", + dist_op, + ctx, + var_names, + attrs=attrs, + parallel_axis=parallel_axis) + + # print("dist_matmul.py dist_op: ", dist_op) + comm_op_cost_list = build_comm_costs_from_descs( + AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, + cluster) + + res_cost = [cost_mapping, comm_op_cost_list] + + return res_cost + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -2122,13 +2585,32 @@ class DistributedMulImpl1(DistributedOperatorImpl): ctx.set_tensor_dist_attr_for_program(intermediate_var_0, out_var_dist_attr) + inputs_ref_shape = {} + inputs_original_shape = {} + for var_name in inputs: + var = inputs[var_name] + inputs_original_shape[var_name] = var.shape + input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) + input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name) + input_ref_shape = infer_shape(main_block, var, + input_tensor_dist_attr, + input_var_dist_attr) + inputs_ref_shape[var_name] = input_ref_shape + var.desc.set_shape(input_ref_shape) + mul_op = main_block.append_op(type='mul', inputs=inputs, outputs={'Out': intermediate_var_0}, attrs=attrs) + if intermediate_var_0.shape != ref_shape: intermediate_var_0.desc.set_shape(ref_shape) + for var_name in inputs: + var = inputs[var_name] + original_shape = inputs_original_shape[var_name] + var.desc.set_shape(original_shape) + c_allreduce_sum_op = main_block.append_op( type='c_allreduce_sum', inputs={'X': intermediate_var_0}, @@ -2139,6 +2621,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): 'use_model_parallel': True, OP_ROLE_KEY: src_op.attr('op_role') }) + if Out_var.shape != ref_shape: Out_var.desc.set_shape(ref_shape) @@ -2198,6 +2681,59 @@ class DistributedMulImpl2(DistributedOperatorImpl): def __init__(self, name): super(DistributedMulImpl2, self).__init__(name) + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Forward): + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + elif int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + res = [] + backward_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + main_block = backward_op.block + vars = main_block.vars + + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx, + processes, desc_mapping, + cluster) + res.append(cost_mapping) + + # need gradient allreduce + var_dim_mapping = dist_attr.get_input_dims_mapping( + backward_op.input("X")[0]) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[ + batch_size_axis] > 1 and is_parameter_related( + backward_op.input("Y")[0], main_block): + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [backward_op.output('Y@GRAD')[0]] + build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, + cluster) + + return res + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes, + desc_mapping, cluster) + + res_cost = [cost_mapping] + return res_cost + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py index 2bf2f887e9..734bd7acf9 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py @@ -215,6 +215,215 @@ class TestDistOpCost(unittest.TestCase): dist_context, cluster) self.assertTrue(dist_op_cost) + def test_dist_op_cost_part3(self): + + def make_program(): + main_program = paddle.static.Program() + start_program = paddle.static.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4], dtype='float32') + x.stop_gradient = True + label = paddle.static.data(name="label", + shape=[8, 1], + dtype='float32') + label.stop_gradient = True + auto.shard_tensor(x, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0] + }) + + auto.shard_tensor(label, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + # embedding + tmp = paddle.fluid.layers.fill_constant_batch_size_like( + input=x, shape=[4], value=1, dtype='int32') + embedding = paddle.nn.Embedding(10, 8) + out = embedding(tmp) + # row parallel embedding + for op in main_program.global_block().ops: + if op.type == "lookup_table_v2": + W = main_program.global_block().vars[op.input("W")[0]] + auto.shard_tensor(W, + dist_attr={ + "process_mesh": + auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + out = paddle.fluid.layers.transpose(out, + [1, 0]) # [8, 2] [-1, 0] + + # matmul_v2 + param1 = paddle.fluid.layers.create_parameter( + [4, 8], paddle.float32) # [2, 8] [0, -1] + auto.shard_tensor(param1, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + param2 = paddle.fluid.layers.create_parameter( + [8, 8], paddle.float32) # [8, 4] [-1, 0] + auto.shard_tensor(param2, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [-1, 0] + }) + out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1] + tmp_param = paddle.fluid.layers.create_parameter( + [8, 8], paddle.float32) # [8, 8] [-1, -1] + auto.shard_tensor(param2, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [-1, -1] + }) + tmp_out = paddle.matmul(out1, tmp_param) + out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] + + out8 = paddle.fluid.layers.transpose(out2, + [1, 0]) # [4, 8] [0, -1] + + # reshape + out9 = paddle.reshape(out8, [8, 2, 4]) # [4, 2, 4] [0, -1, -1] + tmp_reshape_out = paddle.reshape(out9, [8, 4, 2]) + out10 = paddle.reshape(tmp_reshape_out, + [8, 8]) # [4, 8] [0, -1] + + # softmax + softmax = paddle.nn.Softmax() + out11 = softmax(out10) + error_cost = paddle.nn.functional.square_error_cost( + out11, label) + loss = paddle.mean(error_cost) + return main_program, start_program, loss + + main_program, dist_context = parallelizer(make_program, 0) + ops = main_program.global_block().ops + cluster = Cluster() + cluster.gen_default_config_cluster(device_count=2) + for idx, op in enumerate(ops): + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + processes = op_dist_attr.process_mesh.processes + if is_elementwise_op(op.type): + container = get_distributed_operator_impl_container( + "elementwise") + else: + container = get_distributed_operator_impl_container( + op_dist_attr.impl_type) + + dist_impl = container.impls[op_dist_attr.impl_idx] + dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op, + dist_context, cluster) + self.assertTrue(dist_op_cost) + + def test_dist_op_cost_part4(self): + + def make_program(): + main_program = paddle.static.Program() + start_program = paddle.static.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4], dtype='float32') + x.stop_gradient = True + label = paddle.static.data(name="label", + shape=[8, 1], + dtype='float32') + label.stop_gradient = True + auto.shard_tensor(x, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0] + }) + + auto.shard_tensor(label, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + # embedding + tmp = paddle.fluid.layers.fill_constant_batch_size_like( + input=x, shape=[4], value=1, dtype='int32') + embedding = paddle.nn.Embedding(10, 8) + out = embedding(tmp) + # row parallel embedding + for op in main_program.global_block().ops: + if op.type == "lookup_table_v2": + W = main_program.global_block().vars[op.input("W")[0]] + auto.shard_tensor(W, + dist_attr={ + "process_mesh": + auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + out = paddle.fluid.layers.transpose(out, + [1, 0]) # [8, 2] [-1, 0] + + # mul + param1 = paddle.fluid.layers.create_parameter( + [4, 8], paddle.float32) # [2, 8] [0, -1] + auto.shard_tensor(param1, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + param2 = paddle.fluid.layers.create_parameter( + [8, 8], paddle.float32) # [8, 4] [-1, 0] + auto.shard_tensor(param2, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [-1, 0] + }) + out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1] + tmp_param = paddle.fluid.layers.create_parameter( + [8, 8], paddle.float32) # [8, 8] [-1, -1] + auto.shard_tensor(param2, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [-1, -1] + }) + tmp_out = paddle.fluid.layers.mul(out1, tmp_param) + out2 = paddle.fluid.layers.mul(tmp_out, + param2) # [8, 4] [-1, 0] + + out8 = paddle.fluid.layers.transpose(out2, + [1, 0]) # [4, 8] [0, -1] + + # reshape + out9 = paddle.reshape(out8, [8, 2, 4]) # [4, 2, 4] [0, -1, -1] + tmp_reshape_out = paddle.reshape(out9, [8, 4, 2]) + out10 = paddle.reshape(tmp_reshape_out, + [8, 8]) # [4, 8] [0, -1] + + # softmax + softmax = paddle.nn.Softmax() + out11 = softmax(out10) + error_cost = paddle.nn.functional.square_error_cost( + out11, label) + loss = paddle.mean(error_cost) + return main_program, start_program, loss + + main_program, dist_context = parallelizer(make_program, 0) + ops = main_program.global_block().ops + cluster = Cluster() + cluster.gen_default_config_cluster(device_count=2) + for idx, op in enumerate(ops): + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + processes = op_dist_attr.process_mesh.processes + if is_elementwise_op(op.type): + container = get_distributed_operator_impl_container( + "elementwise") + else: + container = get_distributed_operator_impl_container( + op_dist_attr.impl_type) + + dist_impl = container.impls[op_dist_attr.impl_idx] + dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op, + dist_context, cluster) + self.assertTrue(dist_op_cost) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index 7544ff4571..d6d613225d 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -76,6 +76,14 @@ class MLPLayer(nn.Layer): out = self.linear0(out) out = F.gelu(out, approximate=True) out = self.linear1(out) + param = paddle.fluid.layers.create_parameter([1024, 4096], + paddle.float32) + auto.shard_tensor(param, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [-1, 1] + }) + out = paddle.fluid.layers.mul(out, param) return out diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index dfb314796a..5c699881c2 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -93,6 +93,14 @@ class MLPLayer(nn.Layer): }) w_out = self.word_embeddings(input) out = self.linear0(w_out) + param = paddle.fluid.layers.create_parameter([4096, 4096], + paddle.float32) + auto.shard_tensor(param, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [0, -1] + }) + out = paddle.fluid.layers.mul(out, param) gelu_out = F.gelu(out, approximate=True) out = self.linear1(gelu_out) out1 = self.linear2(gelu_out) @@ -228,7 +236,7 @@ class TestMLPReshard(unittest.TestCase): resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() - + print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) -- GitLab