raw_program_optimizer.py 21.0 KB
Newer Older
1
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

14 15
import os

16
from paddle import static
W
wangxiaoning 已提交
17
from paddle.fluid import core
18
from paddle.framework.ir import apply_build_strategy
W
wangxiaoning 已提交
19
from paddle.utils import unique_name
20

21 22 23 24
from .common import (
    OP_ROLE_KEY,
    OP_ROLE_VAR_KEY,
    CollectiveHelper,
25
    OpRole,
26
    is_backward_op,
27
    is_loss_grad_op,
28 29
    is_optimizer_op,
)
30
from .meta_optimizer_base import MetaOptimizerBase
31 32


33 34 35 36 37 38 39 40
def evaluate_flag_apply_pass_to_program(val: str) -> bool:
    val = val.lower()
    if val in ('false', 'off', '0'):
        return False
    else:
        return True


41 42
class RawProgramOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
43
        super().__init__(optimizer)
44 45 46 47
        self.inner_opt = optimizer
        self.meta_optimizers_white_list = [
            "RecomputeOptimizer",
            "AMPOptimizer",
48 49 50 51 52
            "GradientMergeOptimizer",
            "LambOptimizer",
            "LarsOptimizer",
            "DGCOptimizer",
            "LocalSGDOptimizer",
53
        ]
54
        self.meta_optimizers_black_list = []
55 56
        self.global_ring_id = 0

57 58 59
    def _set_basic_info(
        self, loss, role_maker, user_defined_optimizer, user_defined_strategy
    ):
60
        super()._set_basic_info(
61 62 63 64 65
            loss, role_maker, user_defined_optimizer, user_defined_strategy
        )
        self.without_graph_optimization = (
            user_defined_strategy.without_graph_optimization
        )
66 67
        self.fuse_all_reduce_ops = user_defined_strategy.fuse_all_reduce_ops
        if self.fuse_all_reduce_ops:
68 69 70 71 72 73
            self.fuse_grad_size_in_num = (
                user_defined_strategy.fuse_grad_size_in_num
            )
            self.calc_comm_same_stream = (
                user_defined_strategy._calc_comm_same_stream
            )
74 75 76
            self.sync_before_allreduce = os.environ.get(
                'FLAGS_sync_before_allreduce', None
            )
77 78 79 80

    def _can_apply(self):
        if not self.role_maker._is_collective:
            return False
81 82 83 84
        if self.user_defined_strategy.tensor_parallel:
            return False
        if self.user_defined_strategy.sharding:
            return False
85

86
        if self.without_graph_optimization:
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
            return True
        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.without_graph_optimization = False

    def _enable_strategy(self, dist_strategy, context):
        dist_strategy.without_graph_optimization = True

    def _broadcast_params(self, ring_id):
        block = self.startup_program.global_block()
        param = None
        for param in block.iter_parameters():
            if param.is_distributed:
                continue

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
            block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': ring_id,
                    'root': 0,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )

        if not param:
            return  # no parameter on this device
        block.append_op(
            type='c_sync_comm_stream',
            inputs={'X': param},
            outputs={'Out': param},
            attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward},
        )
122 123 124 125 126 127 128 129 130 131 132

    def _get_process_group_info(self):
        # global ring info
        self.global_endpoints = self.endpoints
        self.global_rank = self.rank
        self.global_nranks = self.nranks

    def _init_process_group(self):
        self._get_process_group_info()
        collective_helper = CollectiveHelper(self.role_maker, wait_port=False)
        # Create global ring for all gpus (ring_id = 0)
133 134 135 136 137 138 139 140 141 142
        collective_helper._init_communicator(
            self.startup_program,
            self.current_endpoint,
            self.global_endpoints,
            self.global_rank,
            self.global_ring_id,
            True,
            self.global_ring_id,
            True,
        )
143 144
        self._broadcast_params(self.global_ring_id)

145 146 147
    def minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
148 149 150 151 152
        self.endpoints = self.role_maker._get_trainer_endpoints()
        self.current_endpoint = self.endpoints[self.role_maker._worker_index()]
        self.rank = self.role_maker._worker_index()
        self.nranks = self.role_maker._worker_num()
        if startup_program is None:
W
wangxiaoning 已提交
153
            startup_program = static.default_startup_program()
154 155 156 157 158 159 160
        self.startup_program = startup_program

        block = loss.block
        program = block.program
        self.main_program = program

        optimize_ops, params_grads = self.inner_opt.minimize(
161 162
            loss, startup_program, parameter_list, no_grad_set
        )
163 164 165 166 167
        # Not apply pass only when FLAGS_apply_pass_to_program explicitly set to False
        is_apply_pass_to_program = os.environ.get(
            'FLAGS_apply_pass_to_program', '1'
        )
        if evaluate_flag_apply_pass_to_program(is_apply_pass_to_program):
