auto_parallel_data_parallel_optimization.py 29.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict

import paddle
18 19 20 21
from paddle.distributed.auto_parallel.dist_attribute import (
    OperatorDistAttr,
    TensorDistAttr,
)
22 23
from paddle.distributed.auto_parallel.operators.common import (
    is_data_parallel_reduce_op,
24
    is_data_parallel_scale_op,
25
)
26
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
27 28
from paddle.distributed.auto_parallel.utils import (
    find_higher_order_backward_op,
29
    get_var_numel,
30
    insert_dependencies_for_vars,
31
    is_forward_op,
32 33 34 35
    is_loss_grad_op,
    is_optimize_op,
    ring_id_to_process_group,
)
36
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
37
from paddle.fluid.executor import _is_enable_standalone_executor
38 39
from paddle.static import default_main_program
from paddle.utils import unique_name
40 41

from .pass_base import PassBase, PassType, register_pass
42 43 44

# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [
45 46 47 48 49
    'lars_momentum',
    'sparse_momentum',
    'dgc_momentum',
    'momentum',
    'merge_momentum',
50 51
]

52 53 54
# a heuristic number
__max_stream_num_allow__ = 16

55 56 57 58 59

@register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase):
    """
    Apply Optimizations that specialized for data parallelism in Auto Parallel.
60
    1. prune grad scaling
61 62 63 64 65
    2. overlap comm and calc
    3. fuse allreduce
    """

    def __init__(self):
66
        super().__init__()
67 68 69
        # NOTE not use depence on loss and param_grads
        self.set_attr("dist_context", None)
        self.set_attr("global_rank", -1)
70
        self.set_attr("use_sharding", False)
71 72 73 74 75 76 77 78 79 80
        # {grad1: group1, grad2: group1, grad3: group2}
        # record the order for fuse grad data memory
        self._grad_name_to_group_map = OrderedDict()
        # {group1:[grad1, grad2] , group2:[grad3]}
        self._group_to_grad_name_map = OrderedDict()
        self._support_rescale_grad = False

    def _check_self(self):
        if self.get_attr("dist_context") is None:
            return False
81 82 83
        if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr(
            "global_rank"
        ) < 0:
84 85 86 87 88 89 90 91 92 93 94 95 96 97
            return False

        return True

    def _check_conflict(self, other_pass):
        return True

    def _type(self):
        return PassType.COMM_OPT

    def _apply_single_impl(self, main_program, startup_program, context):

        self.dist_context = self.get_attr("dist_context")
        self.global_rank = int(self.get_attr("global_rank"))
98
        self.use_sharding = self.get_attr("use_sharding")
99
        self.coalesce_prefix = 'coalesce_grad'
100
        if _is_enable_standalone_executor():
101
            self.gradient_sync_stream = "gradient_sync_stream"
102 103 104

        with paddle.static.program_guard(main_program, startup_program):
            self._analyze_program()
J
JZ-LIANG 已提交
105

106
            # TODO refactor here to first fuse then overlap
J
JZ-LIANG 已提交
107 108 109 110
            if self.is_data_parallel_applied():
                self._prune_grad_scaling()
                self._calc_comm_overlap()
                grad_group = self._fuse_allreduce()
111 112
                self._add_dependencies(grad_group)
                self.summary(grad_group)
113 114 115 116 117 118 119 120 121 122 123 124 125

    def _prune_grad_scaling(self):

        if not self._could_be_prune():
            return

        if self._all_dp_groups_same_degree():
            self._scale_backward_initial_grad()
        else:
            self._update_opt_rescale_grad()

        self._remove_grad_scaling()

126 127 128
    def _calc_comm_overlap(self):
        if not self._could_be_overlap():
            return
129 130
        self._comms_overlap_calc()
        self._calc_wait_comms()
131 132

    def _fuse_allreduce(self):
133 134 135 136 137 138

        if not self._could_be_fuse():
            return []

        grad_group = self._group_grads()
        self._update_program(grad_group)
139

140
        return grad_group
141 142 143

    def _analyze_program(self):
        """
144
        build two maps
145 146 147 148 149 150 151 152 153
        {param_grad_name: data_parallel_group}
        {pdata_parallel_group: aram_grad_name}
        """

        block = default_main_program().global_block()
        ops = block.ops
        scaled_grads = []

        for op in ops:
154

155
            if is_data_parallel_reduce_op(op):
156
                grad_name = op.output_arg_names[0]
