adam.py 20.7 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""adam"""
import numpy as np

from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
F
fary86 已提交
25
from mindspore._checkparam import Validator as validator
Z
zhunaipan 已提交
26
from mindspore._checkparam import Rel
R
root 已提交
27
from .optimizer import Optimizer
Z
zhunaipan 已提交
28

29
_adam_opt = C.MultitypeFuncGraph("adam_opt")
Z
zhunaipan 已提交
30 31


32
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
Z
Ziyan 已提交
33
                    "Tensor", "Bool", "Bool")
34
def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
Z
zhunaipan 已提交
35 36 37 38
    """
    Update parameters.

    Args:
39 40
        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
Z
zhunaipan 已提交
41 42
        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
        lr (Tensor): Learning rate.
S
simson 已提交
43
        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
Z
zhunaipan 已提交
44 45 46 47
        param (Tensor): Parameters.
        m (Tensor): m value of parameters.
        v (Tensor): v value of parameters.
        gradient (Tensor): Gradient of parameters.
Z
Ziyan 已提交
48 49
        decay_flag (bool): Applies weight decay or not.
        optim_filter (bool): Applies parameter update or not.
Z
zhunaipan 已提交
50 51 52 53

    Returns:
        Tensor, the new value of v after updating.
    """
Z
Ziyan 已提交
54 55 56 57 58 59 60
    if optim_filter:
        op_mul = P.Mul()
        op_square = P.Square()
        op_sqrt = P.Sqrt()
        op_cast = P.Cast()
        op_reshape = P.Reshape()
        op_shape = P.Shape()
Z
zhunaipan 已提交
61

Z
Ziyan 已提交
62 63 64 65
        param_fp32 = op_cast(param, mstype.float32)
        m_fp32 = op_cast(m, mstype.float32)
        v_fp32 = op_cast(v, mstype.float32)
        gradient_fp32 = op_cast(gradient, mstype.float32)
Z
zhunaipan 已提交
66

Z
Ziyan 已提交
67 68
        next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
                                                - beta1, gradient_fp32)
Z
zhunaipan 已提交
69

Z
Ziyan 已提交
70 71
        next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
                                                - beta2, op_square(gradient_fp32))
Z
zhunaipan 已提交
72

Z
Ziyan 已提交
73 74
        update = next_m / (eps + op_sqrt(next_v))
        if decay_flag:
75
            update = op_mul(weight_decay, param_fp32) + update
Z
zhunaipan 已提交
76

Z
Ziyan 已提交
77 78 79 80 81 82
        update_with_lr = op_mul(lr, update)
        next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))

        next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
        next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
        next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
F
fary86 已提交
83 84 85

        return op_cast(next_param, F.dtype(param))
    return gradient
Z
zhunaipan 已提交
86 87


Z
ZPaC 已提交
88
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
P
panyifeng 已提交
89
                    "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
Z
ZPaC 已提交
90 91
def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr,
                         gradient, params, moment1, moment2, ps_parameter):
92 93
    """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
    success = True
P
panyifeng 已提交
94 95
    indices = gradient.indices
    values = gradient.values
J
jinyaohui 已提交
96 97 98 99
    if ps_parameter:
        op_shape = P.Shape()
        shapes = (op_shape(params), op_shape(moment1), op_shape(moment2),
                  op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
100
                  op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
Z
ZPaC 已提交
101 102
        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
                                               eps, values, indices), shapes), params))
J
jinyaohui 已提交
103 104
    else:
        success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
105
                                               eps, values, indices))
106 107 108
    return success


Z
ZPaC 已提交
109 110 111 112
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
                    "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient,
                             params, moment1, moment2, ps_parameter):
Z
zhunaipan 已提交
113 114
    """Apply adam optimizer to the weight parameter using Tensor."""
    success = True
J
jinyaohui 已提交
115 116
    if ps_parameter:
        op_shape = P.Shape()
Z
ZPaC 已提交
117 118
        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
                                              (op_shape(params), op_shape(moment1), op_shape(moment2))), params))
J
jinyaohui 已提交
119 120 121
    else:
        success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
                                        eps, gradient))
Z
zhunaipan 已提交
122 123
    return success

124 125 126 127 128 129 130 131 132 133
def _check_param_value(beta1, beta2, eps, prim_name):
    """Check the type of inputs."""
    validator.check_value_type("beta1", beta1, [float], prim_name)
    validator.check_value_type("beta2", beta2, [float], prim_name)
    validator.check_value_type("eps", eps, [float], prim_name)
    validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
    validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
    validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)


