diff --git a/python/paddle/distributed/auto_parallel/cost/__init__.py b/python/paddle/distributed/auto_parallel/cost/__init__.py index 9ea58d697952775b028a796122a2a115030e2271..ea6b3bc5b7e76940f67badf2bced4e54a87f436c 100644 --- a/python/paddle/distributed/auto_parallel/cost/__init__.py +++ b/python/paddle/distributed/auto_parallel/cost/__init__.py @@ -14,7 +14,16 @@ from .base_cost import _g_op_cost_factory from .base_cost import Cost -from .comm_op_cost import AllreduceSumCost -from .comp_op_cost import MatmulV2OpCost +from .base_cost import CommContext +from .base_cost import build_comm_desc from .tensor_cost import TensorCost from .estimate_cost import CostEstimator + +from .comp_op_cost import MatmulV2OpCost + +from .comm_op_cost import SendOpCost +from .comm_op_cost import RecvOpCost +from .comm_op_cost import IdentityOpCost +from .comm_op_cost import BroadcastOpCost +from .comm_op_cost import AllgatherOpCost +from .comm_op_cost import AllreduceSumOpCost diff --git a/python/paddle/distributed/auto_parallel/cost/base_cost.py b/python/paddle/distributed/auto_parallel/cost/base_cost.py index cb16d522bc9e3788175a71734a2359b7986debc8..f1843b8f16527280ff477be4b85d722ffb54ef75 100644 --- a/python/paddle/distributed/auto_parallel/cost/base_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/base_cost.py @@ -13,15 +13,31 @@ # limitations under the License from collections import OrderedDict +from functools import reduce + import paddle +from ..cluster import LinkType +from ..process_group import get_process_group + COMM_OP_TYPE = [ - "send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum" + "send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum", + "c_identity" ] NON_COMP_TYPE = ["while"] + COMM_OP_TYPE _g_op_cost_factory = {} +def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None): + desc = {} + desc["op"] = op_type + desc["group_ranks"] = group_ranks + desc["inputs"] = {"X": [(dtype, shape)]} + if attrs is not None: + desc["attrs"] = attrs + return desc + + def _parse_op_to_desc(op, dist_context=None): desc = {} desc["op"] = op.type diff --git a/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py index 235741ba12f4fb67d072043a06733549d19df36a..a32fdf1824e6293be31352a6ac2665c250fca790 100644 --- a/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py @@ -12,17 +12,149 @@ # See the License for the specific language governing permissions and # limitations under the License -from .base_cost import register_op_cost, CommOpCost +import math + +from .base_cost import register_op_cost, CommOpCost, _g_op_cost_factory @register_op_cost -class AllreduceSumCost(CommOpCost): +class AllreduceSumOpCost(CommOpCost): OP_TYPE = "c_allreduce_sum" def __init__(self, op=None, op_desc=None, comm_context=None): - super(AllreduceSumCost, self).__init__( + super(AllreduceSumOpCost, self).__init__( + op=op, op_desc=op_desc, comm_context=comm_context) + + def calc_time(self): + # use tree if cross machine and use ring if in a single machine + time = None + cluster = self.comm_context.cluster + if not cluster.cross_machine(self.group_ranks): + time = self.calc_time_ring() + else: + time = self.calc_time_tree() + + return time + + def calc_time_ring(self): + alpha = self.comm_context.base_ring + alpha += 2 * ( + self.rank_count - self.machine_count) * self.comm_context.intra_ring + alpha += 2 * (self.machine_count - 1) * ( + self.comm_context.inter_ring + self.hops * self.comm_context.switch) + beta = self.comm_context.get_max_beta(self.group_ranks) + time = alpha + 2 * (self.rank_count - 1 + ) / self.rank_count * self.comm_count * beta + + return time + + def calc_time_tree(self): + alpha = self.comm_context.base_tree + alpha += 2 * (self.rank_count / self.machine_count - 1 + ) * self.comm_context.intra_tree + alpha += math.log2(self.machine_count) * ( + self.comm_context.inter_tree + self.hops * self.comm_context.switch) + beta = self.comm_context.get_max_beta(self.group_ranks) + + time = alpha + 2 * self.comm_count * beta + + return time + + +@register_op_cost +class AllgatherOpCost(CommOpCost): + OP_TYPE = "c_allgather" + + def __init__(self, op=None, op_desc=None, comm_context=None): + super(AllgatherOpCost, self).__init__( + op=op, op_desc=op_desc, comm_context=comm_context) + + def calc_time(self): + time = self.calc_time_ring() + return time + + def calc_time_ring(self): + alpha = self.comm_context.base_ring + alpha += ( + self.rank_count - self.machine_count) * self.comm_context.intra_ring + alpha += (self.machine_count - 1) * ( + self.comm_context.inter_ring + self.hops * self.comm_context.switch) + beta = self.comm_context.get_max_beta(self.group_ranks) + time = alpha + (self.rank_count - 1 + ) / self.rank_count * self.comm_count * beta + return time + + +@register_op_cost +class BroadcastOpCost(CommOpCost): + OP_TYPE = "c_broadcast" + + def __init__(self, op=None, op_desc=None, comm_context=None): + super(BroadcastOpCost, self).__init__( + op=op, op_desc=op_desc, comm_context=comm_context) + + def calc_time(self): + time = self.calc_time_ring() + return time + + def calc_time_ring(self): + alpha = self.comm_context.base_ring + if self.machine_count > 1: + alpha += self.comm_context.inter_ring + self.hops * self.comm_context.switch + else: + alpha += self.comm_context.intra_ring + beta = self.comm_context.get_max_beta(self.group_ranks) + time = alpha + self.comm_count * beta + + return time + + +@register_op_cost +class IdentityOpCost(CommOpCost): + OP_TYPE = "c_identity" + + def __init__(self, op=None, op_desc=None, comm_context=None): + super(IdentityOpCost, self).__init__( op=op, op_desc=op_desc, comm_context=comm_context) def calc_time(self): - # NOTE: The actual formula will be filled in the future. return 0 + + +@register_op_cost +class RecvOpCost(CommOpCost): + OP_TYPE = "recv_v2" + + def __init__(self, op=None, op_desc=None, comm_context=None): + super(RecvOpCost, self).__init__( + op=op, op_desc=op_desc, comm_context=comm_context) + + def calc_time(self): + alpha = self.comm_context.base_ring + if self.machine_count > 1: + alpha += self.comm_context.inter_ring + self.hops * self.comm_context.switch + else: + alpha += self.comm_context.intra_ring + beta = self.comm_context.get_max_beta(self.group_ranks) + time = alpha + self.comm_count * beta + return time + + +@register_op_cost +class SendOpCost(CommOpCost): + OP_TYPE = "send_v2" + + def __init__(self, op=None, op_desc=None, comm_context=None): + super(SendOpCost, self).__init__( + op=op, op_desc=op_desc, comm_context=comm_context) + + def calc_time(self): + alpha = self.comm_context.base_ring + if self.machine_count > 1: + alpha += self.comm_context.inter_ring + self.hops * self.comm_context.switch + else: + alpha += self.comm_context.intra_ring + beta = self.comm_context.get_max_beta(self.group_ranks) + time = alpha + self.comm_count * beta + + return time diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 7c747338593a393f135b64faa509ca736074da8d..1f846f5d7361c0210be022b453d0dd178d24821e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -29,4 +29,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_dist_pnorm MODULES test_dist_pnorm ENVS ${dist_ENVS}) py_test_modules(test_dist_slice MODULES test_dist_slice ENVS ${dist_ENVS}) py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS}) + py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS}) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ad1f4ed314de8b35f8909f9f8ee8355261b52f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import json + +import paddle +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.cost import CommContext +from paddle.distributed.auto_parallel.cost import build_comm_desc +from paddle.distributed.auto_parallel.cost import AllreduceSumOpCost +from paddle.distributed.auto_parallel.cost import AllgatherOpCost +from paddle.distributed.auto_parallel.cost import BroadcastOpCost +from paddle.distributed.auto_parallel.cost import SendOpCost +from paddle.distributed.auto_parallel.cost import RecvOpCost +from paddle.distributed.auto_parallel.cost import IdentityOpCost + +from test_cluster import cluster_json, multi_cluster_json + + +class TestCommOpCost(unittest.TestCase): + def test_comm_cost(self): + # Build cluster + file_dir = os.path.dirname(os.path.abspath(__file__)) + cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") + cluster_json_object = json.loads(cluster_json) + with open(cluster_json_path, "w") as cluster_json_file: + json.dump(cluster_json_object, cluster_json_file) + cluster = Cluster() + cluster.build_from_file(cluster_json_path) + + # Build CommConetxt + CommContext._has_instance = None + CommContext._instance = None + comm_context = CommContext(cluster) + + # Check AllreduceSumCost 128MB ring cost + allreduce_sum_op_desc = build_comm_desc( + "c_allreduce_sum", [0, 1, 2, 3, 4, 5, 6, 7], paddle.float32, + [1, 32 * (10**6)]) + allreduce_sum_op_cost = AllreduceSumOpCost( + op_desc=allreduce_sum_op_desc, comm_context=comm_context) + + # Check AllgatherOpCost cost + allgather_op_desc = build_comm_desc("c_allgather", + [0, 1, 2, 3, 4, 5, 6, 7], + paddle.float32, [1, 32 * (10**6)]) + allgather_op_cost = AllgatherOpCost( + op_desc=allgather_op_desc, comm_context=comm_context) + self.assertTrue(allgather_op_cost.time > 0) + + # Check BroadcastOpCost cost + broadcast_op_desc = build_comm_desc("c_broadcast", + [0, 1, 2, 3, 4, 5, 6, 7], + paddle.float32, [1, 32 * (10**6)]) + broadcast_op_cost = BroadcastOpCost( + op_desc=broadcast_op_desc, comm_context=comm_context) + self.assertTrue(broadcast_op_cost.time > 0) + + # Check SendOpCost cost + send_op_desc = build_comm_desc("send_v2", [0, 1], paddle.float32, + [1, 32 * (10**6)]) + send_op_cost = SendOpCost( + op_desc=send_op_desc, comm_context=comm_context) + self.assertTrue(send_op_cost.time > 0) + + # Check RecvOpCost cost + recv_op_desc = build_comm_desc("recv_v2", [0, 1], paddle.float32, + [1, 32 * (10**6)]) + recv_op_cost = RecvOpCost( + op_desc=recv_op_desc, comm_context=comm_context) + self.assertTrue(recv_op_cost.time > 0) + + # Check IdentityOpCost cost + identity_op_desc = build_comm_desc("c_identity", [0, 1], paddle.float32, + [1, 32 * (10**6)]) + identity_op_cost = IdentityOpCost( + op_desc=identity_op_desc, comm_context=comm_context) + self.assertTrue(identity_op_cost.time >= 0) + + # Remove unnecessary files + if os.path.exists(cluster_json_path): + os.remove(cluster_json_path) + + def test_cross_machine_comm_cost(self): + # Build cluster + file_dir = os.path.dirname(os.path.abspath(__file__)) + cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") + cluster_json_object = json.loads(multi_cluster_json) + with open(cluster_json_path, "w") as cluster_json_file: + json.dump(cluster_json_object, cluster_json_file) + cluster = Cluster() + cluster.build_from_file(cluster_json_path) + + # Build CommConetxt + CommContext._has_instance = None + CommContext._instance = None + comm_context = CommContext(cluster) + + # Check AllreduceSumCost 128MB ring cost + allreduce_sum_op_desc = build_comm_desc( + "c_allreduce_sum", + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + paddle.float32, [1, 32 * (10**6)]) + allreduce_sum_op_cost = AllreduceSumOpCost( + op_desc=allreduce_sum_op_desc, comm_context=comm_context) + + # Check AllgatherOpCost cost + allgather_op_desc = build_comm_desc( + "c_allgather", + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + paddle.float32, [1, 32 * (10**6)]) + allgather_op_cost = AllgatherOpCost( + op_desc=allgather_op_desc, comm_context=comm_context) + self.assertTrue(allgather_op_cost.time > 0) + + # Check BroadcastOpCost cost + broadcast_op_desc = build_comm_desc( + "c_broadcast", + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + paddle.float32, [1, 32 * (10**6)]) + broadcast_op_cost = BroadcastOpCost( + op_desc=broadcast_op_desc, comm_context=comm_context) + self.assertTrue(broadcast_op_cost.time > 0) + + # Check SendOpCost cost + send_op_desc = build_comm_desc("send_v2", [0, 1], paddle.float32, + [1, 32 * (10**6)]) + send_op_cost = SendOpCost( + op_desc=send_op_desc, comm_context=comm_context) + self.assertTrue(send_op_cost.time > 0) + + # Check RecvOpCost cost + recv_op_desc = build_comm_desc("recv_v2", [0, 1], paddle.float32, + [1, 32 * (10**6)]) + recv_op_cost = RecvOpCost( + op_desc=recv_op_desc, comm_context=comm_context) + self.assertTrue(recv_op_cost.time > 0) + + # Remove unnecessary files + if os.path.exists(cluster_json_path): + os.remove(cluster_json_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py index 6d6fbfe78e9e6d9a94a1accb53daa84ffcacd654..c0df01ada58f9feb26a61b379d96bb4b675f2e24 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py @@ -13,12 +13,17 @@ # limitations under the License. import unittest +import os +import json import paddle import paddle.distributed.auto_parallel.cost as cost_model from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.cost import CommContext +from test_cluster import cluster_json, multi_cluster_json paddle.enable_static() @@ -58,14 +63,31 @@ class TestCost(unittest.TestCase): self.assertEqual(tensor_cost.cost.memory, 1600) def test_comm_cost(self): + # Build cluster + file_dir = os.path.dirname(os.path.abspath(__file__)) + cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") + cluster_json_object = json.loads(cluster_json) + with open(cluster_json_path, "w") as cluster_json_file: + json.dump(cluster_json_object, cluster_json_file) + cluster = Cluster() + cluster.build_from_file(cluster_json_path) + + # Build CommConetxt + CommContext._has_instance = None + CommContext._instance = None + comm_context = CommContext(cluster) desc = {} desc["op"] = "c_allreduce_sum" - desc["inputs"] = {"X": [([100, 200], paddle.float32)]} + desc["inputs"] = {"X": [(paddle.float32, [100, 200])]} desc["group_ranks"] = [0, 1] allreduce_cost = cost_model._g_op_cost_factory["c_allreduce_sum"]( - op_desc=desc) + op_desc=desc, comm_context=CommContext(cluster)) self.assertTrue(check_cost(allreduce_cost.cost)) + # Remove unnecessary files + if os.path.exists(cluster_json_path): + os.remove(cluster_json_path) + def test_cost_estimator(self): train_program = paddle.static.Program() cost_estimator = cost_model.CostEstimator(train_program)