optimizer.py 167.4 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
16

17
import numpy as np
18
from collections import defaultdict
19

Q
Qiao Longfei 已提交
20
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
21
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
22

23 24
from . import framework
from . import layers
25
from . import unique_name
26
from .backward import append_backward, _some_in_set_, _append_grad_suffix_, _get_no_grad_set_name
27
from .clip import GradientClipBase, error_clip_callback, append_gradient_clip_ops
28 29 30
from .framework import program_guard
from .initializer import Constant
from .layer_helper import LayerHelper
S
sneaxiy 已提交
31
from .layers import ops
32
from .regularizer import append_regularization_ops
33
from .dygraph import base as imperative_base
34
from .dygraph import no_grad
35 36 37 38
from .dygraph.learning_rate_scheduler import LearningRateDecay
from paddle.fluid import core
from paddle.fluid.layers import tensor
from functools import reduce
39
from .wrapped_decorator import signature_safe_contextmanager
M
mapingshuo 已提交
40
from .. import compat as cpt
41

42
__all__ = [
43 44 45 46
    'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'Dpsgd', 'DecayedAdagrad',
    'Ftrl', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer',
    'AdamOptimizer', 'AdamaxOptimizer', 'DpsgdOptimizer',
    'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta',
Z
Zeng Jinle 已提交
47 48 49 50
    'AdadeltaOptimizer', 'ModelAverage', 'LarsMomentum',
    'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer',
    'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer',
    'RecomputeOptimizer'
51
]
Q
Qiao Longfei 已提交
52 53 54 55 56 57


class Optimizer(object):
    """Optimizer Base class.

    Define the common interface of an optimizer.
58 59
    User should not use this class directly,
    but need to use one of it's implementation.
Q
Qiao Longfei 已提交
60 61
    """

62
    @imperative_base.no_grad
63 64 65 66 67
    def __init__(self,
                 learning_rate,
                 parameter_list=None,
                 regularization=None,
                 name=None):
68
        self._parameter_list = parameter_list
L
lujun 已提交
69
        if framework.in_dygraph_mode():
M
minqiyang 已提交
70 71 72 73 74
            if not isinstance(learning_rate, float) and \
                    not isinstance(learning_rate, LearningRateDecay):
                raise TypeError(
                    "learning rate should be float or LearningRateDecay, got %s here"
                    % type(learning_rate))
75 76 77 78
            if name is not None:
                self._name = unique_name.generate(name)
            else:
                self._name = unique_name.generate(self.__class__.__name__)
79
            if self._parameter_list is None:
80 81 82
                raise AttributeError(
                    "parameter_list argument given to the Optimizer should not be None in dygraph mode."
                )
M
minqiyang 已提交
83 84 85 86 87 88
        else:
            if not isinstance(learning_rate, float) and \
                    not isinstance(learning_rate, framework.Variable):
                raise TypeError(
                    "learning rate should be float or Variable, got %s here" %
                    type(learning_rate))
89
            self._name = name
M
minqiyang 已提交
90

D
dzhwinter 已提交
91
        self.regularization = regularization
92
        self._learning_rate = learning_rate
D
dzhwinter 已提交
93 94
        # the learning rate type should be inferenced from loss
        self._dtype = None
95
        # each program should have a independent learning rate
96
        # program -> Variable(learning_rate)
Q
qiaolongfei 已提交
97
        self._learning_rate_map = dict()
98
        if isinstance(self._learning_rate, framework.Variable):
99 100
            self._learning_rate_map[framework.default_main_program(
            )] = self._learning_rate
101 102 103 104 105
        # Dictionary of accumulators. Some optimizer subclasses need to
        # allocate and manage extra variables associated with the parameters
        # to train. These variables are called accumulators.
        # {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...}
        self._accumulators = defaultdict(lambda: dict())
Q
Qiao Longfei 已提交
106
        self.helper = None
107
        self._opti_name_list = []
H
hong 已提交
108
        self._accumulators_holder = {}
109
        self._param_device_map = dict()
110 111
        # if pass grad_clip into minimize, it will not be None
        self._grad_clip = None
H
hong 已提交
112 113 114 115

    @framework.dygraph_only
    def state_dict(self):
        '''
T
tianshuo78520a 已提交
116 117
        Get state dict information from optimizer. It contain all the variable used by optimizer. For Adam optimizer, contains beta1, beta2, momentum etc. If LearningRateDecay have been used, global_step will be include in state dict.
        If the optimizer never be called(minimize function), the state_dict is empty.
H
hong 已提交
118 119 120

        Args: None
        Return:
T
tianshuo78520a 已提交
121
            state_dict(dict) : dict contains all the variable used by optimizer
H
hong 已提交
122 123 124 125 126
        
        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
127 128 129 130 131 132

                with fluid.dygraph.guard():
                    emb = fluid.dygraph.Embedding([10, 10])

                    adam = fluid.optimizer.Adam(0.001, parameter_list=emb.parameters())
                    state_dict = adam.state_dict()
H
hong 已提交
133 134 135 136 137 138 139 140

        '''
        state_dict = {}
        for k, v in self._accumulators.items():
            for para_name, var_tmp in v.items():
                state_dict[var_tmp.name] = var_tmp
        # global step if use lr decay
        if isinstance(self._learning_rate, LearningRateDecay):
141
            var_tmp = None
L
Leo Chen 已提交
142
            if framework.in_dygraph_mode():
143 144
                var_temp = framework._varbase_creator(
                    None, name='global_step', dtype='int32')
L
Leo Chen 已提交
145 146
            else:
                var_temp = Variable(None, name='global_step', dtype='int32')
147

H
hong 已提交
148 149 150 151 152 153 154 155 156
            tensor.fill_constant(
                [1], "int32", self._learning_rate.step_num, out=var_temp)

            state_dict['global_step'] = var_temp
        return state_dict

    @framework.dygraph_only
    def set_dict(self, state_dict):
        '''
T
tianshuo78520a 已提交
157
        Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If LearningRateDecay have been used, global_step will be changed.
H
hong 已提交
158 159 160 161 162 163 164 165

        Args: 
            state_dict(dict) : Dict contains all the Variable needed by optimizer
        Return:
            None
        
        Examples:
            .. code-block:: python
166

H
hong 已提交
167
                with fluid.dygraph.guard():
168
                    emb = fluid.dygraph.Embedding([10, 10])
169

H
hong 已提交
170
                    state_dict = emb.state_dict()
171
                    fluid.save_dygraph(state_dict, "paddle_dy")
172

173 174
                    adam = fluid.optimizer.Adam(learning_rate=fluid.layers.noam_decay( 100, 10000), 
                                                parameter_list=emb.parameters())
H
hong 已提交
175
                    state_dict = adam.state_dict()
176
                    fluid.save_dygraph(state_dict, "paddle_dy")
177

H
hong 已提交
178
                    para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")
179

180
                    adam.set_dict(opti_state_dict)
H
hong 已提交
181 182 183 184 185 186 187 188 189

        '''

        if isinstance(self._learning_rate, LearningRateDecay):
            assert 'global_step' in state_dict, \
                    'Global step not in state dict, Dygraph use LearningRateDecay, global_step must in state_dict'
            global_step = state_dict['global_step']

            if isinstance(global_step, core.VarBase):
190
                step_np = global_step
H
hong 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
                step_np = np.array(step_np.value().get_tensor())
                assert step_np.shape == (1,),  \
                        "global step shape is (1,), the shape is {}".format( step_np.shape )

                self._learning_rate.step_num = int(step_np[0])
            elif isinstance(global_step, Variable):
                step_np = global_step.numpy()
                assert step_np.shape == (1,),  \
                        "global step shape is (1,), the shape is {}".format( step_np.shape )
                self._learning_rate.step_num = step_np[0]
            elif isinstance(global_step, np.ndarray):
                assert global_step.shape == (1,),  \
                        "global step shape is (1,), the shape is {}".format( global_step.shape )
                self._learning_rate.step_num = global_step[0]
            else:
                raise RuntimeError(
                    "Type not supprt, value in state dict must be [VarBase, Variable, numpy], the type is ",
                    type(global_step))

        self._accumulators_holder = state_dict
        for k, v in self._accumulators.items():
            for para_name, var_tmp in v.items():
                assert var_tmp.name in state_dict, \
                        "optimizer variable {} not found".format( var_tmp.name )
215
                var = var_tmp.value()
H
hong 已提交
216 217 218 219 220 221 222 223
                tensor = var.get_tensor()
                model_np = np.array(tensor)

                load_para = state_dict[var_tmp.name]

                if isinstance(load_para, Variable):
                    load_para_np = load_para.numpy()
                elif isinstance(load_para, core.VarBase):
224
                    load_para_np = load_para.numpy()
H
hong 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
                elif isinstance(load_para, np.ndarray):
                    load_para_np = load_para
                else:
                    raise RuntimeError("State dict type {} not supprt".format(
                        str(type(load_para))))

                assert model_np.shape == load_para_np.shape,  \
                                          "Parameter shape not match, Dygraph Parameter [ {} ] need tensor with shape {} but load tensor with shape {}".format(
                                                 item.name, model_np.shape, load_para_np.shape)

                assert model_np.dtype == load_para_np.dtype, \
                                          "Parameter dtype not match, Dygraph Parameter [ {} ] need tensor with dtype {}  but load tensor with dtype {}".format(
                                                item.name, model_np.dtype, load_para_np.dtype)

                tensor.set(load_para_np, framework._current_expected_place())
240

241 242
    def get_opti_var_name_list(self):
        return self._opti_name_list
Q
Qiao Longfei 已提交
243

Q
Qiao Longfei 已提交
244
    def _create_global_learning_rate(self):
245 246 247
        if imperative_base.enabled():
            # create learning rate Variable
            if isinstance(self._learning_rate, float):
M
minqiyang 已提交
248 249 250 251 252 253 254 255 256 257 258 259
                lr = self._global_learning_rate()

                if isinstance(lr, framework.Variable):
                    return
                else:
                    self._learning_rate_map[framework.default_main_program(
                    )] = layers.create_global_var(
                        name=unique_name.generate("learning_rate"),
                        shape=[1],
                        value=float(self._learning_rate),
                        dtype='float32' if self._dtype is None else self._dtype,
                        persistable=True)
260
            # get learning rate Variable from LearningRateDecay
M
minqiyang 已提交
261
            elif isinstance(self._learning_rate, LearningRateDecay):
262 263 264
                self._learning_rate_map[framework.default_main_program(
                )] = self._learning_rate()
            else:
Q
qiaolongfei 已提交
265
                raise TypeError(
266 267
                    "optimizer's learning rate must be float or LearningRateDecay"
                )
268
        else:
269 270 271 272
            lr = self._global_learning_rate()

            if isinstance(lr, framework.Variable):
                return
M
minqiyang 已提交
273 274 275 276 277 278
            else:
                if not isinstance(self._learning_rate, float):
                    raise TypeError(
                        "learning rate variable is create outside optimizer,"
                        "can not create new learning rate variable for new program"
                    )
Q
Qiao Longfei 已提交
279

280 281 282 283 284 285 286 287
            # create learning rate in the current main program
            self._learning_rate_map[framework.default_main_program(
            )] = layers.create_global_var(
                name=unique_name.generate("learning_rate"),
                shape=[1],
                value=float(self._learning_rate),
                dtype='float32' if self._dtype is None else self._dtype,
                persistable=True)
288

289 290 291 292
    @framework.dygraph_only
    def current_step_lr(self):
        """
        .. note::
T
tianshuo78520a 已提交
293
          **This API is ONLY available in Dygraph mode**
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
        
        Get current step learning rate. The return value is all the same When LearningRateDecay is not used,
        otherwise return the step learning rate.

        Returns:
            float: The learning rate of the current step.

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                # example1: LearningRateDecay is not used, return value is all the same
                with fluid.dygraph.guard():
                    emb = fluid.dygraph.Embedding([10, 10])
                    adam = fluid.optimizer.Adam(0.001, parameter_list = emb.parameters())
                    lr = adam.current_step_lr()
                    print(lr) # 0.001

                # example2: PiecewiseDecay is used, return the step learning rate
                with fluid.dygraph.guard():
                    inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
                    linear = fluid.dygraph.nn.Linear(10, 10)
                    inp = fluid.dygraph.to_variable(inp)
                    out = linear(inp)
                    loss = fluid.layers.reduce_mean(out)
                    
                    bd = [2, 4, 6, 8]
                    value = [0.2, 0.4, 0.6, 0.8, 1.0]
                    adam = fluid.optimizer.Adam(fluid.dygraph.PiecewiseDecay(bd, value, 0),
                                           parameter_list=linear.parameters())

                    # first step: learning rate is 0.2
                    np.allclose(adam.current_step_lr(), 0.2, rtol=1e-06, atol=0.0) # True

                    # learning rate for different steps
                    ret = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6, 0.8, 0.8, 1.0, 1.0, 1.0, 1.0]
                    for i in range(12):
                        adam.minimize(loss)
                        lr = adam.current_step_lr()
                        np.allclose(lr, ret[i], rtol=1e-06, atol=0.0) # True

        """
        current_lr = self._global_learning_rate()
        if current_lr:
            return self._global_learning_rate().numpy()[0]

        if isinstance(self._learning_rate, float):
            return self._learning_rate
        else:
            step_lr = self._learning_rate.step()
            if isinstance(step_lr, (float, int)):
                return step_lr
            else:
                return step_lr.numpy()[0]

Y
yuyang18 已提交
351
    def _global_learning_rate(self, program=None):
Q
Qiao Longfei 已提交
352 353 354 355
        """
        get global decayed learning rate
        :return:
        """
356 357
        if program is None:
            program = framework.default_main_program()
Q
qiaolongfei 已提交
358
        return self._learning_rate_map.get(program, None)
Q
Qiao Longfei 已提交
359

Q
Qiao Longfei 已提交
360 361 362 363 364
    def _append_optimize_op(self, block, param_and_grad):
        """ append optimize operator to block and return all the added optimize_op
        """
        raise NotImplementedError()

365 366 367 368
    def _create_param_lr(self, param_and_grad):
        # create learning rate variable for every parameter
        param = param_and_grad[0]
        param_lr = param.optimize_attr['learning_rate']
W
Wu Yi 已提交
369 370
        if type(param_lr) == Variable:
            return param_lr
Q
qiaolongfei 已提交
371
        else:
W
Wu Yi 已提交
372
            if param_lr == 1.0:
Y
yuyang18 已提交
373
                return self._global_learning_rate()
W
Wu Yi 已提交
374
            else:
X
Xin Pan 已提交
375 376 377
                with default_main_program()._lr_schedule_guard(
                        is_with_opt=True), framework.name_scope(
                            'scale_with_param_lr'):
378
                    return self._global_learning_rate() * param_lr
379 380 381 382 383 384 385

    def _create_accumulators(self, block, parameters):
        """Create all accumulators needed by the parameters

        Args:
            block: the block in which the loss variable is present
            parameters: list of parameter variables for the optimizer
Q
Qiao Longfei 已提交
386
        """
387 388
        pass

389
    def _finish_update(self, block, parameters_and_grads):
390 391 392 393 394 395 396 397
        """Finish any custom updates needed
           before completing an optimization step

        Args:
            block: the block in which the loss variable is present
            parameters: list of parameter variables for the optimizer

        Returns:
Q
qiaolongfei 已提交
398
            None
399 400 401
        """
        pass

402 403 404 405 406
    def _add_accumulator(self,
                         name,
                         param,
                         dtype=None,
                         fill_value=0.0,
407
                         shape=None,
408
                         type=None,
409
                         device=None):
410 411 412 413 414 415 416 417 418
        """Utility function to add an accumulator for a parameter

        Args:
            block: the block in which the loss variable is present
            name: name of the accumulator
            param: parameter variable for which accumulator is to be added
            dtype: data type of the accumulator variable
            fill_value: value to initialize the accumulator variable
        """
W
whs 已提交
419 420
        if self._name is not None:
            name = self._name + "_" + name
421 422
        if (name in self._accumulators and
                param.name in self._accumulators[name]):
L
lujun 已提交
423
            if framework.in_dygraph_mode():
X
polish  
Xin Pan 已提交
424
                return self._accumulators[name][param.name]
425
            raise Exception("Accumulator {} already exists for parameter {}".
426
                            format(name, param.name))
427 428
        if shape == None:
            shape = param.shape
Q
Qiao Longfei 已提交
429
        assert isinstance(self.helper, LayerHelper)
430 431 432 433 434

        var_name = param.name + "_" + name
        var_name = unique_name.generate(var_name)
        self._opti_name_list.append(var_name)

Q
Qiao Longfei 已提交
435
        var = self.helper.create_global_variable(
436
            name=var_name,
Q
Qiao Longfei 已提交
437
            persistable=True,
F
fengjiayi 已提交
438
            dtype=dtype or param.dtype,
439
            type=param.type if type is None else type,
H
hong 已提交
440 441
            shape=shape,
            belong_to_optimizer=True)
442 443 444 445 446
        if device is None:
            device = self._get_device_for_param(param.name)
        with device_guard(device):
            self.helper.set_variable_initializer(
                var, initializer=Constant(value=float(fill_value)))
H
hong 已提交
447 448 449 450 451 452 453

        if framework.in_dygraph_mode():
            if len(self._accumulators_holder) > 0:
                assert var_name in self._accumulators_holder, \
                        "Optimizer set error, {} should in state dict".format( var_name )
                var.set_value(self._accumulators_holder[var_name])

Q
Qiao Longfei 已提交
454
        self._accumulators[name][param.name] = var
455
        return var
456 457 458 459 460 461 462 463 464 465 466

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter

        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched

        Returns:
            accumulator variable for the parameter
        """
W
whs 已提交
467 468
        if self._name is not None:
            name = self._name + "_" + name
469 470 471 472 473 474
        if (name not in self._accumulators or
                param.name not in self._accumulators[name]):
            raise Exception("Accumulator {} does not exist for parameter {}".
                            format(name, param.name))
        return self._accumulators[name][param.name]

475 476 477 478 479 480 481 482 483 484 485 486
    def _update_param_device_map(self, parameters_and_grads, target_block):
        for param_and_grad in parameters_and_grads:
            if param_and_grad[0].trainable is True:
                param_name = param_and_grad[0].name
                ops = target_block.ops
                device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName(
                )
                for op in ops:
                    input_arg_names = op.input_arg_names
                    if param_name in input_arg_names:
                        self._param_device_map[param_name] = op.attr(
                            device_attr_name)
487
                        break
488 489 490 491 492 493 494

    def _get_device_for_param(self, param_name):
        device = None
        if param_name in self._param_device_map:
            device = self._param_device_map[param_name]
        return device

495
    def _create_optimization_pass(self, parameters_and_grads):
Q
Qiao Longfei 已提交
496 497 498
        """Add optimization operators to update gradients to variables.

        Args:
Q
qiaolongfei 已提交
499
          parameters_and_grads(list(tuple(Variable, Variable))):
500
            a list of (variable, gradient) pair to update.
Q
Qiao Longfei 已提交
501 502

        Returns:
503
          return_op_list: a list of operators that will complete one step of
504 505 506
            optimization. This will include parameter update ops, global step
            update ops and any other custom ops required by subclasses to manage
            their internal state.
