dgc_optimizer.py 20.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

14
import logging
15
from functools import reduce
16

17 18
from .meta_optimizer_base import MetaOptimizerBase

19 20
__all__ = []

21 22
import paddle
from paddle.common_ops_import import LayerHelper
23
from paddle.fluid import framework
24
from paddle.fluid.dygraph import base as imperative_base
25
from paddle.framework import core, in_dynamic_mode
26
from paddle.nn.clip import ClipGradByNorm, append_gradient_clip_ops
27
from paddle.optimizer import Momentum, Optimizer
28
from paddle.regularizer import L1Decay, L2Decay
29
from paddle.static import create_global_var
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49


class DGCMomentumOptimizer(Optimizer):
    _u_velocity_acc_str = "_dgc_u_"
    _v_velocity_acc_str = "_dgc_v_"

    def __init__(
        self,
        learning_rate,
        momentum,
        rampup_begin_step,
        rampup_step=1,
        sparsity=[0.999],
        parameter_list=None,
        use_nesterov=False,
        num_trainers=None,
        regularization=None,
        grad_clip=None,
        name=None,
    ):
50
        if in_dynamic_mode():
51 52 53 54 55 56 57 58 59 60
            raise Exception("In dygraph, don't support DGCMomentumOptimizer.")

        assert (
            core.is_compiled_with_cuda()
        ), "Paddle is not compiled with CUDA. DGC is only support GPU for now."

        assert learning_rate is not None
        assert momentum is not None
        super().__init__(
            learning_rate=learning_rate,
61 62
            parameters=parameter_list,
            weight_decay=regularization,
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
            grad_clip=grad_clip,
            name=name,
        )
        self.type = "dgc_momentum"
        self._momentum = momentum
        self._use_nesterov = bool(use_nesterov)

        assert rampup_begin_step >= 0, "rampup_begin_step must >= 0"
        self._rampup_begin_step = rampup_begin_step
        self._rampup_step = rampup_step
        self._sparsity = sparsity

        self._rampup_begin_step_var = None
        self._global_step_var = None

        self._dgc_clip_norm = None
79
        self._num_trainers = num_trainers
80
        if grad_clip is not None:
81
            if not isinstance(grad_clip, ClipGradByNorm):
82
                raise TypeError(
83
                    "The type of grad_clip should be 'ClipGradByNorm', because DGCMomentumOptimizer only support ClipGradByNorm"
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
                )
            assert isinstance(num_trainers, int), (
                "The type of num_trainers should be 'int', but received %s"
                % type(num_trainers)
            )
            assert (
                num_trainers > 0
            ), "The value of num_trainers should be greater than 0!"

            self._dgc_clip_norm = grad_clip.clip_norm * (num_trainers**-0.5)

        self.regular_type, self.regular_coeff = self._get_regularization_param(
            self.regularization
        )

    def _get_regularization_param(self, regularization):
        regular_type = 0
        regular_coeff = 0.0

        if regularization is not None:
104
            regular_coeff = regularization._coeff
105 106 107 108 109 110

            if isinstance(regularization, L1Decay):
                regular_type = 1
            elif isinstance(regularization, L2Decay):
                regular_type = 2
            else:
111 112 113
                raise AssertionError(
                    "regularization must be None|L1Decay|L2Deacy"
                )
114 115 116
        return regular_type, regular_coeff

    def _is_use_dgc(self, param_var, grad_var):
117
        var_numel = abs(reduce(lambda x, y: x * y, param_var.shape, 1))
118 119 120 121 122 123 124 125 126 127
        if (
            var_numel < 16384
            or param_var.type == core.VarDesc.VarType.SELECTED_ROWS
            or grad_var.type == core.VarDesc.VarType.SELECTED_ROWS
            or param_var.dtype != core.VarDesc.VarType.FP32
        ):
            return False
        return True

    def _append_optimize_op(self, block, param_and_grad):