157 158 159 160
                if grad_name in self._grad_name_to_group_map:
                    continue
                assert op.has_attr(
                    "ring_id"
161
                ), "Unexpected: comm op [{}] has NOT ring id.".format(str(op))
162 163
                group = ring_id_to_process_group(op.attr("ring_id"))

164 165
                assert (
                    group is not None
166
                ), "Unexpected: data parallel group of [{}] from op [{}] is None".format(
167 168
                    grad_name, str(op)
                )
169 170 171 172 173 174 175 176 177

                self._grad_name_to_group_map[grad_name] = group

                if group not in self._group_to_grad_name_map:
                    self._group_to_grad_name_map[group] = [grad_name]
                else:
                    self._group_to_grad_name_map[group].append(grad_name)

            elif is_data_parallel_scale_op(op):
178
                grad_name = op.output_arg_names[0]
179 180 181 182
                scaled_grads.append(grad_name)

            # TODO support multiple optimizers in on network in future.
            # here we assume that the optimizer is unique in network.
183 184 185 186
            elif (
                is_optimize_op(op)
                and op.type in __rescale_grad_supported_opts__
            ):
187 188 189 190 191 192
                self._support_rescale_grad = True

        not_synchronized_grads = []
        for grad_name in scaled_grads:
            if grad_name not in self._grad_name_to_group_map:
                not_synchronized_grads.append(grad_name)
193 194
        assert (
            len(not_synchronized_grads) == 0
195
        ), "Unexpected: gradients [{}] is scaled BUT NOT synchronized.".format(
196
            not_synchronized_grads
197
        )
198

J
JZ-LIANG 已提交
199 200 201
    def is_data_parallel_applied(self):
        return len(self._group_to_grad_name_map) > 0

202 203
    def _could_be_prune(self):

204
        return self.dist_context.gradient_scale and (
205 206
            self._support_rescale_grad or self._all_dp_groups_same_degree()
        )
207 208

    def _all_dp_groups_same_degree(self):
209 210
        return (
            len(
211 212 213 214
                {
                    len(group.ranks)
                    for group in self._group_to_grad_name_map.keys()
                }
215 216 217
            )
            == 1
        )
218 219 220 221 222 223 224 225

    def _scale_backward_initial_grad(self):

        block = default_main_program().global_block()
        dp_degree = len(list(self._group_to_grad_name_map.keys())[0].ranks)

        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
226 227
                assert op.type == 'fill_constant', (
                    "loss_grad_op must be fill_constant op, "
228
                    "but this op is {}".format(op.type)
229
                )
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
                assert op.has_attr('value')
                loss_scale = float(op.attr('value'))
                loss_scale = loss_scale / dp_degree
                op._set_attr('value', loss_scale)
                break

    def _remove_grad_scaling(self):
        block = default_main_program().global_block()

        for op_idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_scale_op(op):
                block._remove_op(op_idx, False)

        block._sync_with_cpp()

    def _update_opt_rescale_grad(self):

        block = default_main_program().global_block()
        scaled_grads = set()

        for idx, op in reversed(list(enumerate(block.ops))):
251 252 253 254
            if (
                is_optimize_op(op)
                and op.type in __rescale_grad_supported_opts__
            ):
255 256
                assert op.has_attr(
                    'rescale_grad'
257
                ), "Unexpected: op [{}] is supported to have [rescale_grad] attribute.".format(
258 259 260 261
                    str(op)
                )
                assert (
                    len(op.input("Grad")) == 1
262
                ), "Unexpected: op [{}] is supported to have only one input grad var.".format(
263 264
                    str(op)
                )
265 266 267

                grad_name = op.input("Grad")[0]
                dp_degree = len(
268 269
                    list(self._grad_name_to_group_map[grad_name].ranks)
                )
270 271 272 273 274
                scaled_grads.add(grad_name)

                rescale_grad = float(op.attr('rescale_grad')) / dp_degree
                op._set_attr('rescale_grad', rescale_grad)

275 276
        assert scaled_grads == set(
            self._grad_name_to_group_map.keys()
277
        ), "Unexpected: gradients [{}] are unscaled.".format(
278 279
            set(self._grad_name_to_group_map.keys()) - scaled_grads
        )
280 281 282 283 284

    def _could_be_overlap(self):
        # NOTE current different nccl comm will use different cuda stream
        # so if there too many dp group there will be too many stream need to be
        # created and sync.
285
        # revise here when framework support custom stream in static graph mode.
286 287 288
        num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys()))
        if num_dp_comm_stream > __max_stream_num_allow__:
            return False