Q
Qiao Longfei 已提交
507
        """
508 509 510 511 512
        # This is a default implementation of create_optimization_pass that
        # can be shared by most optimizers. This implementation assumes that
        # the subclass will implement the _append_optimize_op method and the
        #  _initialize_tensors method. The subclass can extend the
        # _create_accumulators method if it needs to create accumulators
513
        # for parameters and extend _finish_update method to add custom ops.
514

515
        # Allways called under program_guard use global block as loss block
516 517 518
        # But if current block is in control flow, append optimize op in the
        # grad block of current block

519
        global_block = framework.default_main_program().global_block()
520 521 522 523 524 525 526 527 528
        target_block = global_block
        current_block = framework.default_main_program().current_block()
        if current_block.idx != global_block.idx:
            assert current_block.backward_block_idx != -1, \
                "current block is not global_block, but it doesn't have backward block."
            target_block = framework.default_main_program().blocks[
                current_block.backward_block_idx]

        start = len(target_block.ops)
529
        self.helper = LayerHelper(self.__class__.__name__)
530
        self._update_param_device_map(parameters_and_grads, target_block)
C
chengduo 已提交
531
        self._create_accumulators(
532
            target_block,
C
chengduo 已提交
533
            [p[0] for p in parameters_and_grads if p[0].trainable])
534 535
        self._create_global_learning_rate()

M
minqiyang 已提交
536
        if framework.in_dygraph_mode():
537 538 539
            for param_and_grad in parameters_and_grads:
                if param_and_grad[1] is None:
                    continue
540 541
                if param_and_grad[0].trainable is True:
                    self._append_optimize_op(target_block, param_and_grad)
542 543 544 545 546 547 548
        else:
            for param_and_grad in parameters_and_grads:
                if param_and_grad[1] is None:
                    continue
                with param_and_grad[0].block.program._optimized_guard(
                        param_and_grad), name_scope("optimizer"):
                    if param_and_grad[0].trainable is True:
549 550 551 552 553
                        device = self._get_device_for_param(param_and_grad[0]
                                                            .name)
                        with device_guard(device):
                            optimize_op = self._append_optimize_op(
                                target_block, param_and_grad)
554 555 556

        # Get custom finish ops for subclasses
        # FIXME: Need to fix this once we figure out how to handle dependencies
557
        self._finish_update(target_block, parameters_and_grads)
558

559 560
        end = len(target_block.ops)
        return target_block._slice_ops(start, end)
561 562

    def _process_distribute_lookuptable(self, param_grads):
Q
Qiao Longfei 已提交
563 564 565 566 567 568 569 570 571
        """
        Because distribute lookup table only support SGD optimizer for now, not support
        other optimizer and regularization, so we should find the table parameter out,
        and avoid to add regularization and other op for it, and add sgd optimize op
        for it independently.
        :param param_grads(list((Var, Var))): list of (param, grad) pair.
        :param loss: the loss variable.
        :param startup_program: the startup program
        """
572 573
        program = framework.default_main_program()
        global_block = framework.default_main_program().global_block()
Q
Qiao Longfei 已提交
574 575 576 577 578 579 580 581 582 583 584 585 586 587 588
        table_name = find_distributed_lookup_table(program)
        table_param = None
        table_grad = None
        new_param_grads = []
        for p, g in param_grads:
            if p.name == table_name:
                if table_param is not None:
                    raise RuntimeError(
                        "multi dist table var found, only support one now!")
                table_param = p
                table_grad = g
            else:
                new_param_grads.append((p, g))
        sgd_op = None
        if table_param is not None:
589 590 591 592 593 594 595 596 597 598 599 600 601
            param_and_grad = [table_param, table_grad]
            with table_param.block.program._optimized_guard(param_and_grad), \
                    framework.name_scope("optimizer"):
                self._create_global_learning_rate()
                # create the optimize op
                sgd_op = global_block.append_op(
                    type='sgd',
                    inputs={
                        "Param": table_param,
                        "Grad": table_grad,
                        "LearningRate": self._create_param_lr(param_and_grad)
                    },
                    outputs={"ParamOut": param_and_grad[0]})
Q
Qiao Longfei 已提交
602 603
        return new_param_grads, (table_param, table_grad), sgd_op

604 605 606
    def _append_dgc_ops(self, param_and_grad):
        pass

607 608 609 610 611 612 613
    def backward(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None,
                 callbacks=None):
        """
614
        The first part of ``minimize``, do auto-diff to append backward operations for
615 616 617
        the current program.

        Args:
618 619 620 621
            loss (Variable): ``loss`` variable to run optimizations.
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
622
            parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
623 624
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
625
            no_grad_set (set, optional): Set of ``Variable``  or ``Variable.name`` that don't need
626 627 628
                to be updated. The default value is None.
            callbacks (list, optional): list of callable objects to run when appending backward
                operator for one parameter. The default value is None.
M
minqiyang 已提交
629

630
        Return:
631 632
            list: list of (param, grad) variable pairs, param is ``Parameter``,
                grad is the gradient value corresponding to the parameter.
M
minqiyang 已提交
633

634
        Examples:
635
            See examples in ``apply_gradients``.
636
        """
637
        act_no_grad_set = None
L
Leo Chen 已提交
638
        if framework.in_dygraph_mode():
639
            pass
L
Leo Chen 已提交
640 641
        else:
            act_no_grad_set = self._get_no_grad_set(loss, no_grad_set)
G
gongweibao 已提交
642

C
chengduo 已提交
643
        self._dtype = loss.dtype
L
lujun 已提交
644
        if framework.in_dygraph_mode():
C
chengduo 已提交
645
            params_grads = []
646
            for param in self._parameter_list:
C
chengduo 已提交
647 648
                if not param.trainable:
                    continue
649
                if param._grad_ivar() is not None:
C
chengduo 已提交
650
                    # create gradient variable
651
                    grad_var = param._grad_ivar()
C
chengduo 已提交
652
                    params_grads.append((param, grad_var))
653
        else:
C
chengduo 已提交
654 655 656 657 658
            if callbacks is None:
                callbacks = [error_clip_callback]
            else:
                assert (isinstance(callbacks, list))
            program = loss.block.program
C
chengduo 已提交
659 660 661 662
            assert len(loss.shape) == 1 and loss.shape[0] == 1, \
                "The loss.shape should be (1L,), but the current loss.shape is {}. " \
                "Maybe that you should call fluid.layers.mean to process the current loss.".format(
                    loss.shape)
663 664
            parameter_list = parameter_list if parameter_list \
                else self._parameter_list
C
chengduo 已提交
665 666
            with program_guard(program, startup_program):
                params_grads = append_backward(loss, parameter_list,
667
                                               act_no_grad_set, callbacks)
C
chengduo 已提交
668 669 670 671
                # Note: since we can't use all_reduce_op now,
                #  dgc_op should be the last op of one grad.
                self._append_dgc_ops(params_grads)
        return params_grads
672 673 674 675 676 677 678 679

    def apply_gradients(self, params_grads):
        """
        Second part of `minimize`, appending optimization operators for
        given `params_grads` pairs.

        Args:
            params_grads (list): list of (param, grad) pair to do optimization.
M
minqiyang 已提交
680

681 682
        Returns:
            list: A list of operators appended to the current program.
M
minqiyang 已提交
683

684 685 686
        Examples:
            .. code-block:: python

687
                import paddle.fluid as fluid
688 689 690 691 692 693 694
                loss = network()
                optimizer = fluid.optimizer.SGD(learning_rate=0.1)
                params_grads = optimizer.backward(loss)
                # you may append operations for params_grads here
                # ...
                optimizer.apply_gradients(params_grads)
        """
695

696 697 698 699 700
        params_grads = sorted(params_grads, key=lambda x: x[0].name)

        params_grads, table_param_and_grad, table_optimize_op = \
            self._process_distribute_lookuptable(params_grads)

701 702 703 704 705
        # 'minimize(grad_clip)' or 'set_gradient_clip'
        if self._grad_clip is not None:
            params_grads = self._grad_clip(params_grads)
        else:
            params_grads = append_gradient_clip_ops(params_grads)
706 707 708 709 710 711 712 713 714 715 716 717

        # Add regularization if any
        params_grads = append_regularization_ops(params_grads,
                                                 self.regularization)

        optimize_ops = self._create_optimization_pass(params_grads)
        if table_optimize_op is not None:
            optimize_ops.append(table_optimize_op)
            params_grads.append(table_param_and_grad)

        return optimize_ops

C
chengduo 已提交
718 719 720 721 722 723 724 725 726 727 728 729
    def apply_optimize(self, loss, startup_program, params_grads):
        """
        Second part of `minimize`, appending optimization operators for
        given `params_grads` pairs.
        Args:
            loss (Variable): loss variable to run optimizations.
            startup_program (Program): startup_program for initializing parameters
                in `parameter_list`.
            params_grads (list): list of (param, grad) pair to do optimization.
        Returns:
            list: A list of operators appended to the current program.
        """
L
lujun 已提交
730
        if framework.in_dygraph_mode():
C
chengduo 已提交
731 732
            with program_guard(framework.default_main_program(),
                               framework.default_startup_program()):
733 734
                if self._grad_clip is not None:
                    params_grads = self._grad_clip(params_grads)
735 736
                params_grads = append_regularization_ops(params_grads,
                                                         self.regularization)
C
chengduo 已提交
737 738 739 740 741 742 743
                optimize_ops = self._create_optimization_pass(params_grads)
        else:
            program = loss.block.program
            with program_guard(program, startup_program):
                optimize_ops = self.apply_gradients(params_grads)
        return optimize_ops

G
gongweibao 已提交
744
    def _get_no_grad_set(self, loss, no_grad_set=None):
745
        no_grad_set = _get_no_grad_set_name(no_grad_set)
G
gongweibao 已提交
746 747 748 749 750 751 752 753
        parameters = loss.block.program.global_block().all_parameters()
        param_no_trainable = set(
            [param.name for param in parameters if param.trainable is False])
        # If the parameter is no trainable, it should not have a gradient.
        no_grad_set.update(param_no_trainable)

        return no_grad_set

754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784
    @framework.dygraph_only
    def clear_gradients(self):
        """
        Clear the gradients of all optimized parameters for model.
        
        Returns:
            None
        
        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                with fluid.dygraph.guard():
                    value = np.arange(26).reshape(2, 13).astype("float32")
                    a = fluid.dygraph.to_variable(value)
                    linear = fluid.Linear(13, 5, dtype="float32")
                    # This can be any optimizer supported by dygraph.
                    adam = fluid.optimizer.Adam(learning_rate = 0.01, 
                                                parameter_list = linear.parameters())
                    out = linear(a)
                    out.backward()
                    adam.minimize(out)
                    adam.clear_gradients()

        """
        for p in self._parameter_list:
            if p.trainable:
                p.clear_gradient()

785
    @imperative_base.no_grad
Q
Qiao Longfei 已提交
786 787
    def minimize(self,
                 loss,
788
                 startup_program=None,
Q
Qiao Longfei 已提交
789
                 parameter_list=None,
790 791
                 no_grad_set=None,
                 grad_clip=None):
792
        """
793
        Add operations to minimize ``loss`` by updating ``parameter_list``.
M
minqiyang 已提交
794

795
        Args:
796 797 798 799
            loss (Variable): A ``Variable`` containing the value to minimize.
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
800
            parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
801 802
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
803
            no_grad_set (set, optional): Set of ``Variable``  or ``Variable.name`` that don't need
804 805 806 807 808 809
                to be updated. The default value is None.   
            grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of 
                some derived class of ``GradientClipBase`` . There are three cliping strategies 
                ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , 
                :ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no 
                gradient clipping.
Q
Qiao Longfei 已提交
810

811
        Returns:
812 813 814
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
            by minimize and a list of (param, grad) variable pairs, param is
            ``Parameter``, grad is the gradient value corresponding to the parameter.
815 816 817
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to 
            indicate program pruning. If so, the program will be pruned by ``feed`` and 
            ``fetch_list`` before run, see details in ``Executor``.
818 819 820

        Examples:
            Please refer to the example of current Optimizer.
Q
Qiao Longfei 已提交
821
        """
C
chengduo 已提交
822
        assert isinstance(loss, Variable), "The loss should be an Variable."
823 824 825 826 827 828
        if grad_clip is not None:
            if not isinstance(grad_clip, GradientClipBase):
                raise TypeError(
                    "'grad_clip' should be an instance of GradientClipBase's derived class"
                )
            self._grad_clip = grad_clip
829 830
        parameter_list = parameter_list if parameter_list \
            else self._parameter_list
C
chengduo 已提交
831 832 833 834 835
        params_grads = self.backward(
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
            no_grad_set=no_grad_set)
836

C
chengduo 已提交
837 838
        optimize_ops = self.apply_optimize(
            loss, startup_program=startup_program, params_grads=params_grads)
M
minqiyang 已提交
839

Q
Qiao Longfei 已提交
840
        return optimize_ops, params_grads
Q
Qiao Longfei 已提交
841 842 843


class SGDOptimizer(Optimizer):
Q
qiaolongfei 已提交
844 845 846 847 848 849 850
    """
    Optimizer of the stochastic gradient descent algorithm.

    .. math::

        param\_out = param - learning\_rate * grad

851 852 853
    Parameters:
        learning_rate (float|Variable): The learning rate used to update parameters. \
            Can be a float value or a Variable with one float value as data element.
854 855 856
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
857 858 859 860
        regularization: A Regularizer, such as :ref:`api_fluid_regularizer_L2DecayRegularizer`. \
            Optional, default is None.
        name (str, optional): This parameter is used by developers to print debugging information. \
            For details, please refer to :ref:`api_guide_Name`. Default is None.
Q
qiaolongfei 已提交
861 862 863 864

    Examples:
        .. code-block:: python

865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889
            import paddle
            import paddle.fluid as fluid
            import numpy as np

            place = fluid.CPUPlace()
            main = fluid.Program()
            with fluid.program_guard(main):
                x = fluid.layers.data(name='x', shape=[13], dtype='float32')
                y = fluid.layers.data(name='y', shape=[1], dtype='float32')
                y_predict = fluid.layers.fc(input=x, size=1, act=None)
                cost = fluid.layers.square_error_cost(input=y_predict, label=y)
                avg_cost = fluid.layers.mean(cost)

                sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
                sgd_optimizer.minimize(avg_cost)

                fetch_list = [avg_cost]
                train_reader = paddle.batch(
                    paddle.dataset.uci_housing.train(), batch_size=1)
                feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
                exe = fluid.Executor(place)
                exe.run(fluid.default_startup_program())
                for data in train_reader():
                    exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)

Q
Qiao Longfei 已提交
890 891
    """

892 893 894 895 896
    def __init__(self,
                 learning_rate,
                 parameter_list=None,
                 regularization=None,
                 name=None):
Q
Qiao Longfei 已提交
897
        assert learning_rate is not None
Q
Qiao Longfei 已提交
898
        super(SGDOptimizer, self).__init__(
X
Xin Pan 已提交
899
            learning_rate=learning_rate,
900
            parameter_list=parameter_list,
X
Xin Pan 已提交
901 902
            regularization=regularization,
            name=name)
Q
Qiao Longfei 已提交
903 904
        self.type = "sgd"

905
    @no_grad
906
    def _append_optimize_op(self, block, param_and_grad):
907
        lr = self._create_param_lr(param_and_grad)
908
        if framework.in_dygraph_mode():
909 910 911
            core.ops.sgd(param_and_grad[0], lr, param_and_grad[1],
                         param_and_grad[0])
            return None
912

913
        assert isinstance(block, framework.Block)
Q
Qiao Longfei 已提交
914 915 916 917 918 919
        # create the optimize op
        sgd_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
920
                "LearningRate": lr
Q
Qiao Longfei 已提交
921
            },
M
minqiyang 已提交
922 923
            outputs={"ParamOut": param_and_grad[0]},
            stop_gradient=True)
Q
Qiao Longfei 已提交
924 925

        return sgd_op
926 927 928


class MomentumOptimizer(Optimizer):
Q
qiaolongfei 已提交
929 930 931 932 933 934 935 936 937 938 939 940 941 942
    """

    Simple Momentum optimizer with velocity state

    This optimizer has a flag for Nestrov Momentum.

    The update equations are as follows:

    .. math::

        & velocity = mu * velocity + gradient

        & if (use\_nesterov):

943
        &\quad   param = param - (gradient + mu * velocity) * learning\_rate
Q
qiaolongfei 已提交
944 945 946

        & else:

Q
qiaolongfei 已提交
947
        &\quad   param = param - learning\_rate * velocity
Q
qiaolongfei 已提交
948

949 950 951 952
    Parameters:
        learning_rate (float|Variable): The learning rate used to update parameters. \
            Can be a float value or a Variable with one float value as data element.
        momentum (float): Momentum factor
953 954 955
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
956 957 958 959 960
        use_nesterov (bool, optional): Enables Nesterov momentum, default is false.
        regularization: A Regularizer, such as :ref:`api_fluid_regularizer_L2DecayRegularizer`. \
            Optional, default is None.
        name (str, optional): This parameter is used by developers to print debugging information. \
            For details, please refer to :ref:`api_guide_Name`. Default is None.
Q
qiaolongfei 已提交
961 962 963 964

    Examples:
        .. code-block:: python

965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989
            import paddle
            import paddle.fluid as fluid
            import numpy as np

            place = fluid.CPUPlace()
            main = fluid.Program()
            with fluid.program_guard(main):
                x = fluid.layers.data(name='x', shape=[13], dtype='float32')
                y = fluid.layers.data(name='y', shape=[1], dtype='float32')
                y_predict = fluid.layers.fc(input=x, size=1, act=None)
                cost = fluid.layers.square_error_cost(input=y_predict, label=y)
                avg_cost = fluid.layers.mean(cost)

                moment_optimizer = fluid.optimizer.MomentumOptimizer(learning_rate=0.001, momentum=0.9)
                moment_optimizer.minimize(avg_cost)

                fetch_list = [avg_cost]
                train_reader = paddle.batch(
                    paddle.dataset.uci_housing.train(), batch_size=1)
                feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
                exe = fluid.Executor(place)
                exe.run(fluid.default_startup_program())
                for data in train_reader():
                    exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)

990 991 992
    """
    _velocity_acc_str = "velocity"

X
Xin Pan 已提交
993 994 995
    def __init__(self,
                 learning_rate,
                 momentum,
996
                 parameter_list=None,
X
Xin Pan 已提交
997 998 999
                 use_nesterov=False,
                 regularization=None,
                 name=None):
1000 1001
        assert learning_rate is not None
        assert momentum is not None
Q
Qiao Longfei 已提交
1002
        super(MomentumOptimizer, self).__init__(
X
Xin Pan 已提交
1003
            learning_rate=learning_rate,
1004
            parameter_list=parameter_list,
X
Xin Pan 已提交
1005 1006
            regularization=regularization,
            name=name)
1007 1008
        self.type = "momentum"
        self._momentum = momentum
1009
        self._use_nesterov = bool(use_nesterov)
1010 1011 1012 1013 1014

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        for p in parameters:
Q
Qiao Longfei 已提交
1015
            self._add_accumulator(self._velocity_acc_str, p)
1016 1017 1018 1019 1020 1021

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        velocity_acc = self._get_accumulator(self._velocity_acc_str,
                                             param_and_grad[0])
1022 1023 1024 1025 1026 1027 1028 1029
        lr = self._create_param_lr(param_and_grad)

        if framework.in_dygraph_mode():
            _, _ = core.ops.momentum(param_and_grad[0], param_and_grad[1],
                                     velocity_acc, lr, param_and_grad[0],
                                     velocity_acc, 'mu', self._momentum,
                                     'use_nesterov', self._use_nesterov)
            return None
1030

1031
        attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov}
1032 1033 1034 1035
        inputs = {
            "Param": [param_and_grad[0]],
            "Grad": [param_and_grad[1]],
            "Velocity": [velocity_acc],
1036
            "LearningRate": [lr]
1037 1038 1039 1040 1041 1042
        }

        outputs = {
            "ParamOut": [param_and_grad[0]],
            "VelocityOut": [velocity_acc]
        }
1043 1044 1045
        # create the momentum optimize op
        momentum_op = block.append_op(
            type=self.type,
1046 1047 1048
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
M
minqiyang 已提交
1049
            stop_gradient=True)
1050 1051

        return momentum_op
1052 1053


1054
class DGCMomentumOptimizer(Optimizer):
1055
    """
1056
    DGC (Deep Gradient Compression) Momentum Optimizer. Original paper is https://arxiv.org/abs/1712.01887
1057

G
gongweibao 已提交
1058
    DGC reduces the communication bandwidth by sending only the important gradients (sparse update):\
1059 1060
        only gradients larger than a threshold are transmitted.

G
gongweibao 已提交
1061
    To avoid losing information, DGC accumulates the rest of the gradients locally.
1062 1063 1064

    Eventually, these gradients become large enough to be transmitted.

1065
    Thus, DGC sends the large gradients immediately but eventually sends all of the gradients over time.
1066

G
gongweibao 已提交
1067
    To ensure no loss of accuracy, DGC employs momentum correction and local gradient clipping on top of the gradient sparsification to maintain model performance.
1068 1069 1070 1071

    DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication.

    This optimizer will do two things:
1072

1073 1074
        1. Compress the gradient by get TopK import value from tensor \
            and use it for allreduce to reduce network bandwidth.
1075

1076
        2. Call momentum to optimize the cost.
1077 1078

    Args:
1079 1080
        learning_rate (float|Variable): The learning rate used to update parameters. \
            It can be a float value or a Variable with one float value as a data element.
1081
        momentum (float): Momentum factor.
G
gongweibao 已提交
1082
        rampup_begin_step (int): The beginning step from which gradient compression is implemented.
1083 1084 1085 1086 1087 1088 1089
        rampup_step (int): Time steps used in sparsity warm-up periods. Default is 1.
            For example, if the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 100, \
                it will use 0.75 at 0~19 steps, and 0.9375 at 20~39 steps, and so on. \
                And when reach sparsity array ends, it will use 0.999 then and after.
        sparsity (list[float]): Get top important element from gradient tensor, the ratio is (1 - current sparsity). \
            Default is [0.999]. For example, if the sparsity is [0.99, 0.999], \
                the top [1%, 0.1%] important element will be transmitted.
1090 1091 1092
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
1093 1094 1095 1096 1097 1098 1099
        use_nesterov (bool): Enables Nesterov momentum. True means use Nesterov. Default is False.
        local_grad_clip_norm (float, optional): Local gradient clip norm value. Optional, default is None, represent no need clip.
        num_trainers (int, optional): The number of training nodes. Optional, default is None.
        regularization (WeightDecayRegularizer, optional): A Regularizer, such as \
            :ref:`api_fluid_regularizer_L2DecayRegularizer`. Optional, default is None.
        name (str, optional): This parameter is used by developers to print debugging information. \
            For details, please refer to :ref:`api_guide_Name`. Default is None.
1100 1101 1102 1103

    Examples:
        .. code-block:: python

1104
            import paddle.fluid as fluid
1105
            optimizer = fluid.optimizer.DGCMomentumOptimizer(
G
gongweibao 已提交
1106 1107 1108 1109 1110
                        learning_rate=0.0001,
                        momentum=0.9,
                        rampup_step=1000,
                        rampup_begin_step=1252,
                        sparsity=[0.999, 0.999])
1111 1112

    """
1113 1114
    _u_velocity_acc_str = "_dgc_u_"
    _v_velocity_acc_str = "_dgc_v_"
1115 1116 1117 1118 1119 1120 1121

    def __init__(self,
                 learning_rate,
                 momentum,
                 rampup_begin_step,
                 rampup_step=1,
                 sparsity=[0.999],
1122
                 parameter_list=None,
1123 1124 1125 1126 1127
                 use_nesterov=False,
                 local_grad_clip_norm=None,
                 num_trainers=None,
                 regularization=None,
                 name=None):
Z
zhongpu 已提交
1128 1129
        if framework.in_dygraph_mode():
            raise Exception("In dygraph, don't support DGCMomentumOptimizer.")
1130 1131 1132 1133

        assert core.is_compiled_with_cuda(), \
            "Paddle is not compiled with CUDA. DGC is only support GPU for now."

1134 1135 1136 1137
        assert learning_rate is not None
        assert momentum is not None
        super(DGCMomentumOptimizer, self).__init__(
            learning_rate=learning_rate,
1138
            parameter_list=parameter_list,
1139 1140 1141 1142 1143
            regularization=regularization,
            name=name)
        self.type = "dgc_momentum"
        self._momentum = momentum
        self._use_nesterov = bool(use_nesterov)
1144

1145
        assert rampup_begin_step >= 0, "rampup_begin_step must >= 0"
1146
        self._rampup_begin_step = rampup_begin_step
1147 1148
        self._rampup_step = rampup_step
        self._sparsity = sparsity
1149

1150
        self._rampup_begin_step_var = None
1151
        self._global_step_var = None
1152

1153 1154 1155 1156 1157 1158 1159 1160 1161
        self._local_grad_clip_norm = None
        self._clip_norm = None
        if local_grad_clip_norm is not None:
            assert isinstance(num_trainers, int)
            assert isinstance(local_grad_clip_norm, float)
            assert num_trainers > 0

            self._local_grad_clip_norm = local_grad_clip_norm
            self._num_trainers = num_trainers
1162
            self._clip_norm = local_grad_clip_norm * (num_trainers**-0.5)
1163

1164 1165
        self.regular_type, self.regular_coeff = self._get_regularization_param(
            self.regularization)
1166
        self._grad_clip = None
1167

1168 1169 1170
    def _get_regularization_param(self, regularization):
        regular_type = 0
        regular_coeff = 0.0
1171

1172 1173
        if regularization is not None:
            regular_coeff = regularization._regularization_coeff
1174
            from .regularizer import L1Decay, L2Decay
1175 1176 1177 1178
            if isinstance(regularization, L1Decay):
                regular_type = 1
            elif isinstance(regularization, L2Decay):
                regular_type = 2
1179 1180
            else:
                assert False, 'regularization must be None|L1Decay|L2Deacy'
1181
        return regular_type, regular_coeff
1182

1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193
    def _is_use_dgc(self, param_var, grad_var):
        var_numel = abs(reduce(lambda x, y: x * y, param_var.shape))
        if var_numel < 16384 or \
           param_var.type == core.VarDesc.VarType.SELECTED_ROWS  or \
           grad_var.type == core.VarDesc.VarType.SELECTED_ROWS  or  \
               param_var.dtype != core.VarDesc.VarType.FP32 :
            return False
        return True

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208
        velocity_acc = self._get_accumulator(self._u_velocity_acc_str,
                                             param_and_grad[0])
        assert velocity_acc is not None

        inputs = {
            "Param": param_and_grad[0],
            "Grad": param_and_grad[1],
            "Velocity": velocity_acc,
            "LearningRate": self._create_param_lr(param_and_grad),
        }
        outputs = {
            "ParamOut": param_and_grad[0],
            "VelocityOut": velocity_acc,
        }
        attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov}
1209 1210

        if not self._is_use_dgc(param_and_grad[0], param_and_grad[1]):
1211 1212 1213
            type = "momentum"
        else:
            type = "dgc_momentum"
1214 1215 1216 1217 1218
            inputs.update({
                "current_step": self._global_step_var,
                "nranks": self._nranks_var
            })
            outputs.update({'Grad_out': param_and_grad[1]})
1219
            attrs.update({"rampup_begin_step": float(self._rampup_begin_step)})
1220 1221 1222

        # create the dgc momentum optimize op
        dgc_momentum_op = block.append_op(
1223 1224 1225 1226
            type=type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
1227 1228 1229
            stop_gradient=True)
        return dgc_momentum_op

1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248
    def _add_auto_increment_var(self, counter_name, begin, step=1):
        helper = LayerHelper('global_step_counter')
        counter, is_new_var = helper.create_or_get_global_variable(
            name=counter_name, dtype='float32', shape=[1], persistable=True)
        if is_new_var:
            helper.set_variable_initializer(
                counter,
                initializer=Constant(
                    value=float(begin - 1), force_cpu=True))
            helper.main_program.global_block()._prepend_op(
                type='increment',
                inputs={'X': [counter]},
                outputs={'Out': [counter]},
                attrs={'step': float(step)},
                stop_gradient=True)
            counter.stop_gradient = True

        return counter

1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261
    def _add_nranks_var(self, name, value=-1):
        helper = LayerHelper('global_step_counter')
        counter, is_new_var = helper.create_or_get_global_variable(
            name=name, dtype='float32', shape=[1], persistable=True)
        if is_new_var:
            helper.set_variable_initializer(
                counter,
                initializer=Constant(
                    value=float(value), force_cpu=True))
            counter.stop_gradient = True

        return counter

1262 1263 1264 1265 1266 1267
    def _append_dgc_ops(self, param_and_grads):
        main_program = default_main_program()
        main_program._enable_dgc = True

        # step counter
        self._global_step_var = self._add_auto_increment_var(
G
gongweibao 已提交
1268
            counter_name=core.dgc.kDGCCounterName(), begin=0)
1269

1270 1271 1272
        self._nranks_var = self._add_nranks_var(
            name=core.dgc.kDGCNRanksName(), value=-1)

1273 1274 1275 1276 1277
        # rampup begin step var for all_reduce_op_handle
        self._rampup_begin_step_var = tensor.create_global_var(
            shape=[1],
            dtype=core.VarDesc.VarType.FP32,
            persistable=True,
G
gongweibao 已提交
1278
            name=core.dgc.kDGCRampUpBeginStepName(),
1279 1280 1281
            value=self._rampup_begin_step * 1.0,
            force_cpu=True)

1282 1283
        self.helper = LayerHelper(self.__class__.__name__)

1284
        for param_var, grad_var in param_and_grads:
1285 1286 1287
            # reuse velocity in dgc_op and dgc_momentum_op
            u_var = self._add_accumulator(self._u_velocity_acc_str, param_var)

1288
            if not self._is_use_dgc(param_var, grad_var):
1289 1290
                continue

1291
            v_var = self._add_accumulator(self._v_velocity_acc_str, param_var)
1292 1293 1294 1295 1296

            k_var = tensor.create_global_var(
                shape=[1],
                dtype=param_var.dtype,
                persistable=True,
G
gongweibao 已提交
1297
                name=param_var.name + core.dgc.kDGCKName(),
1298 1299 1300 1301 1302 1303 1304
                value=0.0,
                force_cpu=True)

            encoded_var = tensor.create_global_var(
                shape=[1],
                dtype=param_var.dtype,
                persistable=True,
G
gongweibao 已提交
1305
                name=param_var.name + core.dgc.kDGCEncodedName(),
1306 1307 1308
                value=0.0,
                force_cpu=False)

1309 1310 1311 1312 1313 1314 1315 1316
            gather_var = tensor.create_global_var(
                shape=[1],
                dtype=param_var.dtype,
                persistable=True,
                name=param_var.name + core.dgc.kDGCGatherName(),
                value=0.0,
                force_cpu=False)

1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338
            # del back oprolevarname
            op_maker = core.op_proto_and_checker_maker
            backward = core.op_proto_and_checker_maker.OpRole.Backward
            for op in main_program.global_block().ops:
                if not self._is_the_backward_op(op):
                    continue

                var_attr = op.all_attrs()[op_maker.kOpRoleVarAttrName()]
                if param_var.name not in var_attr:
                    continue

                var_attr.remove(param_var.name)
                var_attr.remove(grad_var.name)
                if len(var_attr) > 1:
                    op._set_attr(op_maker.kOpRoleVarAttrName(), var_attr)
                else:
                    op._remove_attr(op_maker.kOpRoleVarAttrName())

            clip_var = grad_var
            if self._local_grad_clip_norm is not None:
                clip_var = self._append_clip_norm(grad_var, self._clip_norm)
            self._dgc_op(param_var, clip_var, grad_var, u_var, v_var, k_var,
1339
                         encoded_var, gather_var)
1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354

    def _is_the_backward_op(self, op):
        op_maker = core.op_proto_and_checker_maker
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        if op_maker.kOpRoleVarAttrName() in op.attr_names and \
                int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(backward):
            return True
        return False

    def _clip_by_norm(self, x, max_norm, name=None):
        args = {'x': x, 'max_norm': max_norm, 'name': name}

        helper = LayerHelper("dgc_clip_by_norm_op", **args)

        if name is None:
1355 1356
            name = unique_name.generate_with_ignorable_key(".".join(
                [helper.name, 'tmp']))
1357 1358 1359 1360 1361

        out = helper.create_variable(
            type=x.type, name=name, dtype=x.dtype, persistable=False)

        helper.append_op(
G
gongweibao 已提交
1362
            type="dgc_clip_by_norm",
1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374
            inputs={"X": x,
                    "current_step": self._global_step_var},
            attrs={
                "max_norm": max_norm,
                "rampup_begin_step": float(self._rampup_begin_step)
            },
            outputs={"Out": out})
        return out

    def _append_clip_norm(self, grad_var, clip_norm):
        with grad_var.block.program._backward_role_guard():
            return self._clip_by_norm(
G
gongweibao 已提交
1375
                x=grad_var, max_norm=clip_norm, name=grad_var.name)
1376 1377

    def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
1378
                encoded_var, gather_var):
1379 1380
        block = framework.default_main_program().global_block()
        op_maker = core.op_proto_and_checker_maker
1381

1382 1383 1384 1385 1386 1387 1388
        regular_type = self.regular_type
        regular_coeff = self.regular_coeff
        # The regularizer of the Parameters have higher priority
        if param_var.regularizer is not None:
            regular_type, regular_coeff = self._get_regularization_param(
                param_var.regularizer)

1389 1390 1391 1392 1393 1394
        dgc_op = block.append_op(
            type="dgc",
            inputs={
                "U": u_var,
                "V": v_var,
                "Grad": clip_var,
1395
                "Param": param_var,
1396 1397
                "current_step": self._global_step_var,
                "nranks": self._nranks_var,
1398 1399 1400 1401 1402 1403
            },
            outputs={
                "U_out": u_var,
                "V_out": v_var,
                "EncodeGrad": encoded_var,
                "k": k_var,
1404 1405
                "Grad_out": grad_var,
                "GatherBuff": gather_var,
1406 1407 1408 1409 1410 1411
            },
            attrs={
                "m": self._momentum,
                "sparsity": self._sparsity,
                "use_nesterov": self._use_nesterov,
                "rampup_begin_step": float(self._rampup_begin_step),
1412
                "rampup_step": float(self._rampup_step),
1413 1414
                "regular_coeff": float(regular_coeff),
                "regular_type": int(regular_type),
1415 1416 1417 1418 1419 1420 1421 1422
            },
            stop_gradient=True)

        backward = op_maker.OpRole.Backward
        dgc_op._set_attr(op_maker.kOpRoleAttrName(), backward)
        dgc_op._set_attr(op_maker.kOpRoleVarAttrName(),
                         [param_var.name, grad_var.name])

1423
    @imperative_base.no_grad
1424 1425 1426 1427 1428 1429 1430
    def apply_gradients(self, params_grads):
        params_grads = sorted(params_grads, key=lambda x: x[0].name)
        params_grads, table_param_and_grad, table_optimize_op = \
            self._process_distribute_lookuptable(params_grads)

        not_dgc_params_grads = []
        dgc_params_grads = []
1431
        # DGC clip and regularization in optimizer.backward
1432 1433 1434 1435 1436 1437
        for param, grad in params_grads:
            if not self._is_use_dgc(param, grad):
                not_dgc_params_grads.append((param, grad))
            else:
                dgc_params_grads.append((param, grad))

1438 1439 1440 1441 1442 1443
        # 'minimize(grad_clip)' or 'set_gradient_clip'
        if self._grad_clip is not None:
            not_dgc_params_grads = self._grad_clip(not_dgc_params_grads)
        else:
            not_dgc_params_grads = append_gradient_clip_ops(
                not_dgc_params_grads)
1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457

        not_dgc_params_grads = append_regularization_ops(not_dgc_params_grads,
                                                         self.regularization)

        params_grads = not_dgc_params_grads + dgc_params_grads
        params_grads = sorted(params_grads, key=lambda x: x[0].name)

        optimize_ops = self._create_optimization_pass(params_grads)
        if table_optimize_op is not None:
            optimize_ops.append(table_optimize_op)
            params_grads.append(table_param_and_grad)

        return optimize_ops

1458

1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473
class LarsMomentumOptimizer(Optimizer):
    """
    Momentum optimizer with LARS support

    The update equations are as follows:

    .. math::

        & local\_learning\_rate = learning\_rate * lars\_coeff * \\
          \\frac{||param||}{||gradient|| + lars\_weight\_decay * ||param||}

        & velocity = mu * velocity + local\_learning\_rate * (gradient + lars\_weight\_decay * param)

        & param = param - velocity

1474 1475 1476 1477 1478 1479
    Parameters:
        learning_rate (float|Variable): The learning rate used to update parameters. \
            Can be a float value or a Variable with one float value as data element. \
            momentum (float): momentum factor
        lars_coeff (float): Defines how much we trust the layer to change its weights.
        lars_weight_decay (float): Weight decay coefficient for decaying using LARS.
1480 1481 1482
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
1483 1484 1485 1486
        regularization: A Regularizer, such as :ref:`api_fluid_regularizer_L2DecayRegularizer`.
            Optional, default is None.
        name (str, optional): This parameter is used by developers to print debugging information. \
            For details, please refer to :ref:`api_guide_Name`. Default is None.
1487 1488 1489 1490

    Examples:
        .. code-block:: python

1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506
            import paddle.fluid as fluid
            import numpy as np

            np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
            inp = fluid.layers.data(
                name="inp", shape=[2, 2], append_batch_size=False)
            out = fluid.layers.fc(inp, size=3)
            out = fluid.layers.reduce_sum(out)
            optimizer = fluid.optimizer.LarsMomentumOptimizer(learning_rate=0.001, momentum=0.9)
            optimizer.minimize(out)

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())
            exe.run(
                feed={"inp": np_inp},
                fetch_list=[out.name])
1507 1508 1509 1510 1511 1512 1513 1514
    """
    _velocity_acc_str = "velocity"

    def __init__(self,
                 learning_rate,
                 momentum,
                 lars_coeff=0.001,
                 lars_weight_decay=0.0005,
1515
                 parameter_list=None,
1516 1517 1518 1519 1520 1521
                 regularization=None,
                 name=None):
        assert learning_rate is not None
        assert momentum is not None
        super(LarsMomentumOptimizer, self).__init__(
            learning_rate=learning_rate,
1522
            parameter_list=parameter_list,
1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557
            regularization=regularization,
            name=name)
        self.type = "lars_momentum"
        self._momentum = momentum
        self._lars_coeff = float(lars_coeff)
        self._lars_weight_decay = float(lars_weight_decay)

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        for p in parameters:
            self._add_accumulator(self._velocity_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        velocity_acc = self._get_accumulator(self._velocity_acc_str,
                                             param_and_grad[0])
        # create the momentum optimize op
        momentum_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "Velocity": velocity_acc,
                "LearningRate": self._create_param_lr(param_and_grad)
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "VelocityOut": velocity_acc
            },
            attrs={
                "mu": self._momentum,
                "lars_coeff": self._lars_coeff,
                "lars_weight_decay": self._lars_weight_decay
M
minqiyang 已提交
1558 1559
            },
            stop_gradient=True)
1560 1561 1562 1563

        return momentum_op


1564
class AdagradOptimizer(Optimizer):
Q
qiaolongfei 已提交
1565
    """
1566 1567
    The Adaptive Gradient optimizer (Adagrad for short) can adaptively assign
    different learning rates to individual parameters.
Q
qiaolongfei 已提交
1568

1569
    The parameter ``param_out`` update rule with gradient ``grad``:
Q
qiaolongfei 已提交
1570 1571 1572 1573 1574 1575 1576

    .. math::

        moment\_out &= moment + grad * grad

        param\_out &= param - \\frac{learning\_rate * grad}{\sqrt{moment\_out} + \epsilon}

1577 1578 1579 1580 1581 1582
    Related paper: `Adaptive Subgradient Methods for Online Learning and
    Stochastic Optimization <http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf>`_.

    The original paper does not have the ``epsilon`` attribute. It is added here
    in our implementation as also proposed `Per-parameter adaptive learning rate
    methods <http://cs231n.github.io/neural-networks-3/#ada>`_
Q
qiaolongfei 已提交
1583 1584 1585
    for numerical stability to avoid the division by zero error.

    Args:
1586 1587 1588 1589
        learning_rate (float|Variable): The learning rate used to update ``Parameter``.
            It can be a float value or a ``Variable`` with a float type.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-06.
1590 1591 1592
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
1593 1594 1595 1596 1597 1598 1599
        regularization (WeightDecayRegularizer, optional): A ``Regularizer``, such as
             :ref:`api_fluid_regularizer_L2DecayRegularizer`. The default value is None.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
        initial_accumulator_value (float, optional): Initial value for moment accumulator.
            The default value is 0.0.
Q
qiaolongfei 已提交
1600 1601 1602 1603

    Examples:
        .. code-block:: python

1604
            import numpy as np
1605
            import paddle.fluid as fluid
1606 1607

            np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
1608
            inp = fluid.data(name="inp", shape=[2, 2])
1609 1610
            out = fluid.layers.fc(inp, size=3)
            out = fluid.layers.reduce_sum(out)
1611
            optimizer = fluid.optimizer.AdagradOptimizer(learning_rate=0.2)
1612 1613 1614 1615 1616 1617 1618
            optimizer.minimize(out)

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())
            exe.run(
                feed={"inp": np_inp},
                fetch_list=[out.name])