168 169 170 171 172 173 174 175 176 177 178
            pass_attrs = {"use_cuda": True}
            build_strategy = self.user_defined_strategy.build_strategy._copy()
            build_strategy.fuse_all_optimizer_ops = False
            build_strategy.fuse_all_reduce_ops = False
            apply_build_strategy(
                self.main_program,
                self.startup_program,
                build_strategy,
                pass_attrs,
            )
            self.main_program._pass_applied = True
李季 已提交
179 180
        if self.nranks == 1:
            return optimize_ops, params_grads
181 182 183 184 185 186 187
        self._init_process_group()

        self.main_program = program
        if self.nranks > 1:
            self._transpile_main_program(loss)
        return optimize_ops, params_grads

188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
    def _find_gradient_merge_block(self):
        GRAD_MERGE_COND_NAME = "grad_merge_cond_name"
        gm_cond_var_name = None
        for op in self.main_program.global_block().ops:
            if GRAD_MERGE_COND_NAME not in op.attr_names:
                continue
            if gm_cond_var_name is None:
                gm_cond_var_name = op.attr(GRAD_MERGE_COND_NAME)
            else:
                assert gm_cond_var_name == op.attr(
                    GRAD_MERGE_COND_NAME
                ), "multiple gradient merge condition found"
        if gm_cond_var_name is None:
            return None

203 204 205
        cond_op = (
            None  # false_fn of gm is None, so we should only find one block
        )
206 207 208 209 210 211 212 213 214 215 216 217 218 219
        for op in self.main_program.global_block().ops:
            if op.type != 'conditional_block' or 'Cond' not in op.input_names:
                continue
            cond_vars = op.input('Cond')
            if not cond_vars or cond_vars[0] != gm_cond_var_name:
                continue
            assert cond_op is None, "multiple gradient merge block found"
            cond_op = op
        assert cond_op is not None, "cannot find gradient merge block"
        return cond_op._block_attr("sub_block")

    def _insert_allreduce_ops_for_gm(self, gm_block):
        block = self.main_program.global_block()

220 221 222 223
        first_optimize_op_idx = None
        for i, op in reversed(list(enumerate(gm_block.ops))):
            if is_backward_op(op) and first_optimize_op_idx is None:
                first_optimize_op_idx = i + 1
224
                break
225 226
        if first_optimize_op_idx is None:
            first_optimize_op_idx = 0
227 228 229 230

        param_vars = []
        grad_vars = []
        for op in block.ops:
231
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
232 233 234 235 236 237 238 239 240 241 242 243 244
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                assert len(op_role_var) % 2 == 0
                for i in range(0, len(op_role_var), 2):
                    param = block.var(op_role_var[i])
                    grad = block.var(op_role_var[i + 1])
                    if param.is_distributed:
                        continue
                    param_vars.append(param)
                    grad_vars.append(grad)

        if not grad_vars:
            return

245 246 247 248 249 250 251
        gm_block._insert_op(
            first_optimize_op_idx,
            type="c_sync_calc_stream",
            inputs={'X': grad_vars[0]},
            outputs={'Out': grad_vars[0]},
            attrs={OP_ROLE_KEY: OpRole.Backward},
        )
252 253 254 255 256 257

        insert_op_num = 1
        ring_id = self.global_ring_id

        # NOTE: can perform fuse allreduce inside the loop in the future
        for i, (p, g) in enumerate(zip(param_vars, grad_vars)):
