raw_program_optimizer.py 19.5 KB
Newer Older
1
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

14
from paddle import static
W
wangxiaoning 已提交
15
from paddle.fluid import core
16 17
from paddle.framework import _global_flags
from paddle.framework.ir import apply_build_strategy
W
wangxiaoning 已提交
18
from paddle.utils import unique_name
19

20 21 22 23
from .common import (
    OP_ROLE_KEY,
    OP_ROLE_VAR_KEY,
    CollectiveHelper,
24
    OpRole,
25
    is_backward_op,
26
    is_loss_grad_op,
27 28
    is_optimizer_op,
)
29
from .meta_optimizer_base import MetaOptimizerBase
30 31 32 33


class RawProgramOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
34
        super().__init__(optimizer)
35 36 37 38
        self.inner_opt = optimizer
        self.meta_optimizers_white_list = [
            "RecomputeOptimizer",
            "AMPOptimizer",
39 40 41 42 43
            "GradientMergeOptimizer",
            "LambOptimizer",
            "LarsOptimizer",
            "DGCOptimizer",
            "LocalSGDOptimizer",
44
        ]
45
        self.meta_optimizers_black_list = []
46 47
        self.global_ring_id = 0

48 49 50
    def _set_basic_info(
        self, loss, role_maker, user_defined_optimizer, user_defined_strategy
    ):
51
        super()._set_basic_info(
52 53 54 55 56
            loss, role_maker, user_defined_optimizer, user_defined_strategy
        )
        self.without_graph_optimization = (
            user_defined_strategy.without_graph_optimization
        )
57 58
        self.fuse_all_reduce_ops = user_defined_strategy.fuse_all_reduce_ops
        if self.fuse_all_reduce_ops:
59 60 61 62 63 64
            self.fuse_grad_size_in_num = (
                user_defined_strategy.fuse_grad_size_in_num
            )
            self.calc_comm_same_stream = (
                user_defined_strategy._calc_comm_same_stream
            )
65 66 67 68

    def _can_apply(self):
        if not self.role_maker._is_collective:
            return False
69 70 71 72
        if self.user_defined_strategy.tensor_parallel:
            return False
        if self.user_defined_strategy.sharding:
            return False
73

74
        if self.without_graph_optimization:
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
            return True
        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.without_graph_optimization = False

    def _enable_strategy(self, dist_strategy, context):
        dist_strategy.without_graph_optimization = True

    def _broadcast_params(self, ring_id):
        block = self.startup_program.global_block()
        param = None
        for param in block.iter_parameters():
            if param.is_distributed:
                continue

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
            block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': ring_id,
                    'root': 0,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )

        if not param:
            return  # no parameter on this device
        block.append_op(
            type='c_sync_comm_stream',
            inputs={'X': param},
            outputs={'Out': param},
            attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward},
        )
110 111 112 113 114 115 116 117 118 119 120

    def _get_process_group_info(self):
        # global ring info
        self.global_endpoints = self.endpoints
        self.global_rank = self.rank
        self.global_nranks = self.nranks

    def _init_process_group(self):
        self._get_process_group_info()
        collective_helper = CollectiveHelper(self.role_maker, wait_port=False)
        # Create global ring for all gpus (ring_id = 0)
121 122 123 124 125 126 127 128 129 130
        collective_helper._init_communicator(
            self.startup_program,
            self.current_endpoint,
            self.global_endpoints,
            self.global_rank,
            self.global_ring_id,
            True,
            self.global_ring_id,
            True,
        )
131 132
        self._broadcast_params(self.global_ring_id)

133 134 135
    def minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
136 137 138 139 140
        self.endpoints = self.role_maker._get_trainer_endpoints()
        self.current_endpoint = self.endpoints[self.role_maker._worker_index()]
        self.rank = self.role_maker._worker_index()
        self.nranks = self.role_maker._worker_num()
        if startup_program is None:
W
wangxiaoning 已提交
141
            startup_program = static.default_startup_program()
142 143 144 145 146 147 148
        self.startup_program = startup_program

        block = loss.block
        program = block.program
        self.main_program = program

        optimize_ops, params_grads = self.inner_opt.minimize(
149 150
            loss, startup_program, parameter_list, no_grad_set
        )
