common.py 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
_g_distributed_operator_impl_registries = {}
16 17


18
class DistributedOperatorImplContainer:
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
    def __init__(self):
        self._impls = []
        self._name = None

    def register_impl(self, dist_impl):
        self._impls.append(dist_impl)

    def get_impl(self, impl_idx):
        return self._impls[impl_idx]

    def get_impls(self):
        return self._impls


class DistributedOperatorImpl:
    def __init__(self):
        self._name = None
36 37
        self._forward_implemented = False
        self._backward_implemented = False
38

39 40
    @staticmethod
    def forward(dist_ctx, *args, **kwargs):
41 42
        raise NotImplementedError("Please Implement this method in Subclass.")

43 44
    @staticmethod
    def backward(dist_ctx, *grad_outputs, **kwargs):
45 46 47 48 49
        raise NotImplementedError("Please Implement this method in Subclass.")

    def get_name(self):
        return self._name

50
    def is_input_compatible(self, dist_op):
51 52
        raise NotImplementedError("Please Implement this method in Subclass.")

53
    def is_output_compatible(self, dist_op):
54 55
        raise NotImplementedError("Please Implement this method in Subclass.")

56 57 58
    def is_compatible(self, dist_op):
        return self.is_input_compatible(dist_op) and \
            self.is_output_compatible(dist_op)
59

沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
60 61 62
    def is_auto_compatible(self, dist_op):
        raise NotImplementedError("Please Implement this method in Subclass.")

63
    def update_dims_mapping(self, dist_op):
64 65 66
        raise NotImplementedError("Please Implement this method in Subclass.")


67 68 69
def register_distributed_operator_impl_container(name, dist_op_impl_container):
    global _g_distributed_operator_impl_registries
    _g_distributed_operator_impl_registries[name] = dist_op_impl_container
70 71


72 73 74
def get_distributed_operator_impl_container(name):
    global _g_distributed_operator_impl_registries
    return _g_distributed_operator_impl_registries.get(name, None)
75 76 77


def register_distributed_operator_impl(name, dist_impl):
78 79 80
    dist_op_impl_container = get_distributed_operator_impl_container(name)
    if dist_op_impl_container is not None:
        dist_op_impl_container.register_impl(dist_impl)
81
    else:
82
        assert False, "Must register distributed operator registry first."
83 84 85


def get_distributed_operator_impl(name, impl_idx):
86 87
    global _g_distributed_operator_impl_registries
    return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
88 89


90
def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
91 92 93 94
    """
    Here just return the first compatible implemention. 
    This will be improved by cost model in the future.
    """
95 96
    dist_op_impl_container = get_distributed_operator_impl_container(name)
    if dist_op_impl_container is None:
97 98
        return None, -1
    compatible_impls = []
99
    impls = dist_op_impl_container.get_impls()
100 101
    if fwd:
        for idx, impl in enumerate(impls):
102
            if impl.is_input_compatible(dist_op):
103 104 105
                compatible_impls.append((impl, idx))
    else:
        for idx, impl in enumerate(impls):
106
            if impl.is_output_compatible(dist_op):
107 108 109 110 111 112 113 114
                compatible_impls.append((impl, idx))

    if compatible_impls:
        best_compatible_impl, idx = compatible_impls[0]
    else:
        best_compatible_impl, idx = None, -1

    return best_compatible_impl, idx
115 116


Z
zhaoyingli 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
    var_shape = block.var(src_var.name).shape
    var_topoloy = src_var_dist_attr.process_mesh.topology
    var_dims_mapping = src_var_dist_attr.dims_mapping

    complete_shape = []
    for idx, shape in enumerate(var_shape):
        if var_dims_mapping[idx] == -1:
            complete_shape.append(shape)
        else:
            new_shape = shape * var_topoloy[var_dims_mapping[idx]]
            complete_shape.append(new_shape)

    exact_shape = []
    input_topology = op_input_dist_attr.process_mesh.topology
    input_dims_mapping = op_input_dist_attr.dims_mapping
    for idx, shape in enumerate(complete_shape):
        if input_dims_mapping[idx] == -1:
            exact_shape.append(shape)
        else:
            new_shape = shape // input_topology[input_dims_mapping[idx]]
            exact_shape.append(new_shape)

    return exact_shape