Z
zhunaipan 已提交
134 135
class Adam(Optimizer):
    r"""
136
    Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.
Z
zhunaipan 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155

    The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.

    The updating formulas are as follows,

    .. math::
        \begin{array}{ll} \\
            m = \beta_1 * m + (1 - \beta_1) * g \\
            v = \beta_2 * v + (1 - \beta_2) * g * g \\
            l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
            w = w - l * \frac{m}{\sqrt{v} + \epsilon}
        \end{array}

    :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
    :math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
    `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
    `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
    :math:`\epsilon` represents `eps`.

156 157
    Note:
        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
158 159
        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
160

161
        To improve parameter groups performance, the customized order of parameters is supported.
162

163
        The sparse strategy is applied while the SparseGatherV2 operator is used for forward network.
P
panyifeng 已提交
164
        The sparse feature is under continuous development. The sparse
165
        behavior is currently performed on the CPU.
166

Z
zhunaipan 已提交
167
    Args:
168 169
        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
            the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
170
            "lr", "weight_decay" and "order_params" are the keys can be parsed.
171 172 173

            - params: Required. The value should be a list of `Parameter`.

174
            - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
175 176
              If not, the `learning_rate` in the API will be used.

177
            - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
178 179
              will be used. If not, the `weight_decay` in the API will be used.

180 181 182
            - order_params: Optional. If "order_params" is in the keys, the value should be the order of parameters and
              the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
              which in the 'order_params' should be in one of group parameters.
183

184 185
        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
186 187
            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
188 189
            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
            dimension, use fixed learning rate. Other cases are not supported. The float learning rate should be
190 191
            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
            Default: 1e-3.
192 193 194 195
        beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
                       Default: 0.9.
        beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
                       Default: 0.999.
196 197
        eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
                     1e-8.
S
simson 已提交
198
        use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
199 200
            If true, updates of the var, m, and v tensors will be protected by a lock.
            If false, the result is unpredictable. Default: False.
Z
zhunaipan 已提交
201
        use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
202 203
            If true, update the gradients using NAG.
            If false, update the gradients without using NAG. Default: False.
S
simson 已提交
204 205
        weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
        loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
Z
zhunaipan 已提交
206 207 208 209 210 211 212 213 214

    Inputs:
        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

    Outputs:
        Tensor[bool], the value is True.

    Examples:
        >>> net = Net()
215
        >>> #1) All parameters use the same learning rate and weight decay
Z
zhongligeng 已提交
216
        >>> optim = nn.Adam(params=net.trainable_params())
217 218 219
        >>>
        >>> #2) Use parameter groups and set different values
        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
220
        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
221
        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
222
        >>>                 {'params': no_conv_params, 'lr': 0.01},
223
        >>>                 {'order_params': net.trainable_params()}]
L
lihongkang 已提交
224
        >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
225 226
        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
        >>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0.
227
        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
228 229 230
        >>>
        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
        >>> model = Model(net, loss_fn=loss, optimizer=optim)
Z
zhunaipan 已提交
231 232 233
    """

    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
234 235
                 use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
        super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)
236
        _check_param_value(beta1, beta2, eps, self.cls_name)
F
fary86 已提交
237 238
        validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
        validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
Z
zhunaipan 已提交
239 240 241 242 243

        self.beta1 = Tensor(beta1, mstype.float32)
        self.beta2 = Tensor(beta2, mstype.float32)
        self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
        self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
Z
ZPaC 已提交
244
        self.eps = Tensor(eps, mstype.float32)
Z
zhunaipan 已提交
245 246 247 248 249 250

        self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
        self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')

        self.hyper_map = C.HyperMap()
        self.opt = P.Adam(use_locking, use_nesterov)
251
        self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
Z
zhunaipan 已提交
252

Z
ZPaC 已提交
253 254 255 256
        self._ps_pull = P.Pull()
        self._ps_push = P.Push("Adam", [0, 1, 2])
        self._ps_push.add_prim_attr("use_nesterov", use_nesterov)

Z
zhunaipan 已提交
257 258 259 260
    def construct(self, gradients):
        params = self.parameters
        moment1 = self.moment1
        moment2 = self.moment2
R
root 已提交
261 262 263
        gradients = self.decay_weight(gradients)
        gradients = self.scale_grad(gradients)
        lr = self.get_lr()