151 152 153 154 155 156 157 158 159 160 161 162
        if _global_flags()['FLAGS_apply_pass_to_program']:
            pass_attrs = {"use_cuda": True}
            build_strategy = self.user_defined_strategy.build_strategy._copy()
            build_strategy.fuse_all_optimizer_ops = False
            build_strategy.fuse_all_reduce_ops = False
            apply_build_strategy(
                self.main_program,
                self.startup_program,
                build_strategy,
                pass_attrs,
            )
            self.main_program._pass_applied = True
李季 已提交
163 164
        if self.nranks == 1:
            return optimize_ops, params_grads
165 166 167 168 169 170 171
        self._init_process_group()

        self.main_program = program
        if self.nranks > 1:
            self._transpile_main_program(loss)
        return optimize_ops, params_grads

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
    def _find_gradient_merge_block(self):
        GRAD_MERGE_COND_NAME = "grad_merge_cond_name"
        gm_cond_var_name = None
        for op in self.main_program.global_block().ops:
            if GRAD_MERGE_COND_NAME not in op.attr_names:
                continue
            if gm_cond_var_name is None:
                gm_cond_var_name = op.attr(GRAD_MERGE_COND_NAME)
            else:
                assert gm_cond_var_name == op.attr(
                    GRAD_MERGE_COND_NAME
                ), "multiple gradient merge condition found"
        if gm_cond_var_name is None:
            return None

187 188 189
        cond_op = (
            None  # false_fn of gm is None, so we should only find one block
        )
190 191 192 193 194 195 196 197 198 199 200 201 202 203
        for op in self.main_program.global_block().ops:
            if op.type != 'conditional_block' or 'Cond' not in op.input_names:
                continue
            cond_vars = op.input('Cond')
            if not cond_vars or cond_vars[0] != gm_cond_var_name:
                continue
            assert cond_op is None, "multiple gradient merge block found"
            cond_op = op
        assert cond_op is not None, "cannot find gradient merge block"
        return cond_op._block_attr("sub_block")

    def _insert_allreduce_ops_for_gm(self, gm_block):
        block = self.main_program.global_block()

204 205 206 207
        first_optimize_op_idx = None
        for i, op in reversed(list(enumerate(gm_block.ops))):
            if is_backward_op(op) and first_optimize_op_idx is None:
                first_optimize_op_idx = i + 1
208
                break
209 210
        if first_optimize_op_idx is None:
            first_optimize_op_idx = 0
211 212 213 214

        param_vars = []
        grad_vars = []
        for op in block.ops:
215
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
216 217 218 219 220 221 222 223 224 225 226 227 228
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                assert len(op_role_var) % 2 == 0
                for i in range(0, len(op_role_var), 2):
                    param = block.var(op_role_var[i])
                    grad = block.var(op_role_var[i + 1])
                    if param.is_distributed:
                        continue
                    param_vars.append(param)
                    grad_vars.append(grad)

        if not grad_vars:
            return

229 230 231 232 233 234 235
        gm_block._insert_op(
            first_optimize_op_idx,
            type="c_sync_calc_stream",
            inputs={'X': grad_vars[0]},
            outputs={'Out': grad_vars[0]},
            attrs={OP_ROLE_KEY: OpRole.Backward},
        )
236 237 238 239 240 241

        insert_op_num = 1
        ring_id = self.global_ring_id

        # NOTE: can perform fuse allreduce inside the loop in the future
        for i, (p, g) in enumerate(zip(param_vars, grad_vars)):
