auto_parallel_data_parallel_optimization.py 28.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict

import paddle
18 19
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_attribute import (
20 21 22
    OperatorDistAttr,
    TensorDistAttr,
)
23
from paddle.distributed.auto_parallel.static.operators.common import (
24
    is_data_parallel_reduce_op,
25
    is_data_parallel_scale_op,
26
)
27
from paddle.distributed.auto_parallel.static.utils import (
28
    find_higher_order_backward_op,
29
    get_var_numel,
30
    insert_dependencies_for_vars,
31
    is_forward_op,
32 33 34 35
    is_loss_grad_op,
    is_optimize_op,
    ring_id_to_process_group,
)
36
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
37 38
from paddle.static import default_main_program
from paddle.utils import unique_name
39 40

from .pass_base import PassBase, PassType, register_pass
41 42 43

# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [
44 45 46 47 48
    'lars_momentum',
    'sparse_momentum',
    'dgc_momentum',
    'momentum',
    'merge_momentum',
49 50
]

51 52 53
# a heuristic number
__max_stream_num_allow__ = 16

54 55 56 57 58

@register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase):
    """
    Apply Optimizations that specialized for data parallelism in Auto Parallel.
59
    1. prune grad scaling
60 61 62 63 64
    2. overlap comm and calc
    3. fuse allreduce
    """

    def __init__(self):
65
        super().__init__()
66 67 68
        # NOTE not use depence on loss and param_grads
        self.set_attr("dist_context", None)
        self.set_attr("global_rank", -1)
69
        self.set_attr("use_sharding", False)
70 71 72 73 74 75 76 77 78 79
        # {grad1: group1, grad2: group1, grad3: group2}
        # record the order for fuse grad data memory
        self._grad_name_to_group_map = OrderedDict()
        # {group1:[grad1, grad2] , group2:[grad3]}
        self._group_to_grad_name_map = OrderedDict()
        self._support_rescale_grad = False

    def _check_self(self):
        if self.get_attr("dist_context") is None:
            return False
80 81 82
        if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr(
            "global_rank"
        ) < 0:
83 84 85 86 87 88 89 90 91 92 93 94 95 96
            return False

        return True

    def _check_conflict(self, other_pass):
        return True

    def _type(self):
        return PassType.COMM_OPT

    def _apply_single_impl(self, main_program, startup_program, context):

        self.dist_context = self.get_attr("dist_context")
        self.global_rank = int(self.get_attr("global_rank"))
97
        self.use_sharding = self.get_attr("use_sharding")
98
        self.coalesce_prefix = 'coalesce_grad'
99
        self.gradient_sync_stream = "gradient_sync_stream"
100 101 102

        with paddle.static.program_guard(main_program, startup_program):
            self._analyze_program()
J
JZ-LIANG 已提交
103

104
            # TODO refactor here to first fuse then overlap
J
JZ-LIANG 已提交
105 106 107 108
            if self.is_data_parallel_applied():
                self._prune_grad_scaling()
                self._calc_comm_overlap()
                grad_group = self._fuse_allreduce()
109 110
                self._add_dependencies(grad_group)
                self.summary(grad_group)
111 112 113 114 115 116 117 118 119 120 121 122 123

    def _prune_grad_scaling(self):

        if not self._could_be_prune():
            return

        if self._all_dp_groups_same_degree():
            self._scale_backward_initial_grad()
        else:
            self._update_opt_rescale_grad()

        self._remove_grad_scaling()

124 125 126
    def _calc_comm_overlap(self):
        if not self._could_be_overlap():
            return
127 128
        self._comms_overlap_calc()
        self._calc_wait_comms()
129 130

    def _fuse_allreduce(self):
131 132 133 134 135 136

        if not self._could_be_fuse():
            return []

        grad_group = self._group_grads()
        self._update_program(grad_group)
137

138
        return grad_group
139 140 141

    def _analyze_program(self):
        """
142
        build two maps
143 144 145 146 147 148 149 150 151
        {param_grad_name: data_parallel_group}
        {pdata_parallel_group: aram_grad_name}
        """

        block = default_main_program().global_block()
        ops = block.ops
        scaled_grads = []

        for op in ops:
152

153
            if is_data_parallel_reduce_op(op):
154
                grad_name = op.output_arg_names[0]
155 156 157 158
                if grad_name in self._grad_name_to_group_map:
                    continue
                assert op.has_attr(
                    "ring_id"
159
                ), f"Unexpected: comm op [{str(op)}] has NOT ring id."