289 290
        if self.use_sharding:
            return False
291 292
        return True

293
    def _comms_overlap_calc(self):
294 295 296 297 298 299 300 301 302 303 304 305 306
        # TODO support InterpreterCore executor for overlap.
        # InterpreterCore has a different logic for overlapping
        # which is different from use_calc_stream
        block = default_main_program().global_block()

        # comm wait calc to finish
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_reduce_op(op):
                assert op.has_attr('use_calc_stream')
                assert op.has_attr('ring_id')

                op._set_attr('use_calc_stream', False)
                ring_id = op.attr("ring_id")
307 308 309 310 311 312 313
                block._insert_op_without_sync(
                    idx,
                    type='c_wait_compute',
                    inputs={'X': []},
                    outputs={'Out': []},
                    attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
                )
314 315 316

        block._sync_with_cpp()

317
    def _calc_wait_comms(self):
318

319
        if _is_enable_standalone_executor():
320 321
            return

322 323
        block = default_main_program().global_block()

324 325 326 327 328 329 330 331 332 333 334
        # NOTE the naive overlap implement in static hybird parallel only sync comm stream
        # at the end of Backward phase, based on a strong constraint that
        # all communicating gradient would NOT be used after communication in Backward phase.
        # BUT this constraint will fail for scenario like Weight-Sharing and Higher-Order Differentiation,
        # where gradient will be involved in other calculation between data-parallel allreduce kernel submmited
        # into comm streams and the synchronization of comm stream at the end of Backward phase.
        # synchronization of  comm stream should add according to the usage of communicating gradients
        # to support Overlapping for Weight-Sharing and Higher-Order Differentiation.

        ring_id_to_un_sync_grad_map = {}
        op_idx_to_sync_ring_id_map = {}
335
        for group in self._group_to_grad_name_map.keys():
336 337 338
            ring_id_to_un_sync_grad_map[group.id] = []

        # analyze the where need to sync
339
        for i, op in enumerate(block.ops):
340 341 342 343 344 345 346 347 348
            if is_data_parallel_reduce_op(op):
                ring_id = op.attr("ring_id")
                grad_name = op.output_arg_names[0]
                ring_id_to_un_sync_grad_map[ring_id].append(grad_name)
            elif is_data_parallel_scale_op(op):
                continue
            # other ops that might use communicating grad
            else:
                for input_var_name in op.input_arg_names:
349 350 351 352
                    for (
                        ring_id,
                        unsync_grad_names,
                    ) in ring_id_to_un_sync_grad_map.items():
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
                        if input_var_name in unsync_grad_names:
                            # need to sync before op_i
                            if i in op_idx_to_sync_ring_id_map:
                                op_idx_to_sync_ring_id_map[i].append(ring_id)
                            else:
                                op_idx_to_sync_ring_id_map[i] = [ring_id]
                            # all grads in this comm stream are synced
                            ring_id_to_un_sync_grad_map[ring_id] = []

        # insert synchronization
        indices = list(op_idx_to_sync_ring_id_map.keys())
        # TODO the synchronization could be optimized
        # we should record the event of a gradient is communicating and
        # only wait for that event to be completed.
        # BUT paddle static currently not support op api for event record only, so
        # here we try to wait for all kernel in that comm stream to be finish which is not that optimized.
        for i in sorted(indices, reverse=True):
            for ring_id in op_idx_to_sync_ring_id_map[i]:

372 373 374 375 376 377 378
                block._insert_op_without_sync(
                    i,
                    type='c_wait_comm',
                    inputs={'X': []},
                    outputs={'Out': []},
                    attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
                )
379
        block._sync_with_cpp()
