未验证 提交 6ac08db5 编写于 作者: C caozhou 提交者: GitHub

update base of cost model (#42601)

上级 cc077693
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .base_cost import OP_COST_FACTORY from .base_cost import _g_op_cost_factory
from .base_cost import Cost from .base_cost import Cost
from .comm_op_cost import AllreduceSumCost from .comm_op_cost import AllreduceSumCost
from .comp_op_cost import MatmulV2OpCost from .comp_op_cost import MatmulV2OpCost
......
...@@ -19,7 +19,7 @@ COMM_OP_TYPE = [ ...@@ -19,7 +19,7 @@ COMM_OP_TYPE = [
"send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum" "send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum"
] ]
NON_COMP_TYPE = ["while"] + COMM_OP_TYPE NON_COMP_TYPE = ["while"] + COMM_OP_TYPE
OP_COST_FACTORY = {} _g_op_cost_factory = {}
def _parse_op_to_desc(op, dist_context=None): def _parse_op_to_desc(op, dist_context=None):
...@@ -126,66 +126,136 @@ class CommContext: ...@@ -126,66 +126,136 @@ class CommContext:
_instance = None _instance = None
_has_instance = False _has_instance = False
def __init__(self, cluster):
if CommContext._has_instance:
return
self.cluster = cluster
self._alpha_base_ring = 8.4
self._alpha_base_tree = 0
self._alpha_inter = None
self._alpha_intra
self._beta = {}
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs) cls._instance = super().__new__(cls)
_has_instance = True _has_instance = True
return cls._instance return cls._instance
@property def __init__(self, cluster):
def alpha_inter(self): if CommContext._has_instance:
if self._alpha_inter is None: return
if cluster.alpha.inter == "NVL": self.beta = {}
self._alpha_inter = 3.4 self.hops = {}
elif cluster.alpha.inter == "PHB": self.cluster = cluster
self._alpha_inter = 5.7 # if cluster has no info about those vars, it will be set by default
return self._alpha_inter self.base_ring = None
self.base_tree = None
@property # self.base_inter_ring = None
def alpha_intra(self): # self.base_inter_tree = None
if self._alpha_intra is None: self.intra_ring = None
if cluster.alpha.intra == "NVL": self.intra_tree = None
self._alpha_intra = 28 self.inter_ring = None
elif cluster.alpha.intra == "PHB": self.inter_tree = None
self._alpha_intra = 28 self.switch = None
return self._alpha_intra self._post_init()
@property def _post_init(self):
def alpha_base_ring(self): alpha_latency = self.cluster.alpha_latency
return self._alpha_base_ring if alpha_latency is None:
# set default
@property self.base_ring = 8.4
def alpha_base_tree(self): self.base_tree = 0.
return self._alpha_base_tree # NVL in default
self.intra_ring = 3.4
def get_beta(self, ranks): self.intra_tree = 28
# NET in default
self.inter_ring = 9.6
self.inter_tree = 28
self.switch = 10.0
else:
base_ring = alpha_latency.base_ring
self.base_ring = base_ring if base_ring is not None else 8.4
base_tree = alpha_latency.base_tree
self.base_tree = base_tree if base_tree is not None else 0.
intra_ring = alpha_latency.intra_ring
if intra_ring == LinkType.NVL:
self.intra_ring = 3.4
elif intra_ring == LinkType.PHB:
self.intra_ring = 5.7
elif intra_ring is not None:
self.intra_ring = intra_ring
else:
# NVL Default
self.intra_ring = 3.4
intra_tree = alpha_latency.intra_tree
if intra_tree == LinkType.NVL:
self.intra_tree = 28
elif intra_tree == LinkType.PHB:
self.intra_tree = 28
elif intra_tree is not None:
self.intra_tree = intra_tree
else:
# NVL Default
self.intra_tree = 28
inter_ring = alpha_latency.inter_ring
if inter_ring == LinkType.NET:
self.inter_ring = 9.6
elif inter_ring is not None:
self.inter_ring = inter_ring
else:
# NET Default
self.inter_ring = 9.6
inter_tree = alpha_latency.inter_tree
if inter_tree == LinkType.NET:
self.inter_tree = 28
elif inter_tree is not None:
self.inter_tree = inter_tree
else:
# NET Default
self.inter_tree = 28
switch = alpha_latency.switch
self.switch = switch if switch is not None else 10
assert self.base_ring is not None
assert self.base_tree is not None
assert self.intra_ring is not None
assert self.intra_tree is not None
assert self.inter_ring is not None
assert self.inter_tree is not None
assert self.switch is not None
def get_max_beta(self, ranks):
# NOTE: Get beta by ring, even in the case of tree such as tree broadcast
ranks = self.cluster.convert_rank_to_device_id(ranks)
key = ','.join(map(str, sorted(ranks))) key = ','.join(map(str, sorted(ranks)))
max_beta = None max_beta = None
if key in self._beta.keys: if key in self.beta:
max_beta = self._beta[key] max_beta = self.beta[key]
else: else:
for i in range(len(ranks)): for i in range(len(ranks)):
for j in range(i + 1, len(ranks)): for j in range(i + 1, len(ranks)):
if min_beta == None: forward_order_beta = self.cluster.get_beta(ranks[i],
min_beta = cluster.get_beta(ranks[i], ranks[j]) ranks[j])
backward_order_beta = self.cluster.get_beta(ranks[j],
ranks[i])
beta = forward_order_beta if forward_order_beta > backward_order_beta else backward_order_beta
if max_beta == None:
max_beta = beta
else: else:
beta = cluster.get_beta(ranks[i], ranks[j])
if beta > max_beta: if beta > max_beta:
max_beta = beta max_beta = beta
self._beta[key] = max_beta self.beta[key] = max_beta
return max_beta return max_beta
def get_hops(self, ranks):
key = ','.join(map(str, sorted(ranks)))
hops = 0
for i in range(len(ranks)):
for j in range(i + 1, len(ranks)):
hop = self.cluster.get_hop(ranks[i], ranks[j])
hops += hop
self.hops[key] = hops
return hops
class Cost: class Cost:
def __init__(self, time=0, memory=0, flops=0): def __init__(self, time=0, memory=0, flops=0):
...@@ -198,11 +268,13 @@ class Cost: ...@@ -198,11 +268,13 @@ class Cost:
def _check_memory(self, val): def _check_memory(self, val):
assert isinstance( assert isinstance(
val, int) and val >= 0, "Memory must be int and greater than 0." val,
int) and val >= 0, "Memory must be int and greater than equal to 0."
def _check_flops(self, val): def _check_flops(self, val):
assert isinstance( assert isinstance(
val, int) and val >= 0, "FLOPs must be int and greater than 0." val,
int) and val >= 0, "FLOPs must be int and greater than equal to 0."
@property @property
def time(self): def time(self):
...@@ -254,7 +326,7 @@ class OpCost: ...@@ -254,7 +326,7 @@ class OpCost:
op_desc is not None) op_desc is not None)
self._op = op self._op = op
self._op_desc = op_desc self._op_desc = op_desc
self._cost = self.calc_cost() self._cost = None
@property @property
def op(self): def op(self):
...@@ -264,6 +336,18 @@ class OpCost: ...@@ -264,6 +336,18 @@ class OpCost:
def op_desc(self): def op_desc(self):
return self._op_desc return self._op_desc
@property
def time(self):
return self.cost.time
@property
def memory(self):
return self.cost.memory
@property
def flops(self):
return self.cost.flops
@property @property
def cost(self): def cost(self):
return self._cost return self._cost
...@@ -284,6 +368,40 @@ class OpCost: ...@@ -284,6 +368,40 @@ class OpCost:
cost = Cost(time, memory, flops) cost = Cost(time, memory, flops)
return cost return cost
def __add__(self, rhs):
assert isinstance(rhs, (OpCost, Cost))
time = 0
memory = 0
flops = 0
if isinstance(rhs, OpCost):
time = self.cost.time + rhs.cost.time
memory = self.cost.memory + rhs.cost.memory
flops = self.cost.flops + rhs.cost.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
elif isinstance(rhs, Cost):
time = self.time + rhs.time
memory = self.memory + rhs.memory
flops = self.flops + rhs.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
return Cost(time, memory, flops)
def __sub__(self, rhs):
assert isinstance(rhs, (OpCost, Cost))
time = 0
memory = 0
flops = 0
if isinstance(rhs, OpCost):
time = self.cost.time - rhs.cost.time
memory = self.cost.memory - rhs.cost.memory
flops = self.cost.flops - rhs.cost.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
elif isinstance(rhs, Cost):
time = self.time - rhs.time
memory = self.memory - rhs.memory
flops = self.flops - rhs.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
return Cost(time, memory, flops)
class CommOpCost(OpCost): class CommOpCost(OpCost):
OP_TYPE = "COMM" OP_TYPE = "COMM"
...@@ -292,11 +410,83 @@ class CommOpCost(OpCost): ...@@ -292,11 +410,83 @@ class CommOpCost(OpCost):
super(CommOpCost, self).__init__(op=op, op_desc=op_desc) super(CommOpCost, self).__init__(op=op, op_desc=op_desc)
self._check_comm_op_type() self._check_comm_op_type()
self._comm_context = comm_context self._comm_context = comm_context
self._group_ranks = None
self._comm_count = None
self._hops = None
self._rank_count = len(self.group_ranks)
self._machine_count = None
self._cost = self.calc_cost()
@property @property
def comm_context(self): def comm_context(self):
return self._comm_context return self._comm_context
@property
def comm_count(self):
if self._comm_count is None:
dtype = None
shape = None
if self.op is not None:
vars = self.op.block.vars
# NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overrided
var_name = self.op.input("X")[0]
var = vars[var_name]
dtype = var.dtype
shape = var.shape
elif self.op_desc is not None:
dtype = self.op_desc["inputs"]["X"][0][0]
shape = self.op_desc["inputs"]["X"][0][1]
factor = None
if dtype == paddle.float32 or dtype == paddle.int32:
factor = 4
elif dtype == paddle.int64:
factor = 8
elif dtype == paddle.uint8:
factor = 1
elif dtype == paddle.float16:
factor = 2
else:
raise TypeError("This dtype {} is not supported now".format(
dtype))
comm_count = reduce(lambda x, y: x * y, shape) * factor
self._comm_count = comm_count
return self._comm_count
@property
def rank_count(self):
return self._rank_count
@property
def machine_count(self):
if self._machine_count is None:
cluster = self._comm_context.cluster
self._machine_count = cluster.get_involved_machine_count(
self.group_ranks)
return self._machine_count
@property
def hops(self):
if self._hops is None:
self._hops = self.comm_context.get_hops(self.group_ranks)
return self._hops
@property
def group_ranks(self):
if self._group_ranks is None:
if self.op_desc is not None:
self._group_ranks = self.op_desc["group_ranks"]
elif self.op is not None:
ring_id = op.attrs("ring_id")
process_group = get_process_group(ring_id)
if process_group is None:
raise ValueError(
"There not exists process group whose ring_id is {}.".
format(ring_id))
self._group_ranks = process_group.ranks
return self._group_ranks
@classmethod @classmethod
def _check_comm_op_type(cls): def _check_comm_op_type(cls):
if cls.OP_TYPE != "COMM": if cls.OP_TYPE != "COMM":
...@@ -311,6 +501,7 @@ class CompOpCost(OpCost): ...@@ -311,6 +501,7 @@ class CompOpCost(OpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(CompOpCost, self).__init__(op=op, op_desc=op_desc) super(CompOpCost, self).__init__(op=op, op_desc=op_desc)
self._check_comp_op_type() self._check_comp_op_type()
self._cost = self.calc_cost()
self.cluster = cluster self.cluster = cluster
@classmethod @classmethod
...@@ -325,18 +516,22 @@ def register_op_cost(cls): ...@@ -325,18 +516,22 @@ def register_op_cost(cls):
op_type = cls.OP_TYPE op_type = cls.OP_TYPE
def register(op_type): def register(op_type):
OP_COST_FACTORY[op_type] = cls global _g_op_cost_factory
_g_op_cost_factory[op_type] = cls
return register(op_type) register(op_type)
return cls
def calc_time_from_model(op=None, desc=None, cluster=None, comm_context=None): def calc_time_by_modeling(op=None, desc=None, cluster=None):
op_type = op.type if op is not None else desc["op"] op_type = op.type if op is not None else desc["op"]
if op_type in COMM_OP_TYPE: if op_type in COMM_OP_TYPE:
op_cost = OP_COST_FACTORY[op_type](op=op, op_cost = _g_op_cost_factory[op_type](op=op,
op_desc=desc, op_desc=desc,
comm_context=comm_context) comm_context=CommContext(cluster))
elif op_type not in NON_COMP_TYPE: elif op_type not in NON_COMP_TYPE:
op_cost = OP_COST_FACTORY[op_type](op=op, op_desc=desc, cluster=cluster) op_cost = _g_op_cost_factory[op_type](op=op,
op_desc=desc,
cluster=cluster)
time = op_cost.calc_time() time = op_cost.calc_time()
return time return time
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .base_cost import register_op_cost, CommOpCost, OP_COST_FACTORY from .base_cost import register_op_cost, CommOpCost
@register_op_cost @register_op_cost
...@@ -20,7 +20,7 @@ class AllreduceSumCost(CommOpCost): ...@@ -20,7 +20,7 @@ class AllreduceSumCost(CommOpCost):
OP_TYPE = "c_allreduce_sum" OP_TYPE = "c_allreduce_sum"
def __init__(self, op=None, op_desc=None, comm_context=None): def __init__(self, op=None, op_desc=None, comm_context=None):
super(OP_COST_FACTORY["c_allreduce_sum"], self).__init__( super(AllreduceSumCost, self).__init__(
op=op, op_desc=op_desc, comm_context=comm_context) op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self): def calc_time(self):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .base_cost import Cost, register_op_cost, CompOpCost, OP_COST_FACTORY from .base_cost import Cost, register_op_cost, CompOpCost
@register_op_cost @register_op_cost
...@@ -20,7 +20,7 @@ class MatmulV2OpCost(CompOpCost): ...@@ -20,7 +20,7 @@ class MatmulV2OpCost(CompOpCost):
OP_TYPE = "matmul_v2" OP_TYPE = "matmul_v2"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(OP_COST_FACTORY["matmul_v2"], self).__init__( super(MatmulV2OpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster) op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided # For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided
......
...@@ -18,7 +18,7 @@ import paddle ...@@ -18,7 +18,7 @@ import paddle
import paddle.distributed.auto_parallel.cost as cost_model import paddle.distributed.auto_parallel.cost as cost_model
from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc
from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str
from paddle.distributed.auto_parallel.cost.base_cost import calc_time_from_model from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling
paddle.enable_static() paddle.enable_static()
...@@ -45,13 +45,13 @@ class TestCost(unittest.TestCase): ...@@ -45,13 +45,13 @@ class TestCost(unittest.TestCase):
if op.type == "matmul_v2": if op.type == "matmul_v2":
matmul_v2_op = op matmul_v2_op = op
break break
matmul_v2_cost = cost_model.OP_COST_FACTORY["matmul_v2"]( matmul_v2_cost = cost_model._g_op_cost_factory["matmul_v2"](
op=matmul_v2_op) op=matmul_v2_op)
desc = parse_to_desc(op=matmul_v2_op) desc = parse_to_desc(op=matmul_v2_op)
desc_str = parse_desc_to_str(desc) desc_str = parse_desc_to_str(desc)
self.assertIsNotNone(desc_str) self.assertIsNotNone(desc_str)
self.assertTrue(check_cost(matmul_v2_cost.cost)) self.assertTrue(check_cost(matmul_v2_cost.cost))
time = calc_time_from_model(op=matmul_v2_op) time = calc_time_by_modeling(op=matmul_v2_op)
self.assertEqual(time, matmul_v2_cost.cost.time) self.assertEqual(time, matmul_v2_cost.cost.time)
tensor_cost = cost_model.TensorCost(tensor=x) tensor_cost = cost_model.TensorCost(tensor=x)
# check memory # check memory
...@@ -61,7 +61,8 @@ class TestCost(unittest.TestCase): ...@@ -61,7 +61,8 @@ class TestCost(unittest.TestCase):
desc = {} desc = {}
desc["op"] = "c_allreduce_sum" desc["op"] = "c_allreduce_sum"
desc["inputs"] = {"X": [([100, 200], paddle.float32)]} desc["inputs"] = {"X": [([100, 200], paddle.float32)]}
allreduce_cost = cost_model.OP_COST_FACTORY["c_allreduce_sum"]( desc["group_ranks"] = [0, 1]
allreduce_cost = cost_model._g_op_cost_factory["c_allreduce_sum"](
op_desc=desc) op_desc=desc)
self.assertTrue(check_cost(allreduce_cost.cost)) self.assertTrue(check_cost(allreduce_cost.cost))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册