optimizer.py 45.6 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
W
wanghaoshuang 已提交
16
import re
17
from collections import defaultdict
18
from paddle.fluid.framework import Program, Variable, name_scope
19 20 21 22 23 24 25 26 27
from . import framework
from . import layers
from .backward import append_backward
from .framework import program_guard
from . import unique_name
from .initializer import Constant
from .layer_helper import LayerHelper
from .regularizer import append_regularization_ops
from .clip import append_gradient_clip_ops, error_clip_callback
28
from contextlib import contextmanager
29

30
__all__ = [
Q
qiaolongfei 已提交
31
    'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
32
    'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
W
weixing02 已提交
33
    'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
Y
yuyang18 已提交
34
    'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer'
35
]
Q
Qiao Longfei 已提交
36 37 38 39 40 41


class Optimizer(object):
    """Optimizer Base class.

    Define the common interface of an optimizer.
42 43
    User should not use this class directly,
    but need to use one of it's implementation.
Q
Qiao Longfei 已提交
44 45
    """

W
Wu Yi 已提交
46 47 48
    def __init__(self,
                 learning_rate,
                 regularization=None,
W
whs 已提交
49 50
                 LARS_weight_decay=0.0,
                 name=None):
51 52
        if not isinstance(learning_rate, float) and \
                not isinstance(learning_rate, framework.Variable):
Q
qiaolongfei 已提交
53
            raise TypeError("learning rate should be float or Variable")
W
whs 已提交
54
        self._name = name
D
dzhwinter 已提交
55
        self.regularization = regularization
56
        self._learning_rate = learning_rate
D
dzhwinter 已提交
57 58
        # the learning rate type should be inferenced from loss
        self._dtype = None
59 60
        # each program should have a independent learning rate
        # program -> Variable(learning_rate)
Q
qiaolongfei 已提交
61
        self._learning_rate_map = dict()
62 63 64
        if isinstance(self._learning_rate, framework.Variable):
            self._learning_rate_map[framework.default_main_program(
            )] = self._learning_rate
65 66 67 68 69
        # Dictionary of accumulators. Some optimizer subclasses need to
        # allocate and manage extra variables associated with the parameters
        # to train. These variables are called accumulators.
        # {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...}
        self._accumulators = defaultdict(lambda: dict())
Q
Qiao Longfei 已提交
70
        self.helper = None
W
Wu Yi 已提交
71
        self._LARS_weight_decay = LARS_weight_decay
Q
Qiao Longfei 已提交
72

Q
Qiao Longfei 已提交
73
    def _create_global_learning_rate(self):
Y
yuyang18 已提交
74
        lr = self._global_learning_rate()
Q
Qiao Longfei 已提交
75

76 77 78 79
        if isinstance(lr, framework.Variable):
            return
        else:
            if not isinstance(self._learning_rate, float):
Q
qiaolongfei 已提交
80
                raise TypeError(
81 82
                    "learning rate variable is create outside optimizer,"
                    "can not create new learning rate variable for new program")
Q
Qiao Longfei 已提交
83

84 85 86 87 88 89
        # create learning rate in the current main program
        self._learning_rate_map[framework.default_main_program(
        )] = layers.create_global_var(
            name=unique_name.generate("learning_rate"),
            shape=[1],
            value=float(self._learning_rate),
D
dzhwinter 已提交
90
            dtype='float32' if self._dtype == None else self._dtype,
91 92
            persistable=True)

Y
yuyang18 已提交
93
    def _global_learning_rate(self, program=None):
Q
Qiao Longfei 已提交
94 95 96 97
        """
        get global decayed learning rate
        :return:
        """
98 99
        if program is None:
            program = framework.default_main_program()
Q
qiaolongfei 已提交
100
        return self._learning_rate_map.get(program, None)
Q
Qiao Longfei 已提交
101

Q
Qiao Longfei 已提交
102 103 104 105 106
    def _append_optimize_op(self, block, param_and_grad):
        """ append optimize operator to block and return all the added optimize_op
        """
        raise NotImplementedError()

107 108 109 110
    def _create_param_lr(self, param_and_grad):
        # create learning rate variable for every parameter
        param = param_and_grad[0]
        param_lr = param.optimize_attr['learning_rate']
W
Wu Yi 已提交
111 112
        if type(param_lr) == Variable:
            # param learning rate has been updated (LARS)
113
            print("returns updated param lr ", param_lr)
W
Wu Yi 已提交
114
            return param_lr
Q
qiaolongfei 已提交
115
        else:
W
Wu Yi 已提交
116
            if param_lr == 1.0:
Y
yuyang18 已提交
117
                return self._global_learning_rate()
W
Wu Yi 已提交
118
            else:
Y
yuyang18 已提交
119
                return self._global_learning_rate() * param_lr
120 121 122 123 124 125 126

    def _create_accumulators(self, block, parameters):
        """Create all accumulators needed by the parameters

        Args:
            block: the block in which the loss variable is present
            parameters: list of parameter variables for the optimizer
Q
Qiao Longfei 已提交
127
        """
128 129
        pass

130
    def _finish_update(self, block, parameters_and_grads):
131 132 133 134 135 136 137 138
        """Finish any custom updates needed
           before completing an optimization step

        Args:
            block: the block in which the loss variable is present
            parameters: list of parameter variables for the optimizer

        Returns:
Q
qiaolongfei 已提交
139
            None
140 141 142
        """
        pass

143 144 145 146 147 148
    def _add_accumulator(self,
                         name,
                         param,
                         dtype=None,
                         fill_value=0.0,
                         shape=None):