242 243 244 245 246 247 248 249 250 251
            gm_block._insert_op(
                first_optimize_op_idx + insert_op_num,
                type="c_allreduce_sum",
                inputs={'X': g},
                outputs={'Out': g},
                attrs={
                    'ring_id': ring_id,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
252 253
            insert_op_num += 1

254 255 256 257 258 259 260 261 262 263
        gm_block._insert_op(
            first_optimize_op_idx + insert_op_num,
            type="c_sync_comm_stream",
            inputs={'X': grad_vars},
            outputs={'Out': grad_vars},
            attrs={
                'ring_id': ring_id,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
264

265 266
    def _transpile_main_program(self, loss):
        self._insert_loss_grad_ops(loss)
267 268 269 270 271 272
        gm_block = self._find_gradient_merge_block()
        if gm_block is not None:
            # TODO(zjl): support fuse allreduce
            self._insert_allreduce_ops_for_gm(gm_block)
            return

273
        if self.fuse_all_reduce_ops and self.fuse_grad_size_in_num > 1:
274 275 276
            self._allreduce_fusion_program()
        else:
            self._insert_allreduce_ops()
277 278 279 280 281 282 283 284 285 286

    def _insert_loss_grad_ops(self, loss):
        """
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        """
        block = self.main_program.global_block()
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
287 288 289 290 291 292 293 294 295 296
                block._insert_op(
                    idx + 1,
                    type='scale',
                    inputs={'X': loss_grad_var},
                    outputs={'Out': loss_grad_var},
                    attrs={
                        'scale': 1.0 / self.nranks,
                        OP_ROLE_KEY: OpRole.Backward,
                    },
                )
297 298 299 300 301

    def _insert_allreduce_ops(self):
        block = self.main_program.global_block()
        ring_id = self.global_ring_id
        grad = None
302
        grad_vars = []
303
        for idx, op in reversed(list(enumerate(block.ops))):
304
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
305 306 307 308 309 310 311 312 313 314 315 316 317
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0
                offset = 1
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.var(param_name)
                    grad_name = op_role_var[i + 1]
                    grad = block.var(grad_name)
                    if param.is_distributed:
                        continue

318 319 320 321 322 323 324 325 326 327
                    block._insert_op(
                        idx + offset,
                        type='c_allreduce_sum',
                        inputs={'X': grad},
                        outputs={'Out': grad},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Backward,
                        },
                    )
328 329 330 331

        if grad is None:
            return

332 333 334 335 336 337 338
    # This function helps reduce the number of allreduce by integrating op, which can save communication time.
    # to use allreduce fuse, follow these codes:
    # strategy = paddle.distributed.fleet.DistributedStrategy()
    # strategy.without_graph_optimization = True
    # strategy.fuse_all_reduce_ops = True
    # strategy.calc_comm_same_stream = False
    # strategy.fuse_grad_size_in_num = 8
339 340 341
    def _allreduce_fusion_program(self):
        block = self.main_program.global_block()
        ring_id = self.global_ring_id
342
        param_grads = []
343
        first_backward_idx = -1
344

345
        # find all grad params
346
        for idx, op in enumerate(block.ops):
347
            if first_backward_idx == -1 and is_backward_op(op):
348
                first_backward_idx = idx
349
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
350 351 352
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                if len(op_role_var) == 0:
                    continue
353 354 355 356
                assert len(op_role_var) % 2 == 0, (
                    "vars need to be one param var followed by one grad var, "
                    "but got odd number of vars"
                )
357 358 359 360 361 362 363
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.var(param_name)
                    grad_name = op_role_var[i + 1]
                    grad = block.var(grad_name)
                    if param.is_distributed:
                        continue
364 365
                    param_grads.append((param, grad))

366
        outputs_name_to_idx = self.__get_ouputs_name_to_idx(
367 368
            first_backward_idx, block
        )
369

370 371 372 373
        # structure of grad_param_segments is
        # [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
        # each entry of the list is a tuple stores the grads segment list and
        # the corresponding params segment list
374 375 376

        # its type is: dict[dtype, list[tuple[list[grad], list[param]]]]
        grad_param_segments_by_dtype = {}
377
        # split the grad based on dtype and fused size
378
        for param, grad in param_grads:
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
            if grad.dtype not in grad_param_segments_by_dtype:
                grad_param_segments_by_dtype[grad.dtype] = [([], [])]
            grad_segment, param_segment = grad_param_segments_by_dtype[
                grad.dtype
            ][-1]
            if len(param_segment) == self.fuse_grad_size_in_num:
                grad_param_segments_by_dtype[grad.dtype].append(([], []))
                grad_segment, param_segment = grad_param_segments_by_dtype[
                    grad.dtype
                ][-1]
            param_segment.append(param)
            grad_segment.append(grad)

        grad_param_segments = []
        for _, group in grad_param_segments_by_dtype.items():
            grad_param_segments.extend(group)
395

396 397
        if len(grad_param_segments) == 0:
            return
398

399 400 401 402 403 404 405 406
        # because the regroup operation make the relative order invalid,
        # we need to reorder these fuse group by after_idx
        def get_after_idx_of_fuse_group(grad_param_segments):
            grad_segment, param_segment = grad_param_segments
            return max([outputs_name_to_idx[grad][1] for grad in grad_segment])

        grad_param_segments.sort(key=get_after_idx_of_fuse_group)

407 408 409 410 411 412
        fused_vars = [None] * len(grad_param_segments)
        for i in range(len(grad_param_segments) - 1, -1, -1):
            # travers the grad_param_segments in backward
            # not to use reversed since needs the absolute index value
            grad_segment, param_segment = grad_param_segments[i]
            # insert coalesce tensor
413 414
            fused_var = block.create_var(
                name=unique_name.generate(
415
                    f'FusedOutput_{grad_segment[0].name}'
416 417 418 419 420
                ),
                dtype=grad_segment[0].dtype,
                persistable=False,
                stop_gradient=True,
            )
421
            fused_vars[i] = fused_var
422 423 424
            after_idx = max(
                [outputs_name_to_idx[grad][1] for grad in grad_segment]
            )
425 426 427 428 429 430 431 432 433 434 435
            block._insert_op_without_sync(
                after_idx + 1,
                type='c_allreduce_sum',
                inputs={'X': fused_var},
                outputs={'Out': fused_var},
                attrs={
                    'ring_id': ring_id,
                    'use_calc_stream': self.calc_comm_same_stream,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
436 437 438 439 440 441 442 443 444 445 446 447 448 449
        idx = 0
        if not self.calc_comm_same_stream:
            for i in range(len(grad_param_segments)):
                while block.ops[idx].type != 'c_allreduce_sum':
                    idx += 1
                grad_segment, param_segment = grad_param_segments[i]
                for grad in grad_segment:
                    block._insert_op_without_sync(
                        idx + 1,
                        type='depend',
                        inputs={'X': grad, 'Dep': fused_var},
                        outputs={'Out': grad},
                    )
                    idx += 1
450

451
        # update the outputs_name_to_idx after insertion of sync/allreduce ops
452
        outputs_name_to_idx = self.__get_ouputs_name_to_idx(
453 454
            first_backward_idx, block
        )
455 456 457 458 459 460 461 462 463 464 465 466 467
        # the before_idx is not guaranteed sorted, therefore we have to find the
        # topology to insert the coalesce ops
        pos_for_coalesce = {}
        for i in range(len(grad_param_segments) - 1, -1, -1):
            # We separate the insertion of coalesce op and the insertion of sync/allreduce op,
            # since that the coalesce op's insertion may invalidate the outputs_name_to_idx
            grad_segment, param_segment = grad_param_segments[i]
            before_idx = len(block.ops)
            for grad in outputs_name_to_idx:
                before_idx = min(before_idx, outputs_name_to_idx[grad][0])
            pos_for_coalesce[i] = before_idx

        # insert the coalesce op based on the sorted before_idx
468 469 470 471 472
        pos_for_coalesce = sorted(
            pos_for_coalesce.items(),
            key=lambda kv: (kv[1], kv[0]),
            reverse=True,
        )
473 474 475
        for i, before_idx in pos_for_coalesce:
            grad_segment, param_segment = grad_param_segments[i]
            fused_var = fused_vars[i]
476 477 478 479 480 481 482 483 484 485 486 487
            block._insert_op_without_sync(
                before_idx,
                type="coalesce_tensor",
                inputs={"Input": param_segment},
                outputs={"Output": grad_segment, "FusedOutput": fused_var},
                attrs={
                    "copy_data": False,
                    "use_align": True,
                    "dtype": grad_segment[0].dtype,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
488

489
        block._sync_with_cpp()
490 491 492 493 494 495 496 497 498 499 500 501 502

    def __get_ouputs_name_to_idx(self, first_backward_idx, block):
        # Each item of outputs_name_to_idx is a pair of idx.
        # The first entry of this pair is the idx of the first op generates the grad,
        # which is used to indicate the position to insert coalesce op.
        # The second entry of this pair is the idx of the last op generates the grad,
        # which is used to indicate the position to insert sync and allreduce op.
        outputs_name_to_idx = {}
        for idx in range(first_backward_idx, len(block.ops)):
            op = block.ops[idx]
            if is_optimizer_op(op):
                break
            for name in op.output_arg_names:
李季 已提交
503 504
                if name == core.kEmptyVarName():
                    continue
505 506 507 508 509 510
                var = block.var(name)
                if not outputs_name_to_idx.get(var):
                    # if the grad only be generated by one op
                    # the first idx and the last ids are identical
                    outputs_name_to_idx[var] = (idx, idx)
                else:
511 512 513 514
                    outputs_name_to_idx[var] = (
                        outputs_name_to_idx[var][0],
                        idx,
                    )
515
        return outputs_name_to_idx