dist_default.py 24.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
from .common import DistributedOperatorImplContainer
16
from .common import DistributedOperatorImpl
17
from .common import register_distributed_operator_impl_container
J
JZ-LIANG 已提交
18
from .common import register_distributed_operator_impl, is_parameter_related
19 20
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
21
from ..utils import is_valid_list_index, is_prim_op
22 23 24
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
25
from ..utils import set_dist_op_desc_original_id
26
from ..dist_attribute import OperatorDistributedAttribute
27
from paddle.fluid import core, unique_name
J
Jiabin Yang 已提交
28
from paddle.fluid.framework import _non_static_mode
29 30 31
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
32
from ..process_group import new_process_group
33 34
from ..utils import _get_comm_group, _get_corresponding_rank

35 36
__op_not_need_param_init__ = ["while", "cond"]

37

38 39 40 41 42 43 44 45 46 47 48 49
def prim_operator_data_parallel_functor(ctx, src_op):
    dist_op_context = ctx.dist_op_context
    main_block = dist_op_context.work_block
    startup_block = dist_op_context.startup_block

    var_name = src_op.output_arg_names[0]
    if var_name in ctx.grads_params:
        assert var_name not in ctx.synced_gradient, "in primtive mode, grad is already {} synced".format(
            var_name)
        ctx.synced_gradient.add(var_name)
        sync_group = new_process_group(ctx.data_parallel_group)

50 51 52 53 54 55 56 57
        allreduce_op = main_block.append_op(type='c_allreduce_sum',
                                            inputs={'X': [var_name]},
                                            outputs={'Out': [var_name]},
                                            attrs={
                                                'ring_id': sync_group.id,
                                                'use_calc_stream': True,
                                                OP_ROLE_KEY: OpRole.Backward
                                            })
58 59 60

        param = ctx.grads_params[var_name]
        startup_block = dist_op_context.startup_block
61 62 63 64 65 66 67 68 69
        new_op = startup_block.append_op(type='c_broadcast',
                                         inputs={'X': [param]},
                                         outputs={'Out': [param]},
                                         attrs={
                                             'ring_id': sync_group.id,
                                             'root': 0,
                                             'use_calc_stream': True,
                                             OP_ROLE_KEY: OpRole.Forward
                                         })
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

        grad_var = main_block.var(var_name)
        dims_mapping = ctx.get_tensor_dist_attr_for_program(
            grad_var).dims_mapping
        dist_attr = ctx.get_op_dist_attr_for_program(src_op)
        process_mesh = dist_attr.process_mesh
        op_attr = OperatorDistributedAttribute()
        op_attr.process_mesh = process_mesh
        op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
        op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
        ctx.set_op_dist_attr_for_program(allreduce_op, op_attr)

    return


85
class DistributedDefault(DistributedOperatorImplContainer):
86

87 88
    def __init__(self, op_type):
        super(DistributedDefault, self).__init__(op_type)
89 90


91
register_distributed_operator_impl_container(DistributedDefault("default"))
92 93


94
# Replicated Default
95
class DistributedDefaultImpl0(DistributedOperatorImpl):
96

97
    def __init__(self, name):
98
        super(DistributedDefaultImpl0, self).__init__(name)
99 100 101
        self._forward_implemented = True
        self._backward_implemented = True

102
    def is_input_compatible(self, dist_op):
103 104
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
105
        batch_dim_mappings = []
106 107 108 109
        input_names = op_desc.input_names()
        xshape_arg_names = []
        if "XShape" in input_names:
            xshape_arg_names = op_desc.input("XShape")
110 111 112
        for arg_name in op_desc.input_arg_names():
            serial_tensor = dist_op.get_serial_input(arg_name)
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
113 114 115 116
            if serial_tensor.is_parameter:
                for mapping in dims_mapping:
                    if mapping != -1:
                        return False
117
                continue
118 119 120 121 122
            if arg_name not in xshape_arg_names:
                if len(dims_mapping) > 1:
                    for mapping in dims_mapping[1:]:
                        if mapping != -1:
                            return False
123 124
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
125 126 127 128 129 130 131
            else:
                if dims_mapping[0] != -1:
                    return False
                if len(dims_mapping) > 2:
                    for mapping in dims_mapping[2:]:
                        if mapping != -1:
                            return False
132 133 134 135 136 137
                if len(dims_mapping) >= 2:
                    batch_dim_mappings.append(dims_mapping[1])

        if compute_compatible_dim_mapping(batch_dim_mappings) is None:
            return False

138
        return True
139

140
    def is_output_compatible(self, dist_op):
141 142 143
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        output_names = op_desc.output_names()
144
        batch_dim_mappings = []