380 381 382 383 384 385 386 387 388 389 390 391 392 393

    def _could_be_fuse(self):
        # TODO  support gradient fuse higher order gradient.
        # should analyse the dependencies of gradient in backward.
        if find_higher_order_backward_op(default_main_program()):
            return False
        if self.use_sharding:
            return False
        return True

    def _group_grads(self):
        """
        conditions for gradients to be grouped:
        1. group size < max_fuse_numel
394
        2. same dp group
395
        3. same dtype
396
        4. dependency: grad would NOT be used by other ops within group segment
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430

        gradients inside same group would be fuse into one coalesce tensor
        """

        block = default_main_program().global_block()
        ops = block.ops

        # group individual grad vars
        # TODO consider fuse gradient for sharding reduce
        # TODO let user to set fuse_grad_size
        # emb = 50000 * h, ffn = 8 * h * h, mha = 4 * h * h
        h = 2048
        ffn_numel = 2 * (4 * h) * h
        mha_numel = 3 * h * h + h * h
        max_fuse_numel = ffn_numel + mha_numel
        grad_groups = []
        cur_group = GradientsGroup(ops, max_fuse_numel)
        grouped_grad_names = set()

        def collect_group(cur_group, grad_var, ring_id, i):
            if len(cur_group.gradients) == 0:
                cur_group = None
            else:
                cur_group.finalize()
                grad_groups.append(cur_group)

            new_group = GradientsGroup(ops, max_fuse_numel)
            if grad_var:
                new_group.add(grad_var, ring_id, i)
                grouped_grad_names.add(grad_var.name)
            return new_group

        def op_depend_on_group(op, group):
            vars_ = set(op.input_arg_names + op.output_arg_names)
431
            grad_names = {grad.name for grad in group.gradients}
432 433 434 435 436 437 438
            return len(vars_.intersection(grad_names)) > 0

        for i, op in enumerate(ops):
            if is_data_parallel_reduce_op(op):
                ring_id = op.attr("ring_id")
                grad_name = op.output_arg_names[0]
                grad_var = block.var(grad_name)
439
                grad_numel = get_var_numel(grad_var)
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463

                if cur_group.acceptable(grad_var, ring_id):
                    assert grad_name not in grouped_grad_names
                    grouped_grad_names.add(grad_name)
                    cur_group.add(grad_var, ring_id, i)
                else:
                    cur_group = collect_group(cur_group, grad_var, ring_id, i)
            else:
                if op_depend_on_group(op, cur_group):
                    cur_group = collect_group(cur_group, None, None, None)

        # collect last group
        collect_group(cur_group, None, None, None)

        return grad_groups

    def _update_program(self, grad_groups):

        block = default_main_program().global_block()

        remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute']

        for i, group in enumerate(grad_groups[::-1]):

464 465 466 467 468
            # skip unfused big tensor
            if len(group.gradients) <= 1:
                group.coalesce_var = group.gradients[0]
                continue

469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
            ref_process_mesh = set()
            concated_shapes = []
            concated_ranks = []
            for grad_ in group.gradients:
                grad_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(grad_)
                )
                ref_process_mesh.update(
                    set(grad_dist_attr.process_mesh.process_ids)
                )

                shape = grad_.shape
                concated_shapes.extend(shape)
                concated_ranks.append(len(shape))

484
            # create coalesce tensor
485
            group.coalesce_var = block.create_var(
486 487 488
                name=unique_name.generate(
                    self.coalesce_prefix + '_{}'.format(i)
                ),
489 490 491 492
                dtype=group.dtype,
                persistable=False,
                stop_gradient=True,
            )
493

494 495 496 497 498 499 500
            tensor_dist_attr = TensorDistAttr()
            tensor_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
            tensor_dist_attr.dims_mapping = []
            self.dist_context.set_tensor_dist_attr_for_program(
                group.coalesce_var, tensor_dist_attr
            )

501 502 503
            # update allreduce & scale op
            if group.scale_op_idx != -1:
                scale_op = block.ops[group.scale_op_idx]
504 505 506 507 508 509 510 511 512
                assert (
                    scale_op.type == 'scale'
                ), "should found scale op but found {}".format(str(scale_op))
                scale_op._rename_input(
                    scale_op.input_arg_names[0], group.coalesce_var.name
                )
                scale_op._rename_output(
                    scale_op.output_arg_names[0], group.coalesce_var.name
                )
513 514

            allreduce_op = block.ops[group.allreduce_op_idx]
515 516 517 518 519
            assert (
                allreduce_op.type == 'c_allreduce_sum'
            ), "should found c_allreduce_sum op but found {}".format(
                str(allreduce_op)
            )
520 521 522 523 524 525 526 527
            allreduce_op_dist_attr = (
                self.dist_context.get_op_dist_attr_for_program(allreduce_op)
            )
            old_in_name = allreduce_op.input_arg_names[0]
            new_in_name = group.coalesce_var.name
            allreduce_op._rename_input(old_in_name, new_in_name)
            input_dist_attr = allreduce_op_dist_attr.get_input_dist_attr(
                old_in_name
528
            )
