提交 4042f16c 编写于 作者: W wangnan39@huawei.com

add lazy adam optim and support sparse adam & ftrl for cpu backend

上级 af85b2ce
......@@ -27,6 +27,7 @@ from .lars import LARS
from .ftrl import FTRL
from .rmsprop import RMSProp
from .proximal_ada_grad import ProximalAdagrad
from .lazyadam import LazyAdam
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay',
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam',
'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad']
......@@ -101,10 +101,21 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor")
def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1,
moment2):
@adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
"Tensor", "Tensor", "Tensor")
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success = True
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient[1], gradient[0]))
return success
@adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor")
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
......@@ -144,6 +155,10 @@ class Adam(Optimizer):
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU, weight decay and loss scale is not supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
......@@ -231,12 +246,9 @@ class Adam(Optimizer):
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.hyper_map = C.HyperMap()
self.map_ = C.Map()
self.opt = P.Adam(use_locking, use_nesterov)
self.pow = P.Pow()
self.sqrt = P.Sqrt()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.realdiv = P.RealDiv()
self.sparse_opt = P.SparseApplyAdam()
def construct(self, gradients):
params = self.parameters
......@@ -251,13 +263,13 @@ class Adam(Optimizer):
beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power
if self.is_group_lr:
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
self.beta2, self.eps),
lr, gradients, params, moment1, moment2)
success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2)
else:
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
self.beta2, self.eps, lr),
gradients, params, moment1, moment2)
success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2)
return success
......
......@@ -23,8 +23,18 @@ from .optimizer import Optimizer, apply_decay, grad_scale
ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
@ftrl_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
@ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor",
"Tensor")
def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success = True
success = F.depend(success, spars_opt(weight, moment, linear, gradient[1], gradient[0]))
return success
@ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor")
def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
"""Apply ftrl optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
......@@ -67,6 +77,11 @@ class FTRL(Optimizer):
<https://arxiv.org/abs/1002.4908>`_. Refer to paper `Ad Click Prediction: a View from the Trenches
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU, weight decay and loss scale is not supported.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be Parameter.
......@@ -109,8 +124,9 @@ class FTRL(Optimizer):
self.weight_decay = weight_decay
self.decay_tf = tuple((lambda: True)() for x in self.parameters)
self.hyper_map = C.HyperMap()
self.map_ = C.Map()
self.opt = P.ApplyFtrl(use_locking=use_locking)
self.one = Tensor(1, mstype.int32)
self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
def construct(self, grads):
params = self.parameters
......@@ -121,6 +137,6 @@ class FTRL(Optimizer):
if self.reciprocal_scale != 1.0:
grads = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), grads)
lr = self.learning_rate
success = self.hyper_map(F.partial(ftrl_opt, self.opt, lr, self.l1, self.l2, self.lr_power),
linear, grads, params, moments)
success = self.map_(F.partial(ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
linear, grads, params, moments)
return success
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lazy adam"""
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
@lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
"Tensor", "Tensor", "Tensor")
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2):
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
success = True
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient[1], gradient[0]))
return success
@lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor")
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient))
return success
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs."""
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
class LazyAdam(Optimizer):
r"""
Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
\end{array}
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
:math:`\epsilon` represents `eps`.
Note:
The LazyAdam optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set as True. The sparse behavior, to be notice, is not equivalent to the
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
continuous development. The sparse behavior is currently performed on the CPU, weight decay and loss scale is
not supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
0.9.
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
0.999.
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
1e-8.
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
If True, updating of the var, m, and v tensors will be protected by a lock.
If False, the result is unpredictable. Default: False.
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
If True, updates the gradients using NAG.
If False, updates the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
1.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Tensor[bool], the value is True.
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.LazyAdam(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
super(LazyAdam, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
self.beta1 = Tensor(beta1, mstype.float32)
self.beta2 = Tensor(beta2, mstype.float32)
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
self.eps = eps
self.use_nesterov = use_nesterov
self.use_locking = use_locking
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.hyper_map = C.HyperMap()
self.map_ = C.Map()
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.SparseApplyLazyAdam(use_locking, use_nesterov)
def construct(self, gradients):
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
self.beta1_power = self.beta1_power * self.beta1
self.beta2_power = self.beta2_power * self.beta2
if self.is_group_lr:
success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power,
self.beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, self.parameters, self.moment1, self.moment2)
else:
success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power,
self.beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, self.parameters, self.moment1, self.moment2)
return success
......@@ -21,7 +21,7 @@ from mindspore import Tensor, Parameter
import mindspore.common.dtype as mstype
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
from mindspore.ops import operations as P
......@@ -50,6 +50,19 @@ class NetWithoutWeight(nn.Cell):
return x
class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True)
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0
self.gather = P.SparseGatherV2()
def construct(self, indices, label):
return self.gather(self.weight1, indices, self.axis) + self.weight2
def test_adamwithoutparam():
net = NetWithoutWeight()
net.set_train()
......@@ -72,6 +85,33 @@ def test_adamw_compile():
_executor.compile(train_network, inputs, label)
def test_adam_compile():
""" test adam compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Adam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_spares_adam_compile():
""" test_sparse_adam_compile """
indices = Tensor(np.array([0, 1]).astype(np.int32))
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
net = NetWithSparseGatherV2()
net.set_train()
optimizer = Adam(net.trainable_params(), learning_rate=0.1)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
def test_AdamWeightDecay_beta1():
net = Net()
print("**********", net.get_parameters())
......
......@@ -37,6 +37,19 @@ class Net(nn.Cell):
return x
class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True)
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0
self.gather = P.SparseGatherV2()
def construct(self, indices, label):
return self.gather(self.weight1, indices, self.axis) + self.weight2
def test_ftrl():
""" test_ftrl """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
......@@ -48,3 +61,15 @@ def test_ftrl():
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_spares_ftrl_compile():
""" test sparse ftrl compile """
indices = Tensor(np.array([0, 1]).astype(np.int32))
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
net = NetWithSparseGatherV2()
net.set_train()
optimizer = FTRL(net.trainable_params())
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test lazy adam """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import LazyAdam
from mindspore.ops import operations as P
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias")
self.matmul = P.MatMul()
self.biasAdd = P.BiasAdd()
def construct(self, x):
x = self.biasAdd(self.matmul(x, self.weight), self.bias)
return x
class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True)
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0
self.gather = P.SparseGatherV2()
def construct(self, indices, label):
return self.gather(self.weight1, indices, self.axis) + self.weight2
def test_lazy_adam_compile():
""" test lazy adam compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_spares_lazy_adam_compile():
""" test sparse adam compile """
indices = Tensor(np.array([0, 1]).astype(np.int32))
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
net = NetWithSparseGatherV2()
net.set_train()
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
def test_lazy_adam_error():
net = Net()
with pytest.raises(ValueError):
LazyAdam(net.get_parameters(), learning_rate=-0.1)
with pytest.raises(TypeError):
LazyAdam(net.get_parameters(), learning_rate=0.1, beta1=2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册