145 146 147 148 149 150
        xshape_arg_names = []
        if "XShape" in output_names:
            xshape_arg_names = op_desc.output("XShape")
        for arg_name in op_desc.output_arg_names():
            serial_tensor = dist_op.get_serial_output(arg_name)
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
151 152 153 154
            if serial_tensor.is_parameter:
                for mapping in dims_mapping:
                    if mapping != -1:
                        return False
155
                continue
156 157 158 159 160
            if arg_name not in xshape_arg_names:
                if len(dims_mapping) > 1:
                    for mapping in dims_mapping[1:]:
                        if mapping != -1:
                            return False
161 162
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
163 164 165 166 167 168 169
            else:
                if dims_mapping[0] != -1:
                    return False
                if len(dims_mapping) > 2:
                    for mapping in dims_mapping[2:]:
                        if mapping != -1:
                            return False
170 171 172 173 174 175
                if len(dims_mapping) >= 2:
                    batch_dim_mappings.append(dims_mapping[1])

        if compute_compatible_dim_mapping(batch_dim_mappings) is None:
            return False

176 177 178 179 180 181 182
        return True

    def is_auto_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        batch_dim_mappings = []
        # Check input compatibility
183 184 185 186
        input_names = op_desc.input_names()
        xshape_arg_names = []
        if "XShape" in input_names:
            xshape_arg_names = op_desc.input("XShape")
187 188
        for arg_name in op_desc.input_arg_names():
            serial_tensor = dist_op.get_serial_input(arg_name)
189
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
190
            if serial_tensor is not None and serial_tensor.is_parameter:
191 192 193
                for mapping in dims_mapping:
                    if mapping != -1:
                        return False
194
                continue
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
            if arg_name not in xshape_arg_names:
                if len(dims_mapping) > 1:
                    for mapping in dims_mapping[1:]:
                        if mapping != -1:
                            return False
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
            else:
                if dims_mapping[0] != -1:
                    return False
                if len(dims_mapping) > 2:
                    for mapping in dims_mapping[2:]:
                        if mapping != -1:
                            return False
                if len(dims_mapping) >= 2:
                    batch_dim_mappings.append(dims_mapping[1])
211 212 213 214 215 216 217 218

        # Check output compatibility
        output_names = op_desc.output_names()
        xshape_arg_names = []
        if "XShape" in output_names:
            xshape_arg_names = op_desc.output("XShape")
        for arg_name in op_desc.output_arg_names():
            serial_tensor = dist_op.get_serial_output(arg_name)
219
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
220
            if serial_tensor is not None and serial_tensor.is_parameter:
221 222 223
                for mapping in dims_mapping:
                    if mapping != -1:
                        return False
224 225 226 227 228 229
                continue
            if arg_name not in xshape_arg_names:
                if len(dims_mapping) > 1:
                    for mapping in dims_mapping[1:]:
                        if mapping != -1:
                            return False
230 231
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
232 233 234 235 236 237 238
            else:
                if dims_mapping[0] != -1:
                    return False
                if len(dims_mapping) > 2:
                    for mapping in dims_mapping[2:]:
                        if mapping != -1:
                            return False
239 240
                if len(dims_mapping) >= 2:
                    batch_dim_mappings.append(dims_mapping[1])
241 242 243 244 245 246 247

        # Check batch dim mapping compatibility
        if not all(batch_dim_mappings[0] == dim_mapping
                   for dim_mapping in batch_dim_mappings):
            return False

        return True
248

249
    def update_dims_mapping(self, dist_op):
250 251 252
        changed = False
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
253 254

        if op_desc.type() == "while":
255
            return False
256 257 258 259 260 261

        input_names = op_desc.input_names()
        input_xshape_arg_names = []
        if "XShape" in input_names:
            input_xshape_arg_names = op_desc.input("XShape")

262
        output_names = op_desc.output_names()
263
        output_xshape_arg_names = []
264
        if "XShape" in output_names:
265 266
            output_xshape_arg_names = op_desc.output("XShape")

267 268 269 270 271 272
        batch_dim_mappings = []
        for arg_name in op_desc.input_arg_names():
            serial_tensor = dist_op.get_serial_input(arg_name)
            if serial_tensor.is_parameter:
                continue
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
273 274 275 276 277
            if arg_name not in input_xshape_arg_names:
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
            else:
                batch_dim_mappings.append(dims_mapping[1])
278
        for arg_name in op_desc.output_arg_names():
279
            if op_desc.type() == "fill_zeros_like":
280 281
                input_tensor = dist_op.get_serial_input(
                    op_desc.input_arg_names()[0])
282 283
                if input_tensor.is_parameter:
                    continue
284 285 286 287
            serial_tensor = dist_op.get_serial_output(arg_name)
            if serial_tensor.is_parameter:
                continue
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
288
            if arg_name not in output_xshape_arg_names:
