dist_eltwise.py 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
18
from .common import register_distributed_operator_impl, is_parameter_related
19 20 21 22 23 24 25 26 27
from .common import is_elementwise_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
J
Jiabin Yang 已提交
28
from paddle.fluid.framework import _non_static_mode
29 30 31 32 33 34
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0
35 36 37
from ..cost import _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
38 39 40


class DistributedElementwise(DistributedOperatorImplContainer):
41

42 43 44 45 46 47 48 49 50 51
    def __init__(self, op_type):
        super(DistributedElementwise, self).__init__(op_type)


register_distributed_operator_impl_container(
    DistributedElementwise("elementwise"))


# Replicated Elementwise
class DistributedElementwiseImpl0(DistributedOperatorImpl):
52

53 54 55 56 57
    def __init__(self, name):
        super(DistributedElementwiseImpl0, self).__init__(name)
        self._forward_implemented = False
        self._backward_implemented = False

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        """Calculate the cost by the op role."""
        cost = None
        if int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        else:
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
        desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
                                                    dist_context=ctx)
        processes = dist_op.dist_attr.process_mesh.processes
        op_type = dist_op.serial_op.type
        cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
                                                   ctx, processes, desc_mapping,
                                                   cluster)
        res_cost = [cost_mapping]

        return res_cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
        res = []
        desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
                                                    dist_context=ctx)
        dist_attr = dist_op.dist_attr
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
        backward_op = dist_op.serial_op
        op_type = backward_op.type
        cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
                                                   ctx, processes, desc_mapping,
                                                   cluster)
        res.append(cost_mapping)

        main_block = backward_op.block
        vars = main_block.vars
        need_gradient_allreduce = False
        for input_name in backward_op.desc.input_names():
            for varname in backward_op.desc.input(input_name):
                if "@GRAD" not in varname and not is_parameter_related(
                        varname, main_block):
                    var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
                    mesh_shape = process_mesh.topology
                    batch_size_axis = var_dim_mapping[0]
                    if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
                        need_gradient_allreduce = True
                        break

        if need_gradient_allreduce:
            for input_name in backward_op.desc.input_names():
                for varname in backward_op.desc.input(input_name):
                    if "@GRAD" not in varname and is_parameter_related(
                            varname, main_block):
                        var_dim_mapping = dist_attr.get_input_dims_mapping(
                            varname)
                        mesh_shape = process_mesh.topology
                        batch_size_axis = var_dim_mapping[0]
                        parallel_axis = batch_size_axis
                        attrs = {"use_calc_stream": True}
                        var_names = [varname + "@GRAD"]
                        build_dp_costs(res, dist_op, ctx, var_names, attrs,
                                       parallel_axis, cluster)
        return res

126 127
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
128
        if not is_elementwise_op(op_desc.type()):
129
            return False
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        op_dist_attr = dist_op.dist_attr
        dims_mapping_list = []
        input_arg_names = op_desc.input_arg_names()
        max_dims_mapping_len = -1
        for arg_name in input_arg_names:
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
            if max_dims_mapping_len < len(dims_mapping):
                max_dims_mapping_len = len(dims_mapping)
            dims_mapping_list.append(dims_mapping)

        for idx in range(max_dims_mapping_len):
            dim_mappings = []
            for dims_mapping in dims_mapping_list:
                if idx < len(dims_mapping):
                    dim_mappings.append(dims_mapping[-(idx + 1)])
            if compute_compatible_dim_mapping(dim_mappings) is None:
                return False
        return True
148 149 150

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
151 152 153 154 155
        if not is_elementwise_op(op_desc.type()):
            return False
        op_dist_attr = dist_op.dist_attr
        dims_mapping_list = []
        output_arg_names = op_desc.output_arg_names()
156
        max_dims_mapping_len = -1
157 158
        for arg_name in output_arg_names:
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
159 160
            if max_dims_mapping_len < len(dims_mapping):
                max_dims_mapping_len = len(dims_mapping)
161 162
            dims_mapping_list.append(dims_mapping)

163 164 165 166 167 168 169
        for idx in range(max_dims_mapping_len):
            dim_mappings = []
            for dims_mapping in dims_mapping_list:
                if idx < len(dims_mapping):
                    dim_mappings.append(dims_mapping[-(idx + 1)])
            if compute_compatible_dim_mapping(dim_mappings) is None:
                return False
170
        return True
171 172 173

    def is_auto_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
174 175
        if not is_elementwise_op(op_desc.type()):
            return False
176 177
        op_dist_attr = dist_op.dist_attr
        dims_mapping_list = []
178

179
        input_arg_names = op_desc.input_arg_names()
180
        input_max_dims_mapping_len = -1
181 182
        for arg_name in input_arg_names:
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
183 184
            if input_max_dims_mapping_len < len(dims_mapping):
                input_max_dims_mapping_len = len(dims_mapping)
185
            dims_mapping_list.append(dims_mapping)
186

187
        output_arg_names = op_desc.output_arg_names()
188
        output_max_dims_mapping_len = -1
189 190
        for arg_name in output_arg_names:
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
191 192
            if output_max_dims_mapping_len < len(dims_mapping):
                output_max_dims_mapping_len = len(dims_mapping)
193 194
            dims_mapping_list.append(dims_mapping)