1619 1620 1621
    """
    _moment_acc_str = "moment"

X
Xin Pan 已提交
1622 1623 1624
    def __init__(self,
                 learning_rate,
                 epsilon=1.0e-6,
1625
                 parameter_list=None,
X
Xin Pan 已提交
1626
                 regularization=None,
1627
                 name=None,
X
xuezhong 已提交
1628
                 initial_accumulator_value=0.0):
1629 1630
        assert learning_rate is not None
        assert epsilon is not None
Q
Qiao Longfei 已提交
1631
        super(AdagradOptimizer, self).__init__(
X
Xin Pan 已提交
1632
            learning_rate=learning_rate,
1633
            parameter_list=parameter_list,
X
Xin Pan 已提交
1634 1635
            regularization=regularization,
            name=name)
1636 1637
        self.type = "adagrad"
        self._epsilon = epsilon
1638
        self.initial_accumulator_value = initial_accumulator_value
1639 1640 1641 1642 1643

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        for p in parameters:
Z
zhongpu 已提交
1644 1645 1646 1647
            self._add_accumulator(
                self._moment_acc_str,
                p,
                fill_value=self.initial_accumulator_value)
1648 1649 1650 1651 1652 1653

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        moment_acc = self._get_accumulator(self._moment_acc_str,
                                           param_and_grad[0])
1654
        # Create the adagrad optimizer op
1655 1656 1657 1658 1659 1660
        adagrad_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "Moment": moment_acc,
1661
                "LearningRate": self._create_param_lr(param_and_grad)
1662 1663 1664
            },
            outputs={"ParamOut": param_and_grad[0],
                     "MomentOut": moment_acc},
M
minqiyang 已提交
1665 1666
            attrs={"epsilon": self._epsilon},
            stop_gradient=True)
1667 1668

        return adagrad_op
1669 1670 1671


class AdamOptimizer(Optimizer):
Q
qiaolongfei 已提交
1672
    """
T
tianshuo78520a 已提交
1673
    The Adam optimizer uses an optimization described at the end
1674 1675 1676 1677 1678
    of section 2 of `Adam paper <https://arxiv.org/abs/1412.6980>`_ ,
    it can dynamically adjusts the learning rate of each parameter using
    the 1st moment estimates and the 2nd moment estimates of the gradient.
    
    The parameter ``param_out`` update rule with gradient ``grad``:
Q
qiaolongfei 已提交
1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692

    .. math::

        t & = t + 1

        moment\_1\_out & = {\\beta}_1 * moment\_1 + (1 - {\\beta}_1) * grad

        moment\_2\_out & = {\\beta}_2 * moment\_2 + (1 - {\\beta}_2) * grad * grad

        learning\_rate & = learning\_rate * \\
                          \\frac{\sqrt{1 - {\\beta}_2^t}}{1 - {\\beta}_1^t}

        param\_out & = param - learning\_rate * \\frac{moment\_1}{\sqrt{moment\_2} + \epsilon}

1693 1694
    Related paper: `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_

Q
qiaolongfei 已提交
1695
    Args:
1696 1697
        learning_rate (float|Variable, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a ``Variable`` with a float type. The default value is 0.001.
1698 1699
        beta1 (float|Variable, optional): The exponential decay rate for the 1st moment estimates.
            It should be a float number or a Variable with shape [1] and data type as float32.
1700
            The default value is 0.9.
1701 1702
        beta2 (float|Variable, optional): The exponential decay rate for the 2nd moment estimates.
            It should be a float number or a Variable with shape [1] and data type as float32.
1703 1704 1705
            The default value is 0.999.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-08.
1706 1707 1708
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720
        regularization (WeightDecayRegularizer, optional): A ``Regularizer``, such as
             :ref:`api_fluid_regularizer_L2DecayRegularizer`. The default value is None.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
        lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
            The accumulators are updated at every step. Every element of the two moving-average
            is updated in both dense mode and sparse mode. If the size of parameter is very large,
            then the update may be very slow. The lazy mode only update the element that has
            gradient in current mini-batch, so it will be much more faster. But this mode has
            different semantics with the original Adam algorithm and may lead to different result.
            The default value is False.
Q
qiaolongfei 已提交
1721 1722 1723 1724

    Examples:
        .. code-block:: python

1725 1726 1727 1728 1729 1730
            import paddle
            import paddle.fluid as fluid

            place = fluid.CPUPlace()
            main = fluid.Program()
            with fluid.program_guard(main):
1731 1732
                x = fluid.data(name='x', shape=[None, 13], dtype='float32')
                y = fluid.data(name='y', shape=[None, 1], dtype='float32')
1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747
                y_predict = fluid.layers.fc(input=x, size=1, act=None)
                cost = fluid.layers.square_error_cost(input=y_predict, label=y)
                avg_cost = fluid.layers.mean(cost)

                adam_optimizer = fluid.optimizer.AdamOptimizer(0.01)
                adam_optimizer.minimize(avg_cost)

                fetch_list = [avg_cost]
                train_reader = paddle.batch(
                    paddle.dataset.uci_housing.train(), batch_size=1)
                feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
                exe = fluid.Executor(place)
                exe.run(fluid.default_startup_program())
                for data in train_reader():
                    exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
Q
qiaolongfei 已提交
1748

1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765
        .. code-block:: python

            # Adam with beta1/beta2 as Variable
            import paddle
            import paddle.fluid as fluid
            import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler

            place = fluid.CPUPlace()
            main = fluid.Program()
            with fluid.program_guard(main):
                x = fluid.data(name='x', shape=[None, 13], dtype='float32')
                y = fluid.data(name='y', shape=[None, 1], dtype='float32')
                y_predict = fluid.layers.fc(input=x, size=1, act=None)
                cost = fluid.layers.square_error_cost(input=y_predict, label=y)
                avg_cost = fluid.layers.mean(cost)

                # define beta decay variable
1766
                def get_decayed_betas(beta1_init, beta2_init, decay_steps, decay_rate):
1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794
                    global_step = lr_scheduler._decay_step_counter()

                    beta1 = fluid.layers.create_global_var(
                        shape=[1],
                        value=float(beta1_init),
                        dtype='float32',
                        # set persistable for save checkpoints and resume
                        persistable=True,
                        name="beta1")
                    beta2 = fluid.layers.create_global_var(
                        shape=[1],
                        value=float(beta2_init),
                        dtype='float32',
                        # set persistable for save checkpoints and resume
                        persistable=True,
                        name="beta2")

                    div_res = global_step / decay_steps
                    decayed_beta1 = beta1_init * (decay_rate**div_res)
                    decayed_beta2 = beta2_init * (decay_rate**div_res)
                    fluid.layers.assign(decayed_beta1, beta1)
                    fluid.layers.assign(decayed_beta2, beta2)

                    return beta1, beta2

                beta1, beta2 = get_decayed_betas(0.9, 0.99, 1e5, 0.9)
                adam_optimizer = fluid.optimizer.AdamOptimizer(
                                                    learning_rate=0.01,
1795
                                                    beta1=beta1,
1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806
                                                    beta2=beta2)
                adam_optimizer.minimize(avg_cost)

                fetch_list = [avg_cost]
                train_reader = paddle.batch(
                    paddle.dataset.uci_housing.train(), batch_size=1)
                feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
                exe = fluid.Executor(place)
                exe.run(fluid.default_startup_program())
                for data in train_reader():
                    exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
1807 1808 1809
    """
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
Q
qiaolongfei 已提交
1810 1811
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"
1812 1813 1814 1815 1816

    def __init__(self,
                 learning_rate=0.001,
                 beta1=0.9,
                 beta2=0.999,
1817
                 epsilon=1e-8,
1818
                 parameter_list=None,
X
Xin Pan 已提交
1819
                 regularization=None,
Q
Qiao Longfei 已提交
1820
                 name=None,
Q
Qiao Longfei 已提交
1821
                 lazy_mode=False):
1822 1823 1824 1825
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
Q
Qiao Longfei 已提交
1826
        super(AdamOptimizer, self).__init__(
X
Xin Pan 已提交
1827
            learning_rate=learning_rate,
1828
            parameter_list=parameter_list,
X
Xin Pan 已提交
1829 1830
            regularization=regularization,
            name=name)
1831 1832 1833 1834
        self.type = "adam"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
Q
Qiao Longfei 已提交
1835
        self._lazy_mode = lazy_mode
1836 1837 1838 1839 1840 1841

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        # Create accumulator tensors for first and second moments
        for p in parameters:
Q
Qiao Longfei 已提交
1842 1843
            self._add_accumulator(self._moment1_acc_str, p)
            self._add_accumulator(self._moment2_acc_str, p)
Q
qiaolongfei 已提交
1844 1845 1846
            self._add_accumulator(
                name=self._beta1_pow_acc_str,
                param=p,
1847 1848
                fill_value=0.9 if isinstance(self._beta1, Variable) \
                        else self._beta1,
1849
                shape=[1],
1850
                type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
Q
qiaolongfei 已提交
1851 1852 1853
            self._add_accumulator(
                name=self._beta2_pow_acc_str,
                param=p,
1854 1855
                fill_value=0.999 if isinstance(self._beta2, Variable) \
                        else self._beta2,
1856
                shape=[1],
1857
                type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
1858 1859 1860 1861 1862 1863 1864 1865

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        moment1 = self._get_accumulator(self._moment1_acc_str,
                                        param_and_grad[0])
        moment2 = self._get_accumulator(self._moment2_acc_str,
                                        param_and_grad[0])
Q
qiaolongfei 已提交
1866 1867 1868 1869
        beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                              param_and_grad[0])
        beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
                                              param_and_grad[0])
1870
        lr = self._create_param_lr(param_and_grad)
1871
        # create the adam optimize op
1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886

        if framework.in_dygraph_mode():
            _beta1 = self._beta1 if not isinstance(
                self._beta1, Variable) else self._beta1.numpy().item(0)
            _beta2 = self._beta2 if not isinstance(
                self._beta2, Variable) else self._beta2.numpy().item(0)
            _, _, _, _, _ = core.ops.adam(
                param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
                beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
                moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon,
                'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread',
                1000, 'beta1', _beta1, 'beta2', _beta2)

            return None

1887
        inputs = {
1888 1889
            "Param": [param_and_grad[0]],
            "Grad": [param_and_grad[1]],
1890
            "LearningRate": [lr],
1891 1892 1893 1894
            "Moment1": [moment1],
            "Moment2": [moment2],
            "Beta1Pow": [beta1_pow_acc],
            "Beta2Pow": [beta2_pow_acc]
1895 1896
        }
        outputs = {
1897 1898 1899 1900 1901
            "ParamOut": [param_and_grad[0]],
            "Moment1Out": [moment1],
            "Moment2Out": [moment2],
            "Beta1PowOut": [beta1_pow_acc],
            "Beta2PowOut": [beta2_pow_acc],
1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917
        }
        attrs = {
            "epsilon": self._epsilon,
            "lazy_mode": self._lazy_mode,
            "min_row_size_to_use_multithread": 1000
        }

        if isinstance(self._beta1, Variable):
            inputs['Beta1Tensor'] = self._beta1
        else:
            attrs['beta1'] = self._beta1
        if isinstance(self._beta2, Variable):
            inputs['Beta2Tensor'] = self._beta2
        else:
            attrs['beta2'] = self._beta2

1918 1919
        adam_op = block.append_op(
            type=self.type,
1920 1921 1922
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
M
minqiyang 已提交
1923
            stop_gradient=True)
1924 1925 1926

        return adam_op

1927 1928

class AdamaxOptimizer(Optimizer):
Q
qiaolongfei 已提交
1929
    """
1930 1931 1932 1933
    The Adamax optimizer is implemented based on the Adamax Optimization 
    in Section 7 of `Adam paper <https://arxiv.org/abs/1412.6980>`_.
    The Adamax algorithm is a variant of the Adam algorithm based on the infinite norm,
    which makes the learning rate update algorithm more stable and simple.
Q
qiaolongfei 已提交
1934

1935
    The parameter ``param_out`` update rule with gradient ``grad``:
Q
qiaolongfei 已提交
1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948

    .. math::

        t & = t + 1

        moment\_out & = {\\beta}_1 * moment + (1 - {\\beta}_1) * grad

        inf\_norm\_out & = max({\\beta}_2 * inf\_norm + \epsilon, |grad|)

        learning\_rate & = \\frac{learning\_rate}{1 - {\\beta}_1^t}

        param\_out & = param - learning\_rate * \\frac{moment\_out}{inf\_norm\_out}

1949
    Related paper: `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_
Q
qiaolongfei 已提交
1950

1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962
    The original paper does not have an ``epsilon`` attribute,
    it is added here for numerical stability to prevent the division by 0 error.

    Args:
        learning_rate (float|Variable, optional): The learning rate used to update ``Parameter``.
            It can be a float value or a ``Variable`` with a float type. The default value is 0.001.
        beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
            The default value is 0.9.
        beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
            The default value is 0.999.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-08.
1963 1964 1965
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
1966 1967 1968 1969 1970 1971 1972 1973
        regularization (WeightDecayRegularizer, optional): A ``Regularizer``, such as
             :ref:`api_fluid_regularizer_L2DecayRegularizer`. The default value is None.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.

    **Notes**:
        **Currently, AdamaxOptimizer doesn't support sparse parameter optimization.**
Q
qiaolongfei 已提交
1974

1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          import numpy

          # First create the Executor.
          place = fluid.CPUPlace() # fluid.CUDAPlace(0)
          exe = fluid.Executor(place)

          train_program = fluid.Program()
          startup_program = fluid.Program()
          with fluid.program_guard(train_program, startup_program):
1988
              data = fluid.data(name='X', shape=[None, 1], dtype='float32')
1989 1990
              hidden = fluid.layers.fc(input=data, size=10)
              loss = fluid.layers.mean(hidden)
1991
              adam = fluid.optimizer.AdamaxOptimizer(learning_rate=0.2)
1992 1993 1994 1995 1996 1997 1998 1999 2000
              adam.minimize(loss)

          # Run the startup program once and only once.
          exe.run(startup_program)

          x = numpy.random.random(size=(10, 1)).astype('float32')
          outs = exe.run(program=train_program,
                        feed={'X': x},
                         fetch_list=[loss.name])
2001 2002 2003
    """
    _moment_acc_str = "moment"
    _inf_norm_acc_str = "inf_norm"
Q
qiaolongfei 已提交
2004
    _beta1_pow_acc_str = "beta1_pow_acc"
2005 2006 2007 2008 2009

    def __init__(self,
                 learning_rate=0.001,
                 beta1=0.9,
                 beta2=0.999,
2010
                 epsilon=1e-8,
2011
                 parameter_list=None,
X
Xin Pan 已提交
2012 2013
                 regularization=None,
                 name=None):
2014 2015 2016 2017
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
Q
Qiao Longfei 已提交
2018
        super(AdamaxOptimizer, self).__init__(
X
Xin Pan 已提交
2019
            learning_rate=learning_rate,
2020
            parameter_list=parameter_list,
X
Xin Pan 已提交
2021 2022
            regularization=regularization,
            name=name)
2023 2024 2025 2026 2027 2028 2029 2030
        self.type = "adamax"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon

    def _create_accumulators(self, block, parameters):
        # Create accumulator tensors for first moment and infinity norm
        for p in parameters:
Q
Qiao Longfei 已提交
2031 2032
            self._add_accumulator(self._moment_acc_str, p)
            self._add_accumulator(self._inf_norm_acc_str, p)
Q
qiaolongfei 已提交
2033 2034 2035 2036 2037
            self._add_accumulator(
                name=self._beta1_pow_acc_str,
                param=p,
                fill_value=self._beta1,
                shape=[1])
2038 2039 2040 2041 2042 2043 2044

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0])
        inf_norm = self._get_accumulator(self._inf_norm_acc_str,
                                         param_and_grad[0])
Q
qiaolongfei 已提交
2045 2046
        beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                              param_and_grad[0])
2047 2048 2049 2050 2051 2052
        # create the adamax optimize op
        adamax_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
2053
                "LearningRate": self._create_param_lr(param_and_grad),
2054 2055
                "Moment": moment,
                "InfNorm": inf_norm,
Q
qiaolongfei 已提交
2056
                "Beta1Pow": beta1_pow_acc
2057 2058 2059 2060 2061 2062 2063 2064 2065 2066
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "MomentOut": moment,
                "InfNormOut": inf_norm
            },
            attrs={
                "beta1": self._beta1,
                "beta2": self._beta2,
                "epsilon": self._epsilon
M
minqiyang 已提交
2067 2068
            },
            stop_gradient=True)
2069 2070 2071

        return adamax_op

2072
    def _finish_update(self, block, parameters_and_grads):
2073 2074 2075
        """Update Beta1 Power accumulator
        """
        assert isinstance(block, framework.Block)
2076
        for param, grad in parameters_and_grads:
C
chengduo 已提交
2077
            if grad is None or param.trainable is False:
2078
                continue
X
Xin Pan 已提交
2079 2080
            with param.block.program._optimized_guard(
                [param, grad]), name_scope('adamx'):
2081 2082
                beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                                      param)
2083
                block.append_op(
2084 2085 2086
                    type="scale",
                    inputs={"X": beta1_pow_acc},
                    outputs={"Out": beta1_pow_acc},
M
minqiyang 已提交
2087 2088
                    attrs={"scale": self._beta1},
                    stop_gradient=True)
2089 2090


2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128
class DpsgdOptimizer(Optimizer):
    """
    We implement the Dpsgd optimizer according to CCS16 paper -
    Deep Learning with Differential Privacy.

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          import numpy

          # First create the Executor.
          place = fluid.CPUPlace() # fluid.CUDAPlace(0)
          exe = fluid.Executor(place)

          train_program = fluid.Program()
          startup_program = fluid.Program()
          with fluid.program_guard(train_program, startup_program):
              data = fluid.layers.data(name='X', shape=[1], dtype='float32')
              hidden = fluid.layers.fc(input=data, size=10)
              loss = fluid.layers.mean(hidden)
              optimizer = fluid.optimizer.Dpsgd(learning_rate=0.01, clip=10.0, batch_size=16.0, sigma=1.0)
              optimizer.minimize(loss)

          # Run the startup program once and only once.
          exe.run(startup_program)

          x = numpy.random.random(size=(10, 1)).astype('float32')
          outs = exe.run(program=train_program,
                        feed={'X': x},
                         fetch_list=[loss.name])

    Args:
        learning_rate (float|Variable): the learning rate used to update parameters. \
        Can be a float value or a Variable with one float value as data element.
        clip (float): clipping threshold
        batch_size (float): batch size.
        sigma (float): for gaussian noise.
2129 2130 2131
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
2132 2133 2134 2135 2136 2137 2138 2139
    Notes:
       Currently, DpsgdOptimizer doesn't support sparse parameter optimization.
    """

    def __init__(self,
                 learning_rate=0.001,
                 clip=0.9,
                 batch_size=0.999,
2140 2141
                 sigma=1e-8,
                 parameter_list=None):
2142 2143 2144 2145
        assert learning_rate is not None
        assert clip is not None
        assert batch_size is not None
        assert sigma is not None
2146 2147
        super(DpsgdOptimizer, self).__init__(
            learning_rate=learning_rate, parameter_list=parameter_list)
2148 2149 2150 2151
        self.type = "dpsgd"
        self._clip = clip
        self._batch_size = batch_size
        self._sigma = sigma
Z
zhongpu 已提交
2152 2153 2154 2155 2156 2157 2158
        '''
        Note(wangzhongpu):
        This property is only used for debugging, do not need to set it!
        Dpsgd operator use time(NULL) as random seed to generate random number.
        However, during debugging, we need determinated result, so we will set self._seed to a fixed number.
        '''
        self._seed = None
2159 2160 2161 2162 2163

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        # create the dpsgd optimize op
Z
zhongpu 已提交
2164 2165 2166
        if self._seed == None:
            self._seed = 0

2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177
        dpsgd_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "LearningRate": self._create_param_lr(param_and_grad)
            },
            outputs={"ParamOut": param_and_grad[0]},
            attrs={
                "clip": self._clip,
                "batch_size": self._batch_size,
Z
zhongpu 已提交
2178 2179
                "sigma": self._sigma,
                "seed": self._seed
2180 2181 2182 2183 2184 2185
            },
            stop_gradient=True)

        return dpsgd_op


2186
class DecayedAdagradOptimizer(Optimizer):
2187
    """
2188 2189 2190
    The Decayed Adagrad optimizer can be seen as an Adagrad algorithm that introduces
    the decay rate to solve the problem of a sharp drop in the learning rate
    during model training when using the AdagradOptimizer.
2191

2192
    The parameter ``param_out`` update rule with gradient ``grad``:
2193 2194 2195 2196 2197 2198 2199

    .. math::

        moment\_out & = decay * moment + (1 - decay) * grad * grad

        param\_out & = param - \\frac{learning\_rate * grad}{\sqrt{moment\_out} + \epsilon}

2200 2201 2202 2203
    Related paper: `Adaptive Subgradient Methods for Online Learning and Stochastic
    Optimization <http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf>`_.

    The original paper does not have an ``epsilon`` attribute. It is added here for numerical
2204 2205 2206
    stability to avoid the division by zero error.

    Args:
2207 2208 2209 2210 2211
        learning_rate (float|Variable): The learning rate used to update ``Parameter``.
            It can be a float value or a ``Variable`` with a float type.
        decay (float, optional): The decay rate. The default value is 0.95.
        epsilon (float, optional): A small float value for numerical stability.
            The default value is 1e-06.
2212 2213 2214
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
2215 2216 2217 2218 2219 2220 2221 2222
        regularization (WeightDecayRegularizer, optional): A ``Regularizer``, such as
             :ref:`api_fluid_regularizer_L2DecayRegularizer`. The default value is None.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.

    **Notes**:
        **Currently, DecayedAdagradOptimizer doesn't support sparse parameter optimization.**
2223 2224 2225 2226

    Examples:
        .. code-block:: python

2227 2228
            import paddle.fluid as fluid

2229 2230 2231 2232
            x = fluid.data( name='x', shape=[None, 10], dtype='float32' )
            trans = fluid.layers.fc( x, 100 )
            cost = fluid.layers.reduce_mean( trans )
            optimizer = fluid.optimizer.DecayedAdagradOptimizer(learning_rate=0.2)
2233
            optimizer.minimize(cost)
2234 2235 2236
    """
    _moment_acc_str = "moment"

X
Xin Pan 已提交
2237 2238 2239 2240
    def __init__(self,
                 learning_rate,
                 decay=0.95,
                 epsilon=1.0e-6,
2241
                 parameter_list=None,
X
Xin Pan 已提交
2242 2243
                 regularization=None,
                 name=None):
2244 2245 2246 2247
        assert learning_rate is not None
        assert decay is not None
        assert epsilon is not None

Q
Qiao Longfei 已提交
2248
        super(DecayedAdagradOptimizer, self).__init__(
X
Xin Pan 已提交
2249
            learning_rate=learning_rate,
2250
            parameter_list=parameter_list,
X
Xin Pan 已提交
2251 2252
            regularization=regularization,
            name=name)
2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279
        self.type = "decayed_adagrad"
        self._decay = decay
        self._epsilon = epsilon

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        for p in parameters:
            self._add_accumulator(self._moment_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        moment_acc = self._get_accumulator(self._moment_acc_str,
                                           param_and_grad[0])

        # Create the decayed adagrad optimizer op
        decayed_adagrad_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "Moment": moment_acc,
                "LearningRate": self._create_param_lr(param_and_grad)
            },
            outputs={"ParamOut": param_and_grad[0],
                     "MomentOut": moment_acc},
2280 2281
            attrs={"epsilon": self._epsilon,
                   "decay": self._decay},
M
minqiyang 已提交
2282
            stop_gradient=True)
2283 2284

        return decayed_adagrad_op
2285 2286


2287
class AdadeltaOptimizer(Optimizer):
2288
    """
Z
Zeng Jinle 已提交
2289
    **Notes: This API does not support sparse parameter optimization.**
Q
qiaolongfei 已提交
2290

Z
Zeng Jinle 已提交
2291
    Adadelta Optimizer. Please refer to this for details:
Z
Zeng Jinle 已提交
2292 2293 2294
    `ADADELTA: AN ADAPTIVE LEARNING RATE METHOD <https://arxiv.org/abs/1212.5701>`_.

    The update is done as follows:
