auto_parallel_data_parallel_optimization.py 29.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict

import paddle
18 19
from paddle.distributed.auto_parallel.operators.common import (
    is_data_parallel_reduce_op,
20
    is_data_parallel_scale_op,
21 22 23
)
from paddle.distributed.auto_parallel.utils import (
    find_higher_order_backward_op,
24
    get_var_numel,
25
    insert_dependencies_for_vars,
26
    is_forward_op,
27 28 29 30
    is_loss_grad_op,
    is_optimize_op,
    ring_id_to_process_group,
)
31 32
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid import unique_name
33
from paddle.fluid.executor import _is_enable_standalone_executor
34 35 36
from paddle.fluid.framework import default_main_program

from .pass_base import PassBase, PassType, register_pass
37 38 39

# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [
40 41 42 43 44
    'lars_momentum',
    'sparse_momentum',
    'dgc_momentum',
    'momentum',
    'merge_momentum',
45 46
]

47 48 49
# a heuristic number
__max_stream_num_allow__ = 16

50 51 52 53 54

@register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase):
    """
    Apply Optimizations that specialized for data parallelism in Auto Parallel.
55
    1. prune grad scaling
56 57 58 59 60
    2. overlap comm and calc
    3. fuse allreduce
    """

    def __init__(self):
61
        super().__init__()
62 63 64
        # NOTE not use depence on loss and param_grads
        self.set_attr("dist_context", None)
        self.set_attr("global_rank", -1)
65
        self.set_attr("use_sharding", False)
66 67 68 69 70 71 72 73 74 75
        # {grad1: group1, grad2: group1, grad3: group2}
        # record the order for fuse grad data memory
        self._grad_name_to_group_map = OrderedDict()
        # {group1:[grad1, grad2] , group2:[grad3]}
        self._group_to_grad_name_map = OrderedDict()
        self._support_rescale_grad = False

    def _check_self(self):
        if self.get_attr("dist_context") is None:
            return False
76 77 78
        if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr(
            "global_rank"
        ) < 0:
79 80 81 82 83 84 85 86 87 88 89 90 91 92
            return False

        return True

    def _check_conflict(self, other_pass):
        return True

    def _type(self):
        return PassType.COMM_OPT

    def _apply_single_impl(self, main_program, startup_program, context):

        self.dist_context = self.get_attr("dist_context")
        self.global_rank = int(self.get_attr("global_rank"))
93
        self.use_sharding = self.get_attr("use_sharding")
94
        self.coalesce_prefix = 'coalesce_grad'
95
        if _is_enable_standalone_executor():
96
            self.gradient_sync_stream = "gradient_sync_stream"
97 98 99

        with paddle.static.program_guard(main_program, startup_program):
            self._analyze_program()
J
JZ-LIANG 已提交
100

101
            # TODO refactor here to first fuse then overlap
J
JZ-LIANG 已提交
102 103 104 105
            if self.is_data_parallel_applied():
                self._prune_grad_scaling()
                self._calc_comm_overlap()
                grad_group = self._fuse_allreduce()
106 107
                self._add_dependencies(grad_group)
                self.summary(grad_group)
108 109 110 111 112 113 114 115 116 117 118 119 120

    def _prune_grad_scaling(self):

        if not self._could_be_prune():
            return

        if self._all_dp_groups_same_degree():
            self._scale_backward_initial_grad()
        else:
            self._update_opt_rescale_grad()

        self._remove_grad_scaling()

121 122 123
    def _calc_comm_overlap(self):
        if not self._could_be_overlap():
            return
124 125
        self._comms_overlap_calc()
        self._calc_wait_comms()
126 127

    def _fuse_allreduce(self):
128 129 130 131 132 133

        if not self._could_be_fuse():
            return []

        grad_group = self._group_grads()
        self._update_program(grad_group)
134

135
        return grad_group
136 137 138

    def _analyze_program(self):
        """
139
        build two maps
140 141 142 143 144 145 146 147 148
        {param_grad_name: data_parallel_group}
        {pdata_parallel_group: aram_grad_name}
        """

        block = default_main_program().global_block()
        ops = block.ops
        scaled_grads = []

        for op in ops:
149

150
            if is_data_parallel_reduce_op(op):
151
                grad_name = op.output_arg_names[0]
152 153 154 155
                if grad_name in self._grad_name_to_group_map:
                    continue
                assert op.has_attr(
                    "ring_id"
156
                ), "Unexpected: comm op [{}] has NOT ring id.".format(str(op))
