raw_program_optimizer.py 19.2 KB
Newer Older
1
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

W
wangxiaoning 已提交
14 15 16
import paddle.static as static
from paddle.fluid import core
from paddle.utils import unique_name
17

18 19 20 21
from .common import (
    OP_ROLE_KEY,
    OP_ROLE_VAR_KEY,
    CollectiveHelper,
22
    OpRole,
23
    is_backward_op,
24
    is_loss_grad_op,
25 26
    is_optimizer_op,
)
27
from .meta_optimizer_base import MetaOptimizerBase
28 29 30 31


class RawProgramOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
32
        super().__init__(optimizer)
33 34 35 36
        self.inner_opt = optimizer
        self.meta_optimizers_white_list = [
            "RecomputeOptimizer",
            "AMPOptimizer",
37 38 39 40 41
            "GradientMergeOptimizer",
            "LambOptimizer",
            "LarsOptimizer",
            "DGCOptimizer",
            "LocalSGDOptimizer",
42
        ]
43
        self.meta_optimizers_black_list = []
44 45
        self.global_ring_id = 0

46 47 48
    def _set_basic_info(
        self, loss, role_maker, user_defined_optimizer, user_defined_strategy
    ):
49
        super()._set_basic_info(
50 51 52 53 54
            loss, role_maker, user_defined_optimizer, user_defined_strategy
        )
        self.without_graph_optimization = (
            user_defined_strategy.without_graph_optimization
        )
55 56
        self.fuse_all_reduce_ops = user_defined_strategy.fuse_all_reduce_ops
        if self.fuse_all_reduce_ops:
57 58 59 60 61 62
            self.fuse_grad_size_in_num = (
                user_defined_strategy.fuse_grad_size_in_num
            )
            self.calc_comm_same_stream = (
                user_defined_strategy._calc_comm_same_stream
            )
63 64 65 66

    def _can_apply(self):
        if not self.role_maker._is_collective:
            return False
67 68 69 70
        if self.user_defined_strategy.tensor_parallel:
            return False
        if self.user_defined_strategy.sharding:
            return False
71

72
        if self.without_graph_optimization:
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
            return True
        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.without_graph_optimization = False

    def _enable_strategy(self, dist_strategy, context):
        dist_strategy.without_graph_optimization = True

    def _broadcast_params(self, ring_id):
        block = self.startup_program.global_block()
        param = None
        for param in block.iter_parameters():
            if param.is_distributed:
                continue

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
            block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': ring_id,
                    'root': 0,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )

        if not param:
            return  # no parameter on this device
        block.append_op(
            type='c_sync_comm_stream',
            inputs={'X': param},
            outputs={'Out': param},
            attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward},
        )
108 109 110 111 112 113 114 115 116 117 118

    def _get_process_group_info(self):
        # global ring info
        self.global_endpoints = self.endpoints
        self.global_rank = self.rank
        self.global_nranks = self.nranks

    def _init_process_group(self):
        self._get_process_group_info()
        collective_helper = CollectiveHelper(self.role_maker, wait_port=False)
        # Create global ring for all gpus (ring_id = 0)
119 120 121 122 123 124 125 126 127 128
        collective_helper._init_communicator(
            self.startup_program,
            self.current_endpoint,
            self.global_endpoints,
            self.global_rank,
            self.global_ring_id,
            True,
            self.global_ring_id,
            True,
        )
129 130
        self._broadcast_params(self.global_ring_id)

131 132 133
    def minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
134 135 136 137 138
        self.endpoints = self.role_maker._get_trainer_endpoints()
        self.current_endpoint = self.endpoints[self.role_maker._worker_index()]
        self.rank = self.role_maker._worker_index()
        self.nranks = self.role_maker._worker_num()
        if startup_program is None:
W
wangxiaoning 已提交
139
            startup_program = static.default_startup_program()
140 141 142 143 144 145 146
        self.startup_program = startup_program

        block = loss.block
        program = block.program
        self.main_program = program

        optimize_ops, params_grads = self.inner_opt.minimize(
147 148
            loss, startup_program, parameter_list, no_grad_set
        )
李季 已提交
149 150
        if self.nranks == 1:
            return optimize_ops, params_grads
