Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ed3c2d72
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ed3c2d72
编写于
3月 31, 2020
作者:
Z
zhaoting
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add RMSProp optimizer
上级
5c22c088
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
390 addition
and
5 deletion
+390
-5
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+5
-1
mindspore/ccsrc/transform/op_declare.cc
mindspore/ccsrc/transform/op_declare.cc
+16
-0
mindspore/ccsrc/transform/op_declare.h
mindspore/ccsrc/transform/op_declare.h
+6
-0
mindspore/nn/optim/__init__.py
mindspore/nn/optim/__init__.py
+2
-1
mindspore/nn/optim/rmsprop.py
mindspore/nn/optim/rmsprop.py
+187
-0
mindspore/ops/_grad/grad_math_ops.py
mindspore/ops/_grad/grad_math_ops.py
+2
-2
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+4
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+152
-0
tests/mindspore_test_framework/utils/block_util.py
tests/mindspore_test_framework/utils/block_util.py
+4
-0
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+12
-0
未找到文件。
mindspore/ccsrc/transform/convert.cc
浏览文件 @
ed3c2d72
...
...
@@ -183,6 +183,8 @@ const char kNameDiagPart[] = "DiagPart";
const
char
kNameSpaceToBatch
[]
=
"SpaceToBatch"
;
const
char
kNameBatchToSpace
[]
=
"BatchToSpace"
;
const
char
kNameAtan2
[]
=
"Atan2"
;
const
char
kNameApplyRMSProp
[]
=
"ApplyRMSProp"
;
const
char
kNameApplyCenteredRMSProp
[]
=
"ApplyCenteredRMSProp"
;
// -----------------OpAdapter initialization--------------
std
::
unordered_map
<
std
::
string
,
OpAdapterDescPtr
>
&
DfGraphConvertor
::
get_adpt_map
()
{
...
...
@@ -367,7 +369,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{
string
(
kNameDiagPart
),
ADPT_DESC
(
DiagPart
)},
{
string
(
kNameSpaceToBatch
),
ADPT_DESC
(
SpaceToBatchD
)},
{
string
(
kNameBatchToSpace
),
ADPT_DESC
(
BatchToSpaceD
)},
{
string
(
kNameAtan2
),
ADPT_DESC
(
Atan2
)}};
{
string
(
kNameAtan2
),
ADPT_DESC
(
Atan2
)},
{
string
(
kNameApplyRMSProp
),
ADPT_DESC
(
ApplyRMSPropD
)},
{
string
(
kNameApplyCenteredRMSProp
),
ADPT_DESC
(
ApplyCenteredRMSProp
)}};
#ifdef ENABLE_GE
adpt_map
[
string
(
kNamePrint
)]
=
ADPT_DESC
(
Print
);
#endif
...
...
mindspore/ccsrc/transform/op_declare.cc
浏览文件 @
ed3c2d72
...
...
@@ -1202,6 +1202,22 @@ INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP
(
Atan2
)
=
EMPTY_ATTR_MAP
;
OUTPUT_MAP
(
Atan2
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
// ApplyRMSPropD
INPUT_MAP
(
ApplyRMSPropD
)
=
{
{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
ms
)},
{
3
,
INPUT_DESC
(
mom
)},
{
4
,
INPUT_DESC
(
grad
)},
{
5
,
INPUT_DESC
(
lr
)}};
INPUT_ATTR_MAP
(
ApplyRMSPropD
)
=
{{
6
,
ATTR_DESC
(
rho
,
AnyTraits
<
float
>
())},
{
7
,
ATTR_DESC
(
momentum
,
AnyTraits
<
float
>
())},
{
8
,
ATTR_DESC
(
epsilon
,
AnyTraits
<
float
>
())}};
ATTR_MAP
(
ApplyRMSPropD
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyRMSPropD
)
=
{{
0
,
OUTPUT_DESC
(
var
)}};
// ApplyCenteredRMSProp
INPUT_MAP
(
ApplyCenteredRMSProp
)
=
{{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
mg
)},
{
3
,
INPUT_DESC
(
ms
)},
{
4
,
INPUT_DESC
(
mom
)},
{
5
,
INPUT_DESC
(
grad
)},
{
6
,
INPUT_DESC
(
lr
)},
{
7
,
INPUT_DESC
(
rho
)},
{
8
,
INPUT_DESC
(
momentum
)},
{
9
,
INPUT_DESC
(
epsilon
)}};
ATTR_MAP
(
ApplyCenteredRMSProp
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyCenteredRMSProp
)
=
{{
0
,
OUTPUT_DESC
(
var
)}};
#ifdef ENABLE_GE
// Print
INPUT_MAP
(
Print
)
=
EMPTY_INPUT_MAP
;
...
...
mindspore/ccsrc/transform/op_declare.h
浏览文件 @
ed3c2d72
...
...
@@ -445,6 +445,12 @@ DECLARE_OP_ADAPTER(BatchToSpaceD)
DECLARE_OP_USE_OUTPUT
(
BatchToSpaceD
)
DECLARE_OP_ADAPTER
(
Atan2
)
DECLARE_OP_USE_OUTPUT
(
Atan2
)
DECLARE_OP_ADAPTER
(
ApplyRMSPropD
)
DECLARE_OP_USE_INPUT_ATTR
(
ApplyRMSPropD
)
DECLARE_OP_USE_OUTPUT
(
ApplyRMSPropD
)
DECLARE_OP_ADAPTER
(
ApplyCenteredRMSProp
)
DECLARE_OP_USE_OUTPUT
(
ApplyCenteredRMSProp
)
#ifdef ENABLE_GE
DECLARE_OP_ADAPTER
(
Print
)
DECLARE_OP_USE_DYN_INPUT
(
Print
)
...
...
mindspore/nn/optim/__init__.py
浏览文件 @
ed3c2d72
...
...
@@ -25,6 +25,7 @@ from .lamb import Lamb
from
.sgd
import
SGD
from
.lars
import
LARS
from
.ftrl
import
FTRL
from
.rmsprop
import
RMSProp
__all__
=
[
'Optimizer'
,
'Momentum'
,
'LARS'
,
'Adam'
,
'AdamWeightDecay'
,
'AdamWeightDecayDynamicLR'
,
'Lamb'
,
'SGD'
,
'FTRL'
]
'AdamWeightDecayDynamicLR'
,
'Lamb'
,
'SGD'
,
'FTRL'
,
'RMSProp'
]
mindspore/nn/optim/rmsprop.py
0 → 100644
浏览文件 @
ed3c2d72
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""rmsprop"""
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.common.initializer
import
initializer
from
mindspore.common.parameter
import
Parameter
from
mindspore._checkparam
import
ParamValidator
as
validator
import
mindspore.common.dtype
as
mstype
from
.optimizer
import
Optimizer
,
grad_scale
rmsprop_opt
=
C
.
MultitypeFuncGraph
(
"rmsprop_opt"
)
centered_rmsprop_opt
=
C
.
MultitypeFuncGraph
(
"rmsprop_opt"
)
@
rmsprop_opt
.
register
(
"Function"
,
"Number"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_rmsprop_opt
(
opt
,
learning_rate
,
decay
,
epsilon
,
momentum
,
weight
,
ms
,
mom
,
grad
):
"""Apply rmsprop optimizer to the weight parameter."""
success
=
True
success
=
F
.
depend
(
success
,
opt
(
weight
,
ms
,
mom
,
grad
,
learning_rate
,
decay
,
momentum
,
epsilon
))
return
success
@
rmsprop_opt
.
register
(
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_rmsprop_opt_dynamic_lr
(
opt
,
learning_rate
,
decay
,
epsilon
,
momentum
,
weight
,
ms
,
mom
,
grad
):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success
=
True
success
=
F
.
depend
(
success
,
opt
(
weight
,
ms
,
mom
,
grad
,
learning_rate
,
decay
,
momentum
,
epsilon
))
return
success
@
centered_rmsprop_opt
.
register
(
"Function"
,
"Number"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_centered_rmsprop_opt
(
opt
,
learning_rate
,
decay
,
epsilon
,
momentum
,
weight
,
mg
,
ms
,
mom
,
grad
):
"""Apply centered rmsprop optimizer to the weight parameter."""
success
=
True
success
=
F
.
depend
(
success
,
opt
(
weight
,
mg
,
ms
,
mom
,
grad
,
learning_rate
,
decay
,
momentum
,
epsilon
))
return
success
@
centered_rmsprop_opt
.
register
(
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_centered_rmsprop_opt_dynamic_lr
(
opt
,
learning_rate
,
decay
,
epsilon
,
momentum
,
weight
,
mg
,
ms
,
mom
,
grad
):
"""Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate."""
success
=
True
success
=
F
.
depend
(
success
,
opt
(
weight
,
mg
,
ms
,
mom
,
grad
,
learning_rate
,
decay
,
momentum
,
epsilon
))
return
success
class
RMSProp
(
Optimizer
):
"""
Implements Root Mean Squared Propagation (RMSProp) algorithm.
Note:
Update `params` according to the RMSProp algorithm.
The equation is as follows:
.. math::
s_{t} =
\\
rho s_{t-1} + (1 -
\\
rho)(
\\
nabla Q_{i}(w))^2
.. math::
m_{t} =
\\
beta m_{t-1} +
\\
frac{
\\
eta} {
\\
sqrt{s_{t} +
\\
epsilon}}
\\
nabla Q_{i}(w)
.. math::
w = w - m_{t}
The first equation calculates moving average of the squared gradient for
each weight. Then dividing the gradient by :math:`
\\
sqrt{ms_{t} +
\\
epsilon}`.
if centered is True:
.. math::
g_{t} =
\\
rho g_{t-1} + (1 -
\\
rho)
\\
nabla Q_{i}(w)
.. math::
s_{t} =
\\
rho s_{t-1} + (1 -
\\
rho)(
\\
nabla Q_{i}(w))^2
.. math::
m_{t} =
\\
beta m_{t-1} +
\\
frac{
\\
eta} {
\\
sqrt{s_{t} - g_{t}^2 +
\\
epsilon}}
\\
nabla Q_{i}(w)
.. math::
w = w - m_{t}
where, :math:`w` represents `params`, which will be updated.
:math:`g_{t}` is mean gradients, :math:`g_{t-1}` is the last moment of :math:`g_{t}`.
:math:`s_{t}` is the mean square gradients, :math:`s_{t-1}` is the last moment of :math:`s_{t}`,
:math:`m_{t}` is moment, the delta of `w`, :math:`m_{t-1}` is the last moment of :math:`m_{t}`.
:math:`
\\
rho` represents `decay`. :math:`
\\
beta` is the momentum term, represents `momentum`.
:math:`
\\
epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
:math:`
\\
eta` is learning rate, represents `learning_rate`. :math:`
\\
nabla Q_{i}(w)` is gradientse,
represents `gradients`.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters`
should be class mindspore.Parameter.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported.
decay (float): Decay rate.
momentum (float): Hyperparameter of type float, means momentum for the moving average.
epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Tensor[bool], the value is True.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=lr)
>>> model = Model(net, loss, opt)
"""
def
__init__
(
self
,
params
,
learning_rate
=
0.1
,
decay
=
0.9
,
momentum
=
0.0
,
epsilon
=
1e-10
,
use_locking
=
False
,
centered
=
False
,
loss_scale
=
1.0
):
super
(
RMSProp
,
self
).
__init__
(
learning_rate
,
params
)
if
isinstance
(
momentum
,
float
)
and
momentum
<
0.0
:
raise
ValueError
(
"momentum should be at least 0.0, but got momentum {}"
.
format
(
momentum
))
if
decay
<
0.0
:
raise
ValueError
(
"decay should be at least 0.0, but got dampening {}"
.
format
(
decay
))
self
.
decay
=
decay
self
.
epsilon
=
epsilon
validator
.
check_type
(
"use_locking"
,
use_locking
,
[
bool
])
validator
.
check_type
(
"centered"
,
centered
,
[
bool
])
self
.
centered
=
centered
if
centered
:
self
.
opt
=
P
.
ApplyCenteredRMSProp
(
use_locking
)
self
.
mg
=
self
.
parameters
.
clone
(
prefix
=
"mean_grad"
,
init
=
'zeros'
)
else
:
self
.
opt
=
P
.
ApplyRMSProp
(
use_locking
)
self
.
dynamic_lr
=
False
if
not
isinstance
(
learning_rate
,
float
):
self
.
dynamic_lr
=
True
self
.
gather
=
P
.
GatherV2
()
self
.
assignadd
=
P
.
AssignAdd
()
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"global_step"
)
self
.
axis
=
0
self
.
momentum
=
momentum
self
.
ms
=
self
.
parameters
.
clone
(
prefix
=
"mean_square"
,
init
=
'zeros'
)
self
.
moment
=
self
.
parameters
.
clone
(
prefix
=
"moment"
,
init
=
'zeros'
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
decay
=
decay
self
.
reciprocal_scale
=
1.0
/
loss_scale
def
construct
(
self
,
gradients
):
params
=
self
.
parameters
if
self
.
reciprocal_scale
!=
1.0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
gradients
)
if
self
.
dynamic_lr
:
lr
=
self
.
gather
(
self
.
learning_rate
,
self
.
global_step
,
self
.
axis
)
F
.
control_depend
(
lr
,
self
.
assignadd
(
self
.
global_step
,
self
.
one
))
else
:
lr
=
self
.
learning_rate
if
self
.
centered
:
success
=
self
.
hyper_map
(
F
.
partial
(
centered_rmsprop_opt
,
self
.
opt
,
lr
,
self
.
decay
,
self
.
epsilon
,
self
.
momentum
),
params
,
self
.
mg
,
self
.
ms
,
self
.
moment
,
gradients
)
else
:
success
=
self
.
hyper_map
(
F
.
partial
(
rmsprop_opt
,
self
.
opt
,
lr
,
self
.
decay
,
self
.
epsilon
,
self
.
momentum
),
params
,
self
.
ms
,
self
.
moment
,
gradients
)
return
success
mindspore/ops/_grad/grad_math_ops.py
浏览文件 @
ed3c2d72
...
...
@@ -394,8 +394,8 @@ def _split_shape_index(input_shape, axis):
axis
=
tuple
([
axis
])
reduction_indices
=
tuple
([(
i
+
rank
)
%
rank
for
i
in
axis
])
other_indices
=
tuple
(
set
(
range
(
rank
))
-
set
(
reduction_indices
))
reduced_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
[
input_shape
[
i
]
for
i
in
reduction_indices
])
other_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
[
input_shape
[
i
]
for
i
in
other_indices
])
reduced_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
[
1
]
+
[
input_shape
[
i
]
for
i
in
reduction_indices
])
other_num
=
reduce
(
lambda
x
,
y
:
x
*
y
,
[
1
]
+
[
input_shape
[
i
]
for
i
in
other_indices
])
perm
=
reduction_indices
+
other_indices
return
tuple
([
reduced_num
,
other_num
]),
perm
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
ed3c2d72
...
...
@@ -65,7 +65,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
SmoothL1Loss
,
Softmax
,
SoftmaxCrossEntropyWithLogits
,
ROIAlign
,
SparseSoftmaxCrossEntropyWithLogits
,
Tanh
,
TopK
,
BinaryCrossEntropy
,
SparseApplyAdagrad
,
LARSUpdate
,
ApplyFtrl
)
TopK
,
BinaryCrossEntropy
,
SparseApplyAdagrad
,
LARSUpdate
,
ApplyFtrl
,
ApplyRMSProp
,
ApplyCenteredRMSProp
)
from
.other_ops
import
Assign
,
IOU
,
BoundingBoxDecode
,
BoundingBoxEncode
,
CheckValid
,
MakeRefKey
...
...
@@ -228,6 +229,8 @@ __all__ = [
"SpaceToBatch"
,
"BatchToSpace"
,
"Atan2"
,
"ApplyRMSProp"
,
"ApplyCenteredRMSProp"
]
__all__
.
sort
()
mindspore/ops/operations/nn_ops.py
浏览文件 @
ed3c2d72
...
...
@@ -1359,6 +1359,158 @@ class SGD(PrimitiveWithInfer):
validator
.
check_typename
(
"stat_dtype"
,
stat_dtype
,
[
mstype
.
float16
,
mstype
.
float32
])
return
parameters_dtype
class
ApplyRMSProp
(
PrimitiveWithInfer
):
"""
Optimizer that implements the Root Mean Square prop(RMSProp) algorithm.
Note:
Update `var` according to the RMSProp algorithm.
.. math::
s_{t} =
\\
rho s_{t-1} + (1 -
\\
rho)(
\\
nabla Q_{i}(w))^2
.. math::
m_{t} =
\\
beta m_{t-1} +
\\
frac{
\\
eta} {
\\
sqrt{s_{t} +
\\
epsilon}}
\\
nabla Q_{i}(w)
.. math::
w = w - m_{t}
where, :math:`w` represents `var`, which will be updated.
:math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`,
:math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`.
:math:`
\\
rho` represents `decay`. :math:`
\\
beta` is the momentum term, represents `momentum`.
:math:`
\\
epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
:math:`
\\
eta` represents `learning_rate`. :math:`
\\
nabla Q_{i}(w)` represents `grad`.
Args:
use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False.
Inputs:
- **var** (Tensor) - Weights to be update.
- **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
- **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
- **grad** (Tensor) - Gradients, must have the same type as `var`.
- **learning_rate** (Union[Number, Tensor]) - Learning rate.
- **decay** (float) - Decay rate.
- **momentum** (float) - Momentum.
- **epsilon** (float) - Ridge term.
Outputs:
Tensor, parameters to be update.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate)
>>> model = Model(net, loss, opt)
"""
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
False
):
self
.
use_locking
=
validator
.
check_type
(
"use_locking"
,
use_locking
,
[
bool
])
def
infer_shape
(
self
,
var_shape
,
mean_square_shape
,
moment_shape
,
grad_shape
,
learning_rate_shape
,
decay_shape
,
momentum_shape
,
epsilon_shape
):
validator
.
check_param_equal
(
"var_shape"
,
var_shape
,
"mean_square_shape"
,
mean_square_shape
)
validator
.
check_param_equal
(
"var_shape"
,
var_shape
,
"moment_shape"
,
moment_shape
)
validator
.
check_param_equal
(
"var_shape"
,
var_shape
,
"grad_shape"
,
grad_shape
)
return
var_shape
def
infer_dtype
(
self
,
var_dtype
,
mean_square_dtype
,
moment_dtype
,
grad_dtype
,
learning_rate_dtype
,
decay_dtype
,
momentum_dtype
,
epsilon_dtype
):
validator
.
check_subclass
(
"var_dtype"
,
var_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"mean_square_dtype"
,
mean_square_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"moment_dtype"
,
moment_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"grad_dtype"
,
moment_dtype
,
mstype
.
tensor
)
args
=
{
"var_dtype"
:
var_dtype
,
"mean_square_dtype"
:
mean_square_dtype
,
"moment_dtype"
:
moment_dtype
,
"grad_dtype"
:
grad_dtype
}
validator
.
check_type_same
(
args
,
mstype
.
number_type
)
args
=
{
"learning_rate_dtype"
:
learning_rate_dtype
,
"decay_dtype"
:
decay_dtype
,
'momentum_dtype'
:
momentum_dtype
,
"epsilon_dtype"
:
epsilon_dtype
}
validator
.
check_type_same
(
args
,
[
mstype
.
float16
,
mstype
.
float32
])
return
var_dtype
class
ApplyCenteredRMSProp
(
PrimitiveWithInfer
):
"""
Optimizer that implements the centered RMSProp algorithm.
Note:
Update `var` according to the centered RMSProp algorithm.
.. math::
g_{t} =
\\
rho g_{t-1} + (1 -
\\
rho)
\\
nabla Q_{i}(w)
.. math::
s_{t} =
\\
rho s_{t-1} + (1 -
\\
rho)(
\\
nabla Q_{i}(w))^2
.. math::
m_{t} =
\\
beta m_{t-1} +
\\
frac{
\\
eta} {
\\
sqrt{s_{t} - g_{t}^2 +
\\
epsilon}}
\\
nabla Q_{i}(w)
.. math::
w = w - m_{t}
where, :math:`w` represents `var`, which will be updated.
:math:`g_{t}` represents `mean_gradient`, :math:`g_{t-1}` is the last momentent of :math:`g_{t}`.
:math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`,
:math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`.
:math:`
\\
rho` represents `decay`. :math:`
\\
beta` is the momentum term, represents `momentum`.
:math:`
\\
epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
:math:`
\\
eta` represents `learning_rate`. :math:`
\\
nabla Q_{i}(w)` represents `grad`.
Args:
use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False.
Inputs:
- **var** (Tensor) - Weights to be update.
- **mean_gradient** (Tensor) - Mean gradients, must have the same type as `var`.
- **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
- **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
- **grad** (Tensor) - Gradients, must have the same type as `var`.
- **learning_rate** (Union[Number, Tensor]) - Learning rate.
- **decay** (float) - Decay rate.
- **momentum** (float) - Momentum.
- **epsilon** (float) - Ridge term.
Outputs:
Tensor, parameters to be update.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate, centered=True)
>>> model = Model(net, loss, opt)
"""
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
False
):
self
.
use_locking
=
validator
.
check_type
(
"use_locking"
,
use_locking
,
[
bool
])
def
infer_shape
(
self
,
var_shape
,
mean_gradient_shape
,
mean_square_shape
,
moment_shape
,
grad_shape
,
learning_rate_shape
,
decay_shape
,
momentum_shape
,
epsilon_shape
):
validator
.
check_param_equal
(
"var_shape"
,
var_shape
,
"mean_gradient_shape"
,
mean_gradient_shape
)
validator
.
check_param_equal
(
"var_shape"
,
var_shape
,
"mean_square_shape"
,
mean_square_shape
)
validator
.
check_param_equal
(
"var_shape"
,
var_shape
,
"moment_shape"
,
moment_shape
)
validator
.
check_param_equal
(
"var_shape"
,
var_shape
,
"grad_shape"
,
grad_shape
)
return
var_shape
def
infer_dtype
(
self
,
var_dtype
,
mean_gradient_dtype
,
mean_square_dtype
,
moment_dtype
,
grad_dtype
,
learning_rate_dtype
,
rho_dtype
,
momentum_dtype
,
epsilon_dtype
):
validator
.
check_subclass
(
"var_dtype"
,
var_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"mean_gradient_dtype"
,
mean_gradient_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"mean_square_dtype"
,
mean_square_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"moment_dtype"
,
moment_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"grad_dtype"
,
moment_dtype
,
mstype
.
tensor
)
args
=
{
"var_dtype"
:
var_dtype
,
"mean_gradient_dtype"
:
mean_gradient_dtype
,
"mean_square_dtype"
:
mean_square_dtype
,
"moment_dtype"
:
moment_dtype
,
"grad_dtype"
:
grad_dtype
}
validator
.
check_type_same
(
args
,
mstype
.
number_type
)
args
=
{
"learning_rate_dtype"
:
learning_rate_dtype
,
"rho_dtype"
:
rho_dtype
,
'momentum_dtype'
:
momentum_dtype
,
"epsilon_dtype"
:
epsilon_dtype
}
validator
.
check_type_same
(
args
,
[
mstype
.
float16
,
mstype
.
float32
])
return
var_dtype
class
LayerNorm
(
Primitive
):
r
"""
...
...
tests/mindspore_test_framework/utils/block_util.py
浏览文件 @
ed3c2d72
...
...
@@ -223,6 +223,10 @@ class InputOpNet(nn.Cell):
x
=
self
.
op
(
x1
,
x2
,
x3
,
x4
,
x5
,
self
.
c1
)
return
x
def
construct5_c4
(
self
,
x1
,
x2
,
x3
,
x4
,
x5
):
x
=
self
.
op
(
x1
,
x2
,
x3
,
x4
,
x5
,
self
.
c1
,
self
.
c2
,
self
.
c3
,
self
.
c4
)
return
x
def
gen_net
(
op
,
input_num
,
training
=
True
,
desc_const
=
(),
const_first
=
False
,
add_fake_input
=
False
):
if
isinstance
(
op
,
nn
.
Cell
):
return
op
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
ed3c2d72
...
...
@@ -805,6 +805,18 @@ test_case_nn_ops = [
'desc_inputs'
:
[[
3
,
3
],
[
3
,
3
],
[
3
,
3
],
[
3
,
3
]],
'desc_bprop'
:
[
3
,
3
],
'skip'
:
[
'backward'
]}),
(
'ApplyRMSProp'
,
{
'block'
:
P
.
ApplyRMSProp
(),
'desc_const'
:
[
0.9
,
0.0
,
1e-10
,
0.001
],
'desc_inputs'
:
[[
3
,
3
],
[
3
,
3
],
[
3
,
3
],
[
3
,
3
]],
'desc_bprop'
:
[
3
,
3
],
'skip'
:
[
'backward'
]}),
(
'ApplyCenteredRMSProp'
,
{
'block'
:
P
.
ApplyCenteredRMSProp
(),
'desc_const'
:
[
0.9
,
0.0
,
1e-10
,
0.001
],
'desc_inputs'
:
[[
3
,
3
],
[
3
,
3
],
[
3
,
3
],
[
3
,
3
],
[
3
,
3
]],
'desc_bprop'
:
[
3
,
3
],
'skip'
:
[
'backward'
]}),
]
test_case_array_ops
=
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录