W
wangzhen38 已提交
128
        assert isinstance(block, paddle.framework.Block)
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        velocity_acc = self._get_accumulator(
            self._u_velocity_acc_str, param_and_grad[0]
        )
        assert velocity_acc is not None

        inputs = {
            "Param": param_and_grad[0],
            "Grad": param_and_grad[1],
            "Velocity": velocity_acc,
            "LearningRate": self._create_param_lr(param_and_grad),
        }
        outputs = {
            "ParamOut": param_and_grad[0],
            "VelocityOut": velocity_acc,
        }
        attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov}

        if not self._is_use_dgc(param_and_grad[0], param_and_grad[1]):
            type = "momentum"
        else:
            type = "dgc_momentum"
            inputs.update(
                {
                    "current_step": self._global_step_var,
                    "nranks": self._nranks_var,
                }
            )
            outputs.update({'Grad_out': param_and_grad[1]})
            attrs.update({"rampup_begin_step": float(self._rampup_begin_step)})

        # create the dgc momentum optimize op
        dgc_momentum_op = block.append_op(
            type=type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
            stop_gradient=True,
        )
        return dgc_momentum_op

    def _add_auto_increment_var(self, counter_name, begin, step=1):
        helper = LayerHelper('global_step_counter')
        counter, is_new_var = helper.create_or_get_global_variable(
            name=counter_name, dtype='float32', shape=[1], persistable=True
        )
        if is_new_var:
            helper.set_variable_initializer(
                counter,
177
                initializer=paddle.nn.initializer.ConstantInitializer(
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
                    value=float(begin - 1), force_cpu=True
                ),
            )
            helper.main_program.global_block()._prepend_op(
                type='increment',
                inputs={'X': [counter]},
                outputs={'Out': [counter]},
                attrs={'step': float(step)},
                stop_gradient=True,
            )
            counter.stop_gradient = True

        return counter

    def _add_nranks_var(self, name, value=-1):
        helper = LayerHelper('global_step_counter')
        counter, is_new_var = helper.create_or_get_global_variable(
            name=name, dtype='float32', shape=[1], persistable=True
        )
        if is_new_var:
            helper.set_variable_initializer(
                counter,
200
                initializer=paddle.nn.initializer.ConstantInitializer(
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
                    value=float(value), force_cpu=True
                ),
            )
            counter.stop_gradient = True

        return counter

    def _append_dgc_ops(self, param_and_grads):
        main_program = paddle.static.default_main_program()
        main_program._enable_dgc = True

        # step counter
        self._global_step_var = self._add_auto_increment_var(
            counter_name=core.dgc.kDGCCounterName(), begin=0
        )

        self._nranks_var = self._add_nranks_var(
218
            name=core.dgc.kDGCNRanksName(), value=self._num_trainers
219 220 221
        )

        # rampup begin step var for all_reduce_op_handle
222
        self._rampup_begin_step_var = create_global_var(
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
            shape=[1],
            dtype=core.VarDesc.VarType.FP32,
            persistable=True,
            name=core.dgc.kDGCRampUpBeginStepName(),
            value=self._rampup_begin_step * 1.0,
            force_cpu=True,
        )

        self.helper = LayerHelper(self.__class__.__name__)

        for param_var, grad_var in param_and_grads:
            # reuse velocity in dgc_op and dgc_momentum_op
            u_var = self._add_accumulator(self._u_velocity_acc_str, param_var)

            if not self._is_use_dgc(param_var, grad_var):
                continue

            v_var = self._add_accumulator(self._v_velocity_acc_str, param_var)

242
            k_var = create_global_var(
243 244 245 246 247 248 249 250
                shape=[1],
                dtype=param_var.dtype,
                persistable=True,
                name=param_var.name + core.dgc.kDGCKName(),
                value=0.0,
                force_cpu=True,
            )

251
            encoded_var = create_global_var(
252 253 254 255 256 257 258 259
                shape=[1],
                dtype=param_var.dtype,
                persistable=True,
                name=param_var.name + core.dgc.kDGCEncodedName(),
                value=0.0,
                force_cpu=False,
            )

