未验证 提交 cbc5ca0f 编写于 作者: T Tao CHANG 提交者: GitHub

add communication cost for cost model (#42727)

上级 3052f36c
......@@ -14,7 +14,16 @@
from .base_cost import _g_op_cost_factory
from .base_cost import Cost
from .comm_op_cost import AllreduceSumCost
from .comp_op_cost import MatmulV2OpCost
from .base_cost import CommContext
from .base_cost import build_comm_desc
from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator
from .comp_op_cost import MatmulV2OpCost
from .comm_op_cost import SendOpCost
from .comm_op_cost import RecvOpCost
from .comm_op_cost import IdentityOpCost
from .comm_op_cost import BroadcastOpCost
from .comm_op_cost import AllgatherOpCost
from .comm_op_cost import AllreduceSumOpCost
......@@ -13,15 +13,31 @@
# limitations under the License
from collections import OrderedDict
from functools import reduce
import paddle
from ..cluster import LinkType
from ..process_group import get_process_group
COMM_OP_TYPE = [
"send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum"
"send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum",
"c_identity"
]
NON_COMP_TYPE = ["while"] + COMM_OP_TYPE
_g_op_cost_factory = {}
def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
desc = {}
desc["op"] = op_type
desc["group_ranks"] = group_ranks
desc["inputs"] = {"X": [(dtype, shape)]}
if attrs is not None:
desc["attrs"] = attrs
return desc
def _parse_op_to_desc(op, dist_context=None):
desc = {}
desc["op"] = op.type
......
......@@ -12,17 +12,149 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .base_cost import register_op_cost, CommOpCost
import math
from .base_cost import register_op_cost, CommOpCost, _g_op_cost_factory
@register_op_cost
class AllreduceSumCost(CommOpCost):
class AllreduceSumOpCost(CommOpCost):
OP_TYPE = "c_allreduce_sum"
def __init__(self, op=None, op_desc=None, comm_context=None):
super(AllreduceSumCost, self).__init__(
super(AllreduceSumOpCost, self).__init__(
op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
# use tree if cross machine and use ring if in a single machine
time = None
cluster = self.comm_context.cluster
if not cluster.cross_machine(self.group_ranks):
time = self.calc_time_ring()
else:
time = self.calc_time_tree()
return time
def calc_time_ring(self):
alpha = self.comm_context.base_ring
alpha += 2 * (
self.rank_count - self.machine_count) * self.comm_context.intra_ring
alpha += 2 * (self.machine_count - 1) * (
self.comm_context.inter_ring + self.hops * self.comm_context.switch)
beta = self.comm_context.get_max_beta(self.group_ranks)
time = alpha + 2 * (self.rank_count - 1
) / self.rank_count * self.comm_count * beta
return time
def calc_time_tree(self):
alpha = self.comm_context.base_tree
alpha += 2 * (self.rank_count / self.machine_count - 1
) * self.comm_context.intra_tree
alpha += math.log2(self.machine_count) * (
self.comm_context.inter_tree + self.hops * self.comm_context.switch)
beta = self.comm_context.get_max_beta(self.group_ranks)
time = alpha + 2 * self.comm_count * beta
return time
@register_op_cost
class AllgatherOpCost(CommOpCost):
OP_TYPE = "c_allgather"
def __init__(self, op=None, op_desc=None, comm_context=None):
super(AllgatherOpCost, self).__init__(
op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
time = self.calc_time_ring()
return time
def calc_time_ring(self):
alpha = self.comm_context.base_ring
alpha += (
self.rank_count - self.machine_count) * self.comm_context.intra_ring
alpha += (self.machine_count - 1) * (
self.comm_context.inter_ring + self.hops * self.comm_context.switch)
beta = self.comm_context.get_max_beta(self.group_ranks)
time = alpha + (self.rank_count - 1
) / self.rank_count * self.comm_count * beta
return time
@register_op_cost
class BroadcastOpCost(CommOpCost):
OP_TYPE = "c_broadcast"
def __init__(self, op=None, op_desc=None, comm_context=None):
super(BroadcastOpCost, self).__init__(
op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
time = self.calc_time_ring()
return time
def calc_time_ring(self):
alpha = self.comm_context.base_ring
if self.machine_count > 1:
alpha += self.comm_context.inter_ring + self.hops * self.comm_context.switch
else:
alpha += self.comm_context.intra_ring
beta = self.comm_context.get_max_beta(self.group_ranks)
time = alpha + self.comm_count * beta
return time
@register_op_cost
class IdentityOpCost(CommOpCost):
OP_TYPE = "c_identity"
def __init__(self, op=None, op_desc=None, comm_context=None):
super(IdentityOpCost, self).__init__(
op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
# NOTE: The actual formula will be filled in the future.
return 0
@register_op_cost
class RecvOpCost(CommOpCost):
OP_TYPE = "recv_v2"
def __init__(self, op=None, op_desc=None, comm_context=None):
super(RecvOpCost, self).__init__(
op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
alpha = self.comm_context.base_ring
if self.machine_count > 1:
alpha += self.comm_context.inter_ring + self.hops * self.comm_context.switch
else:
alpha += self.comm_context.intra_ring
beta = self.comm_context.get_max_beta(self.group_ranks)
time = alpha + self.comm_count * beta
return time
@register_op_cost
class SendOpCost(CommOpCost):
OP_TYPE = "send_v2"
def __init__(self, op=None, op_desc=None, comm_context=None):
super(SendOpCost, self).__init__(
op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
alpha = self.comm_context.base_ring
if self.machine_count > 1:
alpha += self.comm_context.inter_ring + self.hops * self.comm_context.switch
else:
alpha += self.comm_context.intra_ring
beta = self.comm_context.get_max_beta(self.group_ranks)
time = alpha + self.comm_count * beta
return time
......@@ -29,4 +29,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_dist_pnorm MODULES test_dist_pnorm ENVS ${dist_ENVS})
py_test_modules(test_dist_slice MODULES test_dist_slice ENVS ${dist_ENVS})
py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS})
py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS})
endif()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import json
import paddle
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.cost import CommContext
from paddle.distributed.auto_parallel.cost import build_comm_desc
from paddle.distributed.auto_parallel.cost import AllreduceSumOpCost
from paddle.distributed.auto_parallel.cost import AllgatherOpCost
from paddle.distributed.auto_parallel.cost import BroadcastOpCost
from paddle.distributed.auto_parallel.cost import SendOpCost
from paddle.distributed.auto_parallel.cost import RecvOpCost
from paddle.distributed.auto_parallel.cost import IdentityOpCost
from test_cluster import cluster_json, multi_cluster_json
class TestCommOpCost(unittest.TestCase):
def test_comm_cost(self):
# Build cluster
file_dir = os.path.dirname(os.path.abspath(__file__))
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
cluster = Cluster()
cluster.build_from_file(cluster_json_path)
# Build CommConetxt
CommContext._has_instance = None
CommContext._instance = None
comm_context = CommContext(cluster)
# Check AllreduceSumCost 128MB ring cost
allreduce_sum_op_desc = build_comm_desc(
"c_allreduce_sum", [0, 1, 2, 3, 4, 5, 6, 7], paddle.float32,
[1, 32 * (10**6)])
allreduce_sum_op_cost = AllreduceSumOpCost(
op_desc=allreduce_sum_op_desc, comm_context=comm_context)
# Check AllgatherOpCost cost
allgather_op_desc = build_comm_desc("c_allgather",
[0, 1, 2, 3, 4, 5, 6, 7],
paddle.float32, [1, 32 * (10**6)])
allgather_op_cost = AllgatherOpCost(
op_desc=allgather_op_desc, comm_context=comm_context)
self.assertTrue(allgather_op_cost.time > 0)
# Check BroadcastOpCost cost
broadcast_op_desc = build_comm_desc("c_broadcast",
[0, 1, 2, 3, 4, 5, 6, 7],
paddle.float32, [1, 32 * (10**6)])
broadcast_op_cost = BroadcastOpCost(
op_desc=broadcast_op_desc, comm_context=comm_context)
self.assertTrue(broadcast_op_cost.time > 0)
# Check SendOpCost cost
send_op_desc = build_comm_desc("send_v2", [0, 1], paddle.float32,
[1, 32 * (10**6)])
send_op_cost = SendOpCost(
op_desc=send_op_desc, comm_context=comm_context)
self.assertTrue(send_op_cost.time > 0)
# Check RecvOpCost cost
recv_op_desc = build_comm_desc("recv_v2", [0, 1], paddle.float32,
[1, 32 * (10**6)])
recv_op_cost = RecvOpCost(
op_desc=recv_op_desc, comm_context=comm_context)
self.assertTrue(recv_op_cost.time > 0)
# Check IdentityOpCost cost
identity_op_desc = build_comm_desc("c_identity", [0, 1], paddle.float32,
[1, 32 * (10**6)])
identity_op_cost = IdentityOpCost(
op_desc=identity_op_desc, comm_context=comm_context)
self.assertTrue(identity_op_cost.time >= 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
def test_cross_machine_comm_cost(self):
# Build cluster
file_dir = os.path.dirname(os.path.abspath(__file__))
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json")
cluster_json_object = json.loads(multi_cluster_json)
with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
cluster = Cluster()
cluster.build_from_file(cluster_json_path)
# Build CommConetxt
CommContext._has_instance = None
CommContext._instance = None
comm_context = CommContext(cluster)
# Check AllreduceSumCost 128MB ring cost
allreduce_sum_op_desc = build_comm_desc(
"c_allreduce_sum",
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
paddle.float32, [1, 32 * (10**6)])
allreduce_sum_op_cost = AllreduceSumOpCost(
op_desc=allreduce_sum_op_desc, comm_context=comm_context)
# Check AllgatherOpCost cost
allgather_op_desc = build_comm_desc(
"c_allgather",
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
paddle.float32, [1, 32 * (10**6)])
allgather_op_cost = AllgatherOpCost(
op_desc=allgather_op_desc, comm_context=comm_context)
self.assertTrue(allgather_op_cost.time > 0)
# Check BroadcastOpCost cost
broadcast_op_desc = build_comm_desc(
"c_broadcast",
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
paddle.float32, [1, 32 * (10**6)])
broadcast_op_cost = BroadcastOpCost(
op_desc=broadcast_op_desc, comm_context=comm_context)
self.assertTrue(broadcast_op_cost.time > 0)
# Check SendOpCost cost
send_op_desc = build_comm_desc("send_v2", [0, 1], paddle.float32,
[1, 32 * (10**6)])
send_op_cost = SendOpCost(
op_desc=send_op_desc, comm_context=comm_context)
self.assertTrue(send_op_cost.time > 0)
# Check RecvOpCost cost
recv_op_desc = build_comm_desc("recv_v2", [0, 1], paddle.float32,
[1, 32 * (10**6)])
recv_op_cost = RecvOpCost(
op_desc=recv_op_desc, comm_context=comm_context)
self.assertTrue(recv_op_cost.time > 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
if __name__ == "__main__":
unittest.main()
......@@ -13,12 +13,17 @@
# limitations under the License.
import unittest
import os
import json
import paddle
import paddle.distributed.auto_parallel.cost as cost_model
from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc
from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str
from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.cost import CommContext
from test_cluster import cluster_json, multi_cluster_json
paddle.enable_static()
......@@ -58,14 +63,31 @@ class TestCost(unittest.TestCase):
self.assertEqual(tensor_cost.cost.memory, 1600)
def test_comm_cost(self):
# Build cluster
file_dir = os.path.dirname(os.path.abspath(__file__))
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
cluster = Cluster()
cluster.build_from_file(cluster_json_path)
# Build CommConetxt
CommContext._has_instance = None
CommContext._instance = None
comm_context = CommContext(cluster)
desc = {}
desc["op"] = "c_allreduce_sum"
desc["inputs"] = {"X": [([100, 200], paddle.float32)]}
desc["inputs"] = {"X": [(paddle.float32, [100, 200])]}
desc["group_ranks"] = [0, 1]
allreduce_cost = cost_model._g_op_cost_factory["c_allreduce_sum"](
op_desc=desc)
op_desc=desc, comm_context=CommContext(cluster))
self.assertTrue(check_cost(allreduce_cost.cost))
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
def test_cost_estimator(self):
train_program = paddle.static.Program()
cost_estimator = cost_model.CostEstimator(train_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册