149 150 151 152 153 154 155 156 157
        """Utility function to add an accumulator for a parameter

        Args:
            block: the block in which the loss variable is present
            name: name of the accumulator
            param: parameter variable for which accumulator is to be added
            dtype: data type of the accumulator variable
            fill_value: value to initialize the accumulator variable
        """
W
whs 已提交
158 159
        if self._name is not None:
            name = self._name + "_" + name
160 161
        if (name in self._accumulators and
                param.name in self._accumulators[name]):
162
            raise Exception("Accumulator {} already exists for parameter {}".
163
                            format(name, param.name))
164 165
        if shape == None:
            shape = param.shape
Q
Qiao Longfei 已提交
166 167
        assert isinstance(self.helper, LayerHelper)
        var = self.helper.create_global_variable(
Y
Yu Yang 已提交
168
            name=unique_name.generate(name),
Q
Qiao Longfei 已提交
169
            persistable=True,
F
fengjiayi 已提交
170
            dtype=dtype or param.dtype,
Q
Qiao Longfei 已提交
171
            type=param.type,
172
            shape=shape)
Q
Qiao Longfei 已提交
173
        self.helper.set_variable_initializer(
174
            var, initializer=Constant(value=float(fill_value)))
Q
Qiao Longfei 已提交
175
        self._accumulators[name][param.name] = var
176
        return var
177 178 179 180 181 182 183 184 185 186 187

    def _get_accumulator(self, name, param):
        """Utility function to fetch an accumulator for a parameter

        Args:
            name: name of the accumulator
            param: parameter variable for which accumulator is to be fetched

        Returns:
            accumulator variable for the parameter
        """
W
whs 已提交
188 189
        if self._name is not None:
            name = self._name + "_" + name
190 191 192 193 194 195
        if (name not in self._accumulators or
                param.name not in self._accumulators[name]):
            raise Exception("Accumulator {} does not exist for parameter {}".
                            format(name, param.name))
        return self._accumulators[name][param.name]

Y
yuyang18 已提交
196 197 198 199
    def _create_optimization_pass(self,
                                  parameters_and_grads,
                                  loss,
                                  startup_program=None):
Q
Qiao Longfei 已提交
200 201 202
        """Add optimization operators to update gradients to variables.

        Args:
Q
qiaolongfei 已提交
203 204 205
          loss(Variable): the target that this optimization is for.
          parameters_and_grads(list(tuple(Variable, Variable))):
          a list of (variable, gradient) pair to update.
Q
Qiao Longfei 已提交
206 207

        Returns:
208 209 210 211
          return_op_list: a list of operators that will complete one step of
          optimization. This will include parameter update ops, global step
          update ops and any other custom ops required by subclasses to manage
          their internal state.
Q
Qiao Longfei 已提交
212
        """
213 214 215 216 217
        # This is a default implementation of create_optimization_pass that
        # can be shared by most optimizers. This implementation assumes that
        # the subclass will implement the _append_optimize_op method and the
        #  _initialize_tensors method. The subclass can extend the
        # _create_accumulators method if it needs to create accumulators
218
        # for parameters and extend _finish_update method to add custom ops.
219 220

        # Create any accumulators
Q
Qiao Longfei 已提交
221
        program = loss.block.program
D
dzhwinter 已提交
222
        self._dtype = loss.dtype
223
        with program_guard(program, startup_program):
Y
Yancey1989 已提交
224 225
            global_block = framework.default_main_program().global_block()
            start = len(global_block.ops)
226 227 228
            self.helper = LayerHelper(self.__class__.__name__)
            self._create_accumulators(loss.block,
                                      [p[0] for p in parameters_and_grads])
Q
Qiao Longfei 已提交
229
            self._create_global_learning_rate()
W
Wu Yi 已提交
230 231
            if self._LARS_weight_decay > 0.0:
                layers.append_LARS(parameters_and_grads,
Y
yuyang18 已提交
232
                                   self._global_learning_rate(),
W
Wu Yi 已提交
233
                                   self._LARS_weight_decay)
234 235 236

            optimize_ops = []
            for param_and_grad in parameters_and_grads:
237 238
                if param_and_grad[1] is None:
                    continue
Y
yuyang18 已提交
239
                with param_and_grad[0].block.program.optimized_guard(
240
                        param_and_grad), name_scope("optimizer"):
241
                    if param_and_grad[0].trainable is True:
Y
yuyang18 已提交
242 243 244
                        optimize_op = self._append_optimize_op(loss.block,
                                                               param_and_grad)
                        optimize_ops.append(optimize_op)
245 246 247

            # Get custom finish ops for subclasses
            # FIXME: Need to fix this once we figure out how to handle dependencies
248
            self._finish_update(loss.block, parameters_and_grads)
249

Y
Yancey1989 已提交
250
            end = len(global_block.ops)
W
Wu Yi 已提交
251
            return global_block._slice_ops(start, end)
Q
Qiao Longfei 已提交
252

Q
Qiao Longfei 已提交
253 254
    def minimize(self,
                 loss,
255
                 startup_program=None,
Q
Qiao Longfei 已提交
256 257
                 parameter_list=None,
                 no_grad_set=None):
Q
Qiao Longfei 已提交
258 259
        """Add operations to minimize `loss` by updating `parameter_list`.

F
fengjiayi 已提交
260
        This method combines interface `append_backward()` and
Q
Qiao Longfei 已提交
261 262
        `create_optimization_pass()` into one.
        """
F
fengjiayi 已提交
263
        params_grads = append_backward(loss, parameter_list, no_grad_set,
Y
Yang Yang 已提交
264
                                       [error_clip_callback])
Y
Yu Yang 已提交
265

Y
Yu Yang 已提交
266 267
        params_grads = sorted(params_grads, key=lambda x: x[0].name)

