common.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
_g_distributed_operator_impl_registries = {}
16 17


18
class DistributedOperatorImplContainer:
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
    def __init__(self):
        self._impls = []
        self._name = None

    def register_impl(self, dist_impl):
        self._impls.append(dist_impl)

    def get_impl(self, impl_idx):
        return self._impls[impl_idx]

    def get_impls(self):
        return self._impls


class DistributedOperatorImpl:
    def __init__(self):
        self._name = None
36 37
        self._forward_implemented = False
        self._backward_implemented = False
38

39 40
    @staticmethod
    def forward(dist_ctx, *args, **kwargs):
41 42
        raise NotImplementedError("Please Implement this method in Subclass.")

43 44
    @staticmethod
    def backward(dist_ctx, *grad_outputs, **kwargs):
45 46 47 48 49
        raise NotImplementedError("Please Implement this method in Subclass.")

    def get_name(self):
        return self._name

50
    def is_input_compatible(self, dist_op):
51 52
        raise NotImplementedError("Please Implement this method in Subclass.")

53
    def is_output_compatible(self, dist_op):
54 55
        raise NotImplementedError("Please Implement this method in Subclass.")

56 57 58
    def is_compatible(self, dist_op):
        return self.is_input_compatible(dist_op) and \
            self.is_output_compatible(dist_op)
59

60
    def update_dims_mapping(self, dist_op):
61 62 63
        raise NotImplementedError("Please Implement this method in Subclass.")


64 65 66
def register_distributed_operator_impl_container(name, dist_op_impl_container):
    global _g_distributed_operator_impl_registries
    _g_distributed_operator_impl_registries[name] = dist_op_impl_container
67 68


69 70 71
def get_distributed_operator_impl_container(name):
    global _g_distributed_operator_impl_registries
    return _g_distributed_operator_impl_registries.get(name, None)
72 73 74


def register_distributed_operator_impl(name, dist_impl):
75 76 77
    dist_op_impl_container = get_distributed_operator_impl_container(name)
    if dist_op_impl_container is not None:
        dist_op_impl_container.register_impl(dist_impl)
78
    else:
79
        assert False, "Must register distributed operator registry first."
80 81 82


def get_distributed_operator_impl(name, impl_idx):
83 84
    global _g_distributed_operator_impl_registries
    return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
85 86


87
def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
88 89 90 91
    """
    Here just return the first compatible implemention. 
    This will be improved by cost model in the future.
    """
92 93
    dist_op_impl_container = get_distributed_operator_impl_container(name)
    if dist_op_impl_container is None:
94 95
        return None, -1
    compatible_impls = []
96
    impls = dist_op_impl_container.get_impls()
97 98
    if fwd:
        for idx, impl in enumerate(impls):
99
            if impl.is_input_compatible(dist_op):
100 101 102
                compatible_impls.append((impl, idx))
    else:
        for idx, impl in enumerate(impls):
103
            if impl.is_output_compatible(dist_op):
104 105 106 107 108 109 110 111
                compatible_impls.append((impl, idx))

    if compatible_impls:
        best_compatible_impl, idx = compatible_impls[0]
    else:
        best_compatible_impl, idx = None, -1

    return best_compatible_impl, idx
112 113


Z
zhaoyingli 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
    var_shape = block.var(src_var.name).shape
    var_topoloy = src_var_dist_attr.process_mesh.topology
    var_dims_mapping = src_var_dist_attr.dims_mapping

    complete_shape = []
    for idx, shape in enumerate(var_shape):
        if var_dims_mapping[idx] == -1:
            complete_shape.append(shape)
        else:
            new_shape = shape * var_topoloy[var_dims_mapping[idx]]
            complete_shape.append(new_shape)

    exact_shape = []
    input_topology = op_input_dist_attr.process_mesh.topology
    input_dims_mapping = op_input_dist_attr.dims_mapping
    for idx, shape in enumerate(complete_shape):
        if input_dims_mapping[idx] == -1:
            exact_shape.append(shape)
        else:
            new_shape = shape // input_topology[input_dims_mapping[idx]]
            exact_shape.append(new_shape)

    return exact_shape