Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
834a4071
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
834a4071
编写于
4月 21, 2020
作者:
L
leilei_snow
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the function of checking nan or inf
上级
46acf238
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
31 addition
and
3 deletion
+31
-3
mindspore/_checkparam.py
mindspore/_checkparam.py
+11
-0
mindspore/nn/dynamic_lr.py
mindspore/nn/dynamic_lr.py
+20
-3
未找到文件。
mindspore/_checkparam.py
浏览文件 @
834a4071
...
...
@@ -15,6 +15,7 @@
"""Check parameters."""
import
re
import
inspect
import
math
from
enum
import
Enum
from
functools
import
reduce
,
wraps
from
itertools
import
repeat
...
...
@@ -318,6 +319,16 @@ class Validator:
raise
ValueError
(
f
'
{
msg_prefix
}
type of `
{
arg_name
}
` should be one of
{
type_names
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
@
staticmethod
def
check_float_legal_value
(
arg_name
,
arg_value
,
prim_name
):
"""Checks whether a legal value of float type"""
msg_prefix
=
f
'For
\'
{
prim_name
}
\'
the'
if
prim_name
else
"The"
if
isinstance
(
arg_value
,
float
):
if
math
.
isinf
(
arg_value
)
or
math
.
isnan
(
arg_value
):
raise
ValueError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be legal value, but got
{
arg_value
}
."
)
return
arg_value
raise
TypeError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be float."
)
class
ParamValidator
:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
...
...
mindspore/nn/dynamic_lr.py
浏览文件 @
834a4071
...
...
@@ -28,7 +28,7 @@ def piecewise_constant_lr(milestone, learning_rates):
`milestone`. Let the output learning rate be `y`.
.. math::
y[i] = x_t
for
i \in [M_{t-1}, M_t)
y[i] = x_t
,\ for\
i \in [M_{t-1}, M_t)
Args:
milestone (list[int]): A list of milestone. This list is a monotone increasing list.
...
...
@@ -52,7 +52,7 @@ def piecewise_constant_lr(milestone, learning_rates):
last_item
=
0
for
i
,
item
in
enumerate
(
milestone
):
validator
.
check_integer
(
f
'milestone[
{
i
}
]'
,
item
,
0
,
Rel
.
GT
,
None
)
validator
.
check_
value_type
(
f
'learning_rates[
{
i
}
]'
,
learning_rates
[
i
],
[
float
],
None
)
validator
.
check_
float_legal_value
(
f
'learning_rates[
{
i
}
]'
,
learning_rates
[
i
],
None
)
if
item
<
last_item
:
raise
ValueError
(
f
'The value of milestone[
{
i
}
] must be greater than milestone[
{
i
-
1
}
]'
)
lr
+=
[
learning_rates
[
i
]]
*
(
item
-
last_item
)
...
...
@@ -66,7 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_float_positive
(
'learning_rate'
,
learning_rate
,
None
)
validator
.
check_float_legal_value
(
'learning_rate'
,
learning_rate
,
None
)
validator
.
check_float_positive
(
'decay_rate'
,
decay_rate
,
None
)
validator
.
check_float_legal_value
(
'decay_rate'
,
decay_rate
,
None
)
validator
.
check_value_type
(
'is_stair'
,
is_stair
,
[
bool
],
None
)
...
...
@@ -229,7 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
"""
validator
.
check_float_positive
(
'min_lr'
,
min_lr
,
None
)
validator
.
check_float_legal_value
(
'min_lr'
,
min_lr
,
None
)
validator
.
check_float_positive
(
'max_lr'
,
max_lr
,
None
)
validator
.
check_float_legal_value
(
'max_lr'
,
max_lr
,
None
)
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
,
None
)
...
...
@@ -280,11 +284,14 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
"""
validator
.
check_float_positive
(
'learning_rate'
,
learning_rate
,
None
)
validator
.
check_float_legal_value
(
'learning_rate'
,
learning_rate
,
None
)
validator
.
check_float_positive
(
'end_learning_rate'
,
end_learning_rate
,
None
)
validator
.
check_float_legal_value
(
'end_learning_rate'
,
end_learning_rate
,
None
)
validator
.
check_float_positive
(
'power'
,
power
,
None
)
validator
.
check_float_legal_value
(
'power'
,
power
,
None
)
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_value_type
(
'power'
,
power
,
[
float
],
None
)
validator
.
check_value_type
(
'update_decay_epoch'
,
update_decay_epoch
,
[
bool
],
None
)
function
=
lambda
x
,
y
:
(
x
,
min
(
x
,
y
))
...
...
@@ -298,3 +305,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
decay_epoch
,
tmp_epoch
=
function
(
decay_epoch
,
current_epoch
)
lr
.
append
(
delta
*
(
1
-
tmp_epoch
/
decay_epoch
)
**
power
+
end_learning_rate
)
return
lr
__all__
=
[
'piecewise_constant_lr'
,
'exponential_decay_lr'
,
'natural_exp_decay_lr'
,
'inverse_decay_lr'
,
'cosine_decay_lr'
,
'polynomial_decay_lr'
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录