Y
Yu Yang 已提交
268 269
        params_grads = append_gradient_clip_ops(params_grads)

F
fengjiayi 已提交
270
        # Add regularization if any
D
dzhwinter 已提交
271 272
        params_grads = append_regularization_ops(params_grads,
                                                 self.regularization)
Y
Yu Yang 已提交
273

Y
yuyang18 已提交
274 275
        optimize_ops = self._create_optimization_pass(params_grads, loss,
                                                      startup_program)
T
typhoonzero 已提交
276
        return optimize_ops, params_grads
Q
Qiao Longfei 已提交
277 278 279


class SGDOptimizer(Optimizer):
Q
qiaolongfei 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292 293
    """
    Optimizer of the stochastic gradient descent algorithm.

    .. math::

        param\_out = param - learning\_rate * grad

    Args:
        learning_rate (float|Variable): the learning rate used to update parameters. \
        Can be a float value or a Variable with one float value as data element.

    Examples:
        .. code-block:: python

Q
qiaolongfei 已提交
294
            sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.2)
Q
qiaolongfei 已提交
295
            sgd_optimizer.minimize(cost)
Q
Qiao Longfei 已提交
296 297
    """

D
dzhwinter 已提交
298
    def __init__(self, learning_rate, **kwargs):
Q
Qiao Longfei 已提交
299
        assert learning_rate is not None
Q
Qiao Longfei 已提交
300 301
        super(SGDOptimizer, self).__init__(
            learning_rate=learning_rate, **kwargs)
Q
Qiao Longfei 已提交
302 303
        self.type = "sgd"

304 305
    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)
306

Q
Qiao Longfei 已提交
307 308 309 310 311 312
        # create the optimize op
        sgd_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
313
                "LearningRate": self._create_param_lr(param_and_grad)
Q
Qiao Longfei 已提交
314
            },
315
            outputs={"ParamOut": param_and_grad[0]})
Q
Qiao Longfei 已提交
316 317

        return sgd_op
318 319 320


class MomentumOptimizer(Optimizer):
Q
qiaolongfei 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334
    """

    Simple Momentum optimizer with velocity state

    This optimizer has a flag for Nestrov Momentum.

    The update equations are as follows:

    .. math::

        & velocity = mu * velocity + gradient

        & if (use\_nesterov):

335
        &\quad   param = param - (gradient + mu * velocity) * learning\_rate
Q
qiaolongfei 已提交
336 337 338

        & else:

Q
qiaolongfei 已提交
339
        &\quad   param = param - learning\_rate * velocity
Q
qiaolongfei 已提交
340 341 342 343 344 345 346 347 348 349

    Args:
        learning_rate (float|Variable): the learning rate used to update parameters. \
        Can be a float value or a Variable with one float value as data element.
        momentum (float): momentum factor
        use_nesterov (bool): enables Nesterov momentum

    Examples:
        .. code-block:: python

Q
qiaolongfei 已提交
350
            optimizer = fluid.optimizer.Momentum(learning_rate=0.2, momentum=0.1)
Q
qiaolongfei 已提交
351
            optimizer.minimize(cost)
352 353 354
    """
    _velocity_acc_str = "velocity"

D
dzhwinter 已提交
355
    def __init__(self, learning_rate, momentum, use_nesterov=False, **kwargs):
356 357
        assert learning_rate is not None
        assert momentum is not None
Q
Qiao Longfei 已提交
358 359
        super(MomentumOptimizer, self).__init__(
            learning_rate=learning_rate, **kwargs)
360 361
        self.type = "momentum"
        self._momentum = momentum
362
        self._use_nesterov = bool(use_nesterov)
363 364 365 366 367

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        for p in parameters:
Q
Qiao Longfei 已提交
368
            self._add_accumulator(self._velocity_acc_str, p)
369 370 371 372 373 374 375 376 377 378 379 380 381

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        velocity_acc = self._get_accumulator(self._velocity_acc_str,
                                             param_and_grad[0])
        # create the momentum optimize op
        momentum_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "Velocity": velocity_acc,
382
                "LearningRate": self._create_param_lr(param_and_grad)
383 384 385 386 387
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "VelocityOut": velocity_acc
            },
388
            attrs={"mu": self._momentum,
389
                   "use_nesterov": self._use_nesterov})
390 391

        return momentum_op
392 393 394


class AdagradOptimizer(Optimizer):
Q
qiaolongfei 已提交
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
    """
    **Adaptive Gradient Algorithm (Adagrad)**

    The update is done as follows:

    .. math::

        moment\_out &= moment + grad * grad

        param\_out &= param - \\frac{learning\_rate * grad}{\sqrt{moment\_out} + \epsilon}

    The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
    does not have the epsilon attribute. It is added here in our implementation
    as also proposed here: http://cs231n.github.io/neural-networks-3/#ada
    for numerical stability to avoid the division by zero error.

    Args:
        learning_rate (float|Variable): the learning rate used to update parameters. \
        Can be a float value or a Variable with one float value as data element.
        epsilon (float): a small float value for numerical stability.

    Examples:
        .. code-block:: python

            optimizer = fluid.optimizer.Adagrad(learning_rate=0.2)
            optimizer.minimize(cost)
421 422 423
    """
    _moment_acc_str = "moment"

D
dzhwinter 已提交
424
    def __init__(self, learning_rate, epsilon=1.0e-6, **kwargs):
425 426
        assert learning_rate is not None
        assert epsilon is not None
Q
Qiao Longfei 已提交
427 428
        super(AdagradOptimizer, self).__init__(
            learning_rate=learning_rate, **kwargs)