160 161
                group = ring_id_to_process_group(op.attr("ring_id"))

162 163
                assert (
                    group is not None
164
                ), "Unexpected: data parallel group of [{}] from op [{}] is None".format(
165 166
                    grad_name, str(op)
                )
167 168 169 170 171 172 173 174 175

                self._grad_name_to_group_map[grad_name] = group

                if group not in self._group_to_grad_name_map:
                    self._group_to_grad_name_map[group] = [grad_name]
                else:
                    self._group_to_grad_name_map[group].append(grad_name)

            elif is_data_parallel_scale_op(op):
176
                grad_name = op.output_arg_names[0]
177 178 179 180
                scaled_grads.append(grad_name)

            # TODO support multiple optimizers in on network in future.
            # here we assume that the optimizer is unique in network.
181 182 183 184
            elif (
                is_optimize_op(op)
                and op.type in __rescale_grad_supported_opts__
            ):
185 186 187 188 189 190
                self._support_rescale_grad = True

        not_synchronized_grads = []
        for grad_name in scaled_grads:
            if grad_name not in self._grad_name_to_group_map:
                not_synchronized_grads.append(grad_name)
191 192
        assert (
            len(not_synchronized_grads) == 0
193
        ), "Unexpected: gradients [{}] is scaled BUT NOT synchronized.".format(
194
            not_synchronized_grads
195
        )
196

J
JZ-LIANG 已提交
197 198 199
    def is_data_parallel_applied(self):
        return len(self._group_to_grad_name_map) > 0

200 201
    def _could_be_prune(self):

202
        return self.dist_context.gradient_scale and (
203 204
            self._support_rescale_grad or self._all_dp_groups_same_degree()
        )
205 206

    def _all_dp_groups_same_degree(self):
207 208
        return (
            len(
209 210 211 212
                {
                    len(group.ranks)
                    for group in self._group_to_grad_name_map.keys()
                }
213 214 215
            )
            == 1
        )
216 217 218 219 220 221 222 223

    def _scale_backward_initial_grad(self):

        block = default_main_program().global_block()
        dp_degree = len(list(self._group_to_grad_name_map.keys())[0].ranks)

        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
224 225
                assert op.type == 'fill_constant', (
                    "loss_grad_op must be fill_constant op, "
226
                    "but this op is {}".format(op.type)
227
                )
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
                assert op.has_attr('value')
                loss_scale = float(op.attr('value'))
                loss_scale = loss_scale / dp_degree
                op._set_attr('value', loss_scale)
                break

    def _remove_grad_scaling(self):
        block = default_main_program().global_block()

        for op_idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_scale_op(op):
                block._remove_op(op_idx, False)

        block._sync_with_cpp()

    def _update_opt_rescale_grad(self):

        block = default_main_program().global_block()
        scaled_grads = set()

        for idx, op in reversed(list(enumerate(block.ops))):
249 250 251 252
            if (
                is_optimize_op(op)
                and op.type in __rescale_grad_supported_opts__
            ):
253 254
                assert op.has_attr(
                    'rescale_grad'
255
                ), "Unexpected: op [{}] is supported to have [rescale_grad] attribute.".format(
256 257 258 259
                    str(op)
                )
                assert (
                    len(op.input("Grad")) == 1
260
                ), "Unexpected: op [{}] is supported to have only one input grad var.".format(
261 262
                    str(op)
                )
263 264 265

                grad_name = op.input("Grad")[0]
                dp_degree = len(
266 267
                    list(self._grad_name_to_group_map[grad_name].ranks)
                )
268 269 270 271 272
                scaled_grads.add(grad_name)

                rescale_grad = float(op.attr('rescale_grad')) / dp_degree
                op._set_attr('rescale_grad', rescale_grad)

273 274
        assert scaled_grads == set(
            self._grad_name_to_group_map.keys()
275
        ), "Unexpected: gradients [{}] are unscaled.".format(
276 277
            set(self._grad_name_to_group_map.keys()) - scaled_grads
        )
278 279 280 281 282

    def _could_be_overlap(self):
        # NOTE current different nccl comm will use different cuda stream
        # so if there too many dp group there will be too many stream need to be
        # created and sync.
283
        # revise here when framework support custom stream in static graph mode.
284 285 286
        num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys()))
        if num_dp_comm_stream > __max_stream_num_allow__:
            return False
287 288
        if self.use_sharding:
            return False