157 158
                group = ring_id_to_process_group(op.attr("ring_id"))

159 160
                assert (
                    group is not None
161
                ), "Unexpected: data parallel group of [{}] from op [{}] is None".format(
162 163
                    grad_name, str(op)
                )
164 165 166 167 168 169 170 171 172

                self._grad_name_to_group_map[grad_name] = group

                if group not in self._group_to_grad_name_map:
                    self._group_to_grad_name_map[group] = [grad_name]
                else:
                    self._group_to_grad_name_map[group].append(grad_name)

            elif is_data_parallel_scale_op(op):
173
                grad_name = op.output_arg_names[0]
174 175 176 177
                scaled_grads.append(grad_name)

            # TODO support multiple optimizers in on network in future.
            # here we assume that the optimizer is unique in network.
178 179 180 181
            elif (
                is_optimize_op(op)
                and op.type in __rescale_grad_supported_opts__
            ):
182 183 184 185 186 187
                self._support_rescale_grad = True

        not_synchronized_grads = []
        for grad_name in scaled_grads:
            if grad_name not in self._grad_name_to_group_map:
                not_synchronized_grads.append(grad_name)
188 189
        assert (
            len(not_synchronized_grads) == 0
190
        ), "Unexpected: gradients [{}] is scaled BUT NOT synchronized.".format(
191
            not_synchronized_grads
192
        )
193

J
JZ-LIANG 已提交
194 195 196
    def is_data_parallel_applied(self):
        return len(self._group_to_grad_name_map) > 0

197 198
    def _could_be_prune(self):

199
        return self.dist_context.gradient_scale and (
200 201
            self._support_rescale_grad or self._all_dp_groups_same_degree()
        )
202 203

    def _all_dp_groups_same_degree(self):
204 205 206 207 208 209 210 211 212 213 214
        return (
            len(
                set(
                    [
                        len(group.ranks)
                        for group in self._group_to_grad_name_map.keys()
                    ]
                )
            )
            == 1
        )
215 216 217 218 219 220 221 222

    def _scale_backward_initial_grad(self):

        block = default_main_program().global_block()
        dp_degree = len(list(self._group_to_grad_name_map.keys())[0].ranks)

        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
223 224
                assert op.type == 'fill_constant', (
                    "loss_grad_op must be fill_constant op, "
225
                    "but this op is {}".format(op.type)
226
                )
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
                assert op.has_attr('value')
                loss_scale = float(op.attr('value'))
                loss_scale = loss_scale / dp_degree
                op._set_attr('value', loss_scale)
                break

    def _remove_grad_scaling(self):
        block = default_main_program().global_block()

        for op_idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_scale_op(op):
                block._remove_op(op_idx, False)

        block._sync_with_cpp()

    def _update_opt_rescale_grad(self):

        block = default_main_program().global_block()
        scaled_grads = set()

        for idx, op in reversed(list(enumerate(block.ops))):
248 249 250 251
            if (
                is_optimize_op(op)
                and op.type in __rescale_grad_supported_opts__
            ):
252 253
                assert op.has_attr(
                    'rescale_grad'
254
                ), "Unexpected: op [{}] is supported to have [rescale_grad] attribute.".format(
255 256 257 258
                    str(op)
                )
                assert (
                    len(op.input("Grad")) == 1
259
                ), "Unexpected: op [{}] is supported to have only one input grad var.".format(
260 261
                    str(op)
                )
262 263 264

                grad_name = op.input("Grad")[0]
                dp_degree = len(
265 266
                    list(self._grad_name_to_group_map[grad_name].ranks)
                )
267 268 269 270 271
                scaled_grads.add(grad_name)

                rescale_grad = float(op.attr('rescale_grad')) / dp_degree
                op._set_attr('rescale_grad', rescale_grad)

272 273
        assert scaled_grads == set(
            self._grad_name_to_group_map.keys()
274
        ), "Unexpected: gradients [{}] are unscaled.".format(
275 276
            set(self._grad_name_to_group_map.keys()) - scaled_grads
        )
277 278 279 280 281

    def _could_be_overlap(self):
        # NOTE current different nccl comm will use different cuda stream
        # so if there too many dp group there will be too many stream need to be
        # created and sync.
282
        # revise here when framework support custom stream in static graph mode.
283 284 285
        num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys()))
        if num_dp_comm_stream > __max_stream_num_allow__:
            return False
286 287
        if self.use_sharding:
            return False
288 289
        return True

290
    def _comms_overlap_calc(self):