195 196 197
        assert input_max_dims_mapping_len == output_max_dims_mapping_len
        max_dims_mapping_len = input_max_dims_mapping_len

198 199 200 201 202 203 204 205 206 207 208 209 210 211
        for idx in range(max_dims_mapping_len):
            dim_mappings = []
            for dims_mapping in dims_mapping_list:
                if idx < len(dims_mapping):
                    dim_mappings.append(dims_mapping[-(idx + 1)])
            if not all(dim_mappings[0] == dim_mapping
                       for dim_mapping in dim_mappings):
                return False
        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
212 213
        dims_mapping_list = []

214 215 216
        input_arg_names = op_desc.input_arg_names()
        input_dims_mapping_dict = {}
        input_dims_mapping_lens = {}
217
        input_max_dims_mapping_len = -1
218 219
        for arg_name in input_arg_names:
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
220 221
            if input_max_dims_mapping_len < len(dims_mapping):
                input_max_dims_mapping_len = len(dims_mapping)
222 223 224
            input_dims_mapping_dict[arg_name] = dims_mapping
            input_dims_mapping_lens[arg_name] = len(dims_mapping)
        for arg_name in input_arg_names:
225 226 227 228
            if input_dims_mapping_lens[arg_name] < input_max_dims_mapping_len:
                new_dims_mapping = [
                    -1 for _ in range(input_max_dims_mapping_len)
                ]
229
                for i in range(input_dims_mapping_lens[arg_name]):
230
                    new_idx = (input_max_dims_mapping_len -
231 232 233 234 235 236
                               input_dims_mapping_lens[arg_name]) + i
                    new_dims_mapping[new_idx] = input_dims_mapping_dict[
                        arg_name][i]
                dims_mapping_list.append(new_dims_mapping)
            else:
                dims_mapping_list.append(input_dims_mapping_dict[arg_name])
237

238
        output_arg_names = op_desc.output_arg_names()
239 240 241
        output_dims_mapping_dict = {}
        output_dims_mapping_lens = {}
        output_max_dims_mapping_len = -1
242 243
        for arg_name in output_arg_names:
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
            if output_max_dims_mapping_len < len(dims_mapping):
                output_max_dims_mapping_len = len(dims_mapping)
            output_dims_mapping_dict[arg_name] = dims_mapping
            output_dims_mapping_lens[arg_name] = len(dims_mapping)
        for arg_name in output_arg_names:
            if output_dims_mapping_lens[arg_name] < output_max_dims_mapping_len:
                new_dims_mapping = [
                    -1 for _ in range(output_max_dims_mapping_len)
                ]
                for i in range(output_dims_mapping_lens[arg_name]):
                    new_idx = (output_max_dims_mapping_len -
                               output_dims_mapping_lens[arg_name]) + i
                    new_dims_mapping[new_idx] = output_dims_mapping_dict[
                        arg_name][i]
                dims_mapping_list.append(new_dims_mapping)
            else:
                dims_mapping_list.append(output_dims_mapping_dict[arg_name])
261

262 263
        assert input_max_dims_mapping_len == output_max_dims_mapping_len
        max_dims_mapping_len = input_max_dims_mapping_len
264 265
        compatible_dims_mapping = compute_compatible_dims_mapping(
            dims_mapping_list)
266 267
        if compatible_dims_mapping is None:
            return False
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283

        for arg_name in input_arg_names:
            if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
                new_dims_mapping = [
                    -1 for _ in range(input_dims_mapping_lens[arg_name])
                ]
                for i in range(input_dims_mapping_lens[arg_name]):
                    new_idx = (max_dims_mapping_len -
                               input_dims_mapping_lens[arg_name]) + i
                    new_dims_mapping[i] = compatible_dims_mapping[new_idx]
                if new_dims_mapping != input_dims_mapping_dict[arg_name]:
                    op_dist_attr.set_input_dims_mapping(arg_name,
                                                        new_dims_mapping)
                    changed = True
            else:
                if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
284 285
                    op_dist_attr.set_input_dims_mapping(
                        arg_name, compatible_dims_mapping)
286 287 288
                    changed = True

        for arg_name in output_arg_names:
289 290 291 292 293 294 295 296 297
            if output_dims_mapping_lens[arg_name] < max_dims_mapping_len:
                new_dims_mapping = [
                    -1 for _ in range(output_dims_mapping_lens[arg_name])
                ]
                for i in range(output_dims_mapping_lens[arg_name]):
                    new_idx = (max_dims_mapping_len -
                               output_dims_mapping_lens[arg_name]) + i
                    new_dims_mapping[i] = compatible_dims_mapping[new_idx]
                if new_dims_mapping != output_dims_mapping_dict[arg_name]:
298 299
                    op_dist_attr.set_output_dims_mapping(
                        arg_name, new_dims_mapping)
300 301
                    changed = True
            else:
302
                if compatible_dims_mapping != output_dims_mapping_dict[arg_name]:
303 304 305
                    op_dist_attr.set_output_dims_mapping(
                        arg_name, compatible_dims_mapping)
                    changed = True
306 307 308 309 310 311 312 313 314 315 316 317 318 319

        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.backward(ctx, *args, **kwargs)


register_distributed_operator_impl(
    "elementwise", DistributedElementwiseImpl0("replicate_parallel"))