289 290
        return True

291
    def _comms_overlap_calc(self):
292 293 294 295 296 297 298 299 300 301 302 303 304
        # TODO support InterpreterCore executor for overlap.
        # InterpreterCore has a different logic for overlapping
        # which is different from use_calc_stream
        block = default_main_program().global_block()

        # comm wait calc to finish
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_reduce_op(op):
                assert op.has_attr('use_calc_stream')
                assert op.has_attr('ring_id')

                op._set_attr('use_calc_stream', False)
                ring_id = op.attr("ring_id")
305 306 307 308 309 310 311
                block._insert_op_without_sync(
                    idx,
                    type='c_wait_compute',
                    inputs={'X': []},
                    outputs={'Out': []},
                    attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
                )
312 313 314

        block._sync_with_cpp()

315
    def _calc_wait_comms(self):
316

317
        return
318

319 320
        block = default_main_program().global_block()

321 322 323 324 325 326 327 328 329 330 331
        # NOTE the naive overlap implement in static hybird parallel only sync comm stream
        # at the end of Backward phase, based on a strong constraint that
        # all communicating gradient would NOT be used after communication in Backward phase.
        # BUT this constraint will fail for scenario like Weight-Sharing and Higher-Order Differentiation,
        # where gradient will be involved in other calculation between data-parallel allreduce kernel submmited
        # into comm streams and the synchronization of comm stream at the end of Backward phase.
        # synchronization of  comm stream should add according to the usage of communicating gradients
        # to support Overlapping for Weight-Sharing and Higher-Order Differentiation.

        ring_id_to_un_sync_grad_map = {}
        op_idx_to_sync_ring_id_map = {}
332
        for group in self._group_to_grad_name_map.keys():
333 334 335
            ring_id_to_un_sync_grad_map[group.id] = []

        # analyze the where need to sync
336
        for i, op in enumerate(block.ops):
337 338 339 340 341 342 343 344 345
            if is_data_parallel_reduce_op(op):
                ring_id = op.attr("ring_id")
                grad_name = op.output_arg_names[0]
                ring_id_to_un_sync_grad_map[ring_id].append(grad_name)
            elif is_data_parallel_scale_op(op):
                continue
            # other ops that might use communicating grad
            else:
                for input_var_name in op.input_arg_names:
346 347 348 349
                    for (
                        ring_id,
                        unsync_grad_names,
                    ) in ring_id_to_un_sync_grad_map.items():
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
                        if input_var_name in unsync_grad_names:
                            # need to sync before op_i
                            if i in op_idx_to_sync_ring_id_map:
                                op_idx_to_sync_ring_id_map[i].append(ring_id)
                            else:
                                op_idx_to_sync_ring_id_map[i] = [ring_id]
                            # all grads in this comm stream are synced
                            ring_id_to_un_sync_grad_map[ring_id] = []

        # insert synchronization
        indices = list(op_idx_to_sync_ring_id_map.keys())
        # TODO the synchronization could be optimized
        # we should record the event of a gradient is communicating and
        # only wait for that event to be completed.
        # BUT paddle static currently not support op api for event record only, so
        # here we try to wait for all kernel in that comm stream to be finish which is not that optimized.
        for i in sorted(indices, reverse=True):
            for ring_id in op_idx_to_sync_ring_id_map[i]:

369 370 371 372 373 374 375
                block._insert_op_without_sync(
                    i,
                    type='c_wait_comm',
                    inputs={'X': []},
                    outputs={'Out': []},
                    attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
                )
376
        block._sync_with_cpp()
377 378 379 380 381 382 383 384 385 386 387 388 389 390

    def _could_be_fuse(self):
        # TODO  support gradient fuse higher order gradient.
        # should analyse the dependencies of gradient in backward.
        if find_higher_order_backward_op(default_main_program()):
            return False
        if self.use_sharding:
            return False
        return True

    def _group_grads(self):
        """
        conditions for gradients to be grouped:
        1. group size < max_fuse_numel
391
        2. same dp group
392
        3. same dtype
393
        4. dependency: grad would NOT be used by other ops within group segment
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427

        gradients inside same group would be fuse into one coalesce tensor
        """

        block = default_main_program().global_block()
        ops = block.ops

        # group individual grad vars
        # TODO consider fuse gradient for sharding reduce
        # TODO let user to set fuse_grad_size
        # emb = 50000 * h, ffn = 8 * h * h, mha = 4 * h * h
        h = 2048
        ffn_numel = 2 * (4 * h) * h
        mha_numel = 3 * h * h + h * h
        max_fuse_numel = ffn_numel + mha_numel
        grad_groups = []
        cur_group = GradientsGroup(ops, max_fuse_numel)
        grouped_grad_names = set()

        def collect_group(cur_group, grad_var, ring_id, i):
            if len(cur_group.gradients) == 0:
                cur_group = None
            else:
                cur_group.finalize()
                grad_groups.append(cur_group)

            new_group = GradientsGroup(ops, max_fuse_numel)
            if grad_var:
                new_group.add(grad_var, ring_id, i)
                grouped_grad_names.add(grad_var.name)
            return new_group

        def op_depend_on_group(op, group):
            vars_ = set(op.input_arg_names + op.output_arg_names)