429 430 431 432 433 434 435
        self.type = "adagrad"
        self._epsilon = epsilon

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        for p in parameters:
Q
Qiao Longfei 已提交
436
            self._add_accumulator(self._moment_acc_str, p)
437 438 439 440 441 442 443

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        moment_acc = self._get_accumulator(self._moment_acc_str,
                                           param_and_grad[0])

444
        # Create the adagrad optimizer op
445 446 447 448 449 450
        adagrad_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "Moment": moment_acc,
451
                "LearningRate": self._create_param_lr(param_and_grad)
452 453 454 455 456 457
            },
            outputs={"ParamOut": param_and_grad[0],
                     "MomentOut": moment_acc},
            attrs={"epsilon": self._epsilon})

        return adagrad_op
458 459 460


class AdamOptimizer(Optimizer):
Q
qiaolongfei 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
    """
    This implements the Adam optimizer from Section 2 of the Adam
    paper : https://arxiv.org/abs/1412.6980.
    Adam is a first-order gradient-based optimization method based on
    adaptive estimates of lower-order moments.

    Adam updates:

    .. math::

        t & = t + 1

        moment\_1\_out & = {\\beta}_1 * moment\_1 + (1 - {\\beta}_1) * grad

        moment\_2\_out & = {\\beta}_2 * moment\_2 + (1 - {\\beta}_2) * grad * grad

        learning\_rate & = learning\_rate * \\
                          \\frac{\sqrt{1 - {\\beta}_2^t}}{1 - {\\beta}_1^t}

        param\_out & = param - learning\_rate * \\frac{moment\_1}{\sqrt{moment\_2} + \epsilon}

    Args:
        learning_rate (float|Variable): the learning rate used to update parameters. \
        Can be a float value or a Variable with one float value as data element.
        beta1 (float): The exponential decay rate for the 1st moment estimates.
        beta2 (float): The exponential decay rate for the 2nd moment estimates.
        epsilon (float): a small float value for numerical stability.

    Examples:
        .. code-block:: python

            optimizer = fluid.optimizer.Adam(learning_rate=0.2)
            optimizer.minimize(cost)

495 496 497
    """
    _moment1_acc_str = "moment1"
    _moment2_acc_str = "moment2"
Q
qiaolongfei 已提交
498 499
    _beta1_pow_acc_str = "beta1_pow_acc"
    _beta2_pow_acc_str = "beta2_pow_acc"
500 501 502 503 504

    def __init__(self,
                 learning_rate=0.001,
                 beta1=0.9,
                 beta2=0.999,
505
                 epsilon=1e-8,
D
dzhwinter 已提交
506
                 **kwargs):
507 508 509 510
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
Q
Qiao Longfei 已提交
511 512
        super(AdamOptimizer, self).__init__(
            learning_rate=learning_rate, **kwargs)
513 514 515 516 517 518 519 520 521 522
        self.type = "adam"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        # Create accumulator tensors for first and second moments
        for p in parameters:
Q
Qiao Longfei 已提交
523 524
            self._add_accumulator(self._moment1_acc_str, p)
            self._add_accumulator(self._moment2_acc_str, p)
Q
qiaolongfei 已提交
525 526 527 528 529 530 531 532 533 534 535 536
            self._add_accumulator(
                name=self._beta1_pow_acc_str,
                param=p,
                dtype='float32',
                fill_value=self._beta1,
                shape=[1])
            self._add_accumulator(
                name=self._beta2_pow_acc_str,
                param=p,
                dtype='float32',
                fill_value=self._beta2,
                shape=[1])
537 538 539 540 541 542 543 544

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        moment1 = self._get_accumulator(self._moment1_acc_str,
                                        param_and_grad[0])
        moment2 = self._get_accumulator(self._moment2_acc_str,
                                        param_and_grad[0])
Q
qiaolongfei 已提交
545 546 547 548 549
        beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                              param_and_grad[0])
        beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
                                              param_and_grad[0])

550
        # create the adam optimize op
551 552 553 554 555
        adam_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
556
                "LearningRate": self._create_param_lr(param_and_grad),
557 558
                "Moment1": moment1,
                "Moment2": moment2,
Q
qiaolongfei 已提交
559 560
                "Beta1Pow": beta1_pow_acc,
                "Beta2Pow": beta2_pow_acc
561 562 563 564 565 566 567 568 569 570 571 572 573 574
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "Moment1Out": moment1,
                "Moment2Out": moment2
            },
            attrs={
                "beta1": self._beta1,
                "beta2": self._beta2,
                "epsilon": self._epsilon
            })

        return adam_op

575
    def _finish_update(self, block, param_and_grads):
576 577 578
        """Update Beta1 and Beta2 Power accumulators
        """
        assert isinstance(block, framework.Block)
Q
Qiao Longfei 已提交
579
        main_block = block.program.global_block()
580 581 582 583
        for param, grad in param_and_grads:
            if grad is None:
                continue
            with param.block.program.optimized_guard([param, grad]):
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
                beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                                      param)
                beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
                                                      param)
                main_block.append_op(
                    type="scale",
                    inputs={"X": beta1_pow_acc},
                    outputs={"Out": beta1_pow_acc},
                    attrs={"scale": self._beta1})

                main_block.append_op(
                    type="scale",
                    inputs={"X": beta2_pow_acc},
                    outputs={"Out": beta2_pow_acc},
                    attrs={"scale": self._beta2})
599 600 601


