From c1c9368ff3e748ac6ebef6c4f4824e2e0abd35a3 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Thu, 24 Mar 2022 10:39:11 +0800 Subject: [PATCH] [Auto Parallel] Update cost model (#40457) * refactor cost model --- .../auto_parallel/cost/__init__.py | 20 + .../auto_parallel/cost/base_cost.py | 342 ++++++++++++++++++ .../auto_parallel/cost/comm_op_cost.py | 28 ++ .../auto_parallel/cost/comp_op_cost.py | 33 ++ .../auto_parallel/cost/estimate_cost.py | 69 ++++ .../auto_parallel/cost/tensor_cost.py | 110 ++++++ .../unittests/auto_parallel/CMakeLists.txt | 1 + .../auto_parallel/test_new_cost_model.py | 75 ++++ python/setup.py.in | 1 + 9 files changed, 679 insertions(+) create mode 100644 python/paddle/distributed/auto_parallel/cost/__init__.py create mode 100644 python/paddle/distributed/auto_parallel/cost/base_cost.py create mode 100644 python/paddle/distributed/auto_parallel/cost/comm_op_cost.py create mode 100644 python/paddle/distributed/auto_parallel/cost/comp_op_cost.py create mode 100644 python/paddle/distributed/auto_parallel/cost/estimate_cost.py create mode 100644 python/paddle/distributed/auto_parallel/cost/tensor_cost.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py diff --git a/python/paddle/distributed/auto_parallel/cost/__init__.py b/python/paddle/distributed/auto_parallel/cost/__init__.py new file mode 100644 index 00000000000..7bc8a81b79f --- /dev/null +++ b/python/paddle/distributed/auto_parallel/cost/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .base_cost import OP_COST_FACTORY +from .base_cost import Cost +from .comm_op_cost import AllreduceSumCost +from .comp_op_cost import MatmulV2OpCost +from .tensor_cost import TensorCost +from .estimate_cost import CostEstimator diff --git a/python/paddle/distributed/auto_parallel/cost/base_cost.py b/python/paddle/distributed/auto_parallel/cost/base_cost.py new file mode 100644 index 00000000000..c4ebd836129 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/cost/base_cost.py @@ -0,0 +1,342 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from collections import OrderedDict +import paddle + +COMM_OP_TYPE = [ + "send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum" +] +NON_COMP_TYPE = ["while"] + COMM_OP_TYPE +OP_COST_FACTORY = {} + + +def _parse_op_to_desc(op, dist_context=None): + desc = {} + desc["op"] = op.type + vars = op.block.vars + input_desc = OrderedDict() + for input_name in op.input_names: + var_name_list = op.input(input_name) + var_desc = [] + for var_name in var_name_list: + var = vars[var_name] + shape = None + if dist_context is not None: + dist_tensor = dist_context.get_dist_tensor_for_program(var) + shape = dist_tensor.local_sizes() + else: + shape = var.shape + assert shape is not None + var_desc.append((var.dtype, shape)) + input_desc[input_name] = var_desc + desc["inputs"] = input_desc + + output_desc = OrderedDict() + for out_name in op.output_names: + var_name_list = op.output(out_name) + var_desc = [] + for var_name in var_name_list: + var = vars[var_name] + shape = None + if dist_context is not None: + dist_tensor = dist_context.get_dist_tensor_for_program(var) + shape = dist_tensor.local_sizes() + else: + shape = var.shape + assert shape is not None + var_desc.append((var.dtype, shape)) + output_desc[out_name] = var_desc + desc["outputs"] = output_desc + + attr_desc = op.all_attrs + desc["attrs"] = attr_desc + + return desc + + +def parse_to_desc(op=None, dist_op=None, dist_context=None): + desc = None + if op is None and dist_op is not None and dist_context is not None: + desc = _parse_op_to_desc( + op=dist_op.serial_op, dist_context=dist_context) + elif op is not None and dist_op is None and dist_context is None: + desc = _parse_op_to_desc(op) + + return desc + + +def parse_desc_to_str(desc): + def _parse_dtype(dtype): + dtype_str = "" + if dtype == paddle.float32: + dtype_str = "float32" + elif dtype == paddle.float16: + dtype_str = "float16" + elif dtype == paddle.int32: + dtype_str = "int32" + elif dtype == paddle.int64: + dtype_str = "int64" + elif dtype == paddle.unit8: + dtype_str = "unit8" + else: + raise TypeError("Unsupported dtype {}".format(dtype)) + return dtype_str + + assert isinstance(desc, dict) + desc_str_list = [] + desc_str = None + dtype_str_list = [] + dims_list = [] + shape_list = [] + + desc_str_list.append(desc["op"]) + inputs = desc["inputs"] + for key, item in inputs.items(): + for dtype, shape in item: + dtype_str_list.append(_parse_dtype(dtype)) + shape_list += list(shape) + dims = len(shape) + dims_list.append(dims) + + dtype_str = "*".join(dtype_str_list) + dims_list = [str(item) for item in dims_list] + dims_str = "*".join(dims_list) + + shape_list = [str(item) for item in shape_list] + shape_str = "[" + ",".join(shape_list) + "]" + desc_str_list += [dtype_str, dims_str, shape_str] + desc_str = "_".join(desc_str_list) + + return desc_str + + +class CommContext: + _instance = None + _has_instance = False + + def __init__(self, cluster): + if CommContext._has_instance: + return + self.cluster = cluster + self._alpha_base_ring = 8.4 + self._alpha_base_tree = 0 + self._alpha_inter = None + self._alpha_intra + self._beta = {} + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls, *args, **kwargs) + _has_instance = True + return cls._instance + + @property + def alpha_inter(self): + if self._alpha_inter is None: + if cluster.alpha.inter == "NVL": + self._alpha_inter = 3.4 + elif cluster.alpha.inter == "PHB": + self._alpha_inter = 5.7 + return self._alpha_inter + + @property + def alpha_intra(self): + if self._alpha_intra is None: + if cluster.alpha.intra == "NVL": + self._alpha_intra = 28 + elif cluster.alpha.intra == "PHB": + self._alpha_intra = 28 + return self._alpha_intra + + @property + def alpha_base_ring(self): + return self._alpha_base_ring + + @property + def alpha_base_tree(self): + return self._alpha_base_tree + + def get_beta(self, ranks): + key = ','.join(map(str, sorted(ranks))) + max_beta = None + if key in self._beta.keys: + max_beta = self._beta[key] + else: + for i in range(len(ranks)): + for j in range(i + 1, len(ranks)): + if min_beta == None: + min_beta = cluster.get_beta(ranks[i], ranks[j]) + else: + beta = cluster.get_beta(ranks[i], ranks[j]) + if beta > max_beta: + max_beta = beta + self._beta[key] = max_beta + + return max_beta + + +class Cost: + def __init__(self, time=0, memory=0, flops=0): + self.time = time + self.memory = memory + self.flops = flops + + def _check_time(self, val): + assert val >= 0, "Time must be greater than or equal to 0." + + def _check_memory(self, val): + assert isinstance( + val, int) and val >= 0, "Memory must be int and greater than 0." + + def _check_flops(self, val): + assert isinstance( + val, int) and val >= 0, "FLOPs must be int and greater than 0." + + @property + def time(self): + return self._time + + @time.setter + def time(self, val): + self._check_time(val) + self._time = val + + @property + def memory(self): + return self._memory + + @memory.setter + def memory(self, val): + self._check_memory(val) + self._memory = val + + @property + def flops(self): + return self._flops + + @flops.setter + def flops(self, val): + self._check_flops(val) + self._flops = val + + def __add__(self, rhs): + assert isinstance(rhs, Cost) + time = self.time + rhs.time + memory = self.memory + rhs.memory + flops = self.flops + rhs.flops + assert (time >= 0 and memory >= 0 and flops >= 0) + return Cost(time, memory, flops) + + def __sub__(self, rhs): + assert isinstance(rhs, Cost) + time = self.time - rhs.time + memory = self.memory - rhs.memory + flops = self.flops - rhs.flops + assert (time >= 0 and memory >= 0 and flops >= 0) + return Cost(time, memory, flops) + + +class OpCost: + def __init__(self, op=None, op_desc=None): + assert (op is not None and op_desc is None) or (op is None and + op_desc is not None) + self._op = op + self._op_desc = op_desc + self._cost = self.calc_cost() + + @property + def op(self): + return self._op + + @property + def op_desc(self): + return self._op_desc + + @property + def cost(self): + return self._cost + + def calc_time(self): + return 0 + + def calc_memory(self): + return 0 + + def calc_flops(self): + return 0 + + def calc_cost(self): + time = self.calc_time() + memory = self.calc_memory() + flops = self.calc_flops() + cost = Cost(time, memory, flops) + return cost + + +class CommOpCost(OpCost): + OP_TYPE = "COMM" + + def __init__(self, op=None, op_desc=None, comm_context=None): + super(CommOpCost, self).__init__(op=op, op_desc=op_desc) + self._check_comm_op_type() + self._comm_context = comm_context + + @property + def comm_context(self): + return self._comm_context + + @classmethod + def _check_comm_op_type(cls): + if cls.OP_TYPE != "COMM": + if cls.OP_TYPE not in COMM_OP_TYPE: + raise TypeError("Please Check op type in {}, but got {}.". + format(COMM_OP_TYPE, cls.OP_TYPE)) + + +class CompOpCost(OpCost): + OP_TYPE = "COMP" + + def __init__(self, op=None, op_desc=None, cluster=None): + super(CompOpCost, self).__init__(op=op, op_desc=op_desc) + self._check_comp_op_type() + self.cluster = cluster + + @classmethod + def _check_comp_op_type(cls): + if cls.OP_TYPE != "COMP": + if cls.OP_TYPE in NON_COMP_TYPE: + raise TypeError("Please Check op type not in {}, but got {}.". + format(NON_COMP_TYPE, cls.OP_TYPE)) + + +def register_op_cost(cls): + op_type = cls.OP_TYPE + + def register(op_type): + OP_COST_FACTORY[op_type] = cls + + return register(op_type) + + +def calc_time_from_model(op=None, desc=None, cluster=None, comm_context=None): + op_type = op.type if op is not None else desc["op"] + if op_type in COMM_OP_TYPE: + op_cost = OP_COST_FACTORY[op_type](op=op, + op_desc=desc, + comm_context=comm_context) + elif op_type not in NON_COMP_TYPE: + op_cost = OP_COST_FACTORY[op_type](op=op, op_desc=desc, cluster=cluster) + time = op_cost.calc_time() + return time diff --git a/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py new file mode 100644 index 00000000000..359f6b6e786 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py @@ -0,0 +1,28 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .base_cost import register_op_cost, CommOpCost, OP_COST_FACTORY + + +@register_op_cost +class AllreduceSumCost(CommOpCost): + OP_TYPE = "c_allreduce_sum" + + def __init__(self, op=None, op_desc=None, comm_context=None): + super(OP_COST_FACTORY["c_allreduce_sum"], self).__init__( + op=op, op_desc=op_desc, comm_context=comm_context) + + def calc_time(self): + # NOTE: The actual formula will be filled in the future. + return 0 diff --git a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py new file mode 100644 index 00000000000..c4d88cb25dc --- /dev/null +++ b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .base_cost import Cost, register_op_cost, CompOpCost, OP_COST_FACTORY + + +@register_op_cost +class MatmulV2OpCost(CompOpCost): + OP_TYPE = "matmul_v2" + + def __init__(self, op=None, op_desc=None, cluster=None): + super(OP_COST_FACTORY["matmul_v2"], self).__init__( + op=op, op_desc=op_desc, cluster=cluster) + + # For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided + def calc_flops(self): + # NOTE: The actual formula will be filled in the future + return 0 + + def calc_time(self): + # NOTE: The actual formula will be filled in the future + return 0 diff --git a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py new file mode 100644 index 00000000000..7bd535af8be --- /dev/null +++ b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + + +class CostEstimator: + def __init__(self, + program, + cluster=None, + dist_context=None, + mode="modeling"): + self._program = program + self._cluster = cluster + self._dist_context = dist_context + self._check_mode(mode) + self._mode = mode + self._global_cost = None + self._local_cost = {} + + @property + def program(self): + return self._program + + @property + def dist_context(self): + return self._dist_context + + @property + def cluster(self): + return self._cluster + + @property + def mode(self): + return self._mode + + @property + def global_cost(self): + return self._global_cost + + @property + def local_cost(self): + return self._local_cost + + def get_op_cost(self): + return 0 + + def get_tensor_cost(self): + return 0 + + def get_global_cost(self): + return 0 + + def get_local_cost(self, rank=None): + return 0 + + def _check_mode(self, mode): + if mode not in ["modeling", "profiling"]: + raise ValueError( + "Just support modeling and profiling, but got {}".format(mode)) diff --git a/python/paddle/distributed/auto_parallel/cost/tensor_cost.py b/python/paddle/distributed/auto_parallel/cost/tensor_cost.py new file mode 100644 index 00000000000..2db1c06d596 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/cost/tensor_cost.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from functools import reduce + +import paddle +from paddle.fluid.framework import Variable +from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor + +from .base_cost import Cost + + +class TensorCost: + def __init__(self, tensor=None, dist_tensor=None, shape=None, dtype=None): + self._check_args(tensor, dist_tensor, shape, dtype) + self._tensor = tensor + self._dist_tensor = dist_tensor + self._shape = shape + self._dtype = dtype + self._cost = self.calc_cost() + + @property + def tensor(self): + return self._tensor + + @property + def dist_tensor(self): + return self._dist_tensor + + @property + def shape(self): + return self._shape + + @property + def dtype(self): + return self._dtype + + def _check_args(self, tensor, dist_tensor, shape, dtype): + if tensor is not None: + assert (shape is None and dist_tensor is None and dtype is None) + + if not isinstance(tensor, Variable): + raise TypeError( + "Please check tensor type is Variable, but got {}".format( + type(tensor))) + + elif dist_tensor is not None: + assert (tensor is None and shape is None) + if not isinstance(dist_tensor, DistributedTensor): + raise TypeError( + "Please check dist_tensor type is DistributedTensor, but got {}". + format(type(dist_tensor))) + + elif shape is not None: + assert (tensor is None and dist_tensor is None and + dtype is not None) + if not isinstance(shape, (list, set)): + raise TypeError( + "Please check shape type is list or set, but got {}".format( + type(shape))) + + elif dtype is not None: + assert (tensor is None and dist_tensor is None and + shape is not None) + + @property + def cost(self): + return self._cost + + def calc_cost(self): + dtype = None + shape = None + + if self.dist_tensor: + shape = self.dist_tensor.local_sizes() + dtype = self.dist_tensor.serial_tensor.dtype + elif self.tensor: + shape = self.tensor.shape + dtype = self.tensor.dtype + elif self.shape and self.dtype: + shape = self.shape + dtype = self.dtype + + total_count = reduce(lambda x, y: x * y, shape) + + if dtype == paddle.float32 or dtype == paddle.int32: + dtype_factor = 4 + elif node.dtype == paddle.int64: + dtype_factor = 8 + elif node.dtype == paddle.uint8: + dtype_factor = 1 + else: + dtype_factor = 2 + + memory = total_count * dtype_factor + assert memory >= 0 + cost = Cost(memory=memory) + + return cost diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index a730d21afa5..c16936db5a3 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -17,4 +17,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS}) py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS}) py_test_modules(test_trial MODULES test_trial ENVS ${dist_ENVS}) + py_test_modules(test_new_cost_model MODULES test_new_cost_model ENVS ${dist_ENVS}) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py new file mode 100644 index 00000000000..0cd3041ea4d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import paddle.distributed.auto_parallel.cost as cost_model +from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc +from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str +from paddle.distributed.auto_parallel.cost.base_cost import calc_time_from_model + +paddle.enable_static() + + +def check_cost(cost): + if cost.memory >= 0 and cost.flops >= 0 and cost.time >= 0: + return True + return False + + +class TestCost(unittest.TestCase): + def test_base_cost(self): + cost = cost_model.Cost(memory=100, flops=200, time=0.5) + self.assertTrue(check_cost(cost)) + + def test_comp_cost(self): + x = paddle.static.data(name="x", shape=[20, 20], dtype='float32') + y = paddle.static.data(name="y", shape=[20, 20], dtype='float32') + + z = paddle.matmul(x, y) + matmul_v2_op = None + ops = paddle.static.default_main_program().global_block().ops + for op in ops: + if op.type == "matmul_v2": + matmul_v2_op = op + break + matmul_v2_cost = cost_model.OP_COST_FACTORY["matmul_v2"]( + op=matmul_v2_op) + desc = parse_to_desc(op=matmul_v2_op) + desc_str = parse_desc_to_str(desc) + self.assertIsNotNone(desc_str) + self.assertTrue(check_cost(matmul_v2_cost.cost)) + time = calc_time_from_model(op=matmul_v2_op) + self.assertEqual(time, matmul_v2_cost.cost.time) + tensor_cost = cost_model.TensorCost(tensor=x) + # check memory + self.assertEqual(tensor_cost.cost.memory, 1600) + + def test_comm_cost(self): + desc = {} + desc["op"] = "c_allreduce_sum" + desc["inputs"] = {"X": [([100, 200], paddle.float32)]} + allreduce_cost = cost_model.OP_COST_FACTORY["c_allreduce_sum"]( + op_desc=desc) + self.assertTrue(check_cost(allreduce_cost.cost)) + + def test_cost_estimator(self): + train_program = paddle.static.Program() + cost_estimator = cost_model.CostEstimator(train_program) + self.assertIsNotNone(cost_estimator) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 2dbefb20bb6..7c1232c1d41 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -307,6 +307,7 @@ packages=['paddle', 'paddle.distributed.auto_parallel', 'paddle.distributed.auto_parallel.operators', 'paddle.distributed.auto_parallel.tuner', + 'paddle.distributed.auto_parallel.cost', 'paddle.distributed.passes', 'paddle.framework', 'paddle.jit', -- GitLab