428
            grad_names = {grad.name for grad in group.gradients}
429 430 431 432 433 434 435
            return len(vars_.intersection(grad_names)) > 0

        for i, op in enumerate(ops):
            if is_data_parallel_reduce_op(op):
                ring_id = op.attr("ring_id")
                grad_name = op.output_arg_names[0]
                grad_var = block.var(grad_name)
436
                grad_numel = get_var_numel(grad_var)
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460

                if cur_group.acceptable(grad_var, ring_id):
                    assert grad_name not in grouped_grad_names
                    grouped_grad_names.add(grad_name)
                    cur_group.add(grad_var, ring_id, i)
                else:
                    cur_group = collect_group(cur_group, grad_var, ring_id, i)
            else:
                if op_depend_on_group(op, cur_group):
                    cur_group = collect_group(cur_group, None, None, None)

        # collect last group
        collect_group(cur_group, None, None, None)

        return grad_groups

    def _update_program(self, grad_groups):

        block = default_main_program().global_block()

        remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute']

        for i, group in enumerate(grad_groups[::-1]):

461 462 463 464 465
            # skip unfused big tensor
            if len(group.gradients) <= 1:
                group.coalesce_var = group.gradients[0]
                continue

466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
            ref_process_mesh = set()
            concated_shapes = []
            concated_ranks = []
            for grad_ in group.gradients:
                grad_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(grad_)
                )
                ref_process_mesh.update(
                    set(grad_dist_attr.process_mesh.process_ids)
                )

                shape = grad_.shape
                concated_shapes.extend(shape)
                concated_ranks.append(len(shape))

481
            # create coalesce tensor
482
            group.coalesce_var = block.create_var(
483
                name=unique_name.generate(self.coalesce_prefix + f'_{i}'),
484 485 486 487
                dtype=group.dtype,
                persistable=False,
                stop_gradient=True,
            )
488

489 490 491 492 493 494 495
            tensor_dist_attr = TensorDistAttr()
            tensor_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
            tensor_dist_attr.dims_mapping = []
            self.dist_context.set_tensor_dist_attr_for_program(
                group.coalesce_var, tensor_dist_attr
            )

496 497 498
            # update allreduce & scale op
            if group.scale_op_idx != -1:
                scale_op = block.ops[group.scale_op_idx]
499 500
                assert (
                    scale_op.type == 'scale'
501
                ), f"should found scale op but found {str(scale_op)}"
502 503 504 505 506 507
                scale_op._rename_input(
                    scale_op.input_arg_names[0], group.coalesce_var.name
                )
                scale_op._rename_output(
                    scale_op.output_arg_names[0], group.coalesce_var.name
                )
508 509

            allreduce_op = block.ops[group.allreduce_op_idx]
510 511 512 513 514
            assert (
                allreduce_op.type == 'c_allreduce_sum'
            ), "should found c_allreduce_sum op but found {}".format(
                str(allreduce_op)
            )
515 516 517 518 519 520 521 522
            allreduce_op_dist_attr = (
                self.dist_context.get_op_dist_attr_for_program(allreduce_op)
            )
            old_in_name = allreduce_op.input_arg_names[0]
            new_in_name = group.coalesce_var.name
            allreduce_op._rename_input(old_in_name, new_in_name)
            input_dist_attr = allreduce_op_dist_attr.get_input_dist_attr(
                old_in_name
523
            )
524 525 526 527 528 529 530 531 532 533 534 535
            allreduce_op_dist_attr.set_input_dist_attr(
                new_in_name, input_dist_attr
            )

            old_out_name = allreduce_op.output_arg_names[0]
            new_out_name = group.coalesce_var.name
            allreduce_op._rename_output(old_out_name, new_out_name)
            out_dist_attr = allreduce_op_dist_attr.get_output_dist_attr(
                old_out_name
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                new_out_name, out_dist_attr
536
            )
