raw_program_optimizer.py 19.1 KB
Newer Older
1
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

W
wangxiaoning 已提交
14 15 16
import paddle.static as static
from paddle.fluid import core
from paddle.utils import unique_name
17

18 19 20 21
from .common import (
    OP_ROLE_KEY,
    OP_ROLE_VAR_KEY,
    CollectiveHelper,
22
    OpRole,
23
    is_backward_op,
24
    is_loss_grad_op,
25 26
    is_optimizer_op,
)
27
from .meta_optimizer_base import MetaOptimizerBase
28 29 30 31


class RawProgramOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
32
        super().__init__(optimizer)
33 34 35 36
        self.inner_opt = optimizer
        self.meta_optimizers_white_list = [
            "RecomputeOptimizer",
            "AMPOptimizer",
37 38 39 40 41
            "GradientMergeOptimizer",
            "LambOptimizer",
            "LarsOptimizer",
            "DGCOptimizer",
            "LocalSGDOptimizer",
42
        ]
43 44 45
        self.meta_optimizers_black_list = [
            "GraphExecutionOptimizer",
        ]
46 47
        self.global_ring_id = 0

48 49 50
    def _set_basic_info(
        self, loss, role_maker, user_defined_optimizer, user_defined_strategy
    ):
51
        super()._set_basic_info(
52 53 54 55 56
            loss, role_maker, user_defined_optimizer, user_defined_strategy
        )
        self.without_graph_optimization = (
            user_defined_strategy.without_graph_optimization
        )
57 58
        self.fuse_all_reduce_ops = user_defined_strategy.fuse_all_reduce_ops
        if self.fuse_all_reduce_ops:
59 60 61 62 63 64
            self.fuse_grad_size_in_num = (
                user_defined_strategy.fuse_grad_size_in_num
            )
            self.calc_comm_same_stream = (
                user_defined_strategy._calc_comm_same_stream
            )
65 66 67 68 69

    def _can_apply(self):
        if not self.role_maker._is_collective:
            return False

70
        if self.without_graph_optimization:
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
            return True
        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.without_graph_optimization = False

    def _enable_strategy(self, dist_strategy, context):
        dist_strategy.without_graph_optimization = True

    def _broadcast_params(self, ring_id):
        block = self.startup_program.global_block()
        param = None
        for param in block.iter_parameters():
            if param.is_distributed:
                continue

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
            block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': ring_id,
                    'root': 0,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )

        if not param:
            return  # no parameter on this device
        block.append_op(
            type='c_sync_comm_stream',
            inputs={'X': param},
            outputs={'Out': param},
            attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward},
        )
106 107 108 109 110 111 112 113 114 115 116

    def _get_process_group_info(self):
        # global ring info
        self.global_endpoints = self.endpoints
        self.global_rank = self.rank
        self.global_nranks = self.nranks

    def _init_process_group(self):
        self._get_process_group_info()
        collective_helper = CollectiveHelper(self.role_maker, wait_port=False)
        # Create global ring for all gpus (ring_id = 0)
117 118 119 120 121 122 123 124 125 126
        collective_helper._init_communicator(
            self.startup_program,
            self.current_endpoint,
            self.global_endpoints,
            self.global_rank,
            self.global_ring_id,
            True,
            self.global_ring_id,
            True,
        )
127 128
        self._broadcast_params(self.global_ring_id)

129 130 131
    def minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
132 133 134 135 136
        self.endpoints = self.role_maker._get_trainer_endpoints()
        self.current_endpoint = self.endpoints[self.role_maker._worker_index()]
        self.rank = self.role_maker._worker_index()
        self.nranks = self.role_maker._worker_num()
        if startup_program is None:
W
wangxiaoning 已提交
137
            startup_program = static.default_startup_program()
138 139 140 141 142 143 144
        self.startup_program = startup_program

        block = loss.block
        program = block.program
        self.main_program = program

        optimize_ops, params_grads = self.inner_opt.minimize(
145 146
            loss, startup_program, parameter_list, no_grad_set
        )
李季 已提交
147 148
        if self.nranks == 1:
            return optimize_ops, params_grads
149 150 151 152 153 154 155
        self._init_process_group()

        self.main_program = program
        if self.nranks > 1:
            self._transpile_main_program(loss)
        return optimize_ops, params_grads