291 292 293 294 295 296 297 298 299 300 301 302 303
        # TODO support InterpreterCore executor for overlap.
        # InterpreterCore has a different logic for overlapping
        # which is different from use_calc_stream
        block = default_main_program().global_block()

        # comm wait calc to finish
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_reduce_op(op):
                assert op.has_attr('use_calc_stream')
                assert op.has_attr('ring_id')

                op._set_attr('use_calc_stream', False)
                ring_id = op.attr("ring_id")
304 305 306 307 308 309 310
                block._insert_op_without_sync(
                    idx,
                    type='c_wait_compute',
                    inputs={'X': []},
                    outputs={'Out': []},
                    attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
                )
311 312 313

        block._sync_with_cpp()

314
    def _calc_wait_comms(self):
315

316
        if _is_enable_standalone_executor():
317 318
            return

319 320
        block = default_main_program().global_block()

321 322 323 324 325 326 327 328 329 330 331
        # NOTE the naive overlap implement in static hybird parallel only sync comm stream
        # at the end of Backward phase, based on a strong constraint that
        # all communicating gradient would NOT be used after communication in Backward phase.
        # BUT this constraint will fail for scenario like Weight-Sharing and Higher-Order Differentiation,
        # where gradient will be involved in other calculation between data-parallel allreduce kernel submmited
        # into comm streams and the synchronization of comm stream at the end of Backward phase.
        # synchronization of  comm stream should add according to the usage of communicating gradients
        # to support Overlapping for Weight-Sharing and Higher-Order Differentiation.

        ring_id_to_un_sync_grad_map = {}
        op_idx_to_sync_ring_id_map = {}
332
        for group in self._group_to_grad_name_map.keys():
333 334 335
            ring_id_to_un_sync_grad_map[group.id] = []

        # analyze the where need to sync
336
        for i, op in enumerate(block.ops):
337 338 339 340 341 342 343 344 345
            if is_data_parallel_reduce_op(op):
                ring_id = op.attr("ring_id")
                grad_name = op.output_arg_names[0]
                ring_id_to_un_sync_grad_map[ring_id].append(grad_name)
            elif is_data_parallel_scale_op(op):
                continue
            # other ops that might use communicating grad
            else:
                for input_var_name in op.input_arg_names:
346 347 348 349
                    for (
                        ring_id,
                        unsync_grad_names,
                    ) in ring_id_to_un_sync_grad_map.items():
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
                        if input_var_name in unsync_grad_names:
                            # need to sync before op_i
                            if i in op_idx_to_sync_ring_id_map:
                                op_idx_to_sync_ring_id_map[i].append(ring_id)
                            else:
                                op_idx_to_sync_ring_id_map[i] = [ring_id]
                            # all grads in this comm stream are synced
                            ring_id_to_un_sync_grad_map[ring_id] = []

        # insert synchronization
        indices = list(op_idx_to_sync_ring_id_map.keys())
        # TODO the synchronization could be optimized
        # we should record the event of a gradient is communicating and
        # only wait for that event to be completed.
        # BUT paddle static currently not support op api for event record only, so
        # here we try to wait for all kernel in that comm stream to be finish which is not that optimized.
        for i in sorted(indices, reverse=True):
            for ring_id in op_idx_to_sync_ring_id_map[i]:

369 370 371 372 373 374 375
                block._insert_op_without_sync(
                    i,
                    type='c_wait_comm',
                    inputs={'X': []},
                    outputs={'Out': []},
                    attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
                )
376
        block._sync_with_cpp()