class AdamaxOptimizer(Optimizer):
Q
qiaolongfei 已提交
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
    """
    We implement the Adamax optimizer from Section 7 of the Adam
    paper: https://arxiv.org/abs/1412.6980. Adamax is a variant of the
    Adam algorithm based on the infinity norm.

    Adamax updates:

    .. math::

        t & = t + 1

        moment\_out & = {\\beta}_1 * moment + (1 - {\\beta}_1) * grad

        inf\_norm\_out & = max({\\beta}_2 * inf\_norm + \epsilon, |grad|)

        learning\_rate & = \\frac{learning\_rate}{1 - {\\beta}_1^t}

        param\_out & = param - learning\_rate * \\frac{moment\_out}{inf\_norm\_out}


    The original paper does not have an epsilon attribute.
    However, it is added here for numerical stability to prevent the
    division by 0 error.

    Args:
        learning_rate (float|Variable): the learning rate used to update parameters. \
        Can be a float value or a Variable with one float value as data element.
        beta1 (float): The exponential decay rate for the 1st moment estimates.
        beta2 (float): The exponential decay rate for the 2nd moment estimates.
        epsilon (float): a small float value for numerical stability.

    Examples:
        .. code-block:: python

            optimizer = fluid.optimizer.Adamax(learning_rate=0.2)
            optimizer.minimize(cost)
638 639 640
    """
    _moment_acc_str = "moment"
    _inf_norm_acc_str = "inf_norm"
Q
qiaolongfei 已提交
641
    _beta1_pow_acc_str = "beta1_pow_acc"
642 643 644 645 646

    def __init__(self,
                 learning_rate=0.001,
                 beta1=0.9,
                 beta2=0.999,
647
                 epsilon=1e-8,
D
dzhwinter 已提交
648
                 **kwargs):
649 650 651 652
        assert learning_rate is not None
        assert beta1 is not None
        assert beta2 is not None
        assert epsilon is not None
Q
Qiao Longfei 已提交
653 654
        super(AdamaxOptimizer, self).__init__(
            learning_rate=learning_rate, **kwargs)
655 656 657 658 659 660 661 662
        self.type = "adamax"
        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon

    def _create_accumulators(self, block, parameters):
        # Create accumulator tensors for first moment and infinity norm
        for p in parameters:
Q
Qiao Longfei 已提交
663 664
            self._add_accumulator(self._moment_acc_str, p)
            self._add_accumulator(self._inf_norm_acc_str, p)
Q
qiaolongfei 已提交
665 666 667 668 669 670
            self._add_accumulator(
                name=self._beta1_pow_acc_str,
                param=p,
                dtype='float32',
                fill_value=self._beta1,
                shape=[1])
671 672 673 674 675 676 677

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0])
        inf_norm = self._get_accumulator(self._inf_norm_acc_str,
                                         param_and_grad[0])
Q
qiaolongfei 已提交
678 679
        beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                              param_and_grad[0])
680 681 682 683 684 685
        # create the adamax optimize op
        adamax_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
686
                "LearningRate": self._create_param_lr(param_and_grad),
687 688
                "Moment": moment,
                "InfNorm": inf_norm,
Q
qiaolongfei 已提交
689
                "Beta1Pow": beta1_pow_acc
690 691 692 693 694 695 696 697 698 699 700 701 702 703
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "MomentOut": moment,
                "InfNormOut": inf_norm
            },
            attrs={
                "beta1": self._beta1,
                "beta2": self._beta2,
                "epsilon": self._epsilon
            })

        return adamax_op

704
    def _finish_update(self, block, parameters_and_grads):
705 706 707
        """Update Beta1 Power accumulator
        """
        assert isinstance(block, framework.Block)
Q
Qiao Longfei 已提交
708
        main_block = block.program.global_block()
709 710 711 712
        for param, grad in parameters_and_grads:
            if grad is None:
                continue
            with param.block.program.optimized_guard([param, grad]):
713 714 715 716 717 718 719
                beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
                                                      param)
                main_block.append_op(
                    type="scale",
                    inputs={"X": beta1_pow_acc},
                    outputs={"Out": beta1_pow_acc},
                    attrs={"scale": self._beta1})
720 721 722


class DecayedAdagradOptimizer(Optimizer):
723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
    """
    **Decayed Adagrad Optimizer**

    The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)

    The update is done as follows:

    .. math::

        moment\_out & = decay * moment + (1 - decay) * grad * grad

        param\_out & = param - \\frac{learning\_rate * grad}{\sqrt{moment\_out} + \epsilon}

    The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
    does not have an epsilon attribute. It is added here for numerical
    stability to avoid the division by zero error.

    Args:
        learning_rate (float|Variable): the learning rate used to update parameters. \
        Can be a float value or a Variable with one float value as data element.
        decay (float): decay rate.
        epsilon (float): a small float value for numerical stability.

    Examples:
        .. code-block:: python

            optimizer = fluid.optimizer.DecayedAdagrad(learning_rate=0.2)
            optimizer.minimize(cost)
751 752 753
    """
    _moment_acc_str = "moment"

D
dzhwinter 已提交
754
    def __init__(self, learning_rate, decay=0.95, epsilon=1.0e-6, **kwargs):
755 756 757 758
        assert learning_rate is not None
        assert decay is not None
        assert epsilon is not None

Q
Qiao Longfei 已提交
759 760
        super(DecayedAdagradOptimizer, self).__init__(
            learning_rate=learning_rate, **kwargs)
761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790
        self.type = "decayed_adagrad"
        self._decay = decay
        self._epsilon = epsilon

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        for p in parameters:
            self._add_accumulator(self._moment_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
        assert isinstance(block, framework.Block)

        moment_acc = self._get_accumulator(self._moment_acc_str,
                                           param_and_grad[0])

        # Create the decayed adagrad optimizer op
        decayed_adagrad_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "Moment": moment_acc,
                "LearningRate": self._create_param_lr(param_and_grad)
            },
            outputs={"ParamOut": param_and_grad[0],
                     "MomentOut": moment_acc},
            attrs={"epsilon": self._epsilon})

        return decayed_adagrad_op
