From 6ac08db5713c831e5bcf70fcecc5996d6a552f91 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Tue, 10 May 2022 14:15:00 +0800 Subject: [PATCH] update base of cost model (#42601) --- .../auto_parallel/cost/__init__.py | 2 +- .../auto_parallel/cost/base_cost.py | 305 ++++++++++++++---- .../auto_parallel/cost/comm_op_cost.py | 6 +- .../auto_parallel/cost/comp_op_cost.py | 6 +- .../auto_parallel/test_new_cost_model.py | 9 +- 5 files changed, 262 insertions(+), 66 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/cost/__init__.py b/python/paddle/distributed/auto_parallel/cost/__init__.py index 7bc8a81b79..9ea58d6979 100644 --- a/python/paddle/distributed/auto_parallel/cost/__init__.py +++ b/python/paddle/distributed/auto_parallel/cost/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License -from .base_cost import OP_COST_FACTORY +from .base_cost import _g_op_cost_factory from .base_cost import Cost from .comm_op_cost import AllreduceSumCost from .comp_op_cost import MatmulV2OpCost diff --git a/python/paddle/distributed/auto_parallel/cost/base_cost.py b/python/paddle/distributed/auto_parallel/cost/base_cost.py index c4ebd83612..cb16d522bc 100644 --- a/python/paddle/distributed/auto_parallel/cost/base_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/base_cost.py @@ -19,7 +19,7 @@ COMM_OP_TYPE = [ "send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum" ] NON_COMP_TYPE = ["while"] + COMM_OP_TYPE -OP_COST_FACTORY = {} +_g_op_cost_factory = {} def _parse_op_to_desc(op, dist_context=None): @@ -126,66 +126,136 @@ class CommContext: _instance = None _has_instance = False - def __init__(self, cluster): - if CommContext._has_instance: - return - self.cluster = cluster - self._alpha_base_ring = 8.4 - self._alpha_base_tree = 0 - self._alpha_inter = None - self._alpha_intra - self._beta = {} - def __new__(cls, *args, **kwargs): if cls._instance is None: - cls._instance = super().__new__(cls, *args, **kwargs) + cls._instance = super().__new__(cls) _has_instance = True return cls._instance - @property - def alpha_inter(self): - if self._alpha_inter is None: - if cluster.alpha.inter == "NVL": - self._alpha_inter = 3.4 - elif cluster.alpha.inter == "PHB": - self._alpha_inter = 5.7 - return self._alpha_inter - - @property - def alpha_intra(self): - if self._alpha_intra is None: - if cluster.alpha.intra == "NVL": - self._alpha_intra = 28 - elif cluster.alpha.intra == "PHB": - self._alpha_intra = 28 - return self._alpha_intra - - @property - def alpha_base_ring(self): - return self._alpha_base_ring - - @property - def alpha_base_tree(self): - return self._alpha_base_tree - - def get_beta(self, ranks): + def __init__(self, cluster): + if CommContext._has_instance: + return + self.beta = {} + self.hops = {} + self.cluster = cluster + # if cluster has no info about those vars, it will be set by default + self.base_ring = None + self.base_tree = None + # self.base_inter_ring = None + # self.base_inter_tree = None + self.intra_ring = None + self.intra_tree = None + self.inter_ring = None + self.inter_tree = None + self.switch = None + self._post_init() + + def _post_init(self): + alpha_latency = self.cluster.alpha_latency + if alpha_latency is None: + # set default + self.base_ring = 8.4 + self.base_tree = 0. + # NVL in default + self.intra_ring = 3.4 + self.intra_tree = 28 + # NET in default + self.inter_ring = 9.6 + self.inter_tree = 28 + self.switch = 10.0 + else: + base_ring = alpha_latency.base_ring + self.base_ring = base_ring if base_ring is not None else 8.4 + + base_tree = alpha_latency.base_tree + self.base_tree = base_tree if base_tree is not None else 0. + + intra_ring = alpha_latency.intra_ring + if intra_ring == LinkType.NVL: + self.intra_ring = 3.4 + elif intra_ring == LinkType.PHB: + self.intra_ring = 5.7 + elif intra_ring is not None: + self.intra_ring = intra_ring + else: + # NVL Default + self.intra_ring = 3.4 + + intra_tree = alpha_latency.intra_tree + if intra_tree == LinkType.NVL: + self.intra_tree = 28 + elif intra_tree == LinkType.PHB: + self.intra_tree = 28 + elif intra_tree is not None: + self.intra_tree = intra_tree + else: + # NVL Default + self.intra_tree = 28 + + inter_ring = alpha_latency.inter_ring + if inter_ring == LinkType.NET: + self.inter_ring = 9.6 + elif inter_ring is not None: + self.inter_ring = inter_ring + else: + # NET Default + self.inter_ring = 9.6 + + inter_tree = alpha_latency.inter_tree + if inter_tree == LinkType.NET: + self.inter_tree = 28 + elif inter_tree is not None: + self.inter_tree = inter_tree + else: + # NET Default + self.inter_tree = 28 + + switch = alpha_latency.switch + self.switch = switch if switch is not None else 10 + + assert self.base_ring is not None + assert self.base_tree is not None + assert self.intra_ring is not None + assert self.intra_tree is not None + assert self.inter_ring is not None + assert self.inter_tree is not None + assert self.switch is not None + + def get_max_beta(self, ranks): + # NOTE: Get beta by ring, even in the case of tree such as tree broadcast + ranks = self.cluster.convert_rank_to_device_id(ranks) key = ','.join(map(str, sorted(ranks))) max_beta = None - if key in self._beta.keys: - max_beta = self._beta[key] + if key in self.beta: + max_beta = self.beta[key] else: for i in range(len(ranks)): for j in range(i + 1, len(ranks)): - if min_beta == None: - min_beta = cluster.get_beta(ranks[i], ranks[j]) + forward_order_beta = self.cluster.get_beta(ranks[i], + ranks[j]) + backward_order_beta = self.cluster.get_beta(ranks[j], + ranks[i]) + beta = forward_order_beta if forward_order_beta > backward_order_beta else backward_order_beta + if max_beta == None: + max_beta = beta else: - beta = cluster.get_beta(ranks[i], ranks[j]) if beta > max_beta: max_beta = beta - self._beta[key] = max_beta + self.beta[key] = max_beta return max_beta + def get_hops(self, ranks): + key = ','.join(map(str, sorted(ranks))) + hops = 0 + for i in range(len(ranks)): + for j in range(i + 1, len(ranks)): + hop = self.cluster.get_hop(ranks[i], ranks[j]) + hops += hop + self.hops[key] = hops + + return hops + class Cost: def __init__(self, time=0, memory=0, flops=0): @@ -198,11 +268,13 @@ class Cost: def _check_memory(self, val): assert isinstance( - val, int) and val >= 0, "Memory must be int and greater than 0." + val, + int) and val >= 0, "Memory must be int and greater than equal to 0." def _check_flops(self, val): assert isinstance( - val, int) and val >= 0, "FLOPs must be int and greater than 0." + val, + int) and val >= 0, "FLOPs must be int and greater than equal to 0." @property def time(self): @@ -254,7 +326,7 @@ class OpCost: op_desc is not None) self._op = op self._op_desc = op_desc - self._cost = self.calc_cost() + self._cost = None @property def op(self): @@ -264,6 +336,18 @@ class OpCost: def op_desc(self): return self._op_desc + @property + def time(self): + return self.cost.time + + @property + def memory(self): + return self.cost.memory + + @property + def flops(self): + return self.cost.flops + @property def cost(self): return self._cost @@ -284,6 +368,40 @@ class OpCost: cost = Cost(time, memory, flops) return cost + def __add__(self, rhs): + assert isinstance(rhs, (OpCost, Cost)) + time = 0 + memory = 0 + flops = 0 + if isinstance(rhs, OpCost): + time = self.cost.time + rhs.cost.time + memory = self.cost.memory + rhs.cost.memory + flops = self.cost.flops + rhs.cost.flops + assert (time >= 0 and memory >= 0 and flops >= 0) + elif isinstance(rhs, Cost): + time = self.time + rhs.time + memory = self.memory + rhs.memory + flops = self.flops + rhs.flops + assert (time >= 0 and memory >= 0 and flops >= 0) + return Cost(time, memory, flops) + + def __sub__(self, rhs): + assert isinstance(rhs, (OpCost, Cost)) + time = 0 + memory = 0 + flops = 0 + if isinstance(rhs, OpCost): + time = self.cost.time - rhs.cost.time + memory = self.cost.memory - rhs.cost.memory + flops = self.cost.flops - rhs.cost.flops + assert (time >= 0 and memory >= 0 and flops >= 0) + elif isinstance(rhs, Cost): + time = self.time - rhs.time + memory = self.memory - rhs.memory + flops = self.flops - rhs.flops + assert (time >= 0 and memory >= 0 and flops >= 0) + return Cost(time, memory, flops) + class CommOpCost(OpCost): OP_TYPE = "COMM" @@ -292,11 +410,83 @@ class CommOpCost(OpCost): super(CommOpCost, self).__init__(op=op, op_desc=op_desc) self._check_comm_op_type() self._comm_context = comm_context + self._group_ranks = None + self._comm_count = None + self._hops = None + self._rank_count = len(self.group_ranks) + self._machine_count = None + self._cost = self.calc_cost() @property def comm_context(self): return self._comm_context + @property + def comm_count(self): + if self._comm_count is None: + dtype = None + shape = None + if self.op is not None: + vars = self.op.block.vars + # NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overrided + var_name = self.op.input("X")[0] + var = vars[var_name] + dtype = var.dtype + shape = var.shape + elif self.op_desc is not None: + dtype = self.op_desc["inputs"]["X"][0][0] + shape = self.op_desc["inputs"]["X"][0][1] + + factor = None + if dtype == paddle.float32 or dtype == paddle.int32: + factor = 4 + elif dtype == paddle.int64: + factor = 8 + elif dtype == paddle.uint8: + factor = 1 + elif dtype == paddle.float16: + factor = 2 + else: + raise TypeError("This dtype {} is not supported now".format( + dtype)) + comm_count = reduce(lambda x, y: x * y, shape) * factor + self._comm_count = comm_count + + return self._comm_count + + @property + def rank_count(self): + return self._rank_count + + @property + def machine_count(self): + if self._machine_count is None: + cluster = self._comm_context.cluster + self._machine_count = cluster.get_involved_machine_count( + self.group_ranks) + return self._machine_count + + @property + def hops(self): + if self._hops is None: + self._hops = self.comm_context.get_hops(self.group_ranks) + return self._hops + + @property + def group_ranks(self): + if self._group_ranks is None: + if self.op_desc is not None: + self._group_ranks = self.op_desc["group_ranks"] + elif self.op is not None: + ring_id = op.attrs("ring_id") + process_group = get_process_group(ring_id) + if process_group is None: + raise ValueError( + "There not exists process group whose ring_id is {}.". + format(ring_id)) + self._group_ranks = process_group.ranks + return self._group_ranks + @classmethod def _check_comm_op_type(cls): if cls.OP_TYPE != "COMM": @@ -311,6 +501,7 @@ class CompOpCost(OpCost): def __init__(self, op=None, op_desc=None, cluster=None): super(CompOpCost, self).__init__(op=op, op_desc=op_desc) self._check_comp_op_type() + self._cost = self.calc_cost() self.cluster = cluster @classmethod @@ -325,18 +516,22 @@ def register_op_cost(cls): op_type = cls.OP_TYPE def register(op_type): - OP_COST_FACTORY[op_type] = cls + global _g_op_cost_factory + _g_op_cost_factory[op_type] = cls - return register(op_type) + register(op_type) + return cls -def calc_time_from_model(op=None, desc=None, cluster=None, comm_context=None): +def calc_time_by_modeling(op=None, desc=None, cluster=None): op_type = op.type if op is not None else desc["op"] if op_type in COMM_OP_TYPE: - op_cost = OP_COST_FACTORY[op_type](op=op, - op_desc=desc, - comm_context=comm_context) + op_cost = _g_op_cost_factory[op_type](op=op, + op_desc=desc, + comm_context=CommContext(cluster)) elif op_type not in NON_COMP_TYPE: - op_cost = OP_COST_FACTORY[op_type](op=op, op_desc=desc, cluster=cluster) + op_cost = _g_op_cost_factory[op_type](op=op, + op_desc=desc, + cluster=cluster) time = op_cost.calc_time() return time diff --git a/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py index 359f6b6e78..235741ba12 100644 --- a/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License -from .base_cost import register_op_cost, CommOpCost, OP_COST_FACTORY +from .base_cost import register_op_cost, CommOpCost @register_op_cost @@ -20,7 +20,7 @@ class AllreduceSumCost(CommOpCost): OP_TYPE = "c_allreduce_sum" def __init__(self, op=None, op_desc=None, comm_context=None): - super(OP_COST_FACTORY["c_allreduce_sum"], self).__init__( + super(AllreduceSumCost, self).__init__( op=op, op_desc=op_desc, comm_context=comm_context) def calc_time(self): diff --git a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py index c4d88cb25d..067ad48028 100644 --- a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License -from .base_cost import Cost, register_op_cost, CompOpCost, OP_COST_FACTORY +from .base_cost import Cost, register_op_cost, CompOpCost @register_op_cost @@ -20,7 +20,7 @@ class MatmulV2OpCost(CompOpCost): OP_TYPE = "matmul_v2" def __init__(self, op=None, op_desc=None, cluster=None): - super(OP_COST_FACTORY["matmul_v2"], self).__init__( + super(MatmulV2OpCost, self).__init__( op=op, op_desc=op_desc, cluster=cluster) # For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py index 0cd3041ea4..6d6fbfe78e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py @@ -18,7 +18,7 @@ import paddle import paddle.distributed.auto_parallel.cost as cost_model from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str -from paddle.distributed.auto_parallel.cost.base_cost import calc_time_from_model +from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling paddle.enable_static() @@ -45,13 +45,13 @@ class TestCost(unittest.TestCase): if op.type == "matmul_v2": matmul_v2_op = op break - matmul_v2_cost = cost_model.OP_COST_FACTORY["matmul_v2"]( + matmul_v2_cost = cost_model._g_op_cost_factory["matmul_v2"]( op=matmul_v2_op) desc = parse_to_desc(op=matmul_v2_op) desc_str = parse_desc_to_str(desc) self.assertIsNotNone(desc_str) self.assertTrue(check_cost(matmul_v2_cost.cost)) - time = calc_time_from_model(op=matmul_v2_op) + time = calc_time_by_modeling(op=matmul_v2_op) self.assertEqual(time, matmul_v2_cost.cost.time) tensor_cost = cost_model.TensorCost(tensor=x) # check memory @@ -61,7 +61,8 @@ class TestCost(unittest.TestCase): desc = {} desc["op"] = "c_allreduce_sum" desc["inputs"] = {"X": [([100, 200], paddle.float32)]} - allreduce_cost = cost_model.OP_COST_FACTORY["c_allreduce_sum"]( + desc["group_ranks"] = [0, 1] + allreduce_cost = cost_model._g_op_cost_factory["c_allreduce_sum"]( op_desc=desc) self.assertTrue(check_cost(allreduce_cost.cost)) -- GitLab