377 378 379 380 381 382 383 384 385 386 387 388 389 390

    def _could_be_fuse(self):
        # TODO  support gradient fuse higher order gradient.
        # should analyse the dependencies of gradient in backward.
        if find_higher_order_backward_op(default_main_program()):
            return False
        if self.use_sharding:
            return False
        return True

    def _group_grads(self):
        """
        conditions for gradients to be grouped:
        1. group size < max_fuse_numel
391
        2. same dp group
392
        3. same dtype
393
        4. dependency: grad would NOT be used by other ops within group segment
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435

        gradients inside same group would be fuse into one coalesce tensor
        """

        block = default_main_program().global_block()
        ops = block.ops

        # group individual grad vars
        # TODO consider fuse gradient for sharding reduce
        # TODO let user to set fuse_grad_size
        # emb = 50000 * h, ffn = 8 * h * h, mha = 4 * h * h
        h = 2048
        ffn_numel = 2 * (4 * h) * h
        mha_numel = 3 * h * h + h * h
        max_fuse_numel = ffn_numel + mha_numel
        grad_groups = []
        cur_group = GradientsGroup(ops, max_fuse_numel)
        grouped_grad_names = set()

        def collect_group(cur_group, grad_var, ring_id, i):
            if len(cur_group.gradients) == 0:
                cur_group = None
            else:
                cur_group.finalize()
                grad_groups.append(cur_group)

            new_group = GradientsGroup(ops, max_fuse_numel)
            if grad_var:
                new_group.add(grad_var, ring_id, i)
                grouped_grad_names.add(grad_var.name)
            return new_group

        def op_depend_on_group(op, group):
            vars_ = set(op.input_arg_names + op.output_arg_names)
            grad_names = set([grad.name for grad in group.gradients])
            return len(vars_.intersection(grad_names)) > 0

        for i, op in enumerate(ops):
            if is_data_parallel_reduce_op(op):
                ring_id = op.attr("ring_id")
                grad_name = op.output_arg_names[0]
                grad_var = block.var(grad_name)
436
                grad_numel = get_var_numel(grad_var)
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460

                if cur_group.acceptable(grad_var, ring_id):
                    assert grad_name not in grouped_grad_names
                    grouped_grad_names.add(grad_name)
                    cur_group.add(grad_var, ring_id, i)
                else:
                    cur_group = collect_group(cur_group, grad_var, ring_id, i)
            else:
                if op_depend_on_group(op, cur_group):
                    cur_group = collect_group(cur_group, None, None, None)

        # collect last group
        collect_group(cur_group, None, None, None)

        return grad_groups

    def _update_program(self, grad_groups):

        block = default_main_program().global_block()

        remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute']

        for i, group in enumerate(grad_groups[::-1]):

461 462 463 464 465
            # skip unfused big tensor
            if len(group.gradients) <= 1:
                group.coalesce_var = group.gradients[0]
                continue

466
            # create coalesce tensor
467
            group.coalesce_var = block.create_var(
468 469 470
                name=unique_name.generate(
                    self.coalesce_prefix + '_{}'.format(i)
                ),
471 472 473 474
                dtype=group.dtype,
                persistable=False,
                stop_gradient=True,
            )
475 476 477 478

            # update allreduce & scale op
            if group.scale_op_idx != -1:
                scale_op = block.ops[group.scale_op_idx]
479 480 481 482 483 484 485 486 487
                assert (
                    scale_op.type == 'scale'
                ), "should found scale op but found {}".format(str(scale_op))
                scale_op._rename_input(
                    scale_op.input_arg_names[0], group.coalesce_var.name
                )
                scale_op._rename_output(
                    scale_op.output_arg_names[0], group.coalesce_var.name
                )
488 489

            allreduce_op = block.ops[group.allreduce_op_idx]
490 491 492 493 494 495 496 497 498 499 500
            assert (
                allreduce_op.type == 'c_allreduce_sum'
            ), "should found c_allreduce_sum op but found {}".format(
                str(allreduce_op)
            )
            allreduce_op._rename_input(
                allreduce_op.input_arg_names[0], group.coalesce_var.name
            )
            allreduce_op._rename_output(
                allreduce_op.output_arg_names[0], group.coalesce_var.name
            )
501 502

            # remvoe un-used op
503 504 505 506 507
            remove_op_indices = (
                group.remove_wait_op_indices
                + group.remove_allreduce_op_indices
                + group.remove_scale_op_indices
            )
508
            for idx in sorted(remove_op_indices, reverse=True):
509 510
                assert (
                    block.ops[idx].type in remove_op_types
511
                ), "Unexpected: try to remove op {}".format(str(block.ops[idx]))
512
                block._remove_op(idx, False)
513

514
            # insert coalesce op