289 290
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
291 292 293
            else:
                batch_dim_mappings.append(dims_mapping[1])

294 295 296
        if not batch_dim_mappings:
            return changed

297 298
        compatible_dim_mapping = compute_compatible_dim_mapping(
            batch_dim_mappings)
299 300 301
        if compatible_dim_mapping is None:
            return False

302 303 304 305 306
        for arg_name in op_desc.input_arg_names():
            serial_tensor = dist_op.get_serial_input(arg_name)
            if serial_tensor.is_parameter:
                continue
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
307 308 309 310 311 312 313 314 315 316
            if arg_name not in input_xshape_arg_names:
                if len(dims_mapping) >= 1 and \
                    compatible_dim_mapping != dims_mapping[0]:
                    dims_mapping[0] = compatible_dim_mapping
                    changed = True
            else:
                if len(dims_mapping) >= 2 and \
                    compatible_dim_mapping != dims_mapping[1]:
                    dims_mapping[1] = compatible_dim_mapping
                    changed = True
317
        for arg_name in op_desc.output_arg_names():
318
            if op_desc.type() == "fill_zeros_like":
319 320
                input_tensor = dist_op.get_serial_input(
                    op_desc.input_arg_names()[0])
321 322
                if input_tensor.is_parameter:
                    continue
323 324
            if op_desc.type() in ["shape", "slice"]:
                continue
325 326 327 328
            serial_tensor = dist_op.get_serial_output(arg_name)
            if serial_tensor.is_parameter:
                continue
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
329
            if arg_name not in output_xshape_arg_names:
330 331
                if len(dims_mapping
                       ) >= 1 and compatible_dim_mapping != dims_mapping[0]:
332 333 334
                    dims_mapping[0] = compatible_dim_mapping
                    changed = True
            else:
335 336
                if len(dims_mapping
                       ) >= 2 and compatible_dim_mapping != dims_mapping[1]:
337 338 339 340
                    dims_mapping[1] = compatible_dim_mapping
                    changed = True

        return changed
341 342 343

    @staticmethod
    def forward(ctx, *args, **kwargs):
344
        dist_op_context = ctx.dist_op_context
345 346 347 348
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
349

350
        # check validation of inputs / outputs
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
                input_name)
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
                output_name)
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
                output_name)

        # replicate op in dist program
366
        dist_op_desc = main_block.append_op(type='nop').desc
367
        dist_op_desc.copy_from(src_op.desc)
368
        set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
369 370 371 372 373
        for input_name in src_op.desc.input_names():
            dist_op_desc.set_input(input_name, kwargs[input_name])
        for output_name in src_op.desc.output_names():
            dist_op_desc.set_output(output_name, kwargs[output_name])

374 375 376 377 378 379
        # data parallel synchronization for primtive operators
        from paddle.incubate.autograd import prim_enabled
        if prim_enabled():
            assert is_prim_op(src_op)
            prim_operator_data_parallel_functor(ctx, src_op)
            return
380 381

        # param initialization sync
382 383 384
        if src_op.type in __op_not_need_param_init__:
            return

385 386 387
        for varname in dist_op_desc.input_arg_names():
            if startup_block.has_var(varname) and startup_block.var(
                    varname
388 389
            ).is_parameter and varname not in dist_op_context.already_init_sync_vars:
                dist_op_context.already_init_sync_vars.add(varname)
390
                param = startup_block.var(varname)
391 392 393
                param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
                process_mesh = param_dist_attr.process_mesh
                dims_mapping = param_dist_attr.dims_mapping
394 395

                # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
396 397 398
                if rank_id not in process_mesh.processes:
                    rank_id = _get_corresponding_rank(ctx, process_mesh,
                                                      rank_id)
399

400
                # NOTE all not splited axis should be presented in mesh
401 402 403 404
                for axis, size in enumerate(process_mesh.topology):
                    if size <= 1 or axis in dims_mapping:
                        pass
                    else:
405 406 407
                        group_ranks = _get_comm_group(process_mesh.processes,
                                                      process_mesh.topology,
                                                      axis, rank_id)
408 409
                        sync_group = new_process_group(group_ranks)

410 411 412 413 414 415 416 417 418 419 420 421 422
                        new_op = startup_block.append_op(type='c_broadcast',
                                                         inputs={'X': param},
                                                         outputs={'Out': param},
                                                         attrs={
                                                             'ring_id':
                                                             sync_group.id,
                                                             'root':
                                                             0,
                                                             'use_calc_stream':
                                                             True,
                                                             OP_ROLE_KEY:
                                                             OpRole.Forward
                                                         })
423 424

                        # set distributed attribute
425 426
                        op_attr = OperatorDistributedAttribute()
                        op_attr.process_mesh = process_mesh
