Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
60958d6b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
60958d6b
编写于
4月 14, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
!239 Add dynamic learning rate decay and review optimizer code
Merge pull request !239 from fanglei/master
上级
e9594596
7d700295
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
650 addition
and
134 deletion
+650
-134
mindspore/nn/dynamic_lr.py
mindspore/nn/dynamic_lr.py
+300
-0
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+5
-23
mindspore/nn/optim/momentum.py
mindspore/nn/optim/momentum.py
+5
-33
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+90
-12
mindspore/nn/optim/rmsprop.py
mindspore/nn/optim/rmsprop.py
+5
-27
mindspore/nn/optim/sgd.py
mindspore/nn/optim/sgd.py
+8
-32
tests/ut/python/nn/optim/test_optimizer.py
tests/ut/python/nn/optim/test_optimizer.py
+3
-7
tests/ut/python/nn/test_dynamic_lr.py
tests/ut/python/nn/test_dynamic_lr.py
+234
-0
未找到文件。
mindspore/nn/dynamic_lr.py
0 → 100644
浏览文件 @
60958d6b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dynamic learning rate"""
import
math
from
mindspore._checkparam
import
ParamValidator
as
validator
from
mindspore._checkparam
import
Rel
def
piecewise_constant_lr
(
milestone
,
learning_rates
):
r
"""
Get piecewise constant learning rate.
Calculate learning rate by given `milestone` and `learning_rates`. Let the value of `milestone` be
:math:`(M_1, M_2, ..., M_N)` and the value of `learning_rates` be :math:`(x_1, x_2, ..., x_N)`. N is the length of
`milestone`. Let the output learning rate be `y`.
.. math::
y[i] = x_t for i \in [M_{t-1}, M_t)
Args:
milestone (list[int]): A list of milestone. This list is a monotone increasing list.
learning_rates (list[float]): A list of learning rates.
Returns:
list[float]. The size of list is :math:`M_N`.
Examples:
>>> milestone = [2, 5, 10]
>>> learning_rates = [0.1, 0.05, 0.01]
>>> lr = piecewise_constant_lr(milestone, learning_rates)
[0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]
"""
validator
.
check_type
(
'milestone'
,
milestone
,
(
tuple
,
list
))
validator
.
check_type
(
'learning_rates'
,
learning_rates
,
(
tuple
,
list
))
if
len
(
milestone
)
!=
len
(
learning_rates
):
raise
ValueError
(
'The size of `milestone` must be same with the size of `learning_rates`.'
)
lr
=
[]
last_item
=
0
for
i
,
item
in
enumerate
(
milestone
):
validator
.
check_integer
(
f
'milestone[
{
i
}
]'
,
item
,
0
,
Rel
.
GT
)
validator
.
check_type
(
f
'learning_rates[
{
i
}
]'
,
learning_rates
[
i
],
[
float
])
if
item
<
last_item
:
raise
ValueError
(
f
'The value of milestone[
{
i
}
] must be greater than milestone[
{
i
-
1
}
]'
)
lr
+=
[
learning_rates
[
i
]]
*
(
item
-
last_item
)
last_item
=
item
return
lr
def
_check_inputs
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
):
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
)
validator
.
check_float_positive
(
'learning_rate'
,
learning_rate
)
validator
.
check_float_positive
(
'decay_rate'
,
decay_rate
)
validator
.
check_type
(
'is_stair'
,
is_stair
,
[
bool
])
def
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
=
False
):
r
"""
Calculate learning rate base on exponential decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = learning\_rate * decay\_rate^{\frac{current\_epoch}{decay\_epoch}}
Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
Args:
learning_rate (float): The initial value of learning rate.
decay_rate (float): The decay rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.
is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False.
Returns:
list[float]. The size of list is `total_step`.
Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
>>> total_step = 6
>>> step_per_epoch = 2
>>> decay_epoch = 1
>>> lr = exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
[0.1, 0.1, 0.09000000000000001, 0.09000000000000001, 0.08100000000000002, 0.08100000000000002]
"""
_check_inputs
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
)
lr
=
[]
for
i
in
range
(
total_step
):
if
is_stair
:
lr
.
append
(
learning_rate
*
decay_rate
**
math
.
floor
(
math
.
floor
(
i
/
step_per_epoch
)
/
decay_epoch
))
else
:
lr
.
append
(
learning_rate
*
decay_rate
**
(
math
.
floor
(
i
/
step_per_epoch
)
/
decay_epoch
))
return
lr
def
natural_exp_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
=
False
):
r
"""
Calculate learning rate base on natural exponential decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = learning\_rate * e^{-decay\_rate * current\_epoch}
Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
Args:
learning_rate (float): The initial value of learning rate.
decay_rate (float): The decay rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.
is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False.
Returns:
list[float]. The size of list is `total_step`.
Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
>>> total_step = 6
>>> step_per_epoch = 2
>>> decay_epoch = 2
>>> lr = natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
[0.1, 0.1, 0.1, 0.1, 0.016529888822158657, 0.016529888822158657]
"""
_check_inputs
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
)
function
=
lambda
x
,
y
:
x
if
is_stair
:
function
=
lambda
x
,
y
:
math
.
floor
(
x
/
y
)
*
y
lr
=
[]
for
i
in
range
(
total_step
):
lr
.
append
(
learning_rate
*
math
.
e
**
(
-
decay_rate
*
function
(
math
.
floor
(
i
/
step_per_epoch
),
decay_epoch
)))
return
lr
def
inverse_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
=
False
):
r
"""
Calculate learning rate base on inverse-time decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = learning\_rate / (1 + decay\_rate * current\_epoch / decay\_epoch)
Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
Args:
learning_rate (float): The initial value of learning rate.
decay_rate (float): The decay rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.
is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False.
Returns:
list[float]. The size of list is `total_step`.
Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.5
>>> total_step = 6
>>> step_per_epoch = 1
>>> decay_epoch = 1
>>> lr = inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
[0.1, 0.06666666666666667, 0.05, 0.04, 0.03333333333333333, 0.028571428571428574]
"""
_check_inputs
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
)
lr
=
[]
for
i
in
range
(
total_step
):
if
is_stair
:
lr
.
append
(
learning_rate
/
(
1
+
decay_rate
*
math
.
floor
(
math
.
floor
(
i
/
step_per_epoch
)
/
decay_epoch
)))
else
:
lr
.
append
(
learning_rate
/
(
1
+
decay_rate
*
math
.
floor
(
i
/
step_per_epoch
)
/
decay_epoch
))
return
lr
def
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step
,
step_per_epoch
,
decay_epoch
):
r
"""
Calculate learning rate base on cosine decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) *
(1 + cos(\frac{current\_epoch}{decay\_epoch}\pi))
Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
Args:
min_lr (float): The minimum value of learning rate.
max_lr (float): The maximum value of learning rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.
Returns:
list[float]. The size of list is `total_step`.
Examples:
>>> min_lr = 0.01
>>> max_lr = 0.1
>>> total_step = 6
>>> step_per_epoch = 2
>>> decay_epoch = 2
>>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
"""
validator
.
check_float_positive
(
'min_lr'
,
min_lr
)
validator
.
check_float_positive
(
'max_lr'
,
max_lr
)
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
)
delta
=
0.5
*
(
max_lr
-
min_lr
)
lr
=
[]
for
i
in
range
(
total_step
):
tmp_epoch
=
min
(
math
.
floor
(
i
/
step_per_epoch
),
decay_epoch
)
lr
.
append
(
min_lr
+
delta
*
(
1
+
math
.
cos
(
math
.
pi
*
tmp_epoch
/
decay_epoch
)))
return
lr
def
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
power
,
update_decay_epoch
=
False
):
r
"""
Calculate learning rate base on polynomial decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) *
(1 - tmp\_epoch / decay\_epoch)^{power} + end\_learning\_rate
Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch), current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
If `update_decay_epoch` is true, update the value of `decay_epoch` every epoch. The formula is
:math:`decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)`
Args:
learning_rate (float): The initial value of learning rate.
end_learning_rate (float): The end value of learning rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.
power (float): A value used to calculate decayed learning rate.
update_decay_epoch (bool): If true, update `decay_epoch`. Default: False.
Returns:
list[float]. The size of list is `total_step`.
Examples:
>>> learning_rate = 0.1
>>> end_learning_rate = 0.01
>>> total_step = 6
>>> step_per_epoch = 2
>>> decay_epoch = 2
>>> power = 0.5
>>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
"""
validator
.
check_float_positive
(
'learning_rate'
,
learning_rate
)
validator
.
check_float_positive
(
'end_learning_rate'
,
end_learning_rate
)
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
)
validator
.
check_type
(
'power'
,
power
,
[
float
])
validator
.
check_type
(
'update_decay_epoch'
,
update_decay_epoch
,
[
bool
])
function
=
lambda
x
,
y
:
(
x
,
min
(
x
,
y
))
if
update_decay_epoch
:
function
=
lambda
x
,
y
:
(
x
*
max
(
math
.
ceil
(
y
/
x
),
1
),
y
)
lr
=
[]
delta
=
learning_rate
-
end_learning_rate
for
i
in
range
(
total_step
):
current_epoch
=
math
.
floor
(
i
/
step_per_epoch
)
decay_epoch
,
tmp_epoch
=
function
(
decay_epoch
,
current_epoch
)
lr
.
append
(
delta
*
(
1
-
tmp_epoch
/
decay_epoch
)
**
power
+
end_learning_rate
)
return
lr
mindspore/nn/optim/adam.py
浏览文件 @
60958d6b
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""adam"""
from
typing
import
Iterable
import
numpy
as
np
from
mindspore.common
import
dtype
as
mstype
...
...
@@ -25,7 +24,7 @@ from mindspore.common.parameter import Parameter
from
mindspore.common.tensor
import
Tensor
from
mindspore._checkparam
import
ParamValidator
as
validator
from
mindspore._checkparam
import
Rel
from
.optimizer
import
Optimizer
,
apply_decay
,
grad_scale
from
.optimizer
import
Optimizer
_learning_rate_update_func
=
[
'linear'
,
'cos'
,
'sin'
]
...
...
@@ -168,22 +167,13 @@ class Adam(Optimizer):
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-8
,
use_locking
=
False
,
use_nesterov
=
False
,
weight_decay
=
0.0
,
loss_scale
=
1.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
Adam
,
self
).
__init__
(
learning_rate
,
params
)
super
(
Adam
,
self
).
__init__
(
learning_rate
,
params
,
weight_decay
,
loss_scale
,
decay_filter
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
)
validator
.
check_type
(
"use_locking"
,
use_locking
,
[
bool
])
validator
.
check_type
(
"use_nesterov"
,
use_nesterov
,
[
bool
])
validator
.
check_type
(
"loss_scale"
,
loss_scale
,
[
float
])
validator
.
check_number_range
(
"loss_scale"
,
loss_scale
,
1.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
)
self
.
dynamic_lr
=
False
if
isinstance
(
learning_rate
,
Iterable
)
or
\
(
isinstance
(
learning_rate
,
Tensor
)
and
learning_rate
.
dim
()
==
1
):
self
.
dynamic_lr
=
True
self
.
gather
=
P
.
GatherV2
()
self
.
assignadd
=
P
.
AssignAdd
()
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"global_step"
)
self
.
axis
=
0
self
.
beta1
=
Tensor
(
beta1
,
mstype
.
float32
)
self
.
beta2
=
Tensor
(
beta2
,
mstype
.
float32
)
self
.
beta1_power
=
Parameter
(
initializer
(
1
,
[
1
],
mstype
.
float32
),
name
=
"beta1_power"
)
...
...
@@ -196,8 +186,6 @@ class Adam(Optimizer):
self
.
decay_tf
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
opt
=
P
.
Adam
(
use_locking
,
use_nesterov
)
self
.
weight_decay
=
weight_decay
*
loss_scale
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
pow
=
P
.
Pow
()
self
.
sqrt
=
P
.
Sqrt
()
...
...
@@ -208,15 +196,9 @@ class Adam(Optimizer):
params
=
self
.
parameters
moment1
=
self
.
moment1
moment2
=
self
.
moment2
if
self
.
weight_decay
>
0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
apply_decay
,
self
.
weight_decay
),
self
.
decay_tf
,
params
,
gradients
)
if
self
.
reciprocal_scale
!=
1.0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
gradients
)
lr
=
self
.
learning_rate
if
self
.
dynamic_lr
:
lr
=
self
.
gather
(
self
.
learning_rate
,
self
.
global_step
,
self
.
axis
)
F
.
control_depend
(
lr
,
self
.
assignadd
(
self
.
global_step
,
self
.
one
))
gradients
=
self
.
decay_weight
(
gradients
)
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
beta1_power
=
self
.
beta1_power
*
self
.
beta1
self
.
beta1_power
=
beta1_power
...
...
mindspore/nn/optim/momentum.py
浏览文件 @
60958d6b
...
...
@@ -13,14 +13,9 @@
# limitations under the License.
# ============================================================================
"""momentum"""
from
typing
import
Iterable
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.common.initializer
import
initializer
from
mindspore.common.parameter
import
Parameter
import
mindspore.common.dtype
as
mstype
from
mindspore.common
import
Tensor
from
.optimizer
import
Optimizer
,
apply_decay
,
grad_scale
from
.optimizer
import
Optimizer
momentum_opt
=
C
.
MultitypeFuncGraph
(
"momentum_opt"
)
...
...
@@ -88,43 +83,20 @@ class Momentum(Optimizer):
"""
def
__init__
(
self
,
params
,
learning_rate
,
momentum
,
weight_decay
=
0.0
,
loss_scale
=
1.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
Momentum
,
self
).
__init__
(
learning_rate
,
params
)
super
(
Momentum
,
self
).
__init__
(
learning_rate
,
params
,
weight_decay
,
loss_scale
,
decay_filter
)
if
isinstance
(
momentum
,
float
)
and
momentum
<
0.0
:
raise
ValueError
(
"momentum should be at least 0.0, but got momentum {}"
.
format
(
momentum
))
if
isinstance
(
learning_rate
,
Iterable
)
or
\
(
isinstance
(
learning_rate
,
Tensor
)
and
learning_rate
.
dim
()
==
1
):
self
.
dynamic_lr
=
True
self
.
gather
=
P
.
GatherV2
()
self
.
assignadd
=
P
.
AssignAdd
()
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"global_step"
)
self
.
axis
=
0
else
:
self
.
dynamic_lr
=
False
self
.
gather
=
None
self
.
assignadd
=
None
self
.
global_step
=
None
self
.
axis
=
None
self
.
momentum
=
Parameter
(
momentum
,
name
=
"momentum"
)
self
.
params
=
self
.
parameters
self
.
moments
=
self
.
params
.
clone
(
prefix
=
"moments"
,
init
=
'zeros'
)
self
.
decay_tf
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
opt
=
P
.
ApplyMomentum
()
self
.
weight_decay
=
weight_decay
*
loss_scale
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
one
=
Tensor
(
1
,
mstype
.
int32
)
def
construct
(
self
,
gradients
):
params
=
self
.
params
moments
=
self
.
moments
if
self
.
weight_decay
>
0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
apply_decay
,
self
.
weight_decay
),
self
.
decay_tf
,
params
,
gradients
)
if
self
.
reciprocal_scale
!=
1.0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
gradients
)
if
self
.
dynamic_lr
:
lr
=
self
.
gather
(
self
.
learning_rate
,
self
.
global_step
,
self
.
axis
)
F
.
control_depend
(
lr
,
self
.
assignadd
(
self
.
global_step
,
self
.
one
))
else
:
lr
=
self
.
learning_rate
gradients
=
self
.
decay_weight
(
gradients
)
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
success
=
self
.
hyper_map
(
F
.
partial
(
momentum_opt
,
self
.
opt
,
lr
,
self
.
momentum
),
gradients
,
params
,
moments
)
return
success
mindspore/nn/optim/optimizer.py
浏览文件 @
60958d6b
...
...
@@ -17,9 +17,11 @@ from typing import Iterable
import
numpy
as
np
import
mindspore
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.nn.cell
import
Cell
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
from
mindspore.common.initializer
import
initializer
from
mindspore._checkparam
import
ParamValidator
as
validator
from
mindspore._checkparam
import
Rel
from
mindspore.common.tensor
import
Tensor
...
...
@@ -42,34 +44,110 @@ class Optimizer(Cell):
Args:
learning_rate (float): A floating point value for the learning rate. Should be greater than 0.
parameters (list): A list of parameter, which will be updated. The element in `parameters`
should be class mindspore.Parameter.
should be class mindspore.Parameter.
weight_decay (float): A floating point value for the weight decay. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Default: 1.0. Should be greater than 0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda
x: 'beta' not in x.name and 'gamma' not in x.name.
Raises:
ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1.
TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable.
"""
def
__init__
(
self
,
learning_rate
,
parameters
):
def
__init__
(
self
,
learning_rate
,
parameters
,
weight_decay
=
0.0
,
loss_scale
=
1.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
Optimizer
,
self
).
__init__
()
if
isinstance
(
learning_rate
,
float
):
self
.
dynamic_lr
=
False
self
.
gather
=
None
self
.
assignadd
=
None
self
.
global_step
=
None
validator
.
check_number_range
(
"learning rate"
,
learning_rate
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
)
elif
isinstance
(
learning_rate
,
Iterable
):
learning_rate
=
Tensor
(
np
.
array
(
list
(
learning_rate
)).
astype
(
np
.
float32
))
elif
isinstance
(
learning_rate
,
Tensor
):
if
learning_rate
.
dim
()
>
1
:
raise
ValueError
(
"Learning rate should be a 0 or 1 dim `Tensor`,"
f
"but got
{
learning_rate
.
dim
()
}
."
)
else
:
raise
TypeError
(
"Learning rate should be float, Tensor or Iterable."
)
self
.
dynamic_lr
=
True
self
.
gather
=
P
.
GatherV2
()
self
.
assignadd
=
P
.
AssignAdd
()
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
],
mindspore
.
int32
),
name
=
'global_step'
)
if
isinstance
(
learning_rate
,
Iterable
):
learning_rate
=
Tensor
(
np
.
array
(
list
(
learning_rate
)).
astype
(
np
.
float32
))
elif
isinstance
(
learning_rate
,
Tensor
):
if
learning_rate
.
dim
()
>
1
:
raise
ValueError
(
"Learning rate should be a 0 or 1 dim `Tensor`,"
f
"but got
{
learning_rate
.
dim
()
}
."
)
if
learning_rate
.
dim
()
==
1
and
learning_rate
.
size
()
<
2
:
logger
.
warning
(
"If want to use the dynamic learning rate, please make sure that the number "
"of elements in the list, tuple or tensor passed is greater than 1."
)
else
:
raise
TypeError
(
"Learning rate should be float, Tensor or Iterable."
)
if
loss_scale
<=
0.0
:
raise
ValueError
(
"Loss scale should be greater than 0, but got {}"
.
format
(
loss_scale
))
if
weight_decay
<
0.0
:
raise
ValueError
(
"Weight decay should be equal or greater than 0, but got {}"
.
format
(
weight_decay
))
if
isinstance
(
learning_rate
,
Tensor
)
and
learning_rate
.
dim
()
==
1
and
learning_rate
.
size
()
<
2
:
logger
.
warning
(
"If want to use the dynamic learning rate, please make sure that "
"the number of elements in the list, tuple or tensor passed is greater than 1."
)
self
.
learning_rate
=
Parameter
(
learning_rate
,
name
=
"learning_rate"
)
self
.
parameters
=
ParameterTuple
(
parameters
)
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
weight_decay
=
weight_decay
*
loss_scale
self
.
decay_flags
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
if
not
self
.
parameters
:
raise
ValueError
(
"optimizer got an empty parameter list."
)
def
decay_weight
(
self
,
gradients
):
"""
Weight decay.
An approach to reduce the overfitting of a deep learning neural network model.
Args:
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
`self.parameters`.
Returns:
tuple[Tensor], The gradients after weight decay.
"""
if
self
.
weight_decay
>
0
:
params
=
self
.
params
gradients
=
self
.
hyper_map
(
F
.
partial
(
apply_decay
,
self
.
weight_decay
),
self
.
decay_flags
,
params
,
gradients
)
return
gradients
def
scale_grad
(
self
,
gradients
):
"""
Loss scale for mixed precision.
An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
network.
Args:
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
`self.parameters`.
Returns:
tuple[Tensor], The gradients after loss scale.
"""
if
self
.
reciprocal_scale
!=
1.0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
gradients
)
return
gradients
def
get_lr
(
self
):
"""
Get the learning rate of current step.
Returns:
float, the learning rate of current step.
"""
lr
=
self
.
learning_rate
if
self
.
dynamic_lr
:
lr
=
self
.
gather
(
self
.
learning_rate
,
self
.
global_step
,
0
)
F
.
control_depend
(
lr
,
self
.
assignadd
(
self
.
global_step
,
1
))
return
lr
def
construct
(
self
,
*
hyper_params
):
raise
NotImplementedError
...
...
mindspore/nn/optim/rmsprop.py
浏览文件 @
60958d6b
...
...
@@ -14,12 +14,8 @@
# ============================================================================
"""rmsprop"""
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.common.initializer
import
initializer
from
mindspore.common.parameter
import
Parameter
from
mindspore._checkparam
import
ParamValidator
as
validator
import
mindspore.common.dtype
as
mstype
from
mindspore.common
import
Tensor
from
.optimizer
import
Optimizer
,
grad_scale
,
apply_decay
from
.optimizer
import
Optimizer
rmsprop_opt
=
C
.
MultitypeFuncGraph
(
"rmsprop_opt"
)
centered_rmsprop_opt
=
C
.
MultitypeFuncGraph
(
"rmsprop_opt"
)
...
...
@@ -138,7 +134,7 @@ class RMSProp(Optimizer):
def
__init__
(
self
,
params
,
learning_rate
=
0.1
,
decay
=
0.9
,
momentum
=
0.0
,
epsilon
=
1e-10
,
use_locking
=
False
,
centered
=
False
,
loss_scale
=
1.0
,
weight_decay
=
0.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
RMSProp
,
self
).
__init__
(
learning_rate
,
params
)
super
(
RMSProp
,
self
).
__init__
(
learning_rate
,
params
,
weight_decay
,
loss_scale
,
decay_filter
)
if
isinstance
(
momentum
,
float
)
and
momentum
<
0.0
:
raise
ValueError
(
"momentum should be at least 0.0, but got momentum {}"
.
format
(
momentum
))
...
...
@@ -157,15 +153,6 @@ class RMSProp(Optimizer):
else
:
self
.
opt
=
P
.
ApplyRMSProp
(
use_locking
)
self
.
dynamic_lr
=
False
if
not
isinstance
(
learning_rate
,
float
):
self
.
dynamic_lr
=
True
self
.
gather
=
P
.
GatherV2
()
self
.
assignadd
=
P
.
AssignAdd
()
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"global_step"
)
self
.
axis
=
0
self
.
one
=
Tensor
(
1
,
mstype
.
int32
)
self
.
momentum
=
momentum
self
.
ms
=
self
.
parameters
.
clone
(
prefix
=
"mean_square"
,
init
=
'zeros'
)
...
...
@@ -173,21 +160,12 @@ class RMSProp(Optimizer):
self
.
hyper_map
=
C
.
HyperMap
()
self
.
decay
=
decay
self
.
decay_tf
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
weight_decay
=
weight_decay
*
loss_scale
def
construct
(
self
,
gradients
):
params
=
self
.
parameters
if
self
.
weight_decay
>
0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
apply_decay
,
self
.
weight_decay
),
self
.
decay_tf
,
params
,
gradients
)
if
self
.
reciprocal_scale
!=
1.0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
gradients
)
if
self
.
dynamic_lr
:
lr
=
self
.
gather
(
self
.
learning_rate
,
self
.
global_step
,
self
.
axis
)
F
.
control_depend
(
lr
,
self
.
assignadd
(
self
.
global_step
,
self
.
one
))
else
:
lr
=
self
.
learning_rate
gradients
=
self
.
decay_weight
(
gradients
)
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
if
self
.
centered
:
success
=
self
.
hyper_map
(
F
.
partial
(
centered_rmsprop_opt
,
self
.
opt
,
lr
,
self
.
decay
,
self
.
epsilon
,
self
.
momentum
),
params
,
self
.
mg
,
self
.
ms
,
self
.
moment
,
gradients
)
...
...
mindspore/nn/optim/sgd.py
浏览文件 @
60958d6b
...
...
@@ -14,11 +14,9 @@
# ============================================================================
"""sgd"""
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.common.initializer
import
initializer
from
mindspore.common.parameter
import
Parameter
from
mindspore._checkparam
import
ParamValidator
as
validator
import
mindspore.common.dtype
as
mstype
from
.optimizer
import
Optimizer
,
grad_scale
from
.optimizer
import
Optimizer
sgd_opt
=
C
.
MultitypeFuncGraph
(
"sgd_opt"
)
...
...
@@ -83,7 +81,7 @@ class SGD(Optimizer):
def
__init__
(
self
,
params
,
learning_rate
=
0.1
,
momentum
=
0.0
,
dampening
=
0.0
,
weight_decay
=
0.0
,
nesterov
=
False
,
loss_scale
=
1.0
):
super
(
SGD
,
self
).
__init__
(
learning_rate
,
params
)
super
(
SGD
,
self
).
__init__
(
learning_rate
,
params
,
weight_decay
,
loss_scale
)
if
isinstance
(
momentum
,
float
)
and
momentum
<
0.0
:
raise
ValueError
(
"momentum should be at least 0.0, but got momentum {}"
.
format
(
momentum
))
...
...
@@ -92,44 +90,22 @@ class SGD(Optimizer):
raise
ValueError
(
"dampening should be at least 0.0, but got dampening {}"
.
format
(
dampening
))
self
.
dampening
=
dampening
if
weight_decay
<
0.0
:
raise
ValueError
(
"weight_decay should be at least 0.0, but got weight_decay {}"
.
format
(
weight_decay
))
self
.
weight_decay
=
weight_decay
validator
.
check_type
(
"nesterov"
,
nesterov
,
[
bool
])
self
.
nesterov
=
nesterov
self
.
opt
=
P
.
SGD
(
dampening
,
weight_decay
,
nesterov
)
self
.
dynamic_lr
=
False
self
.
gather
=
None
self
.
global_step
=
None
self
.
axis
=
None
if
not
isinstance
(
learning_rate
,
float
):
self
.
dynamic_lr
=
True
self
.
gather
=
P
.
GatherV2
()
self
.
assignadd
=
P
.
AssignAdd
()
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"global_step"
)
self
.
axis
=
0
self
.
momentum
=
Parameter
(
momentum
,
name
=
"momentum"
)
self
.
params
=
self
.
parameters
self
.
accum
=
self
.
params
.
clone
(
prefix
=
"accum"
,
init
=
'zeros'
)
self
.
stat
=
self
.
params
.
clone
(
prefix
=
"stat"
,
init
=
'ones'
)
self
.
accum
=
self
.
parameters
.
clone
(
prefix
=
"accum"
,
init
=
'zeros'
)
self
.
stat
=
self
.
parameters
.
clone
(
prefix
=
"stat"
,
init
=
'ones'
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
weight_decay
=
weight_decay
*
loss_scale
self
.
reciprocal_scale
=
1.0
/
loss_scale
def
construct
(
self
,
gradients
):
params
=
self
.
params
params
=
self
.
param
eter
s
accum
=
self
.
accum
stat
=
self
.
stat
if
self
.
reciprocal_scale
!=
1.0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
gradients
)
if
self
.
dynamic_lr
:
lr
=
self
.
gather
(
self
.
learning_rate
,
self
.
global_step
,
self
.
axis
)
F
.
control_depend
(
lr
,
self
.
assignadd
(
self
.
global_step
,
1
))
else
:
lr
=
self
.
learning_rate
gradients
=
self
.
decay_weight
(
gradients
)
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
success
=
self
.
hyper_map
(
F
.
partial
(
sgd_opt
,
self
.
opt
,
lr
,
self
.
momentum
),
gradients
,
params
,
accum
,
stat
)
return
success
tests/ut/python/nn/optim/test_optimizer.py
浏览文件 @
60958d6b
...
...
@@ -15,17 +15,11 @@
""" test optimizer """
import
numpy
as
np
import
pytest
from
mindspore.nn.optim
import
Optimizer
,
SGD
,
Adam
,
AdamWeightDecay
,
AdamWeightDecayDynamicLR
from
mindspore
import
Tensor
from
mindspore.nn.optim
import
Optimizer
,
SGD
,
Adam
,
AdamWeightDecay
,
AdamWeightDecayDynamicLR
from
mindspore.common.parameter
import
Parameter
gradient
=
Tensor
(
np
.
zeros
([
1
,
2
,
3
]))
accumulation
=
gradient
variable
=
accumulation
paramsTensor
=
Tensor
(
np
.
zeros
([
1
,
2
,
3
]))
class
IterableObjc
:
def
__iter__
(
self
):
cont
=
0
...
...
@@ -56,6 +50,7 @@ class TestAdam():
def
test_construct
(
self
):
with
pytest
.
raises
(
TypeError
):
gradient
=
Tensor
(
np
.
zeros
([
1
,
2
,
3
]))
adam
=
Adam
(
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-8
,
use_locking
=
False
,
use_nesterov
=
False
,
weight_decay
=
0.0
,
loss_scale
=
1.0
)
adam
.
construct
(
gradient
)
...
...
@@ -105,4 +100,5 @@ class TestUnsupportParam():
def
test_Sgd_init
(
self
):
with
pytest
.
raises
(
TypeError
):
paramsTensor
=
Tensor
(
np
.
zeros
([
1
,
2
,
3
]))
SGD
(
paramsTensor
)
tests/ut/python/nn/test_dynamic_lr.py
0 → 100644
浏览文件 @
60958d6b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Test Dynamic Learning Rate """
import
pytest
import
mindspore
from
mindspore.nn
import
dynamic_lr
as
dr
milestone
=
[
10
,
20
,
30
]
learning_rates
=
[
0.1
,
0.05
,
0.01
]
learning_rate
=
0.1
end_learning_rate
=
0.01
decay_rate
=
0.9
total_step
=
30
step_per_epoch
=
3
decay_epoch
=
2
min_lr
=
0.01
max_lr
=
0.1
power
=
0.5
class
TestInputs
:
def
test_milestone1
(
self
):
milestone1
=
1
with
pytest
.
raises
(
ValueError
):
dr
.
piecewise_constant_lr
(
milestone1
,
learning_rates
)
def
test_milestone2
(
self
):
milestone1
=
[
20
,
10
,
1
]
with
pytest
.
raises
(
ValueError
):
dr
.
piecewise_constant_lr
(
milestone1
,
learning_rates
)
milestone2
=
[
1.0
,
2.0
,
True
]
with
pytest
.
raises
(
ValueError
):
dr
.
piecewise_constant_lr
(
milestone2
,
learning_rates
)
def
test_learning_rates1
(
self
):
lr
=
True
with
pytest
.
raises
(
ValueError
):
dr
.
piecewise_constant_lr
(
milestone
,
lr
)
def
test_learning_rates2
(
self
):
lr
=
[
1
,
2
,
1
]
with
pytest
.
raises
(
ValueError
):
dr
.
piecewise_constant_lr
(
milestone
,
lr
)
def
test_learning_rate_type
(
self
):
lr
=
True
with
pytest
.
raises
(
TypeError
):
dr
.
exponential_decay_lr
(
lr
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
)
with
pytest
.
raises
(
TypeError
):
dr
.
polynomial_decay_lr
(
lr
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
power
)
def
test_learning_rate_value
(
self
):
lr
=
-
1.0
with
pytest
.
raises
(
ValueError
):
dr
.
exponential_decay_lr
(
lr
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
)
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
lr
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
power
)
def
test_end_learning_rate_type
(
self
):
lr
=
True
with
pytest
.
raises
(
TypeError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
lr
,
total_step
,
step_per_epoch
,
decay_epoch
,
power
)
def
test_end_learning_rate_value
(
self
):
lr
=
-
1.0
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
lr
,
total_step
,
step_per_epoch
,
decay_epoch
,
power
)
def
test_decay_rate_type
(
self
):
rate
=
'a'
with
pytest
.
raises
(
TypeError
):
dr
.
exponential_decay_lr
(
learning_rate
,
rate
,
total_step
,
step_per_epoch
,
decay_epoch
)
def
test_decay_rate_value
(
self
):
rate
=
-
1.0
with
pytest
.
raises
(
ValueError
):
dr
.
exponential_decay_lr
(
learning_rate
,
rate
,
total_step
,
step_per_epoch
,
decay_epoch
)
def
test_total_step1
(
self
):
total_step1
=
2.0
with
pytest
.
raises
(
ValueError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step1
,
step_per_epoch
,
decay_epoch
)
with
pytest
.
raises
(
ValueError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step1
,
step_per_epoch
,
decay_epoch
)
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step1
,
step_per_epoch
,
decay_epoch
,
power
)
def
test_total_step2
(
self
):
total_step1
=
-
1
with
pytest
.
raises
(
ValueError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step1
,
step_per_epoch
,
decay_epoch
)
with
pytest
.
raises
(
ValueError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step1
,
step_per_epoch
,
decay_epoch
)
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step1
,
step_per_epoch
,
decay_epoch
,
power
)
def
test_step_per_epoch1
(
self
):
step_per_epoch1
=
True
with
pytest
.
raises
(
ValueError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch1
,
decay_epoch
)
with
pytest
.
raises
(
ValueError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step
,
step_per_epoch1
,
decay_epoch
)
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch1
,
decay_epoch
,
power
)
def
test_step_per_epoch2
(
self
):
step_per_epoch1
=
-
1
with
pytest
.
raises
(
ValueError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch1
,
decay_epoch
)
with
pytest
.
raises
(
ValueError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step
,
step_per_epoch1
,
decay_epoch
)
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch1
,
decay_epoch
,
power
)
def
test_decay_epoch1
(
self
):
decay_epoch1
=
'm'
with
pytest
.
raises
(
ValueError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch1
)
with
pytest
.
raises
(
ValueError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step
,
step_per_epoch
,
decay_epoch1
)
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch1
,
power
)
def
test_decay_epoch2
(
self
):
decay_epoch1
=
-
1
with
pytest
.
raises
(
ValueError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch1
)
with
pytest
.
raises
(
ValueError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step
,
step_per_epoch
,
decay_epoch1
)
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch1
,
power
)
def
test_is_stair
(
self
):
is_stair
=
1
with
pytest
.
raises
(
ValueError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
)
def
test_min_lr_type
(
self
):
min_lr1
=
True
with
pytest
.
raises
(
TypeError
):
dr
.
cosine_decay_lr
(
min_lr1
,
max_lr
,
total_step
,
step_per_epoch
,
decay_epoch
)
def
test_min_lr_value
(
self
):
min_lr1
=
-
1.0
with
pytest
.
raises
(
ValueError
):
dr
.
cosine_decay_lr
(
min_lr1
,
max_lr
,
total_step
,
step_per_epoch
,
decay_epoch
)
def
test_max_lr_type
(
self
):
max_lr1
=
'a'
with
pytest
.
raises
(
TypeError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr1
,
total_step
,
step_per_epoch
,
decay_epoch
)
def
test_max_lr_value
(
self
):
max_lr1
=
-
1.0
with
pytest
.
raises
(
ValueError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr1
,
total_step
,
step_per_epoch
,
decay_epoch
)
def
test_power
(
self
):
power1
=
True
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
power1
)
def
test_update_decay_epoch
(
self
):
update_decay_epoch
=
1
with
pytest
.
raises
(
ValueError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
power
,
update_decay_epoch
)
def
test_learning_rate
():
lr
=
dr
.
piecewise_constant_lr
(
milestone
,
learning_rates
)
assert
len
(
lr
)
==
milestone
[
-
1
]
def
test_exponential_decay
():
lr1
=
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
)
assert
len
(
lr1
)
==
total_step
lr2
=
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
True
)
assert
len
(
lr2
)
==
total_step
def
test_enatural_exp_decay
():
lr1
=
dr
.
natural_exp_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
)
assert
len
(
lr1
)
==
total_step
lr2
=
dr
.
natural_exp_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
True
)
assert
len
(
lr2
)
==
total_step
def
test_inverse_decay
():
lr1
=
dr
.
inverse_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
)
assert
len
(
lr1
)
==
total_step
lr2
=
dr
.
inverse_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
True
)
assert
len
(
lr2
)
==
total_step
def
test_cosine_decay
():
lr
=
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step
,
step_per_epoch
,
decay_epoch
)
assert
len
(
lr
)
==
total_step
def
test_polynomial_decay
():
lr1
=
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
power
)
assert
len
(
lr1
)
==
total_step
lr2
=
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
power
,
True
)
assert
len
(
lr2
)
==
total_step
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录