156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    def _find_gradient_merge_block(self):
        GRAD_MERGE_COND_NAME = "grad_merge_cond_name"
        gm_cond_var_name = None
        for op in self.main_program.global_block().ops:
            if GRAD_MERGE_COND_NAME not in op.attr_names:
                continue
            if gm_cond_var_name is None:
                gm_cond_var_name = op.attr(GRAD_MERGE_COND_NAME)
            else:
                assert gm_cond_var_name == op.attr(
                    GRAD_MERGE_COND_NAME
                ), "multiple gradient merge condition found"
        if gm_cond_var_name is None:
            return None

171 172 173
        cond_op = (
            None  # false_fn of gm is None, so we should only find one block
        )
174 175 176 177 178 179 180 181 182 183 184 185 186 187
        for op in self.main_program.global_block().ops:
            if op.type != 'conditional_block' or 'Cond' not in op.input_names:
                continue
            cond_vars = op.input('Cond')
            if not cond_vars or cond_vars[0] != gm_cond_var_name:
                continue
            assert cond_op is None, "multiple gradient merge block found"
            cond_op = op
        assert cond_op is not None, "cannot find gradient merge block"
        return cond_op._block_attr("sub_block")

    def _insert_allreduce_ops_for_gm(self, gm_block):
        block = self.main_program.global_block()

188 189 190 191
        first_optimize_op_idx = None
        for i, op in reversed(list(enumerate(gm_block.ops))):
            if is_backward_op(op) and first_optimize_op_idx is None:
                first_optimize_op_idx = i + 1
192
                break
193 194
        if first_optimize_op_idx is None:
            first_optimize_op_idx = 0
195 196 197 198

        param_vars = []
        grad_vars = []
        for op in block.ops:
199
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
200 201 202 203 204 205 206 207 208 209 210 211 212
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                assert len(op_role_var) % 2 == 0
                for i in range(0, len(op_role_var), 2):
                    param = block.var(op_role_var[i])
                    grad = block.var(op_role_var[i + 1])
                    if param.is_distributed:
                        continue
                    param_vars.append(param)
                    grad_vars.append(grad)

        if not grad_vars:
            return

213 214 215 216 217 218 219
        gm_block._insert_op(
            first_optimize_op_idx,
            type="c_sync_calc_stream",
            inputs={'X': grad_vars[0]},
            outputs={'Out': grad_vars[0]},
            attrs={OP_ROLE_KEY: OpRole.Backward},
        )
220 221 222 223 224 225

        insert_op_num = 1
        ring_id = self.global_ring_id

        # NOTE: can perform fuse allreduce inside the loop in the future
        for i, (p, g) in enumerate(zip(param_vars, grad_vars)):