2295

Z
Zeng Jinle 已提交
2296 2297
    .. math::

Z
Zeng Jinle 已提交
2298
        E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2
2299

Z
Zeng Jinle 已提交
2300
        learning\_rate &= \sqrt{ ( E(dx_{t-1}^2) + \\epsilon ) / ( E(g_t^2) + \\epsilon ) }
Z
Zeng Jinle 已提交
2301

Z
Zeng Jinle 已提交
2302
        E(dx_t^2) &= \\rho * E(dx_{t-1}^2) + (1-\\rho) * (-g*learning\_rate)^2
2303 2304

    Args:
Z
Zeng Jinle 已提交
2305 2306 2307
        learning_rate (float|Variable): global learning rate.
        epsilon (float): a small float number for numeric stability. Default 1.0e-6.
        rho (float): a floating point value indicating the decay rate. Default 0.95.
2308 2309 2310
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
Z
Zeng Jinle 已提交
2311 2312 2313
        regularization (WeightDecayRegularizer, optional): A Regularizer, such as
                fluid.regularizer.L2DecayRegularizer. Default None, meaning that there is no
                regularization.
2314 2315 2316
        name (str, optional): The default value is None. Normally there is no need for user
                to set this property. For more information, please refer to
                :ref:`api_guide_Name` .
2317 2318 2319 2320

    Examples:
        .. code-block:: python

2321
            import paddle.fluid as fluid
Z
Zeng Jinle 已提交
2322

2323
            image = fluid.data(name='image', shape=[None, 28], dtype='float32')
Z
Zeng Jinle 已提交
2324 2325
            fc = fluid.layers.fc(image, size=10)
            cost = fluid.layers.reduce_mean(fc)
2326 2327
            optimizer = fluid.optimizer.Adadelta(
                learning_rate=0.0003, epsilon=1.0e-6, rho=0.95)
C
chengduo 已提交
2328

Z
Zeng Jinle 已提交
2329 2330 2331 2332
            # optimizer_ops is a list of optimizer operators to update parameters
            # params_grads is a list of (param, param_grad), where param is each
            # parameter and param_grad is the gradient variable of param.
            optimizer_ops, params_grads = optimizer.minimize(cost)
2333
    """
2334

2335 2336 2337
    _avg_squared_grad_acc_str = "_avg_squared_grad"
    _avg_squared_update_acc_str = "_avg_squared_update"

X
Xin Pan 已提交
2338 2339 2340 2341
    def __init__(self,
                 learning_rate,
                 epsilon=1.0e-6,
                 rho=0.95,
2342
                 parameter_list=None,
X
Xin Pan 已提交
2343 2344
                 regularization=None,
                 name=None):
2345 2346 2347 2348 2349 2350
        if learning_rate is None:
            raise ValueError("learning_rate is not set.")
        if epsilon is None:
            raise ValueError("epsilon is not set.")
        if rho is None:
            raise ValueError("rho is not set.")
2351
        super(AdadeltaOptimizer, self).__init__(
X
Xin Pan 已提交
2352
            learning_rate=learning_rate,
2353
            parameter_list=parameter_list,
X
Xin Pan 已提交
2354 2355
            regularization=regularization,
            name=name)
2356 2357 2358 2359 2360
        self.type = "adadelta"
        self._epsilon = epsilon
        self._rho = rho

    def _create_accumulators(self, block, parameters):
2361 2362
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")
2363 2364 2365 2366 2367 2368

        for p in parameters:
            self._add_accumulator(self._avg_squared_grad_acc_str, p)
            self._add_accumulator(self._avg_squared_update_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
2369 2370
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")
2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391

        avg_squared_grad_acc = self._get_accumulator(
            self._avg_squared_grad_acc_str, param_and_grad[0])
        avg_squared_update_acc = self._get_accumulator(
            self._avg_squared_update_acc_str, param_and_grad[0])

        # Create the adadelta optimizer op
        adadelta_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "AvgSquaredGrad": avg_squared_grad_acc,
                "AvgSquaredUpdate": avg_squared_update_acc
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "AvgSquaredGradOut": avg_squared_grad_acc,
                "AvgSquaredUpdateOut": avg_squared_update_acc
            },
            attrs={"epsilon": self._epsilon,
M
minqiyang 已提交
2392 2393
                   "rho": self._rho},
            stop_gradient=True)
2394 2395 2396 2397

        return adadelta_op


Q
qingqing01 已提交
2398 2399 2400 2401 2402 2403 2404 2405 2406 2407
class RMSPropOptimizer(Optimizer):
    """
    Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning
    rate method. The original slides proposed RMSProp: Slide 29 of
    http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf .

    The original equation is as follows:

    ..  math::

Q
qiaolongfei 已提交
2408
        r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2
Q
qingqing01 已提交
2409 2410 2411 2412

        w & = w - \\frac{\\eta} {\\sqrt{r(w,t) + \\epsilon}} \\nabla Q_{i}(w)

    The first equation calculates moving average of the squared gradient for
Q
qiaolongfei 已提交
2413
    each weight. Then dividing the gradient by :math:`sqrt{v(w,t)}`.
Q
qingqing01 已提交
2414 2415 2416 2417 2418 2419

    In some cases, adding a momentum term :math: `\\beta` is beneficial.
    In our implementation, Nesterov momentum is used:

    ..  math::

Q
qiaolongfei 已提交
2420
        r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2
Q
qingqing01 已提交
2421

2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435
        v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{r(w,t) +
            \\epsilon}} \\nabla Q_{i}(w)

        w & = w - v(w, t)

    if centered is True:

    ..  math::

        r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2

        g(w, t) & = \\rho g(w, t-1) + (1 - \\rho)\\nabla Q_{i}(w)

        v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{r(w,t) - (g(w, t))^2 +
Q
qingqing01 已提交
2436 2437 2438 2439
            \\epsilon}} \\nabla Q_{i}(w)

        w & = w - v(w, t)

Q
qiaolongfei 已提交
2440
    where, :math:`\\rho` is a hyperparameter and typical values are 0.9, 0.95
Q
qingqing01 已提交
2441 2442 2443 2444 2445
    and so on. :math: `beta` is the momentum term. :math: `\\epsilon` is a
    smoothing term to avoid division by zero, usually set somewhere in range
    from 1e-4 to 1e-8.


2446 2447 2448
    Parameters:
        learning_rate(float): Global learning rate.
        rho(float): rho is :math: `\\rho` in equation, default is 0.95.
Q
qingqing01 已提交
2449
        epsilon(float): :math: `\\epsilon` in equation is smoothing term to
2450
            avoid division by zero, default is 1e-6.
Q
qiaolongfei 已提交
2451
        momentum(float): :math:`\\beta` in equation is the momentum term,
2452
            default is 0.0.
2453 2454 2455 2456
        centered(bool): If True, gradients are normalized by the estimated variance of
            the gradient; if False, by the uncentered second moment. Setting this to
            True may help with training, but is slightly more expensive in terms of
            computation and memory. Defaults to False.
2457 2458 2459
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
2460 2461 2462 2463
        regularization: A Regularizer, such as :ref:`api_fluid_regularizer_L2DecayRegularizer`. \
            Optional, default is None.
        name (str, optional): This parameter is used by developers to print debugging information. \
            For details, please refer to :ref:`api_guide_Name`. Default is None.
Q
qingqing01 已提交
2464 2465 2466 2467 2468 2469 2470

    Raises:
        ValueError: If learning_rate, rho, epsilon, momentum are None.

    Examples:
          .. code-block:: python

2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495
            import paddle
            import paddle.fluid as fluid
            import numpy as np

            place = fluid.CPUPlace()
            main = fluid.Program()
            with fluid.program_guard(main):
                x = fluid.layers.data(name='x', shape=[13], dtype='float32')
                y = fluid.layers.data(name='y', shape=[1], dtype='float32')
                y_predict = fluid.layers.fc(input=x, size=1, act=None)
                cost = fluid.layers.square_error_cost(input=y_predict, label=y)
                avg_cost = fluid.layers.mean(cost)

                rms_optimizer = fluid.optimizer.RMSProp(learning_rate=0.1)
                rms_optimizer.minimize(avg_cost)

                fetch_list = [avg_cost]
                train_reader = paddle.batch(
                    paddle.dataset.uci_housing.train(), batch_size=1)
                feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
                exe = fluid.Executor(place)
                exe.run(fluid.default_startup_program())
                for data in train_reader():
                    exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)

Q
qingqing01 已提交
2496 2497 2498 2499
    """

    _momentum_acc_str = "momentum"
    _mean_square_acc_str = "mean_square"
2500
    _mean_grad_acc_str = "mean_grad"
Q
qingqing01 已提交
2501 2502 2503 2504 2505 2506

    def __init__(self,
                 learning_rate,
                 rho=0.95,
                 epsilon=1.0e-6,
                 momentum=0.0,
2507
                 centered=False,
2508
                 parameter_list=None,
X
Xin Pan 已提交
2509 2510
                 regularization=None,
                 name=None):
Q
qingqing01 已提交
2511
        super(RMSPropOptimizer, self).__init__(
X
Xin Pan 已提交
2512
            learning_rate=learning_rate,
2513
            parameter_list=parameter_list,
X
Xin Pan 已提交
2514 2515
            regularization=regularization,
            name=name)
Q
qingqing01 已提交
2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528
        if learning_rate is None:
            raise ValueError("learning_rate is not set.")
        if rho is None:
            raise ValueError("rho is not set.")
        if epsilon is None:
            raise ValueError("epsilon is not set.")
        if momentum is None:
            raise ValueError("momentum is not set.")

        self.type = "rmsprop"
        self._rho = rho
        self._epsilon = epsilon
        self._momentum = momentum
2529
        self._centered = centered
Q
qingqing01 已提交
2530 2531 2532 2533 2534 2535 2536 2537

    def _create_accumulators(self, block, parameters):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")

        for p in parameters:
            self._add_accumulator(self._momentum_acc_str, p)
            self._add_accumulator(self._mean_square_acc_str, p)
2538
            self._add_accumulator(self._mean_grad_acc_str, p)
Q
qingqing01 已提交
2539 2540 2541 2542 2543 2544 2545 2546 2547

    def _append_optimize_op(self, block, param_and_grad):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")

        momentum_acc = self._get_accumulator(self._momentum_acc_str,
                                             param_and_grad[0])
        mean_square_acc = self._get_accumulator(self._mean_square_acc_str,
                                                param_and_grad[0])
2548 2549
        mean_grad_acc = self._get_accumulator(self._mean_grad_acc_str,
                                              param_and_grad[0])
Q
qingqing01 已提交
2550 2551 2552 2553 2554 2555 2556
        rmsprop_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "Moment": momentum_acc,
                "MeanSquare": mean_square_acc,
2557
                "MeanGrad": mean_grad_acc,
Q
qingqing01 已提交
2558 2559 2560 2561 2562
                "LearningRate": self._create_param_lr(param_and_grad),
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "MomentOut": momentum_acc,
2563 2564
                "MeanSquareOut": mean_square_acc,
                "MeanGradOut": mean_grad_acc
Q
qingqing01 已提交
2565 2566 2567 2568
            },
            attrs={
                "epsilon": self._epsilon,
                "decay": self._rho,
2569 2570
                "momentum": self._momentum,
                "centered": self._centered
M
minqiyang 已提交
2571 2572
            },
            stop_gradient=True)
Q
qingqing01 已提交
2573 2574 2575 2576

        return rmsprop_op


Q
qiaolongfei 已提交
2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616
class FtrlOptimizer(Optimizer):
    """
    FTRL (Follow The Regularized Leader) Optimizer.

    The paper that proposed Follow The Regularized Leader (FTRL):
    (https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)

    ..  math::

        &new\_accum = squared\_accum + grad^2

        &if (lr\_power == -0.5):

        &\quad  linear\_accum += grad - \\frac{\\sqrt{new\_accum} - \\sqrt{squared\_accum}}{learning\_rate * param}

        &else:

        &\quad   linear\_accum += grad - \\frac{new\_accum^{-lr\_power} - accum^{-lr\_power}}{learning\_rate * param}


        &x = l1 * sign(linear\_accum) - linear\_accum

        &if (lr\_power == -0.5):

        &\quad   y = \\frac{\\sqrt{new\_accum}}{learning\_rate} + (2 * l2)

        &\quad   pre\_shrink = \\frac{x}{y}

        &\quad   param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0)

        &else:

        &\quad   y = \\frac{new\_accum^{-lr\_power}}{learning\_rate} + (2 * l2)

        &\quad   pre\_shrink = \\frac{x}{y}

        &\quad   param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0)

        &squared\_accum += grad^2

2617 2618 2619 2620 2621
    Parameters:
        learning_rate (float|Variable): Global learning rate.
        l1 (float): L1 regularization strength, default is 0.0.
        l2 (float): L2 regularization strength, default is 0.0.
        lr_power (float): Learning Rate Power, default is -0.5.
2622 2623 2624
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
2625 2626 2627 2628
        regularization: A Regularizer, such as :ref:`api_fluid_regularizer_L2DecayRegularizer`. \
            Optional, default is None.
        name (str, optional): This parameter is used by developers to print debugging information. \
            For details, please refer to :ref:`api_guide_Name`. Default is None.
Q
qiaolongfei 已提交
2629 2630 2631 2632 2633 2634 2635

    Raises:
        ValueError: If learning_rate, rho, epsilon, momentum are None.

    Examples:
          .. code-block:: python

2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659
            import paddle
            import paddle.fluid as fluid
            import numpy as np

            place = fluid.CPUPlace()
            main = fluid.Program()
            with fluid.program_guard(main):
                x = fluid.layers.data(name='x', shape=[13], dtype='float32')
                y = fluid.layers.data(name='y', shape=[1], dtype='float32')
                y_predict = fluid.layers.fc(input=x, size=1, act=None)
                cost = fluid.layers.square_error_cost(input=y_predict, label=y)
                avg_cost = fluid.layers.mean(cost)

                ftrl_optimizer = fluid.optimizer.Ftrl(learning_rate=0.1)
                ftrl_optimizer.minimize(avg_cost)

                fetch_list = [avg_cost]
                train_reader = paddle.batch(
                    paddle.dataset.uci_housing.train(), batch_size=1)
                feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
                exe = fluid.Executor(place)
                exe.run(fluid.default_startup_program())
                for data in train_reader():
                    exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
C
chengduo 已提交
2660

2661
    NOTE:
C
chengduo 已提交
2662
       Currently, FtrlOptimizer doesn't support sparse parameter optimization.
Q
qiaolongfei 已提交
2663 2664 2665 2666 2667
    """

    _squared_acc_str = "squared"
    _linear_acc_str = "linear"

X
Xin Pan 已提交
2668 2669 2670 2671 2672
    def __init__(self,
                 learning_rate,
                 l1=0.0,
                 l2=0.0,
                 lr_power=-0.5,
2673
                 parameter_list=None,
X
Xin Pan 已提交
2674 2675
                 regularization=None,
                 name=None):
Q
qiaolongfei 已提交
2676
        super(FtrlOptimizer, self).__init__(
X
Xin Pan 已提交
2677
            learning_rate=learning_rate,
2678
            parameter_list=parameter_list,
X
Xin Pan 已提交
2679 2680
            regularization=regularization,
            name=name)
Q
qiaolongfei 已提交
2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720
        if learning_rate is None:
            raise ValueError("learning_rate is not set.")

        self.type = "ftrl"
        self._l1 = l1
        self._l2 = l2
        self._lr_power = lr_power

    def _create_accumulators(self, block, parameters):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")

        for p in parameters:
            self._add_accumulator(self._squared_acc_str, p)
            self._add_accumulator(self._linear_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")

        squared_acc = self._get_accumulator(self._squared_acc_str,
                                            param_and_grad[0])
        linear_acc = self._get_accumulator(self._linear_acc_str,
                                           param_and_grad[0])
        ftrl_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "SquaredAccumulator": squared_acc,
                "LinearAccumulator": linear_acc,
                "LearningRate": self._create_param_lr(param_and_grad),
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "SquaredAccumOut": squared_acc,
                "LinearAccumOut": linear_acc
            },
            attrs={"l1": self._l1,
                   "l2": self._l1,
M
minqiyang 已提交
2721 2722
                   "lr_power": self._lr_power},
            stop_gradient=True)
Q
qiaolongfei 已提交
2723 2724 2725 2726

        return ftrl_op


Y
Yibing Liu 已提交
2727 2728 2729 2730 2731 2732
class LambOptimizer(AdamOptimizer):
    """
    LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.

    LAMB Optimizer is designed to scale up the batch size of training without losing 
    accuracy, which supports adaptive element-wise updating and accurate layer-wise 