529 530 531 532 533 534 535 536 537 538 539 540
            allreduce_op_dist_attr.set_input_dist_attr(
                new_in_name, input_dist_attr
            )

            old_out_name = allreduce_op.output_arg_names[0]
            new_out_name = group.coalesce_var.name
            allreduce_op._rename_output(old_out_name, new_out_name)
            out_dist_attr = allreduce_op_dist_attr.get_output_dist_attr(
                old_out_name
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                new_out_name, out_dist_attr
541
            )
542 543

            # remvoe un-used op
544 545 546 547 548
            remove_op_indices = (
                group.remove_wait_op_indices
                + group.remove_allreduce_op_indices
                + group.remove_scale_op_indices
            )
549
            for idx in sorted(remove_op_indices, reverse=True):
550 551
                assert (
                    block.ops[idx].type in remove_op_types
552
                ), "Unexpected: try to remove op {}".format(str(block.ops[idx]))
553
                block._remove_op(idx, False)
554

555
            # insert coalesce op
556
            grad_names = [grad.name for grad in group.gradients]
557
            coalesce_op = block._insert_op_without_sync(
558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
                group.coalesce_op_idx,
                type="coalesce_tensor",
                inputs={"Input": grad_names},
                outputs={
                    "Output": grad_names,
                    "FusedOutput": group.coalesce_var,
                },
                attrs={
                    "copy_data": False,
                    "use_align": True,
                    "dtype": group.dtype,
                    "concated_shapes": concated_shapes,
                    "concated_ranks": concated_ranks,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
574

575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
            op_dist_attr = OperatorDistAttr()
            op_dist_attr.impl_idx = 0
            op_dist_attr.impl_type = "default"
            op_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
            for in_name in coalesce_op.input_arg_names:
                in_var = block.var(in_name)
                in_var_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(in_var)
                )
                op_dist_attr.set_input_dims_mapping(
                    in_name, in_var_dist_attr.dims_mapping
                )
            for out_name in coalesce_op.output_arg_names:
                out_var = block.var(out_name)
                out_var_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(out_var)
                )
                op_dist_attr.set_output_dims_mapping(
                    out_name, out_var_dist_attr.dims_mapping
                )

            self.dist_context.set_op_dist_attr_for_program(
                coalesce_op, op_dist_attr
            )

600 601
        block._sync_with_cpp()

602 603 604 605 606
    def _add_dependencies(self, grad_groups):
        # NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based
        # multiple stream executor(standalone exe). This function just for standalone exe. Refactor here
        # in future when only one executor stay.

607
        if not _is_enable_standalone_executor() or len(grad_groups) == 0:
608 609 610 611 612 613
            return
        block = default_main_program().global_block()

        # Build maps
        coalesce_to_vars_map = {}
        for group in grad_groups:
614
            coalesce_to_vars_map[group.coalesce_var.name] = group
615 616

        # analyze dependencies
617
        dep_map = {}
618 619 620 621 622 623 624 625 626
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_forward_op(op):
                break
            if is_optimize_op(op):
                continue

            if is_data_parallel_reduce_op(op):
                coalesce_var_name = op.output_arg_names[0]
                if self.coalesce_prefix in coalesce_var_name:
627 628 629 630 631 632 633 634
                    group = coalesce_to_vars_map[coalesce_var_name]
                    dep_map[idx] = [
                        (
                            idx,
                            group.gradients[-1],
                            group.coalesce_var,
                            op.attr(OP_ROLE_KEY),
                        )
635
                    ]
636 637 638 639 640 641 642 643 644 645
                    dep_map[idx].append(
                        (
                            idx + 1,
                            group.coalesce_var,
                            group.gradients,
                            op.attr(OP_ROLE_KEY),
                        )
                    )

        # insert dependency op
646
        indice = sorted(dep_map.keys(), reverse=True)
647 648 649 650 651 652 653 654 655 656 657 658 659 660
        for i in indice:
            for idx, prior_vars, post_vars, op_role in dep_map[i][::-1]:
                depend_op = insert_dependencies_for_vars(
                    block,
                    idx,
                    prior_vars,
                    post_vars,
                    self.dist_context,
                    op_role,
                    is_recompute=False,
                    sync=False,
                    op_namescope="data_parallel_overlap_dep",
                )
                depend_op.dist_attr.execution_stream = self.gradient_sync_stream
661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682
        block._sync_with_cpp()

        # remove naive synchronization & assign allreduce stream
        def remove_cond(op):
            if op.type != "c_wait_compute":
                return False
            if len(op.input_arg_names) != 0:
                return False
            if len(op.output_arg_names) != 0:
                return False
            return True

        for idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_reduce_op(op):
                op._set_attr('use_calc_stream', True)
                op.dist_attr.execution_stream = self.gradient_sync_stream

            if remove_cond(op):
                block._remove_op(idx, sync=False)

        block._sync_with_cpp()

