未验证 提交 6ac08db5 编写于 作者: C caozhou 提交者: GitHub

update base of cost model (#42601)

上级 cc077693
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .base_cost import OP_COST_FACTORY
from .base_cost import _g_op_cost_factory
from .base_cost import Cost
from .comm_op_cost import AllreduceSumCost
from .comp_op_cost import MatmulV2OpCost
......
......@@ -19,7 +19,7 @@ COMM_OP_TYPE = [
"send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum"
]
NON_COMP_TYPE = ["while"] + COMM_OP_TYPE
OP_COST_FACTORY = {}
_g_op_cost_factory = {}
def _parse_op_to_desc(op, dist_context=None):
......@@ -126,66 +126,136 @@ class CommContext:
_instance = None
_has_instance = False
def __init__(self, cluster):
if CommContext._has_instance:
return
self.cluster = cluster
self._alpha_base_ring = 8.4
self._alpha_base_tree = 0
self._alpha_inter = None
self._alpha_intra
self._beta = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
cls._instance = super().__new__(cls)
_has_instance = True
return cls._instance
@property
def alpha_inter(self):
if self._alpha_inter is None:
if cluster.alpha.inter == "NVL":
self._alpha_inter = 3.4
elif cluster.alpha.inter == "PHB":
self._alpha_inter = 5.7
return self._alpha_inter
@property
def alpha_intra(self):
if self._alpha_intra is None:
if cluster.alpha.intra == "NVL":
self._alpha_intra = 28
elif cluster.alpha.intra == "PHB":
self._alpha_intra = 28
return self._alpha_intra
@property
def alpha_base_ring(self):
return self._alpha_base_ring
@property
def alpha_base_tree(self):
return self._alpha_base_tree
def get_beta(self, ranks):
def __init__(self, cluster):
if CommContext._has_instance:
return
self.beta = {}
self.hops = {}
self.cluster = cluster
# if cluster has no info about those vars, it will be set by default
self.base_ring = None
self.base_tree = None
# self.base_inter_ring = None
# self.base_inter_tree = None
self.intra_ring = None
self.intra_tree = None
self.inter_ring = None
self.inter_tree = None
self.switch = None
self._post_init()
def _post_init(self):
alpha_latency = self.cluster.alpha_latency
if alpha_latency is None:
# set default
self.base_ring = 8.4
self.base_tree = 0.
# NVL in default
self.intra_ring = 3.4
self.intra_tree = 28
# NET in default
self.inter_ring = 9.6
self.inter_tree = 28
self.switch = 10.0
else:
base_ring = alpha_latency.base_ring
self.base_ring = base_ring if base_ring is not None else 8.4
base_tree = alpha_latency.base_tree
self.base_tree = base_tree if base_tree is not None else 0.
intra_ring = alpha_latency.intra_ring
if intra_ring == LinkType.NVL:
self.intra_ring = 3.4
elif intra_ring == LinkType.PHB:
self.intra_ring = 5.7
elif intra_ring is not None:
self.intra_ring = intra_ring
else:
# NVL Default
self.intra_ring = 3.4
intra_tree = alpha_latency.intra_tree
if intra_tree == LinkType.NVL:
self.intra_tree = 28
elif intra_tree == LinkType.PHB:
self.intra_tree = 28
elif intra_tree is not None:
self.intra_tree = intra_tree
else:
# NVL Default
self.intra_tree = 28
inter_ring = alpha_latency.inter_ring
if inter_ring == LinkType.NET:
self.inter_ring = 9.6
elif inter_ring is not None:
self.inter_ring = inter_ring
else:
# NET Default
self.inter_ring = 9.6
inter_tree = alpha_latency.inter_tree
if inter_tree == LinkType.NET:
self.inter_tree = 28
elif inter_tree is not None:
self.inter_tree = inter_tree
else:
# NET Default
self.inter_tree = 28
switch = alpha_latency.switch
self.switch = switch if switch is not None else 10
assert self.base_ring is not None
assert self.base_tree is not None
assert self.intra_ring is not None
assert self.intra_tree is not None
assert self.inter_ring is not None
assert self.inter_tree is not None
assert self.switch is not None
def get_max_beta(self, ranks):
# NOTE: Get beta by ring, even in the case of tree such as tree broadcast
ranks = self.cluster.convert_rank_to_device_id(ranks)
key = ','.join(map(str, sorted(ranks)))
max_beta = None
if key in self._beta.keys:
max_beta = self._beta[key]
if key in self.beta:
max_beta = self.beta[key]
else:
for i in range(len(ranks)):
for j in range(i + 1, len(ranks)):
if min_beta == None:
min_beta = cluster.get_beta(ranks[i], ranks[j])
forward_order_beta = self.cluster.get_beta(ranks[i],
ranks[j])
backward_order_beta = self.cluster.get_beta(ranks[j],
ranks[i])
beta = forward_order_beta if forward_order_beta > backward_order_beta else backward_order_beta
if max_beta == None:
max_beta = beta
else:
beta = cluster.get_beta(ranks[i], ranks[j])
if beta > max_beta:
max_beta = beta
self._beta[key] = max_beta
self.beta[key] = max_beta
return max_beta
def get_hops(self, ranks):
key = ','.join(map(str, sorted(ranks)))
hops = 0
for i in range(len(ranks)):
for j in range(i + 1, len(ranks)):
hop = self.cluster.get_hop(ranks[i], ranks[j])
hops += hop
self.hops[key] = hops
return hops
class Cost:
def __init__(self, time=0, memory=0, flops=0):
......@@ -198,11 +268,13 @@ class Cost:
def _check_memory(self, val):
assert isinstance(
val, int) and val >= 0, "Memory must be int and greater than 0."
val,
int) and val >= 0, "Memory must be int and greater than equal to 0."
def _check_flops(self, val):
assert isinstance(
val, int) and val >= 0, "FLOPs must be int and greater than 0."
val,
int) and val >= 0, "FLOPs must be int and greater than equal to 0."
@property
def time(self):
......@@ -254,7 +326,7 @@ class OpCost:
op_desc is not None)
self._op = op
self._op_desc = op_desc
self._cost = self.calc_cost()
self._cost = None
@property
def op(self):
......@@ -264,6 +336,18 @@ class OpCost:
def op_desc(self):
return self._op_desc
@property
def time(self):
return self.cost.time
@property
def memory(self):
return self.cost.memory
@property
def flops(self):
return self.cost.flops
@property
def cost(self):
return self._cost
......@@ -284,6 +368,40 @@ class OpCost:
cost = Cost(time, memory, flops)
return cost
def __add__(self, rhs):
assert isinstance(rhs, (OpCost, Cost))
time = 0
memory = 0
flops = 0
if isinstance(rhs, OpCost):
time = self.cost.time + rhs.cost.time
memory = self.cost.memory + rhs.cost.memory
flops = self.cost.flops + rhs.cost.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
elif isinstance(rhs, Cost):
time = self.time + rhs.time
memory = self.memory + rhs.memory
flops = self.flops + rhs.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
return Cost(time, memory, flops)
def __sub__(self, rhs):
assert isinstance(rhs, (OpCost, Cost))
time = 0
memory = 0
flops = 0
if isinstance(rhs, OpCost):
time = self.cost.time - rhs.cost.time
memory = self.cost.memory - rhs.cost.memory
flops = self.cost.flops - rhs.cost.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
elif isinstance(rhs, Cost):
time = self.time - rhs.time
memory = self.memory - rhs.memory
flops = self.flops - rhs.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
return Cost(time, memory, flops)
class CommOpCost(OpCost):
OP_TYPE = "COMM"
......@@ -292,11 +410,83 @@ class CommOpCost(OpCost):
super(CommOpCost, self).__init__(op=op, op_desc=op_desc)
self._check_comm_op_type()
self._comm_context = comm_context
self._group_ranks = None
self._comm_count = None
self._hops = None
self._rank_count = len(self.group_ranks)
self._machine_count = None
self._cost = self.calc_cost()
@property
def comm_context(self):
return self._comm_context
@property
def comm_count(self):
if self._comm_count is None:
dtype = None
shape = None
if self.op is not None:
vars = self.op.block.vars
# NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overrided
var_name = self.op.input("X")[0]
var = vars[var_name]
dtype = var.dtype
shape = var.shape
elif self.op_desc is not None:
dtype = self.op_desc["inputs"]["X"][0][0]
shape = self.op_desc["inputs"]["X"][0][1]
factor = None
if dtype == paddle.float32 or dtype == paddle.int32:
factor = 4
elif dtype == paddle.int64:
factor = 8
elif dtype == paddle.uint8:
factor = 1
elif dtype == paddle.float16:
factor = 2
else:
raise TypeError("This dtype {} is not supported now".format(
dtype))
comm_count = reduce(lambda x, y: x * y, shape) * factor
self._comm_count = comm_count
return self._comm_count
@property
def rank_count(self):
return self._rank_count
@property
def machine_count(self):
if self._machine_count is None:
cluster = self._comm_context.cluster
self._machine_count = cluster.get_involved_machine_count(
self.group_ranks)
return self._machine_count
@property
def hops(self):
if self._hops is None:
self._hops = self.comm_context.get_hops(self.group_ranks)
return self._hops
@property
def group_ranks(self):
if self._group_ranks is None:
if self.op_desc is not None:
self._group_ranks = self.op_desc["group_ranks"]
elif self.op is not None:
ring_id = op.attrs("ring_id")
process_group = get_process_group(ring_id)
if process_group is None:
raise ValueError(
"There not exists process group whose ring_id is {}.".
format(ring_id))
self._group_ranks = process_group.ranks
return self._group_ranks
@classmethod
def _check_comm_op_type(cls):
if cls.OP_TYPE != "COMM":
......@@ -311,6 +501,7 @@ class CompOpCost(OpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super(CompOpCost, self).__init__(op=op, op_desc=op_desc)
self._check_comp_op_type()
self._cost = self.calc_cost()
self.cluster = cluster
@classmethod
......@@ -325,18 +516,22 @@ def register_op_cost(cls):
op_type = cls.OP_TYPE
def register(op_type):
OP_COST_FACTORY[op_type] = cls
global _g_op_cost_factory
_g_op_cost_factory[op_type] = cls
return register(op_type)
register(op_type)
return cls
def calc_time_from_model(op=None, desc=None, cluster=None, comm_context=None):
def calc_time_by_modeling(op=None, desc=None, cluster=None):
op_type = op.type if op is not None else desc["op"]
if op_type in COMM_OP_TYPE:
op_cost = OP_COST_FACTORY[op_type](op=op,
op_cost = _g_op_cost_factory[op_type](op=op,
op_desc=desc,
comm_context=comm_context)
comm_context=CommContext(cluster))
elif op_type not in NON_COMP_TYPE:
op_cost = OP_COST_FACTORY[op_type](op=op, op_desc=desc, cluster=cluster)
op_cost = _g_op_cost_factory[op_type](op=op,
op_desc=desc,
cluster=cluster)
time = op_cost.calc_time()
return time
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .base_cost import register_op_cost, CommOpCost, OP_COST_FACTORY
from .base_cost import register_op_cost, CommOpCost
@register_op_cost
......@@ -20,7 +20,7 @@ class AllreduceSumCost(CommOpCost):
OP_TYPE = "c_allreduce_sum"
def __init__(self, op=None, op_desc=None, comm_context=None):
super(OP_COST_FACTORY["c_allreduce_sum"], self).__init__(
super(AllreduceSumCost, self).__init__(
op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .base_cost import Cost, register_op_cost, CompOpCost, OP_COST_FACTORY
from .base_cost import Cost, register_op_cost, CompOpCost
@register_op_cost
......@@ -20,7 +20,7 @@ class MatmulV2OpCost(CompOpCost):
OP_TYPE = "matmul_v2"
def __init__(self, op=None, op_desc=None, cluster=None):
super(OP_COST_FACTORY["matmul_v2"], self).__init__(
super(MatmulV2OpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided
......
......@@ -18,7 +18,7 @@ import paddle
import paddle.distributed.auto_parallel.cost as cost_model
from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc
from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str
from paddle.distributed.auto_parallel.cost.base_cost import calc_time_from_model
from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling
paddle.enable_static()
......@@ -45,13 +45,13 @@ class TestCost(unittest.TestCase):
if op.type == "matmul_v2":
matmul_v2_op = op
break
matmul_v2_cost = cost_model.OP_COST_FACTORY["matmul_v2"](
matmul_v2_cost = cost_model._g_op_cost_factory["matmul_v2"](
op=matmul_v2_op)
desc = parse_to_desc(op=matmul_v2_op)
desc_str = parse_desc_to_str(desc)
self.assertIsNotNone(desc_str)
self.assertTrue(check_cost(matmul_v2_cost.cost))
time = calc_time_from_model(op=matmul_v2_op)
time = calc_time_by_modeling(op=matmul_v2_op)
self.assertEqual(time, matmul_v2_cost.cost.time)
tensor_cost = cost_model.TensorCost(tensor=x)
# check memory
......@@ -61,7 +61,8 @@ class TestCost(unittest.TestCase):
desc = {}
desc["op"] = "c_allreduce_sum"
desc["inputs"] = {"X": [([100, 200], paddle.float32)]}
allreduce_cost = cost_model.OP_COST_FACTORY["c_allreduce_sum"](
desc["group_ranks"] = [0, 1]
allreduce_cost = cost_model._g_op_cost_factory["c_allreduce_sum"](
op_desc=desc)
self.assertTrue(check_cost(allreduce_cost.cost))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册