258 259 260 261 262 263 264 265 266 267
            gm_block._insert_op(
                first_optimize_op_idx + insert_op_num,
                type="c_allreduce_sum",
                inputs={'X': g},
                outputs={'Out': g},
                attrs={
                    'ring_id': ring_id,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
268 269
            insert_op_num += 1

270 271 272 273 274 275 276 277 278 279
        gm_block._insert_op(
            first_optimize_op_idx + insert_op_num,
            type="c_sync_comm_stream",
            inputs={'X': grad_vars},
            outputs={'Out': grad_vars},
            attrs={
                'ring_id': ring_id,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
280

281 282
    def _transpile_main_program(self, loss):
        self._insert_loss_grad_ops(loss)
283 284 285 286 287 288
        gm_block = self._find_gradient_merge_block()
        if gm_block is not None:
            # TODO(zjl): support fuse allreduce
            self._insert_allreduce_ops_for_gm(gm_block)
            return

289
        if self.fuse_all_reduce_ops and self.fuse_grad_size_in_num > 1:
290 291 292
            self._allreduce_fusion_program()
        else:
            self._insert_allreduce_ops()
293 294 295 296 297 298 299 300 301 302

    def _insert_loss_grad_ops(self, loss):
        """
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        """
        block = self.main_program.global_block()
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
303 304 305 306 307 308 309 310 311 312
                block._insert_op(
                    idx + 1,
                    type='scale',
                    inputs={'X': loss_grad_var},
                    outputs={'Out': loss_grad_var},
                    attrs={
                        'scale': 1.0 / self.nranks,
                        OP_ROLE_KEY: OpRole.Backward,
                    },
                )
313 314 315 316 317

    def _insert_allreduce_ops(self):
        block = self.main_program.global_block()
        ring_id = self.global_ring_id
        grad = None
318
        grad_vars = []
319
        for idx, op in reversed(list(enumerate(block.ops))):
320
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
321 322 323 324 325 326 327 328 329 330 331 332 333
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0
                offset = 1
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.var(param_name)
                    grad_name = op_role_var[i + 1]
                    grad = block.var(grad_name)
                    if param.is_distributed:
                        continue

334 335 336 337 338 339 340 341 342 343
                    block._insert_op(
                        idx + offset,
                        type='c_allreduce_sum',
                        inputs={'X': grad},
                        outputs={'Out': grad},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Backward,
                        },
                    )
344 345 346 347

        if grad is None:
            return

348 349 350 351 352 353 354
    # This function helps reduce the number of allreduce by integrating op, which can save communication time.
    # to use allreduce fuse, follow these codes:
    # strategy = paddle.distributed.fleet.DistributedStrategy()
    # strategy.without_graph_optimization = True
    # strategy.fuse_all_reduce_ops = True
    # strategy.calc_comm_same_stream = False
    # strategy.fuse_grad_size_in_num = 8
355 356 357
    def _allreduce_fusion_program(self):
        block = self.main_program.global_block()
        ring_id = self.global_ring_id
358
        param_grads = []
359
        first_backward_idx = -1
360

361
        # find all grad params
362
        for idx, op in enumerate(block.ops):
363
            if first_backward_idx == -1 and is_backward_op(op):
364
                first_backward_idx = idx
365
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
366 367 368
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                if len(op_role_var) == 0:
                    continue
369 370 371 372
                assert len(op_role_var) % 2 == 0, (
                    "vars need to be one param var followed by one grad var, "
                    "but got odd number of vars"
                )
373 374 375 376 377 378 379
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.var(param_name)
                    grad_name = op_role_var[i + 1]
                    grad = block.var(grad_name)
                    if param.is_distributed:
                        continue
380 381
                    param_grads.append((param, grad))

382
        outputs_name_to_idx = self.__get_ouputs_name_to_idx(
383 384
            first_backward_idx, block
        )
385

386 387 388 389
        # structure of grad_param_segments is
        # [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
        # each entry of the list is a tuple stores the grads segment list and
        # the corresponding params segment list
390 391 392

        # its type is: dict[dtype, list[tuple[list[grad], list[param]]]]
        grad_param_segments_by_dtype = {}
393
        # split the grad based on dtype and fused size
394
        for param, grad in param_grads:
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
            if grad.dtype not in grad_param_segments_by_dtype:
                grad_param_segments_by_dtype[grad.dtype] = [([], [])]
            grad_segment, param_segment = grad_param_segments_by_dtype[
                grad.dtype
            ][-1]
            if len(param_segment) == self.fuse_grad_size_in_num:
                grad_param_segments_by_dtype[grad.dtype].append(([], []))
                grad_segment, param_segment = grad_param_segments_by_dtype[
                    grad.dtype
                ][-1]
            param_segment.append(param)
            grad_segment.append(grad)

        grad_param_segments = []
        for _, group in grad_param_segments_by_dtype.items():
            grad_param_segments.extend(group)
411

412 413
        if len(grad_param_segments) == 0:
            return
414

415 416 417 418 419 420 421 422
        # because the regroup operation make the relative order invalid,
        # we need to reorder these fuse group by after_idx
        def get_after_idx_of_fuse_group(grad_param_segments):
            grad_segment, param_segment = grad_param_segments
            return max([outputs_name_to_idx[grad][1] for grad in grad_segment])

        grad_param_segments.sort(key=get_after_idx_of_fuse_group)

423 424 425 426 427 428
        fused_vars = [None] * len(grad_param_segments)
        for i in range(len(grad_param_segments) - 1, -1, -1):
            # travers the grad_param_segments in backward
            # not to use reversed since needs the absolute index value
            grad_segment, param_segment = grad_param_segments[i]
            # insert coalesce tensor
429 430
            fused_var = block.create_var(
                name=unique_name.generate(
431
                    f'FusedOutput_{grad_segment[0].name}'
432 433 434 435 436
                ),
                dtype=grad_segment[0].dtype,
                persistable=False,
                stop_gradient=True,
            )
437
            fused_vars[i] = fused_var
438 439 440
            after_idx = max(
                [outputs_name_to_idx[grad][1] for grad in grad_segment]
            )
441 442 443 444 445 446 447 448 449 450 451
            block._insert_op_without_sync(
                after_idx + 1,
                type='c_allreduce_sum',
                inputs={'X': fused_var},
                outputs={'Out': fused_var},
                attrs={
                    'ring_id': ring_id,
                    'use_calc_stream': self.calc_comm_same_stream,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
452 453 454 455 456 457 458 459
            if not self.calc_comm_same_stream and self.sync_before_allreduce:
                block._insert_op_without_sync(
                    after_idx + 1,
                    type='c_sync_calc_stream',
                    inputs={'X': fused_var},
                    outputs={'Out': fused_var},
                    attrs={OP_ROLE_KEY: OpRole.Backward},
                )
460
        idx = 0
461
        if not self.calc_comm_same_stream and not self.sync_before_allreduce:
462
            for i in range(len(grad_param_segments)):
463 464 465 466
                while (
                    block.ops[idx].type != 'c_allreduce_sum'
                    or fused_vars[i].name not in block.ops[idx].input_arg_names
                ):
467 468 469 470 471 472
                    idx += 1
                grad_segment, param_segment = grad_param_segments[i]
                for grad in grad_segment:
                    block._insert_op_without_sync(
                        idx + 1,
                        type='depend',
473
                        inputs={'X': grad, 'Dep': fused_vars[i]},
474 475 476
                        outputs={'Out': grad},
                    )
                    idx += 1
477

478
        # update the outputs_name_to_idx after insertion of sync/allreduce ops
479
        outputs_name_to_idx = self.__get_ouputs_name_to_idx(
480 481
            first_backward_idx, block
        )
482 483 484 485 486 487 488 489 490 491 492 493 494
        # the before_idx is not guaranteed sorted, therefore we have to find the
        # topology to insert the coalesce ops
        pos_for_coalesce = {}
        for i in range(len(grad_param_segments) - 1, -1, -1):
            # We separate the insertion of coalesce op and the insertion of sync/allreduce op,
            # since that the coalesce op's insertion may invalidate the outputs_name_to_idx
            grad_segment, param_segment = grad_param_segments[i]
            before_idx = len(block.ops)
            for grad in outputs_name_to_idx:
                before_idx = min(before_idx, outputs_name_to_idx[grad][0])
            pos_for_coalesce[i] = before_idx

        # insert the coalesce op based on the sorted before_idx
495 496 497 498 499
        pos_for_coalesce = sorted(
            pos_for_coalesce.items(),
            key=lambda kv: (kv[1], kv[0]),
            reverse=True,
        )
500 501 502
        for i, before_idx in pos_for_coalesce:
            grad_segment, param_segment = grad_param_segments[i]
            fused_var = fused_vars[i]
503 504 505 506 507 508 509 510 511 512 513 514
            block._insert_op_without_sync(
                before_idx,
                type="coalesce_tensor",
                inputs={"Input": param_segment},
                outputs={"Output": grad_segment, "FusedOutput": fused_var},
                attrs={
                    "copy_data": False,
                    "use_align": True,
                    "dtype": grad_segment[0].dtype,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
515

516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
        if self.calc_comm_same_stream or not self.sync_before_allreduce:
            block._sync_with_cpp()
            return

        # insert the sync comm op
        for idx, op in enumerate(block.ops):
            if is_optimizer_op(op):
                block._insert_op_without_sync(
                    idx,
                    type='c_sync_comm_stream',
                    inputs={'X': fused_vars},
                    outputs={'Out': fused_vars},
                    attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
                )
                break
531
        block._sync_with_cpp()
532 533 534 535 536 537 538 539 540 541 542 543 544

    def __get_ouputs_name_to_idx(self, first_backward_idx, block):
        # Each item of outputs_name_to_idx is a pair of idx.
        # The first entry of this pair is the idx of the first op generates the grad,
        # which is used to indicate the position to insert coalesce op.
        # The second entry of this pair is the idx of the last op generates the grad,
        # which is used to indicate the position to insert sync and allreduce op.
        outputs_name_to_idx = {}
        for idx in range(first_backward_idx, len(block.ops)):
            op = block.ops[idx]
            if is_optimizer_op(op):
                break
            for name in op.output_arg_names:
李季 已提交
545 546
                if name == core.kEmptyVarName():
                    continue
547 548 549 550 551 552
                var = block.var(name)
                if not outputs_name_to_idx.get(var):
                    # if the grad only be generated by one op
                    # the first idx and the last ids are identical
                    outputs_name_to_idx[var] = (idx, idx)
                else:
553 554 555 556
                    outputs_name_to_idx[var] = (
                        outputs_name_to_idx[var][0],
                        idx,
                    )
557
        return outputs_name_to_idx