260
            gather_var = create_global_var(
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
                shape=[1],
                dtype=param_var.dtype,
                persistable=True,
                name=param_var.name + core.dgc.kDGCGatherName(),
                value=0.0,
                force_cpu=False,
            )

            # del back oprolevarname
            op_maker = core.op_proto_and_checker_maker
            backward = core.op_proto_and_checker_maker.OpRole.Backward
            for op in main_program.global_block().ops:
                if not self._is_the_backward_op(op):
                    continue

                var_attr = op.all_attrs()[op_maker.kOpRoleVarAttrName()]
                if param_var.name not in var_attr:
                    continue

                var_attr.remove(param_var.name)
                var_attr.remove(grad_var.name)
                if len(var_attr) > 1:
                    op._set_attr(op_maker.kOpRoleVarAttrName(), var_attr)
                else:
                    op._remove_attr(op_maker.kOpRoleVarAttrName())

            clip_var = grad_var
            if self._dgc_clip_norm is not None:
                clip_var = self._append_clip_norm(grad_var, self._dgc_clip_norm)
            self._dgc_op(
                param_var,
                clip_var,
                grad_var,
                u_var,
                v_var,
                k_var,
                encoded_var,
                gather_var,
            )

    def _is_the_backward_op(self, op):
        op_maker = core.op_proto_and_checker_maker
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        if op_maker.kOpRoleVarAttrName() in op.attr_names and int(
            op.all_attrs()[op_maker.kOpRoleAttrName()]
        ) == int(backward):
            return True
        return False

    def _clip_by_norm(self, x, max_norm, name=None):
        args = {'x': x, 'max_norm': max_norm, 'name': name}

        helper = LayerHelper("dgc_clip_by_norm_op", **args)

        if name is None:
            name = paddle.fluid.unique_name.generate_with_ignorable_key(
                ".".join([helper.name, 'tmp'])
            )

        out = helper.create_variable(
            type=x.type, name=name, dtype=x.dtype, persistable=False
        )

        helper.append_op(
            type="dgc_clip_by_norm",
            inputs={"X": x, "current_step": self._global_step_var},
            attrs={
                "max_norm": max_norm,
                "rampup_begin_step": float(self._rampup_begin_step),
            },
            outputs={"Out": out},
        )
        return out

    def _append_clip_norm(self, grad_var, clip_norm):
        with grad_var.block.program._backward_role_guard():
            return self._clip_by_norm(
                x=grad_var, max_norm=clip_norm, name=grad_var.name
            )

    def _dgc_op(
        self,
        param_var,
        clip_var,
        grad_var,
        u_var,
        v_var,
        k_var,
        encoded_var,
        gather_var,
    ):
        block = paddle.static.default_main_program().global_block()
        op_maker = core.op_proto_and_checker_maker

        regular_type = self.regular_type
        regular_coeff = self.regular_coeff
        # The regularizer of the Parameters have higher priority
        if param_var.regularizer is not None:
            regular_type, regular_coeff = self._get_regularization_param(
                param_var.regularizer
            )

        dgc_op = block.append_op(
            type="dgc",
            inputs={
                "U": u_var,
                "V": v_var,
                "Grad": clip_var,
                "Param": param_var,
                "current_step": self._global_step_var,
                "nranks": self._nranks_var,
            },
            outputs={
                "U_out": u_var,
                "V_out": v_var,
                "EncodeGrad": encoded_var,
                "k": k_var,
                "Grad_out": grad_var,
                "GatherBuff": gather_var,
            },
            attrs={
                "m": self._momentum,
                "sparsity": self._sparsity,
                "use_nesterov": self._use_nesterov,
                "rampup_begin_step": float(self._rampup_begin_step),
                "rampup_step": float(self._rampup_step),
                "regular_coeff": float(regular_coeff),
                "regular_type": int(regular_type),
            },
            stop_gradient=True,
        )

        backward = op_maker.OpRole.Backward
        dgc_op._set_attr(op_maker.kOpRoleAttrName(), backward)
        dgc_op._set_attr(
            op_maker.kOpRoleVarAttrName(), [param_var.name, grad_var.name]
        )

399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
    def _process_distribute_lookuptable(self, param_grads):
        """
        Because distribute lookup table only support SGD optimizer for now, not support
        other optimizer and regularization, so we should find the table parameter out,
        and avoid to add regularization and other op for it, and add sgd optimize op
        for it independently.
        :param param_grads(list((Var, Var))): list of (param, grad) pair.
        :param loss: the loss variable.
        :param startup_program: the startup program
        """
        from paddle.distributed.distribute_lookup_table import (
            find_distributed_lookup_table,
        )

        program = framework.default_main_program()
        global_block = framework.default_main_program().global_block()
        table_name = find_distributed_lookup_table(program)
        table_param = None
        table_grad = None
        new_param_grads = []
        for p, g in param_grads:
            if p.name == table_name:
                if table_param is not None:
                    raise RuntimeError(
                        "multi dist table var found, only support one now!"
                    )
                table_param = p
                table_grad = g
            else:
                new_param_grads.append((p, g))
        sgd_op = None
        if table_param is not None:
            param_and_grad = [table_param, table_grad]
            with table_param.block.program._optimized_guard(
                param_and_grad
            ), framework.name_scope("optimizer"):
                self._create_global_learning_rate()
                # create the optimize op
                sgd_op = global_block.append_op(
                    type='sgd',
                    inputs={
                        "Param": table_param,
                        "Grad": table_grad,
                        "LearningRate": self._create_param_lr(param_and_grad),
                    },
                    outputs={"ParamOut": param_and_grad[0]},
                )
        return new_param_grads, (table_param, table_grad), sgd_op