427 428 429
                        op_attr.set_output_dims_mapping(param.name,
                                                        dims_mapping)
                        op_attr.set_input_dims_mapping(param.name, dims_mapping)
430
                        ctx.set_op_dist_attr_for_program(new_op, op_attr)
431 432 433 434 435

    @staticmethod
    def backward(ctx, *args, **kwargs):

        # by now the backward function only insert the gradient allreduce for dist op itself
436
        dist_op_context = ctx.dist_op_context
437 438
        main_block = dist_op_context.work_block
        backward_op = dist_op_context.cur_src_op
439
        dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
440 441
        assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
            str(backward_op))
442
        rank_id = dist_op_context.rank_id
443

444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
        # check validation of inputs / outputs
        for input_name in backward_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
                input_name)
            assert len(kwargs[input_name]) == len(
                backward_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in backward_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
                output_name)
            assert len(kwargs[output_name]) == len(
                backward_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
                output_name)

        # replicate op in dist program
460
        dist_op_desc = main_block.append_op(type='nop').desc
461
        dist_op_desc.copy_from(backward_op.desc)
462 463
        # Refer to the related dist op
        set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
464 465 466 467 468
        for input_name in backward_op.desc.input_names():
            dist_op_desc.set_input(input_name, kwargs[input_name])
        for output_name in backward_op.desc.output_names():
            dist_op_desc.set_output(output_name, kwargs[output_name])

469
        # check if need gradient allreduce
470
        # if there is a non-gradient & non-parameter input and its batch dimension is splited,
471 472 473 474
        # we need insert gradient allreduce for the gradient of parameter in its output
        need_gradient_allreduce = False
        for input_name in backward_op.desc.input_names():
            for varname in backward_op.desc.input(input_name):
J
JZ-LIANG 已提交
475 476
                if "@GRAD" not in varname and not is_parameter_related(
                        varname, main_block):
477 478

                    # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
479
                    process_mesh = dist_attr.process_mesh
480 481 482
                    var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

                    # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
483
                    if rank_id not in process_mesh.processes:
484 485
                        rank_id = _get_corresponding_rank(
                            ctx, process_mesh, rank_id)
486 487 488 489 490

                    mesh_shape = process_mesh.topology
                    batch_size_axis = var_dim_mapping[0]
                    if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
                        need_gradient_allreduce = True
491 492 493
                        group_ranks = _get_comm_group(process_mesh.processes,
                                                      process_mesh.topology,
                                                      batch_size_axis, rank_id)
494 495 496 497 498 499
                        dp_degree = len(group_ranks)
                        dp_group = new_process_group(group_ranks)
                        break

        if need_gradient_allreduce:
            allreduce_vars = []
500 501 502 503 504 505 506 507
            for output_name in backward_op.desc.output_names():
                for varname in backward_op.desc.output(output_name):
                    if varname in kwargs["grad_var_to_var"]:
                        fwd_name = kwargs["grad_var_to_var"][varname]
                        if fwd_name not in main_block.vars:
                            continue
                        if is_parameter_related(fwd_name, main_block):
                            allreduce_vars.append(varname)
508 509 510 511

            if len(allreduce_vars) > 0:

                for varname in allreduce_vars:
512
                    added_ops = []
513 514 515 516 517 518 519 520 521 522 523

                    grad_var = main_block.var(varname)
                    allreduce_op = main_block.append_op(
                        type='c_allreduce_sum',
                        inputs={'X': [grad_var]},
                        outputs={'Out': [grad_var]},
                        attrs={
                            'ring_id': dp_group.id,
                            'use_calc_stream': True,
                            OP_ROLE_KEY: OpRole.Backward
                        })
524
                    added_ops.append(allreduce_op)
525

526 527 528 529 530 531 532 533 534 535
                    if ctx.gradient_scale:
                        scale_op = main_block.append_op(
                            type='scale',
                            inputs={'X': grad_var},
                            outputs={'Out': grad_var},
                            attrs={
                                'scale': 1.0 / dp_degree,
                                OP_ROLE_KEY: OpRole.Backward
                            })
                        added_ops.append(scale_op)
536

537 538 539
                    dims_mapping = ctx.get_tensor_dist_attr_for_program(
                        grad_var).dims_mapping
                    process_mesh = dist_attr.process_mesh
540
                    for op in added_ops:
541 542
                        op_attr = OperatorDistributedAttribute()
                        op_attr.process_mesh = process_mesh
543 544 545 546
                        op_attr.set_output_dims_mapping(grad_var.name,
                                                        dims_mapping)
                        op_attr.set_input_dims_mapping(grad_var.name,
                                                       dims_mapping)
547
                        ctx.set_op_dist_attr_for_program(op, op_attr)
548 549 550 551


register_distributed_operator_impl(
    "default", DistributedDefaultImpl0("replicate_parallel"))