Y
Yibing Liu 已提交
2733 2734
    correction. For more information, please refer to `Large Batch Optimization for 
    Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_ .
Y
Yibing Liu 已提交
2735 2736 2737 2738 2739

    The updating of parameters follows:

    ..  math::

Y
Yibing Liu 已提交
2740
        m_t &= \\beta_1 m_{t - 1}+ (1 - \\beta_1)g_t 
Y
Yibing Liu 已提交
2741

Y
Yibing Liu 已提交
2742
        v_t &= \\beta_2 v_{t - 1}  + (1 - \\beta_2)g_t^2
Y
Yibing Liu 已提交
2743

Y
Yibing Liu 已提交
2744
        r_t &= \\frac{m_t}{\\sqrt{v_t}+\\epsilon}
Y
Yibing Liu 已提交
2745

Y
Yibing Liu 已提交
2746
        w_t &= w_{t-1} -\\eta_t \\frac{\\left \| w_{t-1}\\right \|}{\\left \| r_t + \\lambda w_{t-1}\\right \|} (r_t + \\lambda w_{t-1})
Y
Yibing Liu 已提交
2747 2748 2749 2750 2751 2752


    where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the 
    learning rate, :math:`\\lambda` the LAMB weight decay rate.

    Args:
Y
Yibing Liu 已提交
2753 2754 2755 2756 2757 2758 2759 2760
        learning_rate (float|Variable, optional): the learning rate used to update parameters. \
            Can be a float value or a Variable with data type float32. Default 0.001.
        lamb_weight_decay (float, optional): The LAMB weight decay rate. Default 0.01.
        beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
            Default 0.9.
        beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
            Default 0.999.
        epsilon (float, optional): A small float value for numerical stability. Default 1e-6.
2761 2762 2763
        parameter_list (list, optional):  List of ``Variable`` names to update to minimize ``loss``. \
            This parameter is required in dygraph mode. \
            The default value is None in static mode, at this time all parameters will be updated.
Y
Yibing Liu 已提交
2764 2765 2766 2767 2768 2769 2770
        regularization (Regularizer|None): A Regularizer, such as
           fluid.regularizer.L1DecayRegularizer. Default None.
        exclude_from_weight_decay_fn (function|None): Exclude a parameter from weight 
            decay when **exclude_from_weight_decay_fn(parameter)** returns true. 
            Default None.
        name(str|None): For detailed information, please refer to 
            :ref:`api_guide_Name` . Usually name is no need to set and None by default.
Y
Yibing Liu 已提交
2771 2772 2773 2774 2775 2776

    Examples:
        .. code-block:: python
            
            import paddle.fluid as fluid 

Y
Yibing Liu 已提交
2777
            data = fluid.data(name='x', shape=[-1, 5], dtype='float32')
Y
Yibing Liu 已提交
2778 2779 2780
            hidden = fluid.layers.fc(input=data, size=10)
            cost = fluid.layers.mean(hidden)

Y
Yibing Liu 已提交
2781 2782 2783 2784 2785
            def exclude_fn(param):
                return param.name.endswith('.b_0')

            optimizer = fluid.optimizer.Lamb(learning_rate=0.002,
                                             exclude_from_weight_decay_fn=exclude_fn)
Y
Yibing Liu 已提交
2786 2787 2788 2789
            optimizer.minimize(cost)
    """
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
Y
Yibing Liu 已提交
2790
    # these two not used in op temporarily
Y
Yibing Liu 已提交
2791 2792 2793 2794 2795 2796 2797 2798 2799
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"

    def __init__(self,
                 learning_rate=0.001,
                 lamb_weight_decay=0.01,
                 beta1=0.9,
                 beta2=0.999,
                 epsilon=1e-6,
2800
                 parameter_list=None,
Y
Yibing Liu 已提交
2801
                 regularization=None,
Y
Yibing Liu 已提交
2802
                 exclude_from_weight_decay_fn=None,
Y
Yibing Liu 已提交
2803 2804 2805 2806 2807 2808 2809 2810
                 name=None):
        assert learning_rate is not None
        assert lamb_weight_decay is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
        super(LambOptimizer, self).__init__(
            learning_rate=learning_rate,
2811
            parameter_list=parameter_list,
Y
Yibing Liu 已提交
2812 2813 2814 2815 2816 2817 2818
            regularization=regularization,
            beta1=beta1,
            beta2=beta2,
            epsilon=epsilon,
            name=name)
        self.type = "lamb"
        self._weight_decay = lamb_weight_decay
Y
Yibing Liu 已提交
2819
        self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
Y
Yibing Liu 已提交
2820 2821 2822

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
2823
        block.program._use_lamb = True
Y
Yibing Liu 已提交
2824 2825 2826 2827 2828 2829 2830 2831 2832 2833

        moment1 = self._get_accumulator(self._moment1_acc_str,
                                        param_and_grad[0])
        moment2 = self._get_accumulator(self._moment2_acc_str,
                                        param_and_grad[0])
        beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                              param_and_grad[0])
        beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
                                              param_and_grad[0])

Y
Yibing Liu 已提交
2834 2835 2836 2837 2838 2839
        if self._exclude_from_weight_decay_fn is not None \
            and self._exclude_from_weight_decay_fn(param_and_grad[0]):
            weight_decay = 0.0
        else:
            weight_decay = self._weight_decay

Y
Yibing Liu 已提交
2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860
        # create the lamb optimize op
        lamb_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "LearningRate": self._create_param_lr(param_and_grad),
                "Moment1": moment1,
                "Moment2": moment2,
                "Beta1Pow": beta1_pow_acc,
                "Beta2Pow": beta2_pow_acc
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "Moment1Out": moment1,
                "Moment2Out": moment2
            },
            attrs={
                "beta1": self._beta1,
                "beta2": self._beta2,
                "epsilon": self._epsilon,
Y
Yibing Liu 已提交
2861
                "weight_decay": weight_decay
Y
Yibing Liu 已提交
2862 2863 2864 2865 2866 2867
            },
            stop_gradient=True)

        return lamb_op


2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880
# We short the class name, since users will use the optimizer with the package
# name. The sample code:
#
# import paddle.fluid as fluid
#
# sgd = fluid.optimizer.SGD(...)
#
# It is no need to add an `Optimizer` as the class suffix
SGD = SGDOptimizer
Momentum = MomentumOptimizer
Adagrad = AdagradOptimizer
Adam = AdamOptimizer
Adamax = AdamaxOptimizer
2881
Dpsgd = DpsgdOptimizer
2882
DecayedAdagrad = DecayedAdagradOptimizer
2883
Adadelta = AdadeltaOptimizer
Q
qingqing01 已提交
2884
RMSProp = RMSPropOptimizer
Q
qiaolongfei 已提交
2885
Ftrl = FtrlOptimizer
2886
LarsMomentum = LarsMomentumOptimizer
Y
Yibing Liu 已提交
2887
Lamb = LambOptimizer
2888 2889 2890


class ModelAverage(Optimizer):
2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908 2909
    """
    The ModelAverage optimizer accumulates specific continuous historical parameters
    during training. The accumulated historical range can be controlled by the passed
    ``average_window_rate`` argument. The averaged ``Parameter`` are used in the prediction,
    which usually can improve the accuracy of the prediction.

    Accumulate the average of the ``Parameter`` in the sliding window, the result will be saved
    in a temporary variable, can be applied to the current model's ``Parameter`` by calling
    the ``apply()`` method, and the current model ``Parameter`` can be restored by calling
    the ``restore()`` method.

    The window size for calculating the average is determined by ``average_window_rate``,
    ``min_average_window``, ``max_average_window`` and the current ``Parameter`` update times (num_updates).

    When the cumulative times (num_accumulates) is greater than the specific window
    threshold (average_window), the accumulated ``Parameter`` temporary variable is set to 0.0.
    The following example will help to understand the role of these arguments:

    ::
2910

2911 2912 2913 2914 2915 2916 2917 2918 2919
        if num_accumulates >= min_average_window and num_accumulates >= min(max_average_window, num_updates * average_window_rate):
            num_accumulates = 0

    In the above conditional judgment statement, ``num_accumulates`` indicates the current
    accumulated number, which can be abstractly understood as the length of the cumulative window.
    The length of the window must be at least the length set by the ``min_average_window`` argument,
    and cannot exceed the length specified by the ``max_average_window`` argument or
    ``num_updates * average_window_rate``, where ``num_updates`` indicates the current ``Parameter``
    update times, ``average_window_rate`` is a coefficient that calculates the length of the window.
2920 2921

    Args:
2922 2923 2924 2925 2926 2927 2928 2929
        average_window_rate (float): The calculate ratio of the window length relative to ``Parameter`` update times.
        min_average_window (int, optional): the minimum size of average window length. The default value is 10000.
        max_average_window (int, optional): The maximum size of average window length. The default value is 10000.
        regularization (WeightDecayRegularizer, optional): A ``Regularizer``, such as
             :ref:`api_fluid_regularizer_L2DecayRegularizer`. The default value is None.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.
2930

2931
    Examples:
Q
qiaolongfei 已提交
2932 2933 2934

      .. code-block:: python

2935 2936 2937 2938 2939 2940
        import paddle.fluid as fluid
        import numpy

        # First create the Executor.
        place = fluid.CPUPlace()  # fluid.CUDAPlace(0)
        exe = fluid.Executor(place)
2941

2942 2943 2944 2945
        train_program = fluid.Program()
        startup_program = fluid.Program()
        with fluid.program_guard(train_program, startup_program):
            # build net
2946
            data = fluid.data(name='X', shape=[None, 1], dtype='float32')
2947 2948 2949 2950 2951 2952 2953 2954
            hidden = fluid.layers.fc(input=data, size=10)
            loss = fluid.layers.mean(hidden)
            optimizer = fluid.optimizer.Momentum(learning_rate=0.2, momentum=0.1)
            optimizer.minimize(loss)

            # build ModelAverage optimizer
            model_average = fluid.optimizer.ModelAverage(0.15,
                                                         min_average_window=10000,
2955
                                                         max_average_window=12500)
2956 2957

            exe.run(startup_program)
2958 2959 2960 2961 2962
            for i in range(12500):
                x = numpy.random.random(size=(10, 1)).astype('float32')
                outs = exe.run(program=train_program,
                               feed={'X': x},
                               fetch_list=[loss.name])
2963 2964

            # apply ModelAverage
2965
            with model_average.apply(exe):
2966 2967 2968 2969
                x = numpy.random.random(size=(10, 1)).astype('float32')
                exe.run(program=train_program,
                        feed={'X': x},
                        fetch_list=[loss.name])
2970 2971 2972
    """

    def __init__(self,
W
wanghaoshuang 已提交
2973
                 average_window_rate,
2974 2975
                 min_average_window=10000,
                 max_average_window=10000,
X
Xin Pan 已提交
2976 2977
                 regularization=None,
                 name=None):
Z
zhongpu 已提交
2978 2979
        if framework.in_dygraph_mode():
            raise Exception("In dygraph, don't support ModelAverage.")
X
Xin Pan 已提交
2980 2981
        super(ModelAverage, self).__init__(
            0.0, regularization=regularization, name=name)
2982 2983 2984
        self.average_window = average_window_rate
        self.min_average_window = min_average_window
        self.max_average_window = max_average_window
2985

2986
        self.params_grads = []
2987 2988
        for param in framework.default_main_program().global_block(
        ).all_parameters():
2989
            if param.do_model_average != False:
2990
                grad = param.block.create_var(
2991 2992
                    name=unique_name.generate_with_ignorable_key(".".join(
                        [param.name, 'tmp'])),
2993 2994
                    dtype=param.dtype,
                    persistable=False,
W
wanghaoshuang 已提交
2995
                    stop_gradient=True)
2996
                self.params_grads.append((param, grad))
2997

2998
        for param, grad in self.params_grads:
2999 3000
            if grad is None:
                continue
X
Xin Pan 已提交
3001 3002
            with param.block.program._optimized_guard(
                [param, grad]), name_scope('move_average'):
3003
                self._append_average_accumulate_op(param)
3004

3005 3006 3007 3008
        self.apply_program = Program()
        block = self.apply_program.global_block()
        with program_guard(main_program=self.apply_program):
            for param_grad in self.params_grads:
3009
                self._add_average_apply_op(block, param_grad)
3010 3011 3012 3013 3014

        self.restore_program = Program()
        block = self.restore_program.global_block()
        with program_guard(main_program=self.restore_program):
            for param_grad in self.params_grads:
3015
                self._add_average_restore_op(block, param_grad)
3016

3017
    def _add_average_apply_op(self, block, param_grad):
L
Luo Tao 已提交
3018 3019 3020 3021 3022 3023
        param = block._clone_variable(param_grad[0])
        grad = block._clone_variable(param_grad[1])
        sum_1 = block._clone_variable(self._get_accumulator('sum_1', param))
        sum_2 = block._clone_variable(self._get_accumulator('sum_2', param))
        sum_3 = block._clone_variable(self._get_accumulator('sum_3', param))
        num_accumulates = block._clone_variable(
3024
            self._get_accumulator('num_accumulates', param))
L
Luo Tao 已提交
3025
        old_num_accumulates = block._clone_variable(
3026
            self._get_accumulator('old_num_accumulates', param))
L
Luo Tao 已提交
3027
        num_updates = block._clone_variable(
3028 3029 3030 3031 3032 3033
            self._get_accumulator('num_updates', param))
        # backup param value to grad
        layers.assign(input=param, output=grad)
        # param = (sum_1 + sum_2 + sum_3) / (num_accumulates + old_num_accumulates)
        tmp = layers.sum(x=[num_accumulates, old_num_accumulates])
        sum = layers.sum(x=[sum_1, sum_2, sum_3])
D
dzhwinter 已提交
3034 3035 3036 3037
        tmp = layers.cast(
            x=tmp, dtype='float32' if self._dtype == None else self._dtype)
        sum = layers.cast(
            x=sum, dtype='float32' if self._dtype == None else self._dtype)
S
sneaxiy 已提交
3038
        ops._elementwise_div(x=sum, y=tmp, out=param)
3039 3040

    def _add_average_restore_op(self, block, param_grad):
L
Luo Tao 已提交
3041 3042
        param = block._clone_variable(param_grad[0])
        grad = block._clone_variable(param_grad[1])
3043 3044 3045 3046 3047 3048 3049 3050 3051 3052 3053 3054 3055 3056 3057 3058 3059 3060 3061 3062 3063 3064 3065 3066 3067 3068 3069 3070 3071 3072 3073 3074 3075 3076 3077 3078 3079
        layers.assign(input=grad, output=param)

    def _append_average_accumulate_op(self, param):
        self.helper = LayerHelper("average_accumulate")
        sum_1 = self._add_accumulator('sum_1', param)
        sum_2 = self._add_accumulator('sum_2', param)
        sum_3 = self._add_accumulator('sum_3', param)
        num_accumulates = self._add_accumulator(
            'num_accumulates', param, dtype='int64', shape=[1])
        old_num_accumulates = self._add_accumulator(
            'old_num_accumulates', param, dtype='int64', shape=[1])
        num_updates = self._add_accumulator(
            'num_updates', param, dtype='int64', shape=[1])

        self.helper.append_op(
            type='average_accumulates',
            inputs={
                "param": param,
                "in_sum_1": sum_1,
                "in_sum_2": sum_2,
                "in_sum_3": sum_3,
                "in_num_accumulates": num_accumulates,
                "in_old_num_accumulates": old_num_accumulates,
                "in_num_updates": num_updates
            },
            outputs={
                "out_sum_1": sum_1,
                "out_sum_2": sum_2,
                "out_sum_3": sum_3,
                "out_num_accumulates": num_accumulates,
                "out_old_num_accumulates": old_num_accumulates,
                "out_num_updates": num_updates,
            },
            attrs={
                "average_window": self.average_window,
                "min_average_window": self.min_average_window,
                "max_average_window": self.max_average_window,
M
minqiyang 已提交
3080 3081
            },
            stop_gradient=True)
3082

S
rename  
sneaxiy 已提交
3083
    @signature_safe_contextmanager
3084
    def apply(self, executor, need_restore=True):
3085 3086
        """
        Apply the average of the cumulative ``Parameter`` to the parameters of the current model.
3087 3088

        Args:
3089 3090 3091 3092 3093 3094 3095 3096 3097 3098 3099 3100 3101 3102 3103 3104 3105 3106 3107 3108 3109 3110 3111 3112 3113 3114 3115 3116 3117 3118 3119 3120 3121 3122 3123 3124 3125 3126 3127 3128 3129 3130 3131 3132
            executor(fluid.Executor): The current network executor.
            need_restore(bool): Restore flag variable, if set to True, the network will restore
                the parameters of the network to the default value, if set to False,
                it will not be restored. The default value is True.

        Examples:

          .. code-block:: python

            import paddle.fluid as fluid
            import numpy

            # First create the Executor.
            place = fluid.CPUPlace()  # fluid.CUDAPlace(0)
            exe = fluid.Executor(place)

            train_program = fluid.Program()
            startup_program = fluid.Program()
            with fluid.program_guard(train_program, startup_program):
                # build net
                data = fluid.data(name='X', shape=[None, 1], dtype='float32')
                hidden = fluid.layers.fc(input=data, size=10)
                loss = fluid.layers.mean(hidden)
                optimizer = fluid.optimizer.Momentum(learning_rate=0.2, momentum=0.1)
                optimizer.minimize(loss)

                # build ModelAverage optimizer
                model_average = fluid.optimizer.ModelAverage(0.15,
                                                            min_average_window=10000,
                                                            max_average_window=12500)

                exe.run(startup_program)
                for i in range(12500):
                    x = numpy.random.random(size=(10, 1)).astype('float32')
                    outs = exe.run(program=train_program,
                                feed={'X': x},
                                fetch_list=[loss.name])

                # apply ModelAverage
                with model_average.apply(exe):
                    x = numpy.random.random(size=(10, 1)).astype('float32')
                    exe.run(program=train_program,
                            feed={'X': x},
                            fetch_list=[loss.name])
3133
        """
3134 3135 3136 3137 3138 3139
        executor.run(self.apply_program)
        try:
            yield
        finally:
            if need_restore:
                self.restore(executor)
3140 3141

    def restore(self, executor):
3142 3143
        """
        Restore ``Parameter`` values of current model.
3144 3145
        
        Args:
3146 3147 3148 3149 3150 3151 3152 3153 3154 3155 3156 3157 3158 3159 3160 3161 3162 3163 3164 3165 3166 3167 3168 3169 3170 3171 3172 3173 3174 3175 3176 3177 3178 3179 3180 3181 3182 3183 3184 3185 3186 3187 3188 3189
            executor(fluid.Executor): The current network executor.

        Examples:

          .. code-block:: python

            import paddle.fluid as fluid
            import numpy

            # First create the Executor.
            place = fluid.CPUPlace()  # fluid.CUDAPlace(0)
            exe = fluid.Executor(place)

            train_program = fluid.Program()
            startup_program = fluid.Program()
            with fluid.program_guard(train_program, startup_program):
                # build net
                data = fluid.data(name='X', shape=[None, 1], dtype='float32')
                hidden = fluid.layers.fc(input=data, size=10)
                loss = fluid.layers.mean(hidden)
                optimizer = fluid.optimizer.Momentum(learning_rate=0.2, momentum=0.1)
                optimizer.minimize(loss)

                # build ModelAverage optimizer
                model_average = fluid.optimizer.ModelAverage(0.15,
                                                            min_average_window=10000,
                                                            max_average_window=12500)

                exe.run(startup_program)
                for i in range(12500):
                    x = numpy.random.random(size=(10, 1)).astype('float32')
                    outs = exe.run(program=train_program,
                                feed={'X': x},
                                fetch_list=[loss.name])

                # apply ModelAverage
                with model_average.apply(exe, False):
                    x = numpy.random.random(size=(10, 1)).astype('float32')
                    exe.run(program=train_program,
                            feed={'X': x},
                            fetch_list=[loss.name])

                # restore Parameters
                model_average.restore(exe)
3190
        """
3191
        executor.run(self.restore_program)
3192 3193 3194 3195 3196 3197 3198 3199 3200 3201


class ExponentialMovingAverage(object):
    """
    Compute the moving average of parameters with exponential decay.
    Given a parameter :math:`\\theta`, its exponential moving average (EMA)
    will be

    ..  math::

3202
        \\text{EMA}_0 & = 0
3203

3204 3205
	\\text{EMA}_t & = \\text{decay} * \\text{EMA}_{t-1} + (1 - \\text{decay}) * \\theta_t

Y
Yibing Liu 已提交
3206 3207 3208 3209
    The average results calculated by **update()** method will be saved in 
    temporary variables which are created and maintained by the object, and can 
    be applied to parameters of current model by calling **apply()** method. And 
    the **restore()** method is used to restore the parameters.
3210 3211 3212 3213 3214 3215 3216 3217 3218 3219 3220 3221 3222 3223 3224 3225 3226 3227 3228 3229 3230

    **Bias correction**. All EMAs are initialized to :math:`0` and hence they will be 
    zero biased, which can be corrected by divided by a factor 
    :math:`(1 - \\text{decay}^t)` , i.e., the actual EMAs applied to parameters 
    when calling **apply()** method would be 

    ..  math::
    
        \\widehat{\\text{EMA}}_t = \\frac{\\text{EMA}_t}{1 - \\text{decay}^t}

    **Decay rate scheduling**. A large decay rate very close to 1 would result 
    in that the averages move very slowly. And a better strategy is to set a 
    relative smaller decay rate in the very beginning. The argument **thres_steps**
    allows users to pass a Variable to schedule the decay rate, in this case, 
    the actual decay rate becomes
     
    ..  math::
    
        \\min(\\text{decay}, \\frac{1 + \\text{thres_steps}}{10 + \\text{thres_steps}})

    Usually **thres_steps** can be the global training steps.
3231 3232 3233


    Args:
Y
Yibing Liu 已提交
3234 3235 3236 3237 3238 3239 3240
	decay (float, optional): The exponential decay rate, usually close to 1, such as 
            0.999, 0.9999, ... . Default 0.999.
        thres_steps (Variable|None): If not `None`, schedule the decay rate. 
            Default None.
        name (str|None): For detailed information, please refer to 
            :ref:`api_guide_Name`. Usually name is no need to set and None by 
            default.
3241 3242 3243 3244 3245


    Examples:

	.. code-block:: python
3246 3247 3248 3249 3250

	    import numpy
	    import paddle
	    import paddle.fluid as fluid

Y
Yibing Liu 已提交
3251
	    data = fluid.data(name='x', shape=[-1, 5], dtype='float32')
3252 3253 3254 3255 3256 3257 3258 3259
	    hidden = fluid.layers.fc(input=data, size=10)
	    cost = fluid.layers.mean(hidden)

	    test_program = fluid.default_main_program().clone(for_test=True)

	    optimizer = fluid.optimizer.Adam(learning_rate=0.001)
	    optimizer.minimize(cost)

3260
	    global_steps = fluid.layers.autoincreased_step_counter()
3261 3262 3263 3264 3265 3266 3267 3268 3269 3270 3271 3272 3273 3274 3275 3276 3277 3278 3279 3280 3281 3282 3283 3284 3285 3286 3287 3288 3289
	    ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps)
	    ema.update()

	    place = fluid.CPUPlace()
	    exe = fluid.Executor(place)
	    exe.run(fluid.default_startup_program())

	    for pass_id in range(3):
		for batch_id in range(6):
		    data = numpy.random.random(size=(10, 5)).astype('float32')
		    exe.run(program=fluid.default_main_program(),
			feed={'x': data}, 
			fetch_list=[cost.name])

		# usage 1
		with ema.apply(exe):
		    data = numpy.random.random(size=(10, 5)).astype('float32')
		    exe.run(program=test_program,
			    feed={'x': data}, 
			    fetch_list=[hidden.name])
			    

		 # usage 2
		with ema.apply(exe, need_restore=False):
		    data = numpy.random.random(size=(10, 5)).astype('float32')
		    exe.run(program=test_program,
			    feed={'x': data}, 
			    fetch_list=[hidden.name])
		ema.restore(exe)
3290 3291
    """

3292
    def __init__(self, decay=0.999, thres_steps=None, name=None):
Z
zhongpu 已提交
3293 3294 3295
        if framework.in_dygraph_mode():
            raise Exception(
                "In dygraph, don't support ExponentialMovingAverage.")
3296
        self._decay = decay
3297
        self._thres_steps = thres_steps
3298
        self._name = name if name is not None else ''
3299 3300
        self._decay_var = self._get_ema_decay()

3301
        self._step_counter_name = "@EMA_STEP_COUNTER@"
Y
Yibing Liu 已提交
3302
        self._params_tmps = []
3303
        for param in default_main_program().global_block().all_parameters():
3304 3305 3306 3307 3308 3309 3310
            if param.do_model_average != False:
                tmp = param.block.create_var(
                    name=unique_name.generate(".".join(
                        [self._name + param.name, 'ema_tmp'])),
                    dtype=param.dtype,
                    persistable=False,
                    stop_gradient=True)
Y
Yibing Liu 已提交
3311
                self._params_tmps.append((param, tmp))
3312

Y
Yibing Liu 已提交
3313 3314
        self._ema_vars = {}
        for param, tmp in self._params_tmps:
3315 3316
            with param.block.program._optimized_guard(
                [param, tmp]), name_scope('moving_average'):
Y
Yibing Liu 已提交
3317
                self._ema_vars[param.name] = self._create_ema_vars(param)
3318 3319 3320 3321

        self.apply_program = Program()
        block = self.apply_program.global_block()
        with program_guard(main_program=self.apply_program):
3322
            decay_pow, global_step = self._get_decay_pow(block)
Y
Yibing Liu 已提交
3323
            for param, tmp in self._params_tmps:
3324 3325
                param = block._clone_variable(param)
                tmp = block._clone_variable(tmp)
Y
Yibing Liu 已提交
3326
                ema = block._clone_variable(self._ema_vars[param.name])
3327
                layers.assign(input=param, output=tmp)
3328
                # bias correction
3329 3330 3331
                with layers.control_flow.Switch() as switch:
                    with switch.case(global_step > 0):
                        layers.assign(output=ema, input=ema / (1.0 - decay_pow))
3332 3333 3334 3335 3336
                layers.assign(input=ema, output=param)

        self.restore_program = Program()
        block = self.restore_program.global_block()
        with program_guard(main_program=self.restore_program):
Y
Yibing Liu 已提交
3337
            for param, tmp in self._params_tmps:
3338 3339 3340 3341
                tmp = block._clone_variable(tmp)
                param = block._clone_variable(param)
                layers.assign(input=tmp, output=param)

3342 3343 3344 3345 3346 3347 3348 3349 3350 3351 3352 3353 3354 3355 3356 3357 3358 3359 3360 3361 3362 3363
    def _get_ema_decay(self):
        with default_main_program()._lr_schedule_guard():
            decay_var = layers.tensor.create_global_var(
                shape=[1],
                value=self._decay,
                dtype='float32',
                persistable=True,
                name="scheduled_ema_decay_rate")

            if self._thres_steps is not None:
                decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0)
                with layers.control_flow.Switch() as switch:
                    with switch.case(decay_t < self._decay):
                        layers.tensor.assign(decay_t, decay_var)
                    with switch.default():
                        layers.tensor.assign(
                            np.array(
                                [self._decay], dtype=np.float32),
                            decay_var)
        return decay_var

    def _get_decay_pow(self, block):
3364 3365 3366 3367 3368 3369 3370
        global_step = layers.create_global_var(
            name=self._step_counter_name,
            shape=[1],
            value=0,
            dtype='int64',
            persistable=True)
        global_step = layers.cast(global_step, "float32")
3371
        decay_var = block._clone_variable(self._decay_var)
3372 3373
        decay_pow_acc = layers.elementwise_pow(decay_var, global_step)
        return decay_pow_acc, global_step
3374

Y
Yibing Liu 已提交
3375
    def _create_ema_vars(self, param):
3376 3377 3378 3379 3380 3381 3382 3383 3384
        param_ema = layers.create_global_var(
            name=unique_name.generate(self._name + param.name + '_ema'),
            shape=param.shape,
            value=0.0,
            dtype=param.dtype,
            persistable=True)

        return param_ema

Y
Yibing Liu 已提交
3385 3386 3387 3388 3389
    def update(self):
        """ 
        Update Exponential Moving Average. Should only call this method in 
        train program.
        """
3390 3391
        global_step = layers.autoincreased_step_counter(
            counter_name=self._step_counter_name)
3392
        param_master_emas = []
Y
Yibing Liu 已提交
3393 3394 3395 3396
        for param, tmp in self._params_tmps:
            with param.block.program._optimized_guard(
                [param, tmp]), name_scope('moving_average'):
                param_ema = self._ema_vars[param.name]
3397
                if param.name + '.master' in self._ema_vars:
3398 3399 3400 3401 3402 3403 3404 3405 3406 3407 3408 3409 3410 3411 3412 3413 3414
                    master_ema = self._ema_vars[param.name + '.master']
                    param_master_emas.append([param_ema, master_ema])
                else:
                    ema_t = param_ema * self._decay_var + param * (
                        1 - self._decay_var)
                    layers.assign(input=ema_t, output=param_ema)

        # for fp16 params
        for param_ema, master_ema in param_master_emas:
            default_main_program().global_block().append_op(
                type="cast",
                inputs={"X": master_ema},
                outputs={"Out": param_ema},
                attrs={
                    "in_dtype": master_ema.dtype,
                    "out_dtype": param_ema.dtype
                })
Y
Yibing Liu 已提交
3415

3416 3417 3418 3419 3420 3421 3422
    @signature_safe_contextmanager
    def apply(self, executor, need_restore=True):
        """
        Apply moving average to parameters for evaluation.
        
        Args:
            executor (Executor): The Executor to execute applying.
Y
Yibing Liu 已提交
3423 3424
            need_restore (bool, optional): Whether to restore parameters after 
                applying. Default True.
3425 3426 3427 3428 3429 3430 3431 3432 3433 3434 3435 3436 3437 3438 3439
        """
        executor.run(self.apply_program)
        try:
            yield
        finally:
            if need_restore:
                self.restore(executor)

    def restore(self, executor):
        """Restore parameters.
        
        Args:
            executor (Executor): The Executor to execute restoring.
        """
        executor.run(self.restore_program)
H
hutuxian 已提交
3440 3441 3442


class PipelineOptimizer(object):
3443 3444
    """
    Pipeline Optimizer
H
hutuxian 已提交
3445

T
tianshuo78520a 已提交
3446
    Train with pipeline mode. The program will be split by cut_list. 
H
hutuxian 已提交
3447 3448

    If the len of cut_list is k, then the whole program (including \
T
tianshuo78520a 已提交
3449
    backward part) will be split to 2*k-1 sections. 
H
hutuxian 已提交
3450 3451 3452 3453
    
    So the length of place_list and concurrency_list must be also 2*k-1.

    Note: Though the asynchronous mode is applied in pipeline training to speed up, \
3454
    the final performance depends on the training progress of each pipeline heavily.
H
hutuxian 已提交
3455 3456 3457

    And we will try the synchronous mode in the future.

3458
    Args:
H
hutuxian 已提交
3459 3460 3461 3462
        optimizer (Optimizer): The based optimizer, such as SGD.
        cut_list (list of Variable list): The cut variable of the main_program.
        place_list (list of Place): The place where the section will run on.
        concurrency_list (list of int): The concurrency degree.
3463 3464
        queue_size (int): Each section will consume scopes from its in-scope queue 
                        and produce scopes to out-scope queue. And this parameter 
H
hutuxian 已提交
3465 3466 3467 3468
                        specify the scope queue size. [Optional. Default: 30].
        sync_steps (int): The synchronization steps between different cards. [Optional. Default: 1].
        start_cpu_core_id (int): specify the first cpu core id. [Optional. Default:0].

3469 3470
    Examples:
        .. code-block:: python
H
hutuxian 已提交
3471

3472
            import paddle.fluid as fluid
H
hutuxian 已提交
3473 3474 3475 3476 3477 3478 3479 3480 3481 3482 3483 3484 3485 3486 3487 3488 3489 3490 3491 3492 3493 3494 3495 3496 3497 3498 3499 3500 3501 3502 3503 3504 3505 3506
            import paddle.fluid.layers as layers

            x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0)
            y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0)
            emb_x = layers.embedding(input=x, param_attr=fluid.ParamAttr(name="embx"), size=[10,2], is_sparse=False)
            emb_y = layers.embedding(input=y, param_attr=fluid.ParamAttr(name="emby",learning_rate=0.9), size=[10,2], is_sparse=False)
            concat = layers.concat([emb_x, emb_y], axis=1)
            fc = layers.fc(input=concat, name="fc", size=1, num_flatten_dims=1, bias_attr=False)
            loss = layers.reduce_mean(fc)
            optimizer = fluid.optimizer.SGD(learning_rate=0.5)
            optimizer = fluid.optimizer.PipelineOptimizer(optimizer,
                    cut_list=[[emb_x, emb_y], [loss]],
                    place_list=[fluid.CPUPlace(), fluid.CUDAPlace(0), fluid.CPUPlace()],
                    concurrency_list=[1, 1, 4],
                    queue_size=2,
                    sync_steps=1,
                    )
            optimizer.minimize(loss)
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            exe.run(fluid.default_startup_program())
            filelist = [] # you should set your own filelist, e.g. filelist = ["dataA.txt"]
            dataset = fluid.DatasetFactory().create_dataset("FileInstantDataset")
            dataset.set_use_var([x,y])
            dataset.set_batch_size(batch_size)
            dataset.set_filelist(filelist)
            exe.train_from_dataset(
                        fluid.default_main_program(),
                        dataset,
                        thread=2,
                        debug=False,
                        fetch_list=[],
                        fetch_info=[],
                        print_period=1)
3507 3508
    """

H
hutuxian 已提交
3509 3510 3511 3512 3513 3514 3515 3516
    def __init__(self,
                 optimizer,
                 cut_list=None,
                 place_list=None,
                 concurrency_list=None,
                 queue_size=30,
                 sync_steps=1,
                 start_cpu_core_id=0):
Z
zhongpu 已提交
3517 3518
        if framework.in_dygraph_mode():
            raise Exception("In dygraph, don't support PipelineOptimizer.")
H
hutuxian 已提交
3519 3520 3521 3522 3523 3524 3525 3526 3527
        # TODO: check properties
        self._optimizer = optimizer
        self._cut_list = cut_list
        self._place_list = place_list
        self._concurrency_list = concurrency_list
        self._queue_size = queue_size
        self._sync_steps = sync_steps
        self._start_cpu_core_id = start_cpu_core_id

H
hutuxian 已提交
3528
    def _create_vars(self, block, main_program):
H
hutuxian 已提交
3529 3530 3531 3532 3533 3534 3535 3536 3537 3538 3539
        used_var_set = set()
        for op_idx in range(block.desc.op_size()):
            op_desc = block.desc.op(op_idx)
            vars = op_desc.input_arg_names() + op_desc.output_arg_names()
            for var in vars:
                if var in used_var_set:
                    continue
                used_var_set.add(var)
                source_var = main_program.block(0).var(str(var))
                block._clone_variable(source_var, False)

H
hutuxian 已提交
3540
    def _extract_section_opt_ops(self, ops, cut_point_name):
H
hutuxian 已提交
3541 3542 3543 3544 3545 3546 3547 3548 3549 3550 3551 3552 3553 3554 3555
        """
        Extract opt ops in the given section
        """
        output_names = set(cut_point_name)
        relevant_op_flags = [True] * len(ops)
        for i, op in reversed(list(enumerate(ops))):
            if _some_in_set_(op.desc.output_arg_names(), output_names):
                for name in op.desc.input_arg_names():
                    output_names.add(name)
            else:
                relevant_op_flags[i] = False

        op_path = [ops[i] for i in range(len(ops)) if relevant_op_flags[i]]
        return op_path

H
hutuxian 已提交
3556
    def _find_input_output(self, ops, name, is_forward=True):
H
hutuxian 已提交
3557 3558 3559 3560 3561 3562 3563 3564 3565 3566 3567 3568 3569 3570
        """
        Find the inputs or outputs of a section
        """
        all_set = set()
        part_set = set()
        for op in ops:
            if is_forward:
                part_set.update(op.desc.output_arg_names())
            else:
                part_set.update(op.desc.input_arg_names())
            all_set.update(op.desc.output_arg_names())
            all_set.update(op.desc.input_arg_names())
        return all_set - part_set

H
hutuxian 已提交
3571
    def _find_persistable_vars(self, ops, whole_parameters):
H
hutuxian 已提交
3572 3573 3574 3575 3576 3577 3578 3579 3580 3581 3582 3583 3584 3585 3586 3587 3588 3589 3590 3591 3592 3593 3594 3595 3596 3597 3598
        """
        find the persistable input vars in current section
        """
        res = set()
        for op in ops:
            vars = op.desc.input_arg_names()
            for var in vars:
                if var in whole_parameters:
                    res.add(var)
        return res

    def _is_opt_role_op(self, op):
        op_maker = core.op_proto_and_checker_maker
        optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
        if op_maker.kOpRoleAttrName() in op.attr_names and \
                int(op.all_attrs()[op_maker.kOpRoleAttrName()]) & int(optimize_role) != 0:
            return True
        return False

    def _is_lr_role_op(self, op):
        op_maker = core.op_proto_and_checker_maker
        optimize_role = core.op_proto_and_checker_maker.OpRole.LRSched
        if op_maker.kOpRoleAttrName() in op.attr_names and \
                int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
            return True
        return False

H
hutuxian 已提交
3599
    def _extract_section_ops(self, ops, cut_point_name):
H
hutuxian 已提交
3600 3601 3602 3603 3604 3605 3606 3607 3608 3609 3610 3611 3612 3613 3614 3615 3616 3617 3618
        """
        Extract ops in the given section 
        """
        output_names = set(cut_point_name)
        relevant_op_flags = [True] * len(ops)
        for i, op in reversed(list(enumerate(ops))):
            if not self._is_opt_role_op(op) and _some_in_set_(
                    op.desc.output_arg_names(), output_names):
                for name in op.desc.input_arg_names():
                    output_names.add(name)
            elif op.desc.type() == "print" and op.desc.input_arg_names()[
                    0] in output_names:
                continue
            else:
                relevant_op_flags[i] = False

        op_path = [ops[i] for i in range(len(ops)) if relevant_op_flags[i]]
        return op_path

H
hutuxian 已提交
3619 3620
    def _find_section_opt(self, ops, params):
        res = self._extract_section_opt_ops(ops, params)
H
hutuxian 已提交
3621 3622
        return res

H
hutuxian 已提交
3623
    def _split_program(self, main_program, cut_list):
H
hutuxian 已提交
3624 3625 3626 3627 3628 3629 3630 3631 3632 3633 3634 3635 3636 3637 3638 3639 3640 3641 3642 3643
        programs = []
        block = main_program.block(0)
        whole_parameters = [e.name for e in block.all_parameters()]
        cut_var_names = []
        cut_len = len(cut_list)
        sec_params = []
        for i, cut_vars in enumerate(cut_list[:-1]):
            cut_var_names.append([cut_var.name for cut_var in cut_vars])
        for i, cut_vars in reversed(list(enumerate(cut_list[:-1]))):
            cut_var_names.append(
                [_append_grad_suffix_(cut_var.name) for cut_var in cut_vars])
            if i == 0:
                cut_var_names[-1] += [var.name for var in cut_list[-1]]
        ops = block.ops[:]
        for i, cut_vars in enumerate(cut_var_names):
            program = {
                "program": Program(),
                "input_set": set(),
                "output_set": set()
            }
H
hutuxian 已提交
3644
            cur_ops = self._extract_section_ops(ops, cut_vars)
H
hutuxian 已提交
3645 3646 3647 3648 3649 3650
            if i == 0:
                for op in ops:
                    if self._is_lr_role_op(op):
                        cur_ops.append(op)
            #prevent inplace in/out
            program["input_set"].update(
H
hutuxian 已提交
3651
                self._find_input_output(
H
hutuxian 已提交
3652 3653 3654 3655 3656 3657
                    cur_ops, [], is_forward=True))
            for e in cur_ops:
                ops.remove(e)

            if i < cut_len:
                sec_params.append(
H
hutuxian 已提交
3658
                    self._find_persistable_vars(cur_ops, whole_parameters))
H
hutuxian 已提交
3659
            if i >= cut_len - 1:
H
hutuxian 已提交
3660 3661
                opt_ops = self._find_section_opt(
                    ops, sec_params[2 * cut_len - 2 - i])
H
hutuxian 已提交
3662 3663 3664 3665 3666 3667 3668 3669 3670 3671

                for e in opt_ops:
                    ops.remove(e)
                cur_ops += opt_ops

            op_descs = [op.desc for op in cur_ops]
            for op_desc in op_descs:
                ap_op = program["program"].block(0).desc.append_op()
                ap_op.copy_from(op_desc)
            program["input_set"].update(
H
hutuxian 已提交
3672
                self._find_input_output(
H
hutuxian 已提交
3673 3674 3675
                    cur_ops, cut_vars, is_forward=True))
            program["input_set"].update(sec_params[min(i, 2 * cut_len - 2 - i)])
            program["output_set"].update(
H
hutuxian 已提交
3676
                self._find_input_output(
H
hutuxian 已提交
3677 3678 3679 3680 3681 3682 3683 3684 3685 3686 3687 3688 3689 3690
                    cur_ops, cut_vars, is_forward=False))
            programs.append(program)
        program = {
            "program": Program(),
            "input_set": set(),
            "output_set": set()
        }
        op_descs = [op.desc for op in ops]
        for op_desc in op_descs:
            ap_op = program["program"].block(0).desc.append_op()
            ap_op.copy_from(op_desc)
        program["input_set"].update(
            [cut_var.name + "@GRAD" for cut_var in cut_list[0]])
        program["input_set"].update(
H
hutuxian 已提交
3691
            self._find_input_output(
H
hutuxian 已提交
3692 3693 3694 3695 3696 3697 3698 3699 3700 3701 3702 3703 3704 3705 3706 3707 3708 3709 3710 3711
                ops, [], is_forward=True))
        program["input_set"].update(sec_params[0])
        programs.append(program)
        inputs = set()
        for program in reversed(list(programs)):
            output_list = list(program["output_set"])
            for output in output_list:
                if output not in inputs:
                    program["output_set"].remove(output)
            inputs.update(program["input_set"])
        return programs

    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        self._optimizer.minimize(loss, startup_program, parameter_list,
                                 no_grad_set)
        program = loss.block.program
H
hutuxian 已提交
3712 3713 3714 3715 3716 3717 3718 3719
        if len(self._cut_list) == 0:
            program_list = []
            ptmp = {"program": program, "input_set": set(), "output_set": set()}
            program_list.append(ptmp)
        else:
            program_list = self._split_program(program, self._cut_list)
            for p in program_list:
                self._create_vars(p["program"].block(0), program)
H
hutuxian 已提交
3720 3721 3722 3723 3724 3725 3726 3727 3728 3729 3730 3731 3732 3733 3734 3735 3736 3737 3738 3739
        whole_parameters = [e.name for e in program.block(0).all_parameters()]
        param_need_sync = []
        for i, section_p in enumerate(program_list):
            if not isinstance(self._place_list[i], core.CUDAPlace):
                continue
            section_var = [e for e in section_p["program"].block(0).vars]
            for p in section_var:
                if p in whole_parameters:
                    param_need_sync.append(p)
        program._pipeline_opt = {
            "trainer": "PipelineTrainer",
            "device_worker": "Section",
            "section_program_list": program_list,
            "place_list": self._place_list,
            "concurrency_list": self._concurrency_list,
            "queue_size": self._queue_size,
            "start_cpu_core_id": self._start_cpu_core_id,
            "sync_steps": self._sync_steps,
            "param_need_sync": param_need_sync
        }
M
mapingshuo 已提交
3740 3741


M
mapingshuo 已提交
3742 3743 3744 3745 3746 3747 3748 3749 3750 3751 3752 3753 3754 3755 3756 3757 3758 3759 3760 3761 3762 3763 3764 3765 3766 3767 3768 3769 3770 3771 3772 3773 3774 3775 3776 3777 3778 3779 3780 3781 3782 3783 3784 3785 3786 3787 3788 3789 3790 3791 3792 3793 3794 3795 3796 3797 3798 3799 3800 3801 3802 3803
class RecomputeOptimizer(Optimizer):
    """
    Recompute Optimizer Wrapper

    Normally, a training step contains three sub-steps: first, run forward
    Operators to calculate the loss; second, run backward Operators to 
    calculate gradient of the parameters; third, apply optimization method
    to update the value of the parameters.

    In the forward computation process, all variables that are needed by 
    backward computation process will be kept in memory, which occupy a great
    amount of memory when the network becomes very deep.

    Recompute split the network to k segments. In each segment, It will 
    recompute the forward Operators, before running backward operators. It is
    very helpful for saving memory.
 
    The Variables that separate a network to segments are called as checkpoints,
    and users should set it manually. The usage is very simple:

    Args:
        optimizer (Optimizer): The optimizer that is applied to parameters.

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid
            import numpy as np
            def gen_data():
                return {"x": np.random.random(size=(32, 32)).astype('float32'),
                "y": np.random.randint(2, size=(32, 1)).astype('int64')}
            def mlp(input_x, input_y, hid_dim=128, label_dim=2):
                print(input_x)
                fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
                prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
                cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
                sum_cost = fluid.layers.reduce_mean(cost)
                return sum_cost, fc_1, prediction
            input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
            input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
            cost, fc_1, pred = mlp(input_x, input_y)

            sgd = fluid.optimizer.Adam(learning_rate=0.01)
            sgd = fluid.optimizer.RecomputeOptimizer(sgd)
            sgd._set_checkpoints([fc_1, pred])
            sgd.minimize(cost)

            print("Finished optimize")
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            exe.run(fluid.default_startup_program())
            step = 10

            for i in range(step):
                cost_val = exe.run(feed=gen_data(),
                       program=fluid.default_main_program(),
                       fetch_list=[cost.name])
                print("step=%d cost=%f" % (i, cost_val[0]))

    """

    def __init__(self, optimizer):
Z
zhongpu 已提交
3804 3805
        if framework.in_dygraph_mode():
            raise Exception("In dygraph, don't support RecomputeOptimizer.")
M
mapingshuo 已提交
3806 3807 3808 3809 3810 3811 3812 3813 3814 3815 3816 3817 3818 3819 3820 3821 3822 3823 3824 3825 3826 3827 3828 3829 3830 3831 3832 3833 3834 3835 3836 3837 3838 3839 3840 3841 3842 3843 3844 3845 3846 3847 3848 3849 3850 3851 3852 3853 3854 3855 3856 3857 3858 3859 3860 3861 3862 3863 3864 3865 3866 3867 3868 3869 3870 3871 3872 3873 3874 3875 3876 3877 3878 3879 3880
        self._optimizer = optimizer
        self._checkpoints = None

    def _set_checkpoints(self, checkpoints):
        self._checkpoints = checkpoints

    def load(self, stat_dict):
        """
        load function is not supported by Recompute Optimizer for now.
        :return: None

        Args:
            stat_dict: the dict load by load_persistable method

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import paddle.compat as cpt
                
                def mlp(input_x, input_y, hid_dim=128, label_dim=2):
                    fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
                    prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
                    cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
                    sum_cost = fluid.layers.reduce_mean(cost)
                    return sum_cost, fc_1, prediction
                
                input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
                input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
                cost, fc_1, pred = mlp(input_x, input_y)
                print("Finished FF")
                
                sgd = fluid.optimizer.Adam(learning_rate=0.01)
                sgd = fluid.optimizer.RecomputeOptimizer(sgd)
                sgd._set_checkpoints([fc_1, pred])
                try:
                    stat_dict = {}
                    sgd.load(stat_dict)
                except NotImplementedError as e:
                    print(cpt.get_exception_message(e))
        """
        raise NotImplementedError(
            "load function is not supported by Recompute Optimizer for now")

    def apply_gradients(self, params_grads):
        """
        call apply_gradients function of self._optimizer.

        Args:
            params_grads (list): list of (param, grad) pair to do optimization.

        Returns:
            list: A list of operators appended to the current program.

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import paddle.fluid.framework as framework

                def mlp(input_x, input_y, hid_dim=128, label_dim=2):
                    fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
                    prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
                    cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
                    sum_cost = fluid.layers.reduce_mean(cost)
                    return sum_cost, fc_1, prediction


                input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
                input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
                cost, fc_1, pred = mlp(input_x, input_y)
                print("Finished FF")

                sgd = fluid.optimizer.Adam(learning_rate=0.01)
                sgd = fluid.optimizer.RecomputeOptimizer(sgd)
3881
                sgd._set_checkpoints([fc_1, pred])
M
mapingshuo 已提交
3882 3883 3884 3885
                params_grads = sgd.backward(
                    cost,
                    startup_program=None,
                    parameter_list=None,
3886
                    no_grad_set=None)
M
mapingshuo 已提交
3887 3888 3889 3890 3891 3892 3893 3894 3895 3896 3897 3898 3899 3900 3901

                program = cost.block.program
                with framework.program_guard(program, None):
                    optimize_ops = sgd.apply_gradients(params_grads)

                print("Finished apply gradients")
        """

        return self._optimizer.apply_gradients(params_grads=params_grads)

    def backward(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None,
3902
                 callbacks=None):
M
mapingshuo 已提交
3903 3904 3905 3906 3907 3908 3909
        """
        call append_backward with checkpoints.

        Args:
            loss (Variable): loss variable to run optimizations.
            startup_program (Program): startup_program for initializing parameters
                in `parameter_list`.
3910 3911
            parameter_list (list): list of Variables or Variable.names to update.
            no_grad_set (set|None): set of Variables or Variables.names should be ignored.
M
mapingshuo 已提交
3912 3913 3914 3915 3916 3917 3918 3919 3920 3921 3922 3923 3924 3925 3926 3927 3928 3929 3930 3931 3932 3933 3934 3935
            callbacks (list|None): list of callables to run when appending backward
                operator for one parameter.
            checkpoints (list): list of Variables as checkpoints

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
    
                def mlp(input_x, input_y, hid_dim=128, label_dim=2):
                    fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
                    prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
                    cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
                    sum_cost = fluid.layers.reduce_mean(cost)
                    return sum_cost, fc_1, prediction
    
    
                input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
                input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
                cost, fc_1, pred = mlp(input_x, input_y)
                print("Finished FF")
    
                sgd = fluid.optimizer.Adam(learning_rate=0.01)
                sgd = fluid.optimizer.RecomputeOptimizer(sgd)
3936
                sgd._set_checkpoints([fc_1, pred])
M
mapingshuo 已提交
3937 3938 3939 3940
                params_grads = sgd.backward(
                    cost,
                    startup_program=None,
                    parameter_list=None,
3941
                    no_grad_set=None)
M
mapingshuo 已提交
3942 3943 3944 3945 3946 3947 3948 3949 3950 3951 3952 3953 3954 3955 3956
                print("Finished backward")
        """

        if framework.in_dygraph_mode():
            raise NotImplementedError(
                "DyGraph current does not support recompute")

        self._dtype = loss.dtype
        program = loss.block.program
        with program_guard(program, startup_program):
            params_grads = append_backward(
                loss,
                parameter_list,
                no_grad_set,
                checkpoints=self._checkpoints)
3957 3958 3959
            # Note: since we can't use all_reduce_op now,
            #  dgc_op should be the last op of one grad.
            self._optimizer._append_dgc_ops(params_grads)
M
mapingshuo 已提交
3960 3961 3962 3963 3964 3965 3966 3967 3968 3969 3970 3971 3972 3973 3974 3975 3976 3977 3978
        return params_grads

    def apply_optimize(self, loss, startup_program, params_grads):
        """
        call the apply_optimize function of self._optimizer
        Args:
            loss (Variable): loss variable to run optimizations.
            startup_program (Program): startup_program for initializing parameters
                in `parameter_list`.
            params_grads (list): list of (param, grad) pair to do optimization.
        Examples:
            .. code-block:: python
                import paddle.fluid as fluid
                
                def mlp(input_x, input_y, hid_dim=128, label_dim=2):
                    fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
                    prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
                    cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
                    sum_cost = fluid.layers.reduce_mean(cost)
M
mapingshuo 已提交
3979
                    return sum_cost, fc_1, prediction                
M
mapingshuo 已提交
3980 3981 3982 3983 3984 3985 3986 3987
                
                input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
                input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
                cost, fc_1, pred = mlp(input_x, input_y)
                print("Finished FF")
                
                sgd = fluid.optimizer.Adam(learning_rate=0.01)
                sgd = fluid.optimizer.RecomputeOptimizer(sgd)
3988
                sgd._set_checkpoints([fc_1, pred])
M
mapingshuo 已提交
3989 3990 3991 3992
                params_grads = sgd.backward(
                    cost,
                    startup_program=None,
                    parameter_list=None,
3993
                    no_grad_set=None)
M
mapingshuo 已提交
3994 3995 3996 3997 3998 3999 4000 4001 4002 4003 4004 4005 4006 4007 4008 4009
                
                optimize_ops = sgd.apply_optimize(
                    cost, startup_program=None, params_grads=params_grads)
                
                print("Finished apply_optimize")
        """

        return self._optimizer.apply_optimize(
            loss, startup_program=startup_program, params_grads=params_grads)

    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None,
                 grad_clip=None):
4010
        assert isinstance(loss, Variable), "The loss should be an Variable."
M
mapingshuo 已提交
4011 4012 4013 4014 4015
        assert (self._checkpoints is not None
                ), "You should call _set_checkpoints first"
        if framework.in_dygraph_mode():
            raise NotImplementedError(
                "DyGraph current does not support recompute")
4016 4017 4018 4019 4020 4021
        if grad_clip is not None:
            if not isinstance(grad_clip, GradientClipBase):
                raise TypeError(
                    "'grad_clip' should be an instance of GradientClipBase's derived class"
                )
            self._optimizer._grad_clip = grad_clip
M
mapingshuo 已提交
4022 4023 4024 4025
        params_grads = self.backward(
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
4026
            no_grad_set=no_grad_set)
M
mapingshuo 已提交
4027 4028 4029 4030 4031 4032 4033

        optimize_ops = self.apply_optimize(
            loss, startup_program=startup_program, params_grads=params_grads)

        return optimize_ops, params_grads


M
mapingshuo 已提交
4034 4035 4036 4037 4038 4039 4040 4041 4042 4043 4044 4045 4046 4047 4048 4049 4050 4051 4052 4053 4054 4055 4056 4057 4058 4059 4060 4061 4062 4063 4064 4065 4066 4067 4068 4069 4070 4071 4072 4073 4074 4075 4076 4077 4078 4079 4080 4081 4082 4083 4084 4085 4086 4087 4088
class LookaheadOptimizer(object):
    """
    This implements the Lookahead optimizer of the
    paper : https://arxiv.org/abs/1907.08610.

    Lookahead keeps two sets of params: the fast_params and
    the slow_params. inner_optimizer update fast_params every 
    training step. Lookahead updates the slow_params and fast_params 
    every k training steps as follows:

    .. math::
        
        slow\_param_t &= slow\_param_{t-1} + \\alpha * (fast\_param_{t-1} - slow\_param_{t-1})
	
	fast\_param_t &=  slow\_param_t

    Args:
        inner_optimizer (Optimizer): The optimizer that update fast params step by step. 
        alpha (float): The learning rate of Lookahead.
        k (int): The slow params is updated every k steps.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.fluid as fluid
            import numpy as np

	    x = fluid.layers.data(name='x', shape=[2], dtype='float32')
	    label = fluid.layers.data(name="label", shape=[1], dtype="int64")
	    y = fluid.layers.fc(input=[x], size=2, act="softmax")
	    loss = fluid.layers.cross_entropy(input=y, label=label)
	    loss = fluid.layers.mean(x=loss)
	    sgd = fluid.optimizer.SGD(learning_rate=0.01)
	    optimizer = fluid.optimizer.LookaheadOptimizer(sgd,
                                            alpha=0.5,
                                            k=5)
	    optimizer.minimize(loss)
	    main_program = fluid.default_main_program()
	    place = fluid.CPUPlace()
	    exe = fluid.Executor(place)
	    exe.run(fluid.default_startup_program())

	    feeder = fluid.DataFeeder(feed_list=[x, label], place=place)

	    step = 0
            while(step < 10):
                step += 1
		exe.run(fluid.default_main_program(),
            	feed=feeder.feed(batch_data))

    """

    def __init__(self, inner_optimizer, alpha=0.5, k=5):

Z
zhongpu 已提交
4089 4090
        if framework.in_dygraph_mode():
            raise Exception("In dygraph, don't support LookaheadOptimizer.")
M
mapingshuo 已提交
4091 4092 4093 4094 4095 4096 4097 4098 4099 4100 4101 4102 4103 4104 4105 4106 4107 4108 4109 4110 4111 4112 4113 4114 4115 4116 4117 4118 4119 4120 4121 4122 4123 4124 4125 4126 4127 4128 4129 4130 4131 4132 4133 4134 4135 4136 4137 4138 4139 4140 4141 4142 4143 4144 4145 4146 4147 4148 4149 4150 4151 4152 4153 4154 4155 4156 4157 4158 4159 4160 4161 4162 4163 4164 4165 4166 4167 4168 4169 4170 4171 4172 4173 4174 4175 4176 4177 4178 4179 4180 4181 4182 4183 4184 4185 4186
        assert (inner_optimizer is not None), "inner optimizer can not be None"
        assert (
            0.0 <= alpha <= 1.0
        ), "alpha should be larger or equal to 0.0, and less or equal than 1.0"
        assert (isinstance(k, int) and k > 0), "k should be a positive integer"

        self.inner_optimizer = inner_optimizer
        self.alpha = alpha
        self.k = k
        self.type = "lookahead"

    def minimize(self, loss, startup_program=None):

        # Apply inner optimizer to the main_program
        mini_out = self.inner_optimizer.minimize(
            loss, startup_program=startup_program)

        # Get startup_program and main_program
        if startup_program is None:
            startup_program = default_startup_program()
        main_block = loss.block

        # add some vars to the main_program
        params = [param.name for param in main_block.all_parameters()]
        param_to_slow = {}
        for param in params:
            fast_var = main_block.var(param)
            assert (fast_var is not None)
            slow_var = main_block.create_var(
                name=param + "@SLOW",
                shape=fast_var.shape,
                dtype=fast_var.dtype,
                persistable=True)
            param_to_slow[param] = slow_var

        # add some vars to the startup_program
        startup_block = startup_program.global_block()
        for param in params:
            fast_var = startup_block.var(param)
            assert (fast_var is not None)
            slow_var = startup_block.create_var(
                name=param + "@SLOW",
                shape=fast_var.shape,
                dtype=fast_var.dtype,
                persistable=True)

            startup_block.append_op(
                type="assign",
                inputs={"X": fast_var},
                outputs={"Out": slow_var})

        # Add Var k to main prog and startup prog
        k = layers.create_global_var(
            name="lookahead_k",
            shape=[1],
            value=int(self.k),
            dtype='int32',
            persistable=True)

        # Add Var alpha to main prog and startup prog
        alpha = layers.create_global_var(
            name="lookahead_alpha",
            shape=[1],
            value=float(self.alpha),
            dtype='float32',
            persistable=True)

        # Add Var step
        step = layers.create_global_var(
            name="lookahead_step",
            shape=[1],
            value=int(0),
            dtype='int32',
            persistable=True)
        layers.increment(x=step, value=1.0, in_place=True)

        # lookahead
        zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0)

        one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0)

        mod = layers.elementwise_mod(step, k)
        with layers.control_flow.Switch() as switch:
            with switch.case(mod == zero_var):
                for param_name in params:
                    fast_var = main_block.var(param_name)
                    slow_var = param_to_slow[param_name]
                    tmp_var = layers.elementwise_add(
                        layers.elementwise_mul(fast_var, alpha),
                        layers.elementwise_mul(
                            slow_var, layers.elementwise_sub(one_var, alpha)))
                    layers.assign(input=tmp_var, output=slow_var)
                    layers.assign(input=tmp_var, output=fast_var)
            with switch.default():
                pass
        return mini_out