151 152 153 154 155 156 157
        self._init_process_group()

        self.main_program = program
        if self.nranks > 1:
            self._transpile_main_program(loss)
        return optimize_ops, params_grads

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
    def _find_gradient_merge_block(self):
        GRAD_MERGE_COND_NAME = "grad_merge_cond_name"
        gm_cond_var_name = None
        for op in self.main_program.global_block().ops:
            if GRAD_MERGE_COND_NAME not in op.attr_names:
                continue
            if gm_cond_var_name is None:
                gm_cond_var_name = op.attr(GRAD_MERGE_COND_NAME)
            else:
                assert gm_cond_var_name == op.attr(
                    GRAD_MERGE_COND_NAME
                ), "multiple gradient merge condition found"
        if gm_cond_var_name is None:
            return None

173 174 175
        cond_op = (
            None  # false_fn of gm is None, so we should only find one block
        )
176 177 178 179 180 181 182 183 184 185 186 187 188 189
        for op in self.main_program.global_block().ops:
            if op.type != 'conditional_block' or 'Cond' not in op.input_names:
                continue
            cond_vars = op.input('Cond')
            if not cond_vars or cond_vars[0] != gm_cond_var_name:
                continue
            assert cond_op is None, "multiple gradient merge block found"
            cond_op = op
        assert cond_op is not None, "cannot find gradient merge block"
        return cond_op._block_attr("sub_block")

    def _insert_allreduce_ops_for_gm(self, gm_block):
        block = self.main_program.global_block()

190 191 192 193
        first_optimize_op_idx = None
        for i, op in reversed(list(enumerate(gm_block.ops))):
            if is_backward_op(op) and first_optimize_op_idx is None:
                first_optimize_op_idx = i + 1
194
                break
195 196
        if first_optimize_op_idx is None:
            first_optimize_op_idx = 0
197 198 199 200

        param_vars = []
        grad_vars = []
        for op in block.ops:
201
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
202 203 204 205 206 207 208 209 210 211 212 213 214
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                assert len(op_role_var) % 2 == 0
                for i in range(0, len(op_role_var), 2):
                    param = block.var(op_role_var[i])
                    grad = block.var(op_role_var[i + 1])
                    if param.is_distributed:
                        continue
                    param_vars.append(param)
                    grad_vars.append(grad)

        if not grad_vars:
            return

215 216 217 218 219 220 221
        gm_block._insert_op(
            first_optimize_op_idx,
            type="c_sync_calc_stream",
            inputs={'X': grad_vars[0]},
            outputs={'Out': grad_vars[0]},
            attrs={OP_ROLE_KEY: OpRole.Backward},
        )
222 223 224 225 226 227

        insert_op_num = 1
        ring_id = self.global_ring_id

        # NOTE: can perform fuse allreduce inside the loop in the future
        for i, (p, g) in enumerate(zip(param_vars, grad_vars)):
