提交 883fde04 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!839 add parameter verification for RMSprop

Merge pull request !839 from wangnan39/add_parameter_verification_for_rmsprop
...@@ -145,9 +145,12 @@ class Adam(Optimizer): ...@@ -145,9 +145,12 @@ class Adam(Optimizer):
When the learning_rate is float or learning_rate is a Tensor When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate. but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3. Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). 0.9.
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
0.999.
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
1e-8.
use_locking (bool): Whether to enable a lock to protect updating variable tensors. use_locking (bool): Whether to enable a lock to protect updating variable tensors.
If True, updating of the var, m, and v tensors will be protected by a lock. If True, updating of the var, m, and v tensors will be protected by a lock.
If False, the result is unpredictable. Default: False. If False, the result is unpredictable. Default: False.
...@@ -155,8 +158,8 @@ class Adam(Optimizer): ...@@ -155,8 +158,8 @@ class Adam(Optimizer):
If True, updates the gradients using NAG. If True, updates the gradients using NAG.
If False, updates the gradients without using NAG. Default: False. If False, updates the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Default: 1.0. loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
Should be equal to or greater than 1. 1.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
......
...@@ -46,8 +46,8 @@ class Optimizer(Cell): ...@@ -46,8 +46,8 @@ class Optimizer(Cell):
learning_rate (float): A floating point value for the learning rate. Should be greater than 0. learning_rate (float): A floating point value for the learning rate. Should be greater than 0.
parameters (list): A list of parameter, which will be updated. The element in `parameters` parameters (list): A list of parameter, which will be updated. The element in `parameters`
should be class mindspore.Parameter. should be class mindspore.Parameter.
weight_decay (float): A floating point value for the weight decay. If the type of `weight_decay` weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0.
input is int, it will be convertd to float. Default: 0.0. If the type of `weight_decay` input is int, it will be convertd to float. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the
type of `loss_scale` input is int, it will be convertd to float. Default: 1.0. type of `loss_scale` input is int, it will be convertd to float. Default: 1.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda
...@@ -87,21 +87,15 @@ class Optimizer(Cell): ...@@ -87,21 +87,15 @@ class Optimizer(Cell):
if isinstance(weight_decay, int): if isinstance(weight_decay, int):
weight_decay = float(weight_decay) weight_decay = float(weight_decay)
validator.check_value_type("weight_decay", weight_decay, [float], None)
validator.check_float_legal_value('weight_decay', weight_decay, None) validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None)
if isinstance(loss_scale, int): if isinstance(loss_scale, int):
loss_scale = float(loss_scale) loss_scale = float(loss_scale)
validator.check_value_type("loss_scale", loss_scale, [float], None)
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None)
validator.check_float_legal_value('loss_scale', loss_scale, None)
if loss_scale <= 0.0:
raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale))
self.loss_scale = loss_scale self.loss_scale = loss_scale
if weight_decay < 0.0:
raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay))
self.learning_rate = Parameter(learning_rate, name="learning_rate") self.learning_rate = Parameter(learning_rate, name="learning_rate")
self.parameters = ParameterTuple(parameters) self.parameters = ParameterTuple(parameters)
self.reciprocal_scale = 1.0 / loss_scale self.reciprocal_scale = 1.0 / loss_scale
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""rmsprop""" """rmsprop"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer from .optimizer import Optimizer
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
...@@ -91,14 +92,16 @@ class RMSProp(Optimizer): ...@@ -91,14 +92,16 @@ class RMSProp(Optimizer):
take the i-th value as the learning rate. take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate. but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Other cases are not supported. Default: 0.1.
decay (float): Decay rate. decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9.
momentum (float): Hyperparameter of type float, means momentum for the moving average. momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or
epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0. greater than 0.Default: 0.0.
epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than
0. Default: 1e-10.
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False.
loss_scale (float): A floating point value for the loss scale. Default: 1.0. loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'beta' not in x.name and 'gamma' not in x.name. lambda x: 'beta' not in x.name and 'gamma' not in x.name.
...@@ -118,17 +121,15 @@ class RMSProp(Optimizer): ...@@ -118,17 +121,15 @@ class RMSProp(Optimizer):
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
validator.check_value_type("decay", decay, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0: validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) validator.check_value_type("momentum", momentum, [float], self.cls_name)
validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
if decay < 0.0: validator.check_value_type("epsilon", epsilon, [float], self.cls_name)
raise ValueError("decay should be at least 0.0, but got dampening {}".format(decay)) validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
self.decay = decay
self.epsilon = epsilon
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("centered", centered, [bool], self.cls_name) validator.check_value_type("centered", centered, [bool], self.cls_name)
self.centered = centered self.centered = centered
if centered: if centered:
self.opt = P.ApplyCenteredRMSProp(use_locking) self.opt = P.ApplyCenteredRMSProp(use_locking)
...@@ -137,11 +138,10 @@ class RMSProp(Optimizer): ...@@ -137,11 +138,10 @@ class RMSProp(Optimizer):
self.opt = P.ApplyRMSProp(use_locking) self.opt = P.ApplyRMSProp(use_locking)
self.momentum = momentum self.momentum = momentum
self.ms = self.parameters.clone(prefix="mean_square", init='zeros') self.ms = self.parameters.clone(prefix="mean_square", init='zeros')
self.moment = self.parameters.clone(prefix="moment", init='zeros') self.moment = self.parameters.clone(prefix="moment", init='zeros')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.epsilon = epsilon
self.decay = decay self.decay = decay
def construct(self, gradients): def construct(self, gradients):
......
...@@ -49,12 +49,12 @@ class SGD(Optimizer): ...@@ -49,12 +49,12 @@ class SGD(Optimizer):
When the learning_rate is float or learning_rate is a Tensor When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate. but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 0.1. Other cases are not supported. Default: 0.1.
momentum (float): A floating point value the momentum. Default: 0. momentum (float): A floating point value the momentum. Default: 0.0.
dampening (float): A floating point value of dampening for momentum. Default: 0. dampening (float): A floating point value of dampening for momentum. Default: 0.0.
weight_decay (float): Weight decay (L2 penalty). Default: 0. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
nesterov (bool): Enables the Nesterov momentum. Default: False. nesterov (bool): Enables the Nesterov momentum. Default: False.
loss_scale (float): A floating point value for the loss scale, which should be larger loss_scale (float): A floating point value for the loss scale, which should be larger
than 0.0. Default: 1.0. than 0.0. Default: 1.0.
Inputs: Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test adam """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore import Tensor, Parameter
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import operations as P
from mindspore.nn.optim import RMSProp
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias")
self.matmul = P.MatMul()
self.biasAdd = P.BiasAdd()
def construct(self, x):
x = self.biasAdd(self.matmul(x, self.weight), self.bias)
return x
def test_rmsprop_compile():
""" test_adamw_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = RMSProp(net.trainable_params(), learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_rmsprop_e():
net = Net()
with pytest.raises(ValueError):
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1)
with pytest.raises(TypeError):
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册