537 538

            # remvoe un-used op
539 540 541 542 543
            remove_op_indices = (
                group.remove_wait_op_indices
                + group.remove_allreduce_op_indices
                + group.remove_scale_op_indices
            )
544
            for idx in sorted(remove_op_indices, reverse=True):
545 546
                assert (
                    block.ops[idx].type in remove_op_types
547
                ), f"Unexpected: try to remove op {str(block.ops[idx])}"
548
                block._remove_op(idx, False)
549

550
            # insert coalesce op
551
            grad_names = [grad.name for grad in group.gradients]
552
            coalesce_op = block._insert_op_without_sync(
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568
                group.coalesce_op_idx,
                type="coalesce_tensor",
                inputs={"Input": grad_names},
                outputs={
                    "Output": grad_names,
                    "FusedOutput": group.coalesce_var,
                },
                attrs={
                    "copy_data": False,
                    "use_align": True,
                    "dtype": group.dtype,
                    "concated_shapes": concated_shapes,
                    "concated_ranks": concated_ranks,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
569

570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
            op_dist_attr = OperatorDistAttr()
            op_dist_attr.impl_idx = 0
            op_dist_attr.impl_type = "default"
            op_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
            for in_name in coalesce_op.input_arg_names:
                in_var = block.var(in_name)
                in_var_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(in_var)
                )
                op_dist_attr.set_input_dims_mapping(
                    in_name, in_var_dist_attr.dims_mapping
                )
            for out_name in coalesce_op.output_arg_names:
                out_var = block.var(out_name)
                out_var_dist_attr = (
                    self.dist_context.get_tensor_dist_attr_for_program(out_var)
                )
                op_dist_attr.set_output_dims_mapping(
                    out_name, out_var_dist_attr.dims_mapping
                )

            self.dist_context.set_op_dist_attr_for_program(
                coalesce_op, op_dist_attr
            )

595 596
        block._sync_with_cpp()

597 598 599 600 601
    def _add_dependencies(self, grad_groups):
        # NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based
        # multiple stream executor(standalone exe). This function just for standalone exe. Refactor here
        # in future when only one executor stay.

602
        if len(grad_groups) == 0:
603 604 605 606 607 608
            return
        block = default_main_program().global_block()

        # Build maps
        coalesce_to_vars_map = {}
        for group in grad_groups:
609
            coalesce_to_vars_map[group.coalesce_var.name] = group
610 611

        # analyze dependencies
612
        dep_map = {}
613 614 615 616 617 618 619 620 621
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_forward_op(op):
                break
            if is_optimize_op(op):
                continue

            if is_data_parallel_reduce_op(op):
                coalesce_var_name = op.output_arg_names[0]
                if self.coalesce_prefix in coalesce_var_name:
622 623 624 625 626 627 628 629
                    group = coalesce_to_vars_map[coalesce_var_name]
                    dep_map[idx] = [
                        (
                            idx,
                            group.gradients[-1],
                            group.coalesce_var,
                            op.attr(OP_ROLE_KEY),
                        )
630
                    ]
631 632 633 634 635 636 637 638 639 640
                    dep_map[idx].append(
                        (
                            idx + 1,
                            group.coalesce_var,
                            group.gradients,
                            op.attr(OP_ROLE_KEY),
                        )
                    )

        # insert dependency op
641
        indice = sorted(dep_map.keys(), reverse=True)
642 643 644 645 646 647 648 649 650 651 652 653 654 655
        for i in indice:
            for idx, prior_vars, post_vars, op_role in dep_map[i][::-1]:
                depend_op = insert_dependencies_for_vars(
                    block,
                    idx,
                    prior_vars,
                    post_vars,
                    self.dist_context,
                    op_role,
                    is_recompute=False,
                    sync=False,
                    op_namescope="data_parallel_overlap_dep",
                )
                depend_op.dist_attr.execution_stream = self.gradient_sync_stream
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677
        block._sync_with_cpp()

        # remove naive synchronization & assign allreduce stream
        def remove_cond(op):
            if op.type != "c_wait_compute":
                return False
            if len(op.input_arg_names) != 0:
                return False
            if len(op.output_arg_names) != 0:
                return False
            return True

        for idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_reduce_op(op):
                op._set_attr('use_calc_stream', True)
                op.dist_attr.execution_stream = self.gradient_sync_stream

            if remove_cond(op):
                block._remove_op(idx, sync=False)

        block._sync_with_cpp()

678 679 680
    def summary(self, grad_groups=[]):
        # TODO: add logger module
        import logging
681

682 683 684 685 686 687 688 689 690 691 692 693
        self._logger = logging.getLogger()
        self._logger.propagate = False
        if not self._logger.handlers:
            self._logger.setLevel(logging.INFO)
            log_handler = logging.StreamHandler()
            log_format = logging.Formatter(
                '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
            )
            log_handler.setFormatter(log_format)
            self._logger.addHandler(log_handler)

        if len(grad_groups) > 0:
694
            self._logger.info("Data Parallel Optimization: ")
695
            self._logger.info(
696
                " {} Allreduce ops are fused into {} coalesce allreduce ops.".format(
697 698 699
                    len(self._grad_name_to_group_map.keys()), len(grad_groups)
                )
            )
700
            self._logger.debug("gradient fusing group are following: ")
701 702
            fused_grads = set()
            for i, group in enumerate(grad_groups):
703 704
                self._logger.debug(
                    "coalesce gradient [{}] is composed by: {}".format(
705 706 707
                        i, [grad.name for grad in group.gradients]
                    )
                )
708
                fused_grads.update([grad.name for grad in group.gradients])
709 710 711
            individual_grads = set(self._grad_name_to_group_map.keys()) - set(
                fused_grads
            )
712
            self._logger.debug(
713
                "the following [{}] gradients are not fused: ".format(
714 715 716
                    len(individual_grads)
                )
            )
717
            self._logger.debug(f"individual gradient {individual_grads}")
718 719


720
class GradientsGroup:
721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741
    def __init__(self, ops, max_group_size):
        self.max_group_size = max_group_size
        self.ops = ops

        self.gradients = []
        self.numel = 0
        self.dtype = None
        self.ring_id = None
        self.coalesce_var = None
        self.coalesce_op_idx = -1
        self.allreduce_op_idx = -1
        self.scale_op_idx = -1
        self.remove_wait_op_indices = []
        self.remove_allreduce_op_indices = []
        self.remove_scale_op_indices = []

    def acceptable(self, grad_var, ring_id):
        if len(self.gradients) == 0:
            return True
        if ring_id != self.ring_id:
            return False
742
        if get_var_numel(grad_var) + self.numel > self.max_group_size:
743 744 745 746 747 748 749 750 751 752
            return False
        if grad_var.dtype != self.dtype:
            return False

        return True

    def add(self, grad_var, ring_id, i):
        self.gradients.append(grad_var)
        self.ring_id = ring_id
        self.dtype = grad_var.dtype
753
        self.numel += get_var_numel(grad_var)
754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769

        # remove auxiliary ops in non-fuse dp allreduce
        self.remove_allreduce_op_indices.append(i)

        # NOTE this pass rely on the original synchronization add in previous passes
        # (same stream or calc_wait_comm & comm_wait_calc)
        # to guarantee the correctness of comm_calc execution order.
        # so the calc_wait_comm should be keep.
        grad_op_idx = i - 1
        if i > 0 and self.ops[i - 1].type == 'c_wait_compute':
            self.remove_wait_op_indices.append(i - 1)
            grad_op_idx -= 1
        if i + 1 < len(self.ops) and is_data_parallel_scale_op(self.ops[i - 1]):
            self.remove_scale_op_indices.append(i + 1)

        if len(self.gradients) == 1:
770 771 772 773 774
            # TODO Remove this is a temporary hack for Tensor Parallel. the logic
            # for find grad_op should be more general.
            if self.ops[grad_op_idx].type == "c_allreduce_sum":
                grad_op_idx -= 1

775
            grad_op = self.ops[grad_op_idx]
776 777 778 779 780
            assert (
                grad_var.name in grad_op.output_arg_names
            ), "grad [{}] should be output of {}".format(
                grad_var.name, str(grad_op)
            )
781 782 783 784 785 786 787 788
            self.coalesce_op_idx = grad_op_idx

    def finalize(self):
        self.allreduce_op_idx = self.remove_allreduce_op_indices.pop()
        if len(self.remove_wait_op_indices) > 1:
            self.remove_wait_op_indices.pop()
        if len(self.remove_scale_op_indices) > 1:
            self.scale_op_idx = self.remove_scale_op_indices.pop()