226 227 228 229 230 231 232 233 234 235
            gm_block._insert_op(
                first_optimize_op_idx + insert_op_num,
                type="c_allreduce_sum",
                inputs={'X': g},
                outputs={'Out': g},
                attrs={
                    'ring_id': ring_id,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
236 237
            insert_op_num += 1

238 239 240 241 242 243 244 245 246 247
        gm_block._insert_op(
            first_optimize_op_idx + insert_op_num,
            type="c_sync_comm_stream",
            inputs={'X': grad_vars},
            outputs={'Out': grad_vars},
            attrs={
                'ring_id': ring_id,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
248

249 250
    def _transpile_main_program(self, loss):
        self._insert_loss_grad_ops(loss)
251 252 253 254 255 256
        gm_block = self._find_gradient_merge_block()
        if gm_block is not None:
            # TODO(zjl): support fuse allreduce
            self._insert_allreduce_ops_for_gm(gm_block)
            return

257
        if self.fuse_all_reduce_ops and self.fuse_grad_size_in_num > 1:
258 259 260
            self._allreduce_fusion_program()
        else:
            self._insert_allreduce_ops()
261 262 263 264 265 266 267 268 269 270

    def _insert_loss_grad_ops(self, loss):
        """
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        """
        block = self.main_program.global_block()
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
271 272 273 274 275 276 277 278 279 280
                block._insert_op(
                    idx + 1,
                    type='scale',
                    inputs={'X': loss_grad_var},
                    outputs={'Out': loss_grad_var},
                    attrs={
                        'scale': 1.0 / self.nranks,
                        OP_ROLE_KEY: OpRole.Backward,
                    },
                )
281 282 283 284 285

    def _insert_allreduce_ops(self):
        block = self.main_program.global_block()
        ring_id = self.global_ring_id
        grad = None
286
        grad_vars = []
287
        for idx, op in reversed(list(enumerate(block.ops))):
288
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
289 290 291 292 293 294 295 296 297 298 299 300 301
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0
                offset = 1
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.var(param_name)
                    grad_name = op_role_var[i + 1]
                    grad = block.var(grad_name)
                    if param.is_distributed:
                        continue

302
                    grad_vars.append(grad)
303 304 305 306 307 308 309 310 311
                    block._insert_op(
                        idx + offset,
                        type='c_sync_calc_stream',
                        inputs={'X': grad},
                        outputs={'Out': grad},
                        attrs={
                            OP_ROLE_KEY: OpRole.Backward,
                        },
                    )
312
                    offset += 1
313 314 315 316 317 318 319 320 321 322
                    block._insert_op(
                        idx + offset,
                        type='c_allreduce_sum',
                        inputs={'X': grad},
                        outputs={'Out': grad},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Backward,
                        },
                    )
323 324 325 326 327 328

        if grad is None:
            return

        for idx, op in enumerate(block.ops):
            if is_optimizer_op(op):
329 330 331 332 333 334 335
                block._insert_op(
                    idx,
                    type='c_sync_comm_stream',
                    inputs={'X': grad_vars},
                    outputs={'Out': grad_vars},
                    attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
                )
336
                break
337

338 339 340 341 342 343 344
    # This function helps reduce the number of allreduce by integrating op, which can save communication time.
    # to use allreduce fuse, follow these codes:
    # strategy = paddle.distributed.fleet.DistributedStrategy()
    # strategy.without_graph_optimization = True
    # strategy.fuse_all_reduce_ops = True
    # strategy.calc_comm_same_stream = False
    # strategy.fuse_grad_size_in_num = 8
345 346 347
    def _allreduce_fusion_program(self):
        block = self.main_program.global_block()
        ring_id = self.global_ring_id
348
        param_grads = []
349
        first_backward_idx = -1
350

351
        # find all grad params
352
        for idx, op in enumerate(block.ops):
353
            if first_backward_idx == -1 and is_backward_op(op):
354
                first_backward_idx = idx
355
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
356 357 358
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                if len(op_role_var) == 0:
                    continue
359 360 361 362
                assert len(op_role_var) % 2 == 0, (
                    "vars need to be one param var followed by one grad var, "
                    "but got odd number of vars"
                )
363 364 365 366 367 368 369
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.var(param_name)
                    grad_name = op_role_var[i + 1]
                    grad = block.var(grad_name)
                    if param.is_distributed:
                        continue
370 371
                    param_grads.append((param, grad))

372
        outputs_name_to_idx = self.__get_ouputs_name_to_idx(
373 374
            first_backward_idx, block
        )
375

376 377 378 379 380
        # structure of grad_param_segments is
        # [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
        # each entry of the list is a tuple stores the grads segment list and
        # the corresponding params segment list
        grad_param_segments = []
381 382
        last_dtype = None
        # split the grad based on dtype and fused size
383
        for param, grad in param_grads:
384 385 386 387 388
            if (
                len(grad_param_segments) == 0
                or len(grad_param_segments[-1][0]) == self.fuse_grad_size_in_num
                or grad.dtype != last_dtype
            ):
389 390
                grad_param_segments.append(([grad], [param]))
                last_dtype = grad.dtype
391
            else:
392 393
                grad_param_segments[-1][0].append(grad)
                grad_param_segments[-1][1].append(param)
394

395 396
        if len(grad_param_segments) == 0:
            return
397

398 399 400 401 402 403
        fused_vars = [None] * len(grad_param_segments)
        for i in range(len(grad_param_segments) - 1, -1, -1):
            # travers the grad_param_segments in backward
            # not to use reversed since needs the absolute index value
            grad_segment, param_segment = grad_param_segments[i]
            # insert coalesce tensor
404 405 406 407 408 409 410 411
            fused_var = block.create_var(
                name=unique_name.generate(
                    'FusedOutput_{}'.format(grad_segment[0].name)
                ),
                dtype=grad_segment[0].dtype,
                persistable=False,
                stop_gradient=True,
            )
412 413
            fused_vars[i] = fused_var
            after_idx = outputs_name_to_idx[grad_segment[-1]][1]
414 415 416 417 418 419 420 421 422 423 424
            block._insert_op_without_sync(
                after_idx + 1,
                type='c_allreduce_sum',
                inputs={'X': fused_var},
                outputs={'Out': fused_var},
                attrs={
                    'ring_id': ring_id,
                    'use_calc_stream': self.calc_comm_same_stream,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
425 426 427 428 429 430
            if not self.calc_comm_same_stream:
                block._insert_op_without_sync(
                    after_idx + 1,
                    type='c_sync_calc_stream',
                    inputs={'X': fused_var},
                    outputs={'Out': fused_var},
431 432
                    attrs={OP_ROLE_KEY: OpRole.Backward},
                )
433

434
        # update the outputs_name_to_idx after insertion of sync/allreduce ops
435
        outputs_name_to_idx = self.__get_ouputs_name_to_idx(
436 437
            first_backward_idx, block
        )
438 439 440 441 442 443 444 445 446 447 448 449 450
        # the before_idx is not guaranteed sorted, therefore we have to find the
        # topology to insert the coalesce ops
        pos_for_coalesce = {}
        for i in range(len(grad_param_segments) - 1, -1, -1):
            # We separate the insertion of coalesce op and the insertion of sync/allreduce op,
            # since that the coalesce op's insertion may invalidate the outputs_name_to_idx
            grad_segment, param_segment = grad_param_segments[i]
            before_idx = len(block.ops)
            for grad in outputs_name_to_idx:
                before_idx = min(before_idx, outputs_name_to_idx[grad][0])
            pos_for_coalesce[i] = before_idx

        # insert the coalesce op based on the sorted before_idx
451 452 453 454 455
        pos_for_coalesce = sorted(
            pos_for_coalesce.items(),
            key=lambda kv: (kv[1], kv[0]),
            reverse=True,
        )
456 457 458
        for i, before_idx in pos_for_coalesce:
            grad_segment, param_segment = grad_param_segments[i]
            fused_var = fused_vars[i]
459 460 461 462 463 464 465 466 467 468 469 470
            block._insert_op_without_sync(
                before_idx,
                type="coalesce_tensor",
                inputs={"Input": param_segment},
                outputs={"Output": grad_segment, "FusedOutput": fused_var},
                attrs={
                    "copy_data": False,
                    "use_align": True,
                    "dtype": grad_segment[0].dtype,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
471 472

        if self.calc_comm_same_stream:
473 474
            block._sync_with_cpp()
            return
475

476 477 478
        # insert the sync comm op
        for idx, op in enumerate(block.ops):
            if is_optimizer_op(op):
479 480 481 482 483 484 485
                block._insert_op_without_sync(
                    idx,
                    type='c_sync_comm_stream',
                    inputs={'X': fused_vars},
                    outputs={'Out': fused_vars},
                    attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
                )
486 487
                break
        block._sync_with_cpp()
488 489 490 491 492 493 494 495 496 497 498 499 500

    def __get_ouputs_name_to_idx(self, first_backward_idx, block):
        # Each item of outputs_name_to_idx is a pair of idx.
        # The first entry of this pair is the idx of the first op generates the grad,
        # which is used to indicate the position to insert coalesce op.
        # The second entry of this pair is the idx of the last op generates the grad,
        # which is used to indicate the position to insert sync and allreduce op.
        outputs_name_to_idx = {}
        for idx in range(first_backward_idx, len(block.ops)):
            op = block.ops[idx]
            if is_optimizer_op(op):
                break
            for name in op.output_arg_names:
李季 已提交
501 502
                if name == core.kEmptyVarName():
                    continue
503 504 505 506 507 508
                var = block.var(name)
                if not outputs_name_to_idx.get(var):
                    # if the grad only be generated by one op
                    # the first idx and the last ids are identical
                    outputs_name_to_idx[var] = (idx, idx)
                else:
509 510 511 512
                    outputs_name_to_idx[var] = (
                        outputs_name_to_idx[var][0],
                        idx,
                    )
513
        return outputs_name_to_idx