228 229 230 231 232 233 234 235 236 237
            gm_block._insert_op(
                first_optimize_op_idx + insert_op_num,
                type="c_allreduce_sum",
                inputs={'X': g},
                outputs={'Out': g},
                attrs={
                    'ring_id': ring_id,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
238 239
            insert_op_num += 1

240 241 242 243 244 245 246 247 248 249
        gm_block._insert_op(
            first_optimize_op_idx + insert_op_num,
            type="c_sync_comm_stream",
            inputs={'X': grad_vars},
            outputs={'Out': grad_vars},
            attrs={
                'ring_id': ring_id,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
250

251 252
    def _transpile_main_program(self, loss):
        self._insert_loss_grad_ops(loss)
253 254 255 256 257 258
        gm_block = self._find_gradient_merge_block()
        if gm_block is not None:
            # TODO(zjl): support fuse allreduce
            self._insert_allreduce_ops_for_gm(gm_block)
            return

259
        if self.fuse_all_reduce_ops and self.fuse_grad_size_in_num > 1:
260 261 262
            self._allreduce_fusion_program()
        else:
            self._insert_allreduce_ops()
263 264 265 266 267 268 269 270 271 272

    def _insert_loss_grad_ops(self, loss):
        """
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        """
        block = self.main_program.global_block()
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
273 274 275 276 277 278 279 280 281 282
                block._insert_op(
                    idx + 1,
                    type='scale',
                    inputs={'X': loss_grad_var},
                    outputs={'Out': loss_grad_var},
                    attrs={
                        'scale': 1.0 / self.nranks,
                        OP_ROLE_KEY: OpRole.Backward,
                    },
                )
283 284 285 286 287

    def _insert_allreduce_ops(self):
        block = self.main_program.global_block()
        ring_id = self.global_ring_id
        grad = None
288
        grad_vars = []
289
        for idx, op in reversed(list(enumerate(block.ops))):
290
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
291 292 293 294 295 296 297 298 299 300 301 302 303
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0
                offset = 1
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.var(param_name)
                    grad_name = op_role_var[i + 1]
                    grad = block.var(grad_name)
                    if param.is_distributed:
                        continue

304
                    grad_vars.append(grad)
305 306 307 308 309 310 311 312 313
                    block._insert_op(
                        idx + offset,
                        type='c_sync_calc_stream',
                        inputs={'X': grad},
                        outputs={'Out': grad},
                        attrs={
                            OP_ROLE_KEY: OpRole.Backward,
                        },
                    )
314
                    offset += 1
315 316 317 318 319 320 321 322 323 324
                    block._insert_op(
                        idx + offset,
                        type='c_allreduce_sum',
                        inputs={'X': grad},
                        outputs={'Out': grad},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Backward,
                        },
                    )
325 326 327 328 329 330

        if grad is None:
            return

        for idx, op in enumerate(block.ops):
            if is_optimizer_op(op):
331 332 333 334 335 336 337
                block._insert_op(
                    idx,
                    type='c_sync_comm_stream',
                    inputs={'X': grad_vars},
                    outputs={'Out': grad_vars},
                    attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
                )
338
                break
339

340 341 342 343 344 345 346
    # This function helps reduce the number of allreduce by integrating op, which can save communication time.
    # to use allreduce fuse, follow these codes:
    # strategy = paddle.distributed.fleet.DistributedStrategy()
    # strategy.without_graph_optimization = True
    # strategy.fuse_all_reduce_ops = True
    # strategy.calc_comm_same_stream = False
    # strategy.fuse_grad_size_in_num = 8
347 348 349
    def _allreduce_fusion_program(self):
        block = self.main_program.global_block()
        ring_id = self.global_ring_id
350
        param_grads = []
351
        first_backward_idx = -1
352

353
        # find all grad params
354
        for idx, op in enumerate(block.ops):
355
            if first_backward_idx == -1 and is_backward_op(op):
356
                first_backward_idx = idx
357
            if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
358 359 360
                op_role_var = op.attr(OP_ROLE_VAR_KEY)
                if len(op_role_var) == 0:
                    continue
361 362 363 364
                assert len(op_role_var) % 2 == 0, (
                    "vars need to be one param var followed by one grad var, "
                    "but got odd number of vars"
                )
365 366 367 368 369 370 371
                for i in range(0, len(op_role_var), 2):
                    param_name = op_role_var[i]
                    param = block.var(param_name)
                    grad_name = op_role_var[i + 1]
                    grad = block.var(grad_name)
                    if param.is_distributed:
                        continue
372 373
                    param_grads.append((param, grad))

374
        outputs_name_to_idx = self.__get_ouputs_name_to_idx(
375 376
            first_backward_idx, block
        )
377

378 379 380 381 382
        # structure of grad_param_segments is
        # [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
        # each entry of the list is a tuple stores the grads segment list and
        # the corresponding params segment list
        grad_param_segments = []
383 384
        last_dtype = None
        # split the grad based on dtype and fused size
385
        for param, grad in param_grads:
386 387 388 389 390
            if (
                len(grad_param_segments) == 0
                or len(grad_param_segments[-1][0]) == self.fuse_grad_size_in_num
                or grad.dtype != last_dtype
            ):
391 392
                grad_param_segments.append(([grad], [param]))
                last_dtype = grad.dtype
393
            else:
394 395
                grad_param_segments[-1][0].append(grad)
                grad_param_segments[-1][1].append(param)
396

397 398
        if len(grad_param_segments) == 0:
            return
399

400 401 402 403 404 405
        fused_vars = [None] * len(grad_param_segments)
        for i in range(len(grad_param_segments) - 1, -1, -1):
            # travers the grad_param_segments in backward
            # not to use reversed since needs the absolute index value
            grad_segment, param_segment = grad_param_segments[i]
            # insert coalesce tensor
406 407 408 409 410 411 412 413
            fused_var = block.create_var(
                name=unique_name.generate(
                    'FusedOutput_{}'.format(grad_segment[0].name)
                ),
                dtype=grad_segment[0].dtype,
                persistable=False,
                stop_gradient=True,
            )
414 415
            fused_vars[i] = fused_var
            after_idx = outputs_name_to_idx[grad_segment[-1]][1]
416 417 418 419 420 421 422 423 424 425 426
            block._insert_op_without_sync(
                after_idx + 1,
                type='c_allreduce_sum',
                inputs={'X': fused_var},
                outputs={'Out': fused_var},
                attrs={
                    'ring_id': ring_id,
                    'use_calc_stream': self.calc_comm_same_stream,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
427 428 429 430 431 432
            if not self.calc_comm_same_stream:
                block._insert_op_without_sync(
                    after_idx + 1,
                    type='c_sync_calc_stream',
                    inputs={'X': fused_var},
                    outputs={'Out': fused_var},
433 434
                    attrs={OP_ROLE_KEY: OpRole.Backward},
                )
435

436
        # update the outputs_name_to_idx after insertion of sync/allreduce ops
437
        outputs_name_to_idx = self.__get_ouputs_name_to_idx(
438 439
            first_backward_idx, block
        )
440 441 442 443 444 445 446 447 448 449 450 451 452
        # the before_idx is not guaranteed sorted, therefore we have to find the
        # topology to insert the coalesce ops
        pos_for_coalesce = {}
        for i in range(len(grad_param_segments) - 1, -1, -1):
            # We separate the insertion of coalesce op and the insertion of sync/allreduce op,
            # since that the coalesce op's insertion may invalidate the outputs_name_to_idx
            grad_segment, param_segment = grad_param_segments[i]
            before_idx = len(block.ops)
            for grad in outputs_name_to_idx:
                before_idx = min(before_idx, outputs_name_to_idx[grad][0])
            pos_for_coalesce[i] = before_idx

        # insert the coalesce op based on the sorted before_idx
453 454 455 456 457
        pos_for_coalesce = sorted(
            pos_for_coalesce.items(),
            key=lambda kv: (kv[1], kv[0]),
            reverse=True,
        )
458 459 460
        for i, before_idx in pos_for_coalesce:
            grad_segment, param_segment = grad_param_segments[i]
            fused_var = fused_vars[i]
461 462 463 464 465 466 467 468 469 470 471 472
            block._insert_op_without_sync(
                before_idx,
                type="coalesce_tensor",
                inputs={"Input": param_segment},
                outputs={"Output": grad_segment, "FusedOutput": fused_var},
                attrs={
                    "copy_data": False,
                    "use_align": True,
                    "dtype": grad_segment[0].dtype,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
473 474

        if self.calc_comm_same_stream:
475 476
            block._sync_with_cpp()
            return
477

478 479 480
        # insert the sync comm op
        for idx, op in enumerate(block.ops):
            if is_optimizer_op(op):
481 482 483 484 485 486 487
                block._insert_op_without_sync(
                    idx,
                    type='c_sync_comm_stream',
                    inputs={'X': fused_vars},
                    outputs={'Out': fused_vars},
                    attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
                )
488 489
                break
        block._sync_with_cpp()
490 491 492 493 494 495 496 497 498 499 500 501 502

    def __get_ouputs_name_to_idx(self, first_backward_idx, block):
        # Each item of outputs_name_to_idx is a pair of idx.
        # The first entry of this pair is the idx of the first op generates the grad,
        # which is used to indicate the position to insert coalesce op.
        # The second entry of this pair is the idx of the last op generates the grad,
        # which is used to indicate the position to insert sync and allreduce op.
        outputs_name_to_idx = {}
        for idx in range(first_backward_idx, len(block.ops)):
            op = block.ops[idx]
            if is_optimizer_op(op):
                break
            for name in op.output_arg_names:
李季 已提交
503 504
                if name == core.kEmptyVarName():
                    continue
505 506 507 508 509 510
                var = block.var(name)
                if not outputs_name_to_idx.get(var):
                    # if the grad only be generated by one op
                    # the first idx and the last ids are identical
                    outputs_name_to_idx[var] = (idx, idx)
                else:
511 512 513 514
                    outputs_name_to_idx[var] = (
                        outputs_name_to_idx[var][0],
                        idx,
                    )
515
        return outputs_name_to_idx