448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
    @imperative_base.no_grad()
    def apply_gradients(self, params_grads):
        # Note: since we can't use all_reduce_op now,
        # dgc_op should be the last op of one grad.
        # Maybe need a grad allreduce pass.
        self._append_dgc_ops(params_grads)

        params_grads = sorted(params_grads, key=lambda x: x[0].name)
        (
            params_grads,
            table_param_and_grad,
            table_optimize_op,
        ) = self._process_distribute_lookuptable(params_grads)

        not_dgc_params_grads = []
        dgc_params_grads = []
        # DGC clip and regularization in optimizer.backward
        for param, grad in params_grads:
            if not self._is_use_dgc(param, grad):
                not_dgc_params_grads.append((param, grad))
            else:
                dgc_params_grads.append((param, grad))

        # 'optimizer(grad_clip)' or 'set_gradient_clip'
        if self._grad_clip is not None:
            not_dgc_params_grads = self._grad_clip(not_dgc_params_grads)
        else:
            not_dgc_params_grads = append_gradient_clip_ops(
                not_dgc_params_grads
            )

        not_dgc_params_grads = self.append_regularization_ops(
            not_dgc_params_grads, self.regularization
        )

        params_grads = not_dgc_params_grads + dgc_params_grads
        params_grads = sorted(params_grads, key=lambda x: x[0].name)

        optimize_ops = self._create_optimization_pass(params_grads)
        if table_optimize_op is not None:
            optimize_ops.append(table_optimize_op)
            params_grads.append(table_param_and_grad)

        return optimize_ops

493 494 495

class DGCOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
496
        super().__init__(optimizer)
497 498 499 500
        self.inner_opt = optimizer
        self.dgc_opt = None
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = []
501
        self.meta_optimizers_black_list = []
502

503 504 505
    def _set_basic_info(
        self, loss, role_maker, user_defined_optimizer, user_defined_strategy
    ):
506
        super()._set_basic_info(
507 508
            loss, role_maker, user_defined_optimizer, user_defined_strategy
        )
509

510 511 512 513
    def _init_dgc_opt(self):
        if self.dgc_opt is not None:
            return

514
        opt = self.inner_opt
515 516 517 518

        if not self.role_maker._is_collective:
            return

519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
        if not isinstance(opt, Momentum):
            return

        configs = self.user_defined_strategy.dgc_configs
        if len(configs['sparsity']) == 0:
            # default is [0.999]
            configs['sparsity'] = [0.999]

        self.dgc_opt = DGCMomentumOptimizer(
            learning_rate=opt._learning_rate,
            momentum=opt._momentum,
            rampup_begin_step=configs['rampup_begin_step'],
            rampup_step=configs['rampup_step'],
            sparsity=configs['sparsity'],
            parameter_list=opt._parameter_list,
            use_nesterov=opt._use_nesterov,
535
            num_trainers=self.role_maker._worker_num(),
536 537
            regularization=opt.regularization,
            grad_clip=opt._grad_clip,
538 539
            name=opt._name,
        )
540 541

    def _can_apply(self):
542 543 544
        if not self.role_maker._is_collective:
            return False

545 546 547 548
        if self.user_defined_strategy.dgc:
            if not isinstance(self.inner_opt, Momentum):
                logging.warn("dgc only works on Momentum optimizer")
                return False
549
            if self.role_maker._worker_num() <= 1:
550 551 552 553 554 555 556 557 558
                logging.warn("dgc only works on multi cards")
                return False

            return True

        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.dgc = False
559
        dist_strategy.dgc_configs = {}
560

561
    def _enable_strategy(self, dist_strategy, context):
562 563 564
        dist_strategy.dgc = True
        dist_strategy.dgc_configs = {"rampup_begin_step": 0, "rampup_step": 1}

565 566 567 568 569 570 571 572
    def backward(
        self,
        loss,
        startup_program=None,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None,
    ):
573
        self._init_dgc_opt()
574 575 576
        return self.dgc_opt.backward(
            loss, startup_program, parameter_list, no_grad_set, callbacks
        )
577

578
    def apply_gradients(self, params_grads):
579
        self._init_dgc_opt()
580 581 582
        return self.dgc_opt.apply_gradients(params_grads=params_grads)

    def apply_optimize(self, loss, startup_program, params_grads):
583
        self._init_dgc_opt()
584
        return self.dgc_opt._apply_optimize(
585 586 587 588 589 590
            loss, startup_program=startup_program, params_grads=params_grads
        )

    def minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
591
        self._init_dgc_opt()
592 593 594
        optimize_ops, params_grads = self.dgc_opt.minimize(
            loss, startup_program, parameter_list, no_grad_set
        )
595
        return optimize_ops, params_grads