Z
zhunaipan 已提交
264 265 266 267 268

        beta1_power = self.beta1_power * self.beta1
        self.beta1_power = beta1_power
        beta2_power = self.beta2_power * self.beta2
        self.beta2_power = beta2_power
269
        if self.is_group_lr:
Z
ZPaC 已提交
270 271
            success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
                                          beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
J
jinyaohui 已提交
272
                                lr, gradients, params, moment1, moment2, self.ps_parameters)
273
        else:
Z
ZPaC 已提交
274 275
            success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
                                          beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
J
jinyaohui 已提交
276
                                gradients, params, moment1, moment2, self.ps_parameters)
Z
zhunaipan 已提交
277 278
        return success

J
jinyaohui 已提交
279

Z
zhunaipan 已提交
280 281
class AdamWeightDecay(Optimizer):
    """
282
    Implements the Adam algorithm to fix the weight decay.
Z
zhunaipan 已提交
283

284 285
    Note:
        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
286
        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
287
        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
Z
zhunaipan 已提交
288

289
        To improve parameter groups performance, the customized order of parameters can be supported.
Z
zhunaipan 已提交
290

291 292 293 294
    Args:
        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
            the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
            "lr", "weight_decay" and "order_params" are the keys can be parsed.
Z
zhunaipan 已提交
295

296
            - params: Required. The value should be a list of `Parameter`.
Z
zhunaipan 已提交
297

298
            - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
299
              If not, the `learning_rate` in the API will be used.
Z
zhunaipan 已提交
300

301
            - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
302
              will be used. If not, the `weight_decay` in the API will be used.
Z
zhunaipan 已提交
303

304 305 306
            - order_params: Optional. If "order_params" is in the keys, the value should be the order of parameters and
              the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
              which in the 'order_params' should be in one of group parameters.
Z
zhunaipan 已提交
307

308 309
        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
310 311
            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
312 313
            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
            dimension, use fixed learning rate. Other cases are not supported. The float learning rate should be
314 315
            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
            Default: 1e-3.
316
        beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
Z
zhunaipan 已提交
317
            Should be in range (0.0, 1.0).
318
        beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
Z
zhunaipan 已提交
319 320 321
            Should be in range (0.0, 1.0).
        eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
            Should be greater than 0.
S
simson 已提交
322
        weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
Z
zhunaipan 已提交
323 324 325 326 327

    Inputs:
        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

    Outputs:
Z
Ziyan 已提交
328
        tuple[bool], all elements are True.
Z
zhunaipan 已提交
329 330 331

    Examples:
        >>> net = Net()
332 333 334 335 336 337 338 339 340 341 342 343 344 345
        >>> #1) All parameters use the same learning rate and weight decay
        >>> optim = nn.AdamWeightDecay(params=net.trainable_params())
        >>>
        >>> #2) Use parameter groups and set different values
        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
        >>>                 {'params': no_conv_params, 'lr': 0.01},
        >>>                 {'order_params': net.trainable_params()}]
        >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
        >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
        >>>
Z
zhunaipan 已提交
346
        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
347 348 349 350 351
        >>> model = Model(net, loss_fn=loss, optimizer=optim)
   """
    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
        super(AdamWeightDecay, self).__init__(learning_rate, params, weight_decay)
        _check_param_value(beta1, beta2, eps, self.cls_name)
Z
zhunaipan 已提交
352 353 354
        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
        self.eps = Tensor(np.array([eps]).astype(np.float32))
355 356
        self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
        self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
Z
zhunaipan 已提交
357 358 359
        self.hyper_map = C.HyperMap()

    def construct(self, gradients):
360 361 362 363 364 365 366 367 368 369 370 371 372 373
        lr = self.get_lr()
        if self.is_group:
            if self.is_group_lr:
                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
                                              lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
                                              gradients, self.decay_flags, self.optim_filter)
            else:
                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
                                              self.weight_decay, self.parameters, self.moments1, self.moments2,
                                              gradients, self.decay_flags, self.optim_filter)
        else:
            optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
                                          self.parameters, self.moments1, self.moments2,
                                          gradients, self.decay_flags, self.optim_filter)
Z
Ziyan 已提交
374
        if self.use_parallel:
Z
Ziyan 已提交
375
            self.broadcast_params(optim_result)
Z
Ziyan 已提交
376
        return optim_result