comm_op_cost.py 5.7 KB
Newer Older
C
caozhou 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15 16
import math

17
from .base_cost import CommOpCost, register_op_cost
18 19 20


@register_op_cost
21
class AllreduceSumOpCost(CommOpCost):
22 23 24
    OP_TYPE = "c_allreduce_sum"

    def __init__(self, op=None, op_desc=None, comm_context=None):
25 26 27
        super(AllreduceSumOpCost, self).__init__(op=op,
                                                 op_desc=op_desc,
                                                 comm_context=comm_context)
28 29 30 31 32 33 34 35 36 37 38 39 40 41

    def calc_time(self):
        # use tree if cross machine and use ring if in a single machine
        time = None
        cluster = self.comm_context.cluster
        if not cluster.cross_machine(self.group_ranks):
            time = self.calc_time_ring()
        else:
            time = self.calc_time_tree()

        return time

    def calc_time_ring(self):
        alpha = self.comm_context.base_ring
42 43
        alpha += 2 * (self.rank_count -
                      self.machine_count) * self.comm_context.intra_ring
44 45 46
        alpha += 2 * (self.machine_count - 1) * (
            self.comm_context.inter_ring + self.hops * self.comm_context.switch)
        beta = self.comm_context.get_max_beta(self.group_ranks)
47 48
        time = alpha + 2 * (self.rank_count -
                            1) / self.rank_count * self.comm_count * beta
49 50 51 52 53

        return time

    def calc_time_tree(self):
        alpha = self.comm_context.base_tree
54 55
        alpha += 2 * (self.rank_count / self.machine_count -
                      1) * self.comm_context.intra_tree
56 57 58 59 60 61 62 63 64 65 66 67 68 69
        alpha += math.log2(self.machine_count) * (
            self.comm_context.inter_tree + self.hops * self.comm_context.switch)
        beta = self.comm_context.get_max_beta(self.group_ranks)

        time = alpha + 2 * self.comm_count * beta

        return time


@register_op_cost
class AllgatherOpCost(CommOpCost):
    OP_TYPE = "c_allgather"

    def __init__(self, op=None, op_desc=None, comm_context=None):
70 71 72
        super(AllgatherOpCost, self).__init__(op=op,
                                              op_desc=op_desc,
                                              comm_context=comm_context)
73 74 75 76 77 78 79

    def calc_time(self):
        time = self.calc_time_ring()
        return time

    def calc_time_ring(self):
        alpha = self.comm_context.base_ring
80 81
        alpha += (self.rank_count -
                  self.machine_count) * self.comm_context.intra_ring
82 83 84
        alpha += (self.machine_count - 1) * (
            self.comm_context.inter_ring + self.hops * self.comm_context.switch)
        beta = self.comm_context.get_max_beta(self.group_ranks)
85 86
        time = alpha + (self.rank_count -
                        1) / self.rank_count * self.comm_count * beta
87 88 89 90 91 92 93 94
        return time


@register_op_cost
class BroadcastOpCost(CommOpCost):
    OP_TYPE = "c_broadcast"

    def __init__(self, op=None, op_desc=None, comm_context=None):
95 96 97
        super(BroadcastOpCost, self).__init__(op=op,
                                              op_desc=op_desc,
                                              comm_context=comm_context)
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119

    def calc_time(self):
        time = self.calc_time_ring()
        return time

    def calc_time_ring(self):
        alpha = self.comm_context.base_ring
        if self.machine_count > 1:
            alpha += self.comm_context.inter_ring + self.hops * self.comm_context.switch
        else:
            alpha += self.comm_context.intra_ring
        beta = self.comm_context.get_max_beta(self.group_ranks)
        time = alpha + self.comm_count * beta

        return time


@register_op_cost
class IdentityOpCost(CommOpCost):
    OP_TYPE = "c_identity"

    def __init__(self, op=None, op_desc=None, comm_context=None):
120 121 122
        super(IdentityOpCost, self).__init__(op=op,
                                             op_desc=op_desc,
                                             comm_context=comm_context)
123 124 125

    def calc_time(self):
        return 0
126 127 128 129 130 131 132


@register_op_cost
class RecvOpCost(CommOpCost):
    OP_TYPE = "recv_v2"

    def __init__(self, op=None, op_desc=None, comm_context=None):
133 134 135
        super(RecvOpCost, self).__init__(op=op,
                                         op_desc=op_desc,
                                         comm_context=comm_context)
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152

    def calc_time(self):
        alpha = self.comm_context.base_ring
        if self.machine_count > 1:
            alpha += self.comm_context.inter_ring + self.hops * self.comm_context.switch
        else:
            alpha += self.comm_context.intra_ring
        beta = self.comm_context.get_max_beta(self.group_ranks)
        time = alpha + self.comm_count * beta
        return time


@register_op_cost
class SendOpCost(CommOpCost):
    OP_TYPE = "send_v2"

    def __init__(self, op=None, op_desc=None, comm_context=None):
153 154 155
        super(SendOpCost, self).__init__(op=op,
                                         op_desc=op_desc,
                                         comm_context=comm_context)
156 157 158 159 160 161 162 163 164 165 166

    def calc_time(self):
        alpha = self.comm_context.base_ring
        if self.machine_count > 1:
            alpha += self.comm_context.inter_ring + self.hops * self.comm_context.switch
        else:
            alpha += self.comm_context.intra_ring
        beta = self.comm_context.get_max_beta(self.group_ranks)
        time = alpha + self.comm_count * beta

        return time