common.py 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
_g_distributed_operator_impl_registries = {}
16 17


18
class DistributedOperatorImplContainer:
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
    def __init__(self):
        self._impls = []
        self._name = None

    def register_impl(self, dist_impl):
        self._impls.append(dist_impl)

    def get_impl(self, impl_idx):
        return self._impls[impl_idx]

    def get_impls(self):
        return self._impls


class DistributedOperatorImpl:
    def __init__(self):
        self._name = None
36 37
        self._forward_implemented = False
        self._backward_implemented = False
38

39 40
    @staticmethod
    def forward(dist_ctx, *args, **kwargs):
41 42
        raise NotImplementedError("Please Implement this method in Subclass.")

43 44
    @staticmethod
    def backward(dist_ctx, *grad_outputs, **kwargs):
45 46 47 48 49
        raise NotImplementedError("Please Implement this method in Subclass.")

    def get_name(self):
        return self._name

50
    def is_input_compatible(self, dist_op):
51 52
        raise NotImplementedError("Please Implement this method in Subclass.")

53
    def is_output_compatible(self, dist_op):
54 55
        raise NotImplementedError("Please Implement this method in Subclass.")

56 57 58
    def is_compatible(self, dist_op):
        return self.is_input_compatible(dist_op) and \
            self.is_output_compatible(dist_op)
59

60
    def update_dims_mapping(self, dist_op):
61 62 63
        raise NotImplementedError("Please Implement this method in Subclass.")


64 65 66
def register_distributed_operator_impl_container(name, dist_op_impl_container):
    global _g_distributed_operator_impl_registries
    _g_distributed_operator_impl_registries[name] = dist_op_impl_container
67 68


69 70 71
def get_distributed_operator_impl_container(name):
    global _g_distributed_operator_impl_registries
    return _g_distributed_operator_impl_registries.get(name, None)
72 73 74


def register_distributed_operator_impl(name, dist_impl):
75 76 77
    dist_op_impl_container = get_distributed_operator_impl_container(name)
    if dist_op_impl_container is not None:
        dist_op_impl_container.register_impl(dist_impl)
78
    else:
79
        assert False, "Must register distributed operator registry first."
80 81 82


def get_distributed_operator_impl(name, impl_idx):
83 84
    global _g_distributed_operator_impl_registries
    return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
85 86


87
def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
88 89 90 91
    """
    Here just return the first compatible implemention. 
    This will be improved by cost model in the future.
    """
92 93
    dist_op_impl_container = get_distributed_operator_impl_container(name)
    if dist_op_impl_container is None:
94 95
        return None, -1
    compatible_impls = []
96
    impls = dist_op_impl_container.get_impls()
97 98
    if fwd:
        for idx, impl in enumerate(impls):
99
            if impl.is_input_compatible(dist_op):
100 101 102
                compatible_impls.append((impl, idx))
    else:
        for idx, impl in enumerate(impls):
103
            if impl.is_output_compatible(dist_op):
104 105 106 107 108 109 110 111
                compatible_impls.append((impl, idx))

    if compatible_impls:
        best_compatible_impl, idx = compatible_impls[0]
    else:
        best_compatible_impl, idx = None, -1

    return best_compatible_impl, idx
112 113


114 115 116 117 118 119 120 121 122 123
def copy_distributed_attr_for_var(dist_context, dst_var, src_var):
    """
    copy src var's dist_attr to dst var
    """
    dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
    dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr)


def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block,
                                      src_op_dist_attr):
124 125 126
    """
    copy src op's dist_attr to dst dist op
    """
127 128
    from ..dist_attribute import OperatorDistributedAttribute
    # need check dist op attr and its inputs and outputs
129

130 131 132
    op_dist_attr = OperatorDistributedAttribute()
    op_dist_attr.process_mesh = src_op_dist_attr.process_mesh
    op_dist_attr.impl_idx = src_op_dist_attr.impl_idx
133 134 135

    for input_varname in dist_op.desc.input_arg_names():
        input_var = dst_block.var(input_varname)
136
        tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
137
            input_var)
138
        op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
139 140 141

    for output_varname in dist_op.desc.output_arg_names():
        output_var = dst_block.var(output_varname)
142
        tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
143
            output_var)
144 145 146 147
        op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)

    dist_context.set_op_dist_attr_for_program(dist_op, op_dist_attr)
    op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_op)