791 792


793
class AdadeltaOptimizer(Optimizer):
794 795
    """
    **Adadelta Optimizer**
Q
qiaolongfei 已提交
796

797
    Simple Adadelta optimizer with average squared grad state and
798
    average squared update state.
799 800 801 802 803 804 805 806 807 808 809 810
    The details of adadelta please refer to this
    `ADADELTA: AN ADAPTIVE LEARNING RATE METHOD
    <http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf>`_.

    ..  math::

        E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2 \\\\
        learning\\_rate &= sqrt( ( E(dx_{t-1}^2) + \\epsilon ) / ( \\
                          E(g_t^2) + \\epsilon ) ) \\\\
        E(dx_t^2) &= \\rho * E(dx_{t-1}^2) + (1-\\rho) * (-g*learning\\_rate)^2

    Args:
Q
qiaolongfei 已提交
811
        learning_rate(float): global learning rate
812 813 814 815 816 817 818 819 820
        rho(float): rho in equation
        epsilon(float): epsilon in equation

    Examples:
        .. code-block:: python

            optimizer = fluid.optimizer.Adadelta(
                learning_rate=0.0003, epsilon=1.0e-6, rho=0.95)
            _, params_grads = optimizer.minimize(cost)
821
    """
822

823 824 825 826
    _avg_squared_grad_acc_str = "_avg_squared_grad"
    _avg_squared_update_acc_str = "_avg_squared_update"

    def __init__(self, learning_rate, epsilon=1.0e-6, rho=0.95, **kwargs):
827 828 829 830 831 832
        if learning_rate is None:
            raise ValueError("learning_rate is not set.")
        if epsilon is None:
            raise ValueError("epsilon is not set.")
        if rho is None:
            raise ValueError("rho is not set.")
833 834 835 836 837 838 839
        super(AdadeltaOptimizer, self).__init__(
            learning_rate=learning_rate, **kwargs)
        self.type = "adadelta"
        self._epsilon = epsilon
        self._rho = rho

    def _create_accumulators(self, block, parameters):
840 841
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")
842 843 844 845 846 847

        for p in parameters:
            self._add_accumulator(self._avg_squared_grad_acc_str, p)
            self._add_accumulator(self._avg_squared_update_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
848 849
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")
850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875

        avg_squared_grad_acc = self._get_accumulator(
            self._avg_squared_grad_acc_str, param_and_grad[0])
        avg_squared_update_acc = self._get_accumulator(
            self._avg_squared_update_acc_str, param_and_grad[0])

        # Create the adadelta optimizer op
        adadelta_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "AvgSquaredGrad": avg_squared_grad_acc,
                "AvgSquaredUpdate": avg_squared_update_acc
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "AvgSquaredGradOut": avg_squared_grad_acc,
                "AvgSquaredUpdateOut": avg_squared_update_acc
            },
            attrs={"epsilon": self._epsilon,
                   "rho": self._rho})

        return adadelta_op


Q
qingqing01 已提交
876 877 878 879 880 881 882 883 884 885
class RMSPropOptimizer(Optimizer):
    """
    Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning
    rate method. The original slides proposed RMSProp: Slide 29 of
    http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf .

    The original equation is as follows:

    ..  math::

Q
qiaolongfei 已提交
886
        r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2
Q
qingqing01 已提交
887 888 889 890

        w & = w - \\frac{\\eta} {\\sqrt{r(w,t) + \\epsilon}} \\nabla Q_{i}(w)

    The first equation calculates moving average of the squared gradient for
Q
qiaolongfei 已提交
891
    each weight. Then dividing the gradient by :math:`sqrt{v(w,t)}`.
Q
qingqing01 已提交
892 893 894 895 896 897

    In some cases, adding a momentum term :math: `\\beta` is beneficial.
    In our implementation, Nesterov momentum is used:

    ..  math::

Q
qiaolongfei 已提交
898
        r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2
Q
qingqing01 已提交
899 900 901 902 903 904

        v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{v(w,t) +
            \\epsilon}} \\nabla Q_{i}(w)

        w & = w - v(w, t)

Q
qiaolongfei 已提交
905
    where, :math:`\\rho` is a hyperparameter and typical values are 0.9, 0.95
Q
qingqing01 已提交
906 907 908 909 910 911
    and so on. :math: `beta` is the momentum term. :math: `\\epsilon` is a
    smoothing term to avoid division by zero, usually set somewhere in range
    from 1e-4 to 1e-8.


    Args:
Q
qiaolongfei 已提交
912
        learning_rate(float): global learning rate.
Q
qingqing01 已提交
913 914 915
        rho(float): rho is :math: `\\rho` in equation, set 0.95 by default.
        epsilon(float): :math: `\\epsilon` in equation is smoothing term to
            avoid division by zero, set 1e-6 by default.
Q
qiaolongfei 已提交
916
        momentum(float): :math:`\\beta` in equation is the momentum term,
Q
qingqing01 已提交
917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992
            set 0.0 by default.

    Raises:
        ValueError: If learning_rate, rho, epsilon, momentum are None.

    Examples:
          .. code-block:: python

              optimizer = fluid.optimizer.RMSProp(0.0001)
              _, params_grads = optimizer.minimize(cost)
    """

    _momentum_acc_str = "momentum"
    _mean_square_acc_str = "mean_square"

    def __init__(self,
                 learning_rate,
                 rho=0.95,
                 epsilon=1.0e-6,
                 momentum=0.0,
                 **kwargs):
        super(RMSPropOptimizer, self).__init__(
            learning_rate=learning_rate, **kwargs)
        if learning_rate is None:
            raise ValueError("learning_rate is not set.")
        if rho is None:
            raise ValueError("rho is not set.")
        if epsilon is None:
            raise ValueError("epsilon is not set.")
        if momentum is None:
            raise ValueError("momentum is not set.")

        self.type = "rmsprop"
        self._rho = rho
        self._epsilon = epsilon
        self._momentum = momentum

    def _create_accumulators(self, block, parameters):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")

        for p in parameters:
            self._add_accumulator(self._momentum_acc_str, p)
            self._add_accumulator(self._mean_square_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")

        momentum_acc = self._get_accumulator(self._momentum_acc_str,
                                             param_and_grad[0])
        mean_square_acc = self._get_accumulator(self._mean_square_acc_str,
                                                param_and_grad[0])
        rmsprop_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "Moment": momentum_acc,
                "MeanSquare": mean_square_acc,
                "LearningRate": self._create_param_lr(param_and_grad),
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "MomentOut": momentum_acc,
                "MeanSquareOut": mean_square_acc
            },
            attrs={
                "epsilon": self._epsilon,
                "decay": self._rho,
                "momentum": self._momentum
            })

        return rmsprop_op


