From d0f4465d3554346a4704500a1dce2a756ea938e8 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Mon, 25 Jul 2022 19:07:33 +0800 Subject: [PATCH] [Auto Parallel] Add dist op cost (#44146) * update comp cost * add dist default op cost * add dist fill constant batch size like op cost * add elewise op cost * add fill_constant_batch_size_like op cost unittest * add unittest and remove fill_constant_batch_size_like grad op cost * add to cmakelist * fix unittest bug --- .../auto_parallel/cost/__init__.py | 5 + .../auto_parallel/cost/comp_op_cost.py | 18 --- .../auto_parallel/operators/dist_default.py | 71 +++++++++++ .../auto_parallel/operators/dist_eltwise.py | 73 ++++++++++- .../dist_fill_constant_batch_size_like.py | 28 +++++ .../unittests/auto_parallel/CMakeLists.txt | 1 + .../unittests/auto_parallel/test_comp_cost.py | 6 - .../auto_parallel/test_dist_op_cost.py | 114 ++++++++++++++++++ 8 files changed, 291 insertions(+), 25 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py diff --git a/python/paddle/distributed/auto_parallel/cost/__init__.py b/python/paddle/distributed/auto_parallel/cost/__init__.py index ea6b3bc5b7..fa665bde1e 100644 --- a/python/paddle/distributed/auto_parallel/cost/__init__.py +++ b/python/paddle/distributed/auto_parallel/cost/__init__.py @@ -16,10 +16,15 @@ from .base_cost import _g_op_cost_factory from .base_cost import Cost from .base_cost import CommContext from .base_cost import build_comm_desc +from .base_cost import build_comp_desc_from_op +from .base_cost import build_comp_desc_from_dist_op +from .base_cost import build_dp_costs +from .base_cost import build_comp_costs_from_descs from .tensor_cost import TensorCost from .estimate_cost import CostEstimator from .comp_op_cost import MatmulV2OpCost +from .comp_op_cost import FillConstantBatchSizeLikeOpCost from .comm_op_cost import SendOpCost from .comm_op_cost import RecvOpCost diff --git a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py index 6556a1110d..8d8abe8d8e 100644 --- a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py @@ -357,24 +357,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost): return 0 -@register_op_cost -class FillConstantBatchSizeLikeGradOpCost(CompOpCost): - OP_TYPE = "fill_constant_batch_size_like_grad" - - def __init__(self, op=None, op_desc=None, cluster=None): - super(FillConstantBatchSizeLikeGradOpCost, - self).__init__(op=op, op_desc=op_desc, cluster=cluster) - - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - - @register_op_cost class GatherOpCost(CompOpCost): OP_TYPE = "gather" diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 9d9d5371ac..9b288d36e4 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -31,6 +31,9 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank +from ..cost import _g_op_cost_factory +from ..cost import build_comp_desc_from_dist_op, build_dp_costs +from ..cost import build_comp_costs_from_descs __op_not_need_param_init__ = ["while", "cond"] @@ -99,6 +102,74 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True + def calc_cost(self, op_role, dist_op, ctx, cluster): + """Calculate the cost by the op role.""" + cost = None + if int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + else: + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + op_type = dist_op.serial_op.type + cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type], + ctx, processes, desc_mapping, + cluster) + res_cost = [cost_mapping] + + return res_cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + res = [] + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + backward_op = dist_op.serial_op + op_type = backward_op.type + cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type], + ctx, processes, desc_mapping, + cluster) + res.append(cost_mapping) + + main_block = backward_op.block + vars = main_block.vars + need_gradient_allreduce = False + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and not is_parameter_related( + varname, main_block): + var_dim_mapping = dist_attr.get_input_dims_mapping(varname) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + need_gradient_allreduce = True + break + + if need_gradient_allreduce: + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and is_parameter_related( + varname, main_block): + var_dim_mapping = dist_attr.get_input_dims_mapping( + varname) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [varname + "@GRAD"] + build_dp_costs(res, dist_op, ctx, var_names, attrs, + parallel_axis, cluster) + return res + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr diff --git a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py index 02f2741d88..348e2ee457 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py @@ -15,7 +15,7 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container -from .common import register_distributed_operator_impl +from .common import register_distributed_operator_impl, is_parameter_related from .common import is_elementwise_op from ..utils import is_dim_shard from ..utils import is_dim_replicate @@ -32,6 +32,9 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank from .dist_default import DistributedDefaultImpl0 +from ..cost import _g_op_cost_factory +from ..cost import build_comp_desc_from_dist_op, build_dp_costs +from ..cost import build_comp_costs_from_descs class DistributedElementwise(DistributedOperatorImplContainer): @@ -52,6 +55,74 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): self._forward_implemented = False self._backward_implemented = False + def calc_cost(self, op_role, dist_op, ctx, cluster): + """Calculate the cost by the op role.""" + cost = None + if int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + else: + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + op_type = dist_op.serial_op.type + cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type], + ctx, processes, desc_mapping, + cluster) + res_cost = [cost_mapping] + + return res_cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + res = [] + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + processes = process_mesh.processes + backward_op = dist_op.serial_op + op_type = backward_op.type + cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type], + ctx, processes, desc_mapping, + cluster) + res.append(cost_mapping) + + main_block = backward_op.block + vars = main_block.vars + need_gradient_allreduce = False + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and not is_parameter_related( + varname, main_block): + var_dim_mapping = dist_attr.get_input_dims_mapping(varname) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + need_gradient_allreduce = True + break + + if need_gradient_allreduce: + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and is_parameter_related( + varname, main_block): + var_dim_mapping = dist_attr.get_input_dims_mapping( + varname) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [varname + "@GRAD"] + build_dp_costs(res, dist_op, ctx, var_names, attrs, + parallel_axis, cluster) + return res + def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc if not is_elementwise_op(op_desc.type()): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py index 27e8983707..d39a775d16 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py @@ -27,7 +27,12 @@ from paddle.fluid import core, unique_name from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.distributed.fleet.meta_optimizers.common import OpRole from .dist_default import DistributedDefaultImpl0 +from ..cost import FillConstantBatchSizeLikeOpCost +from ..cost import build_comp_desc_from_dist_op, build_dp_costs +from ..cost import build_comp_costs_from_descs +from ..cost import AllreduceSumOpCost class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer): @@ -47,6 +52,29 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True + def calc_cost(self, op_role, dist_op, ctx, cluster): + cost = None + if int(op_role) == int(OpRole.Backward): + raise ValueError( + "The fill_constant_batch_size_like has no grad op.") + else: + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, + dist_context=ctx) + processes = dist_op.dist_attr.process_mesh.processes + op_type = dist_op.serial_op.type + cost_mapping = build_comp_costs_from_descs( + FillConstantBatchSizeLikeOpCost, ctx, processes, desc_mapping, + cluster) + + res_cost = [cost_mapping] + return res_cost + def is_input_compatible(self, dist_op): return True diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 6c51ce1fff..85b13b38a4 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -55,4 +55,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS}) py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS}) py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS}) + py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS}) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_comp_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_comp_cost.py index 8472354826..0a3a5993ff 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_comp_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_comp_cost.py @@ -35,7 +35,6 @@ from paddle.distributed.auto_parallel.cost.comp_op_cost import EmbeddingOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import EmbeddingGradOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantBatchSizeLikeOpCost -from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantBatchSizeLikeGradOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import GatherOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import GeluOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import GeluGradOpCost @@ -184,11 +183,6 @@ class TestCompOpCost(unittest.TestCase): self.assertTrue(op_cost.time >= 0) self.assertTrue(op_cost.memory >= 0) - op_cost = FillConstantBatchSizeLikeGradOpCost(cluster=cluster) - self.assertTrue(op_cost.flops >= 0) - self.assertTrue(op_cost.time >= 0) - self.assertTrue(op_cost.memory >= 0) - op_cost = GatherOpCost(cluster=cluster) self.assertTrue(op_cost.flops >= 0) self.assertTrue(op_cost.time >= 0) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py new file mode 100644 index 0000000000..0956c5bae6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py @@ -0,0 +1,114 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import copy + +import paddle +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container, is_elementwise_op + +from paddle.fluid import program_guard +from paddle.fluid.backward import append_backward +from paddle.fluid.backward import append_backward + +paddle.enable_static() + + +def parallelizer(program_func, rank): + from paddle.distributed.auto_parallel.completion import Completer + from paddle.distributed.auto_parallel.partitioner import Partitioner + from paddle.distributed.auto_parallel.dist_context import DistributedContext + + main_program, startup_program, loss = program_func() + + # complete forward + dist_context = DistributedContext() + completer = Completer(dist_context) + completer.complete_forward_annotation(main_program) + dist_context.block_state.parse_forward_blocks(main_program) + + # generate backward and complete backward + with paddle.static.program_guard(main_program, startup_program): + params_grads = append_backward( + loss, None, None, None, distop_context=dist_context.dist_op_context) + completer.complete_backward_annotation(main_program) + dist_context.block_state.parse_backward_blocks(main_program) + + optimizer = paddle.optimizer.SGD(learning_rate=0.001) + # generate opt and complete opt + with program_guard(main_program, startup_program): + optimize_ops = copy.deepcopy(optimizer).apply_gradients(params_grads) + + completer.complete_update_annotation(main_program) + + return main_program, dist_context + + +class TestDistOpCost(unittest.TestCase): + + def test_dist_fill_constatnt_batch_size_like_op_cost(self): + + def make_program(): + main_program = paddle.static.Program() + start_program = paddle.static.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4, 8], dtype='float32') + x.stop_gradient = True + label = paddle.static.data(name="label", + shape=[4, 1], + dtype='float32') + label.stop_gradient = True + auto.shard_tensor(x, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1] + }) + tmp = paddle.fluid.layers.fill_constant_batch_size_like( + input=x, shape=[2, 8], value=1, dtype='float32') + weight_attr = paddle.ParamAttr() + linear = paddle.nn.Linear(8, 8, weight_attr=weight_attr) + linear_out = linear(x) + gelu_out = paddle.nn.functional.gelu(linear_out) + # default op with dp + tmp = paddle.static.nn.layer_norm(gelu_out) + error_cost = paddle.nn.functional.square_error_cost(tmp, label) + loss = paddle.mean(error_cost) + return main_program, start_program, loss + + main_program, dist_context = parallelizer(make_program, 0) + ops = main_program.global_block().ops + cluster = Cluster() + cluster.gen_default_config_cluster(device_count=2) + for idx, op in enumerate(ops): + if op.type != "matmul_v2" and op.type != "matmul_v2_grad" and op.type != "sgd": + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + processes = op_dist_attr.process_mesh.processes + if is_elementwise_op(op.type): + container = get_distributed_operator_impl_container( + "elementwise") + else: + container = get_distributed_operator_impl_container( + op_dist_attr.impl_type) + + dist_impl = container.impls[op_dist_attr.impl_idx] + dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op, + dist_context, cluster) + self.assertTrue(dist_op_cost) + + +if __name__ == "__main__": + unittest.main() -- GitLab