515 516 517 518 519 520 521 522
            concated_shapes = []
            concated_ranks = []
            for grad_ in group.gradients:
                shape = grad_.shape
                concated_shapes.extend(shape)
                concated_ranks.append(len(shape))

            grad_names = [grad.name for grad in group.gradients]
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
            block._insert_op_without_sync(
                group.coalesce_op_idx,
                type="coalesce_tensor",
                inputs={"Input": grad_names},
                outputs={
                    "Output": grad_names,
                    "FusedOutput": group.coalesce_var,
                },
                attrs={
                    "copy_data": False,
                    "use_align": True,
                    "dtype": group.dtype,
                    "concated_shapes": concated_shapes,
                    "concated_ranks": concated_ranks,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
540 541 542 543

        block._sync_with_cpp()
        # TODO update dist attr

544 545 546 547 548
    def _add_dependencies(self, grad_groups):
        # NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based
        # multiple stream executor(standalone exe). This function just for standalone exe. Refactor here
        # in future when only one executor stay.

549
        if not _is_enable_standalone_executor() or len(grad_groups) == 0:
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
            return
        block = default_main_program().global_block()

        # Build maps
        vars_to_coalesce_map = {}
        coalesce_to_vars_map = {}

        for group in grad_groups:
            grad_names = []
            coalesce_name = group.coalesce_var.name
            for grad in group.gradients:
                vars_to_coalesce_map[grad.name] = coalesce_name
                grad_names.append(grad.name)
            coalesce_to_vars_map[coalesce_name] = grad_names

        # analyze dependencies
        # Record ONLY the last grad that generated before allreduce
        # NOTE need to be update when we allow multiple calc stream for backward calc
        not_sync_coalesces = []
        prior_allreduce_deps = {}
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_forward_op(op):
                break
            if is_optimize_op(op):
                continue

            if is_data_parallel_reduce_op(op):
                coalesce_var_name = op.output_arg_names[0]

                # NOTE only add extra deps for fused tensor, other tensor rely on
                # data flow analysis of executor.
                if self.coalesce_prefix in coalesce_var_name:
                    prior_allreduce_deps[coalesce_var_name] = [
                        idx,
                        None,
                        coalesce_var_name,
                    ]
                    not_sync_coalesces.append(coalesce_var_name)
                continue

            for out_name in op.output_arg_names:
                var_name = vars_to_coalesce_map.get(out_name, None)
                if var_name in not_sync_coalesces:
                    prior_allreduce_deps[var_name][1] = out_name
                    not_sync_coalesces.remove(var_name)
        assert (
            len(not_sync_coalesces) == 0
597
        ), "Unexpected: {} has NOT been add prior Dep before allreduce.".format(
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
            not_sync_coalesces
        )

        # Record ONLY the first grad that used after allreduce
        # NOTE need to be update when we allow multiple calc stream for backward calc
        not_sync_coalesces = []
        post_allreduce_deps = {}
        for idx, op in enumerate(block.ops):
            if is_forward_op(op):
                continue

            if is_data_parallel_reduce_op(op):
                coalesce_var_name = op.input_arg_names[0]
                if self.coalesce_prefix in coalesce_var_name:
                    post_allreduce_deps[coalesce_var_name] = [
                        None,
                        coalesce_var_name,
                        None,
                    ]
                    not_sync_coalesces.append(coalesce_var_name)
                continue

            for out_name in op.input_arg_names:
                var_name = vars_to_coalesce_map.get(out_name, None)
                if var_name in not_sync_coalesces:
                    post_allreduce_deps[var_name][0] = idx
                    post_allreduce_deps[var_name][2] = out_name
                    not_sync_coalesces.remove(var_name)

        assert (
            len(not_sync_coalesces) == 0
629
        ), "Unexpected: {} has NOT been add post Dep after allreduce.".format(
630 631 632 633 634 635 636 637 638 639 640 641 642
            not_sync_coalesces
        )

        # Update program IR insert dependencise op
        dep_var_pairs = []
        for deps in [prior_allreduce_deps, post_allreduce_deps]:
            for pair in deps.values():
                dep_var_pairs.append(pair)

        dep_var_pairs.sort(key=lambda x: x[0], reverse=True)
        for idx, prior_name, post_name in dep_var_pairs:
            prior_var = block.var(prior_name)
            post_var = block.var(post_name)
643
            depend_op = insert_dependencies_for_vars(
644 645 646 647 648 649 650 651
                block,
                idx,
                prior_var,
                post_var,
                self.dist_context,
                OpRole.Backward,
                process_mesh=[
                    -1
652
                ],  # hack to avoid initialize the dist attr for coalesce var
653 654
                is_recompute=False,
                sync=False,
655
                op_namescope="data_parallel_overlap_dep",
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
            )
            depend_op.dist_attr.execution_stream = self.gradient_sync_stream
        block._sync_with_cpp()

        # remove naive synchronization & assign allreduce stream
        def remove_cond(op):
            if op.type != "c_wait_compute":
                return False
            if len(op.input_arg_names) != 0:
                return False
            if len(op.output_arg_names) != 0:
                return False
            return True

        for idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_reduce_op(op):
                op._set_attr('use_calc_stream', True)
                op.dist_attr.execution_stream = self.gradient_sync_stream

            if remove_cond(op):
                block._remove_op(idx, sync=False)

        block._sync_with_cpp()

680 681 682
    def summary(self, grad_groups=[]):
        # TODO: add logger module
        import logging
683

684 685 686 687 688 689 690 691 692 693 694 695
        self._logger = logging.getLogger()
        self._logger.propagate = False
        if not self._logger.handlers:
            self._logger.setLevel(logging.INFO)
            log_handler = logging.StreamHandler()
            log_format = logging.Formatter(
                '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
            )
            log_handler.setFormatter(log_format)
            self._logger.addHandler(log_handler)

        if len(grad_groups) > 0:
696
            self._logger.info("Data Parallel Optimization: ")
697
            self._logger.info(
698
                " {} Allreduce ops are fused into {} coalesce allreduce ops.".format(
699 700 701
                    len(self._grad_name_to_group_map.keys()), len(grad_groups)
                )
            )
702
            self._logger.debug("gradient fusing group are following: ")
703 704
            fused_grads = set()
            for i, group in enumerate(grad_groups):
705 706
                self._logger.debug(
                    "coalesce gradient [{}] is composed by: {}".format(
707 708 709
                        i, [grad.name for grad in group.gradients]
                    )
                )
710
                fused_grads.update([grad.name for grad in group.gradients])
711 712 713
            individual_grads = set(self._grad_name_to_group_map.keys()) - set(
                fused_grads
            )
714
            self._logger.debug(
715
                "the following [{}] gradients are not fused: ".format(
716 717 718
                    len(individual_grads)
                )
            )
719 720 721
            self._logger.debug(
                "individual gradient {}".format(individual_grads)
            )
722 723


724
class GradientsGroup:
725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
    def __init__(self, ops, max_group_size):
        self.max_group_size = max_group_size
        self.ops = ops

        self.gradients = []
        self.numel = 0
        self.dtype = None
        self.ring_id = None
        self.coalesce_var = None
        self.coalesce_op_idx = -1
        self.allreduce_op_idx = -1
        self.scale_op_idx = -1
        self.remove_wait_op_indices = []
        self.remove_allreduce_op_indices = []
        self.remove_scale_op_indices = []

    def acceptable(self, grad_var, ring_id):
        if len(self.gradients) == 0:
            return True
        if ring_id != self.ring_id:
            return False
746
        if get_var_numel(grad_var) + self.numel > self.max_group_size:
747 748 749 750 751 752 753 754 755 756
            return False
        if grad_var.dtype != self.dtype:
            return False

        return True

    def add(self, grad_var, ring_id, i):
        self.gradients.append(grad_var)
        self.ring_id = ring_id
        self.dtype = grad_var.dtype
757
        self.numel += get_var_numel(grad_var)
758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773

        # remove auxiliary ops in non-fuse dp allreduce
        self.remove_allreduce_op_indices.append(i)

        # NOTE this pass rely on the original synchronization add in previous passes
        # (same stream or calc_wait_comm & comm_wait_calc)
        # to guarantee the correctness of comm_calc execution order.
        # so the calc_wait_comm should be keep.
        grad_op_idx = i - 1
        if i > 0 and self.ops[i - 1].type == 'c_wait_compute':
            self.remove_wait_op_indices.append(i - 1)
            grad_op_idx -= 1
        if i + 1 < len(self.ops) and is_data_parallel_scale_op(self.ops[i - 1]):
            self.remove_scale_op_indices.append(i + 1)

        if len(self.gradients) == 1:
774 775 776 777 778
            # TODO Remove this is a temporary hack for Tensor Parallel. the logic
            # for find grad_op should be more general.
            if self.ops[grad_op_idx].type == "c_allreduce_sum":
                grad_op_idx -= 1

779
            grad_op = self.ops[grad_op_idx]
780 781 782 783 784
            assert (
                grad_var.name in grad_op.output_arg_names
            ), "grad [{}] should be output of {}".format(
                grad_var.name, str(grad_op)
            )
785 786 787 788 789 790 791 792
            self.coalesce_op_idx = grad_op_idx

    def finalize(self):
        self.allreduce_op_idx = self.remove_allreduce_op_indices.pop()
        if len(self.remove_wait_op_indices) > 1:
            self.remove_wait_op_indices.pop()
        if len(self.remove_scale_op_indices) > 1:
            self.scale_op_idx = self.remove_scale_op_indices.pop()