Q
qiaolongfei 已提交
993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099
class FtrlOptimizer(Optimizer):
    """
    FTRL (Follow The Regularized Leader) Optimizer.

    The paper that proposed Follow The Regularized Leader (FTRL):
    (https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)

    ..  math::

        &new\_accum = squared\_accum + grad^2

        &if (lr\_power == -0.5):

        &\quad  linear\_accum += grad - \\frac{\\sqrt{new\_accum} - \\sqrt{squared\_accum}}{learning\_rate * param}

        &else:

        &\quad   linear\_accum += grad - \\frac{new\_accum^{-lr\_power} - accum^{-lr\_power}}{learning\_rate * param}


        &x = l1 * sign(linear\_accum) - linear\_accum

        &if (lr\_power == -0.5):

        &\quad   y = \\frac{\\sqrt{new\_accum}}{learning\_rate} + (2 * l2)

        &\quad   pre\_shrink = \\frac{x}{y}

        &\quad   param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0)

        &else:

        &\quad   y = \\frac{new\_accum^{-lr\_power}}{learning\_rate} + (2 * l2)

        &\quad   pre\_shrink = \\frac{x}{y}

        &\quad   param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0)

        &squared\_accum += grad^2

    Args:
        learning_rate (float|Variable): global learning rate.
        l1 (float):
        l2 (float):
        lr_power (float):

    Raises:
        ValueError: If learning_rate, rho, epsilon, momentum are None.

    Examples:
          .. code-block:: python

              optimizer = fluid.optimizer.Ftrl(0.0001)
              _, params_grads = optimizer.minimize(cost)
    """

    _squared_acc_str = "squared"
    _linear_acc_str = "linear"

    def __init__(self, learning_rate, l1=0.0, l2=0.0, lr_power=-0.5, **kwargs):
        super(FtrlOptimizer, self).__init__(
            learning_rate=learning_rate, **kwargs)
        if learning_rate is None:
            raise ValueError("learning_rate is not set.")

        self.type = "ftrl"
        self._l1 = l1
        self._l2 = l2
        self._lr_power = lr_power

    def _create_accumulators(self, block, parameters):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")

        for p in parameters:
            self._add_accumulator(self._squared_acc_str, p)
            self._add_accumulator(self._linear_acc_str, p)

    def _append_optimize_op(self, block, param_and_grad):
        if not isinstance(block, framework.Block):
            raise TypeError("block is not instance of framework.Block.")

        squared_acc = self._get_accumulator(self._squared_acc_str,
                                            param_and_grad[0])
        linear_acc = self._get_accumulator(self._linear_acc_str,
                                           param_and_grad[0])
        ftrl_op = block.append_op(
            type=self.type,
            inputs={
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "SquaredAccumulator": squared_acc,
                "LinearAccumulator": linear_acc,
                "LearningRate": self._create_param_lr(param_and_grad),
            },
            outputs={
                "ParamOut": param_and_grad[0],
                "SquaredAccumOut": squared_acc,
                "LinearAccumOut": linear_acc
            },
            attrs={"l1": self._l1,
                   "l2": self._l1,
                   "lr_power": self._lr_power})

        return ftrl_op


1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113
# We short the class name, since users will use the optimizer with the package
# name. The sample code:
#
# import paddle.fluid as fluid
#
# sgd = fluid.optimizer.SGD(...)
#
# It is no need to add an `Optimizer` as the class suffix
SGD = SGDOptimizer
Momentum = MomentumOptimizer
Adagrad = AdagradOptimizer
Adam = AdamOptimizer
Adamax = AdamaxOptimizer
DecayedAdagrad = DecayedAdagradOptimizer
1114
Adadelta = AdadeltaOptimizer
Q
qingqing01 已提交
1115
RMSProp = RMSPropOptimizer
Q
qiaolongfei 已提交
1116
Ftrl = FtrlOptimizer
1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133