683 684 685
    def summary(self, grad_groups=[]):
        # TODO: add logger module
        import logging
686

687 688 689 690 691 692 693 694 695 696 697 698
        self._logger = logging.getLogger()
        self._logger.propagate = False
        if not self._logger.handlers:
            self._logger.setLevel(logging.INFO)
            log_handler = logging.StreamHandler()
            log_format = logging.Formatter(
                '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
            )
            log_handler.setFormatter(log_format)
            self._logger.addHandler(log_handler)

        if len(grad_groups) > 0:
699
            self._logger.info("Data Parallel Optimization: ")
700
            self._logger.info(
701
                " {} Allreduce ops are fused into {} coalesce allreduce ops.".format(
702 703 704
                    len(self._grad_name_to_group_map.keys()), len(grad_groups)
                )
            )
705
            self._logger.debug("gradient fusing group are following: ")
706 707
            fused_grads = set()
            for i, group in enumerate(grad_groups):
708 709
                self._logger.debug(
                    "coalesce gradient [{}] is composed by: {}".format(
710 711 712
                        i, [grad.name for grad in group.gradients]
                    )
                )
713
                fused_grads.update([grad.name for grad in group.gradients])
714 715 716
            individual_grads = set(self._grad_name_to_group_map.keys()) - set(
                fused_grads
            )
717
            self._logger.debug(
718
                "the following [{}] gradients are not fused: ".format(
719 720 721
                    len(individual_grads)
                )
            )
722 723 724
            self._logger.debug(
                "individual gradient {}".format(individual_grads)
            )
725 726


727
class GradientsGroup:
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748
    def __init__(self, ops, max_group_size):
        self.max_group_size = max_group_size
        self.ops = ops

        self.gradients = []
        self.numel = 0
        self.dtype = None
        self.ring_id = None
        self.coalesce_var = None
        self.coalesce_op_idx = -1
        self.allreduce_op_idx = -1
        self.scale_op_idx = -1
        self.remove_wait_op_indices = []
        self.remove_allreduce_op_indices = []
        self.remove_scale_op_indices = []

    def acceptable(self, grad_var, ring_id):
        if len(self.gradients) == 0:
            return True
        if ring_id != self.ring_id:
            return False
749
        if get_var_numel(grad_var) + self.numel > self.max_group_size:
750 751 752 753 754 755 756 757 758 759
            return False
        if grad_var.dtype != self.dtype:
            return False

        return True

    def add(self, grad_var, ring_id, i):
        self.gradients.append(grad_var)
        self.ring_id = ring_id
        self.dtype = grad_var.dtype
760
        self.numel += get_var_numel(grad_var)
761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776

        # remove auxiliary ops in non-fuse dp allreduce
        self.remove_allreduce_op_indices.append(i)

        # NOTE this pass rely on the original synchronization add in previous passes
        # (same stream or calc_wait_comm & comm_wait_calc)
        # to guarantee the correctness of comm_calc execution order.
        # so the calc_wait_comm should be keep.
        grad_op_idx = i - 1
        if i > 0 and self.ops[i - 1].type == 'c_wait_compute':
            self.remove_wait_op_indices.append(i - 1)
            grad_op_idx -= 1
        if i + 1 < len(self.ops) and is_data_parallel_scale_op(self.ops[i - 1]):
            self.remove_scale_op_indices.append(i + 1)

        if len(self.gradients) == 1:
777 778 779 780 781
            # TODO Remove this is a temporary hack for Tensor Parallel. the logic
            # for find grad_op should be more general.
            if self.ops[grad_op_idx].type == "c_allreduce_sum":
                grad_op_idx -= 1

782
            grad_op = self.ops[grad_op_idx]
783 784 785 786 787
            assert (
                grad_var.name in grad_op.output_arg_names
            ), "grad [{}] should be output of {}".format(
                grad_var.name, str(grad_op)
            )
788 789 790 791 792 793 794 795
            self.coalesce_op_idx = grad_op_idx

    def finalize(self):
        self.allreduce_op_idx = self.remove_allreduce_op_indices.pop()
        if len(self.remove_wait_op_indices) > 1:
            self.remove_wait_op_indices.pop()
        if len(self.remove_scale_op_indices) > 1:
            self.scale_op_idx = self.remove_scale_op_indices.pop()