class ModelAverage(Optimizer):
    """Accumulate the average of parameters whtin sliding window. The average
    result will be saved in temporary variables which can be applied to
    parameter variables of current model by calling 'apply()' method. And the
    'restore()' method is used to restored the parameter values of current model.

    The size of average window is determined by average_window_rate,
    min_average_window, max_average_window and current update times.

    Args:
        average_window_rate: The rate of average window.
        min_average_window: The minimum size of average window.
        max_average_window: The maximum size of average window.

    Examples:
Q
qiaolongfei 已提交
1134 1135 1136

      .. code-block:: python

1137
        optimizer = fluid.optimizer.Momentum()
1138 1139
        optimizer.minimize(cost)
        model_average = fluid.optimizer.ModelAverage(0.15,
1140 1141 1142 1143 1144
                                                min_average_window=10000,
                                                max_average_window=20000)
        for pass_id in range(args.pass_num):
            for data in train_reader():
                exe.run(fluid.default_main_program()...)
1145 1146 1147 1148

            with model_average.apply(exe):
                for data in test_reader():
                    exe.run(inference_program...)
1149 1150 1151
    """

    def __init__(self,
W
wanghaoshuang 已提交
1152
                 average_window_rate,
1153 1154 1155 1156 1157 1158 1159
                 min_average_window=10000,
                 max_average_window=10000,
                 **kwargs):
        super(ModelAverage, self).__init__(0.0, **kwargs)
        self.average_window = average_window_rate
        self.min_average_window = min_average_window
        self.max_average_window = max_average_window
1160

1161
        self.params_grads = []
1162 1163
        for param in framework.default_main_program().global_block(
        ).all_parameters():
1164
            if param.do_model_average != False:
1165 1166 1167 1168
                grad = param.block.create_var(
                    name=unique_name.generate(".".join([param.name, 'tmp'])),
                    dtype=param.dtype,
                    persistable=False,
W
wanghaoshuang 已提交
1169
                    stop_gradient=True)
1170
                self.params_grads.append((param, grad))
1171

1172
        for param, grad in self.params_grads:
1173 1174 1175
            if grad is None:
                continue
            with param.block.program.optimized_guard([param, grad]):
1176
                self._append_average_accumulate_op(param)
1177

1178 1179 1180 1181
        self.apply_program = Program()
        block = self.apply_program.global_block()
        with program_guard(main_program=self.apply_program):
            for param_grad in self.params_grads:
1182
                self._add_average_apply_op(block, param_grad)
1183 1184 1185 1186 1187

        self.restore_program = Program()
        block = self.restore_program.global_block()
        with program_guard(main_program=self.restore_program):
            for param_grad in self.params_grads:
1188
                self._add_average_restore_op(block, param_grad)
1189

1190
    def _add_average_apply_op(self, block, param_grad):
L
Luo Tao 已提交
1191 1192 1193 1194 1195 1196
        param = block._clone_variable(param_grad[0])
        grad = block._clone_variable(param_grad[1])
        sum_1 = block._clone_variable(self._get_accumulator('sum_1', param))
        sum_2 = block._clone_variable(self._get_accumulator('sum_2', param))
        sum_3 = block._clone_variable(self._get_accumulator('sum_3', param))
        num_accumulates = block._clone_variable(
1197
            self._get_accumulator('num_accumulates', param))
L
Luo Tao 已提交
1198
        old_num_accumulates = block._clone_variable(
1199
            self._get_accumulator('old_num_accumulates', param))
L
Luo Tao 已提交
1200
        num_updates = block._clone_variable(
1201 1202 1203 1204 1205 1206
            self._get_accumulator('num_updates', param))
        # backup param value to grad
        layers.assign(input=param, output=grad)
        # param = (sum_1 + sum_2 + sum_3) / (num_accumulates + old_num_accumulates)
        tmp = layers.sum(x=[num_accumulates, old_num_accumulates])
        sum = layers.sum(x=[sum_1, sum_2, sum_3])
D
dzhwinter 已提交
1207 1208 1209 1210
        tmp = layers.cast(
            x=tmp, dtype='float32' if self._dtype == None else self._dtype)
        sum = layers.cast(
            x=sum, dtype='float32' if self._dtype == None else self._dtype)
1211 1212 1213
        layers.elementwise_div(x=sum, y=tmp, out=param)

    def _add_average_restore_op(self, block, param_grad):
L
Luo Tao 已提交
1214 1215
        param = block._clone_variable(param_grad[0])
        grad = block._clone_variable(param_grad[1])
1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254
        layers.assign(input=grad, output=param)

    def _append_average_accumulate_op(self, param):
        self.helper = LayerHelper("average_accumulate")
        sum_1 = self._add_accumulator('sum_1', param)
        sum_2 = self._add_accumulator('sum_2', param)
        sum_3 = self._add_accumulator('sum_3', param)
        num_accumulates = self._add_accumulator(
            'num_accumulates', param, dtype='int64', shape=[1])
        old_num_accumulates = self._add_accumulator(
            'old_num_accumulates', param, dtype='int64', shape=[1])
        num_updates = self._add_accumulator(
            'num_updates', param, dtype='int64', shape=[1])

        self.helper.append_op(
            type='average_accumulates',
            inputs={
                "param": param,
                "in_sum_1": sum_1,
                "in_sum_2": sum_2,
                "in_sum_3": sum_3,
                "in_num_accumulates": num_accumulates,
                "in_old_num_accumulates": old_num_accumulates,
                "in_num_updates": num_updates
            },
            outputs={
                "out_sum_1": sum_1,
                "out_sum_2": sum_2,
                "out_sum_3": sum_3,
                "out_num_accumulates": num_accumulates,
                "out_old_num_accumulates": old_num_accumulates,
                "out_num_updates": num_updates,
            },
            attrs={
                "average_window": self.average_window,
                "min_average_window": self.min_average_window,
                "max_average_window": self.max_average_window,
            })

1255 1256
    @contextmanager
    def apply(self, executor, need_restore=True):
1257 1258
        """Apply average values to parameters of current model.
        """
1259 1260 1261 1262 1263 1264
        executor.run(self.apply_program)
        try:
            yield
        finally:
            if need_restore:
                self.restore(executor)
1265 1266 1267 1268

    def restore(self, executor):
        """Restore parameter values of current model.
        """
1269
        executor.run(self.restore_program)