Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8cbbbd95
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8cbbbd95
编写于
4月 17, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add cell name to error message
上级
2d31ae97
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
272 addition
and
175 deletion
+272
-175
mindspore/_checkparam.py
mindspore/_checkparam.py
+114
-14
mindspore/nn/cell.py
mindspore/nn/cell.py
+4
-0
mindspore/nn/dynamic_lr.py
mindspore/nn/dynamic_lr.py
+23
-23
mindspore/nn/layer/basic.py
mindspore/nn/layer/basic.py
+3
-3
mindspore/nn/layer/embedding.py
mindspore/nn/layer/embedding.py
+2
-2
mindspore/nn/layer/image.py
mindspore/nn/layer/image.py
+11
-11
mindspore/nn/layer/lstm.py
mindspore/nn/layer/lstm.py
+2
-2
mindspore/nn/layer/pooling.py
mindspore/nn/layer/pooling.py
+21
-30
mindspore/nn/metrics/fbeta.py
mindspore/nn/metrics/fbeta.py
+2
-2
mindspore/nn/metrics/precision.py
mindspore/nn/metrics/precision.py
+2
-2
mindspore/nn/metrics/recall.py
mindspore/nn/metrics/recall.py
+2
-2
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+17
-17
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+20
-18
mindspore/nn/optim/lamb.py
mindspore/nn/optim/lamb.py
+17
-17
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+2
-2
mindspore/nn/optim/rmsprop.py
mindspore/nn/optim/rmsprop.py
+3
-3
mindspore/nn/optim/sgd.py
mindspore/nn/optim/sgd.py
+2
-2
mindspore/ops/op_info_register.py
mindspore/ops/op_info_register.py
+2
-2
mindspore/train/amp.py
mindspore/train/amp.py
+9
-9
mindspore/train/loss_scale_manager.py
mindspore/train/loss_scale_manager.py
+2
-2
tests/ut/python/nn/test_dynamic_lr.py
tests/ut/python/nn/test_dynamic_lr.py
+6
-6
tests/ut/python/nn/test_psnr.py
tests/ut/python/nn/test_psnr.py
+1
-1
tests/ut/python/nn/test_ssim.py
tests/ut/python/nn/test_ssim.py
+2
-2
tests/ut/python/ops/test_nn_ops.py
tests/ut/python/ops/test_nn_ops.py
+2
-2
tests/ut/python/pynative_mode/nn/test_pooling.py
tests/ut/python/pynative_mode/nn/test_pooling.py
+1
-1
未找到文件。
mindspore/_checkparam.py
浏览文件 @
8cbbbd95
...
...
@@ -17,7 +17,7 @@ import re
from
enum
import
Enum
from
functools
import
reduce
from
itertools
import
repeat
from
collections
import
Iterable
from
collections
.abc
import
Iterable
import
numpy
as
np
from
mindspore
import
log
as
logger
...
...
@@ -98,7 +98,7 @@ class Validator:
"""validator for checking input parameters"""
@
staticmethod
def
check
(
arg_name
,
arg_value
,
value_name
,
value
,
rel
=
Rel
.
EQ
,
prim_name
=
None
):
def
check
(
arg_name
,
arg_value
,
value_name
,
value
,
rel
=
Rel
.
EQ
,
prim_name
=
None
,
excp_cls
=
ValueError
):
"""
Method for judging relation between two int values or list/tuple made up of ints.
...
...
@@ -108,8 +108,8 @@ class Validator:
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
f
'
{
value_name
}
:
{
value
}
'
)
msg_prefix
=
f
'For
{
prim_name
}
the'
if
prim_name
else
"The"
raise
ValueError
(
f
'
{
msg_prefix
}
`
{
arg_name
}
` should be
{
rel_str
}
, but got
{
arg_value
}
.'
)
msg_prefix
=
f
'For
\'
{
prim_name
}
\'
the'
if
prim_name
else
"The"
raise
excp_cls
(
f
'
{
msg_prefix
}
`
{
arg_name
}
` should be
{
rel_str
}
, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_integer
(
arg_name
,
arg_value
,
value
,
rel
,
prim_name
):
...
...
@@ -118,8 +118,17 @@ class Validator:
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
or
isinstance
(
arg_value
,
bool
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'For
{
prim_name
}
the `
{
arg_name
}
` should be an int and must
{
rel_str
}
,'
f
' but got
{
arg_value
}
.'
)
msg_prefix
=
f
'For
\'
{
prim_name
}
\'
the'
if
prim_name
else
"The"
raise
ValueError
(
f
'
{
msg_prefix
}
`
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_number
(
arg_name
,
arg_value
,
value
,
rel
,
prim_name
):
"""Integer value judgment."""
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_name
}
` must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
...
...
@@ -133,9 +142,46 @@ class Validator:
f
' but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_number_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
,
prim_name
):
"""Method for checking whether a numeric value is in some range."""
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_name
}
` should be in range
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_string
(
arg_name
,
arg_value
,
valid_values
,
prim_name
):
"""Checks whether a string is in some value list"""
if
isinstance
(
arg_value
,
str
)
and
arg_value
in
valid_values
:
return
arg_value
if
len
(
valid_values
)
==
1
:
raise
ValueError
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_name
}
` should be str and must be
{
valid_values
[
0
]
}
,'
f
' but got
{
arg_value
}
.'
)
raise
ValueError
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_name
}
` should be str and must be one of
{
valid_values
}
,'
f
' but got
{
arg_value
}
.'
)
@
staticmethod
def
check_pad_value_by_mode
(
pad_mode
,
padding
,
prim_name
):
"""Validates value of padding according to pad_mode"""
if
pad_mode
!=
'pad'
and
padding
!=
0
:
raise
ValueError
(
f
"For '
{
prim_name
}
', padding must be zero when pad_mode is '
{
pad_mode
}
'."
)
return
padding
@
staticmethod
def
check_float_positive
(
arg_name
,
arg_value
,
prim_name
):
"""Float type judgment."""
msg_prefix
=
f
'For
\'
{
prim_name
}
\'
the'
if
prim_name
else
"The"
if
isinstance
(
arg_value
,
float
):
if
arg_value
>
0
:
return
arg_value
raise
ValueError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be positive, but got
{
arg_value
}
."
)
raise
TypeError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be float."
)
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
prim_name
):
"""Check whether some type is sublcass of another type"""
"""Check
s
whether some type is sublcass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
template_type
=
(
template_type
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
...
...
@@ -143,16 +189,44 @@ class Validator:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
the type of `
{
arg_name
}
` should be subclass'
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
))
}
, but got
{
type_str
}
.'
)
@
staticmethod
def
check_const_input
(
arg_name
,
arg_value
,
prim_name
):
"""Check valid value."""
if
arg_value
is
None
:
raise
ValueError
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_name
}
` must be a const input, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_scalar_type_same
(
args
,
valid_values
,
prim_name
):
"""check whether the types of inputs are the same."""
def
_check_tensor_type
(
arg
):
arg_key
,
arg_val
=
arg
elem_type
=
arg_val
if
not
elem_type
in
valid_values
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg_key
}
` should be in
{
valid_values
}
,'
f
' but `
{
arg_key
}
` is
{
elem_type
}
.'
)
return
(
arg_key
,
elem_type
)
def
_check_types_same
(
arg1
,
arg2
):
arg1_name
,
arg1_type
=
arg1
arg2_name
,
arg2_type
=
arg2
if
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg2_name
}
` should be same as `
{
arg1_name
}
`,'
f
' but `
{
arg1_name
}
` is
{
arg1_type
}
and `
{
arg2_name
}
` is
{
arg2_type
}
.'
)
return
arg1
elem_types
=
map
(
_check_tensor_type
,
args
.
items
())
reduce
(
_check_types_same
,
elem_types
)
@
staticmethod
def
check_tensor_type_same
(
args
,
valid_values
,
prim_name
):
"""
check
whether the element types of input tensors are the same."""
"""
Checks
whether the element types of input tensors are the same."""
def
_check_tensor_type
(
arg
):
arg_key
,
arg_val
=
arg
Validator
.
check_subclass
(
arg_key
,
arg_val
,
mstype
.
tensor
,
prim_name
)
elem_type
=
arg_val
.
element_type
()
if
not
elem_type
in
valid_values
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
element type of `
{
arg_key
}
` should be in
{
valid_values
}
,'
f
' but `
{
arg_key
}
` is
{
elem_type
}
.'
)
f
' but
element type of
`
{
arg_key
}
` is
{
elem_type
}
.'
)
return
(
arg_key
,
elem_type
)
def
_check_types_same
(
arg1
,
arg2
):
...
...
@@ -168,8 +242,13 @@ class Validator:
@
staticmethod
def
check_scalar_or_tensor_type_same
(
args
,
valid_values
,
prim_name
):
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
def
check_scalar_or_tensor_type_same
(
args
,
valid_values
,
prim_name
,
allow_mix
=
False
):
"""
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
"""
def
_check_argument_type
(
arg
):
arg_key
,
arg_val
=
arg
if
isinstance
(
arg_val
,
type
(
mstype
.
tensor
)):
...
...
@@ -188,6 +267,9 @@ class Validator:
arg2_type
=
arg2_type
.
element_type
()
elif
not
(
isinstance
(
arg1_type
,
type
(
mstype
.
tensor
))
or
isinstance
(
arg2_type
,
type
(
mstype
.
tensor
))):
pass
elif
allow_mix
:
arg1_type
=
arg1_type
.
element_type
()
if
isinstance
(
arg1_type
,
type
(
mstype
.
tensor
))
else
arg1_type
arg2_type
=
arg2_type
.
element_type
()
if
isinstance
(
arg2_type
,
type
(
mstype
.
tensor
))
else
arg2_type
else
:
excp_flag
=
True
...
...
@@ -199,13 +281,14 @@ class Validator:
@
staticmethod
def
check_value_type
(
arg_name
,
arg_value
,
valid_types
,
prim_name
):
"""Check whether a values is instance of some types."""
"""Checks whether a value is instance of some types."""
valid_types
=
valid_types
if
isinstance
(
valid_types
,
Iterable
)
else
(
valid_types
,)
def
raise_error_msg
():
"""func for raising error message when check failed"""
type_names
=
[
t
.
__name__
for
t
in
valid_types
]
num_types
=
len
(
valid_types
)
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
the type of `
{
arg_name
}
` should be
'
f
'
{
"one of "
if
num_types
>
1
else
""
}
'
msg_prefix
=
f
'For
\'
{
prim_name
}
\'
the'
if
prim_name
else
'The
'
raise
TypeError
(
f
'
{
msg_prefix
}
type of `
{
arg_name
}
` should be
{
"one of "
if
num_types
>
1
else
""
}
'
f
'
{
type_names
if
num_types
>
1
else
type_names
[
0
]
}
, but got
{
type
(
arg_value
).
__name__
}
.'
)
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
...
...
@@ -216,6 +299,23 @@ class Validator:
return
arg_value
raise_error_msg
()
@
staticmethod
def
check_type_name
(
arg_name
,
arg_type
,
valid_types
,
prim_name
):
"""Checks whether a type in some specified types"""
valid_types
=
valid_types
if
isinstance
(
valid_types
,
Iterable
)
else
(
valid_types
,)
def
get_typename
(
t
):
return
t
.
__name__
if
hasattr
(
t
,
'__name__'
)
else
str
(
t
)
if
arg_type
in
valid_types
:
return
arg_type
type_names
=
[
get_typename
(
t
)
for
t
in
valid_types
]
msg_prefix
=
f
'For
\'
{
prim_name
}
\'
the'
if
prim_name
else
'The'
if
len
(
valid_types
)
==
1
:
raise
ValueError
(
f
'
{
msg_prefix
}
type of `
{
arg_name
}
` should be
{
type_names
[
0
]
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
raise
ValueError
(
f
'
{
msg_prefix
}
type of `
{
arg_name
}
` should be one of
{
type_names
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
class
ParamValidator
:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
...
...
mindspore/nn/cell.py
浏览文件 @
8cbbbd95
...
...
@@ -103,6 +103,10 @@ class Cell:
def
parameter_layout_dict
(
self
):
return
self
.
_parameter_layout_dict
@
property
def
cls_name
(
self
):
return
self
.
__class__
.
__name__
@
parameter_layout_dict
.
setter
def
parameter_layout_dict
(
self
,
value
):
if
not
isinstance
(
value
,
dict
):
...
...
mindspore/nn/dynamic_lr.py
浏览文件 @
8cbbbd95
...
...
@@ -15,7 +15,7 @@
"""dynamic learning rate"""
import
math
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
...
...
@@ -43,16 +43,16 @@ def piecewise_constant_lr(milestone, learning_rates):
>>> lr = piecewise_constant_lr(milestone, learning_rates)
[0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]
"""
validator
.
check_
type
(
'milestone'
,
milestone
,
(
tuple
,
list
)
)
validator
.
check_
type
(
'learning_rates'
,
learning_rates
,
(
tuple
,
list
)
)
validator
.
check_
value_type
(
'milestone'
,
milestone
,
(
tuple
,
list
),
None
)
validator
.
check_
value_type
(
'learning_rates'
,
learning_rates
,
(
tuple
,
list
),
None
)
if
len
(
milestone
)
!=
len
(
learning_rates
):
raise
ValueError
(
'The size of `milestone` must be same with the size of `learning_rates`.'
)
lr
=
[]
last_item
=
0
for
i
,
item
in
enumerate
(
milestone
):
validator
.
check_integer
(
f
'milestone[
{
i
}
]'
,
item
,
0
,
Rel
.
GT
)
validator
.
check_
type
(
f
'learning_rates[
{
i
}
]'
,
learning_rates
[
i
],
[
float
]
)
validator
.
check_integer
(
f
'milestone[
{
i
}
]'
,
item
,
0
,
Rel
.
GT
,
None
)
validator
.
check_
value_type
(
f
'learning_rates[
{
i
}
]'
,
learning_rates
[
i
],
[
float
],
None
)
if
item
<
last_item
:
raise
ValueError
(
f
'The value of milestone[
{
i
}
] must be greater than milestone[
{
i
-
1
}
]'
)
lr
+=
[
learning_rates
[
i
]]
*
(
item
-
last_item
)
...
...
@@ -62,12 +62,12 @@ def piecewise_constant_lr(milestone, learning_rates):
def
_check_inputs
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
):
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
)
validator
.
check_float_positive
(
'learning_rate'
,
learning_rate
)
validator
.
check_float_positive
(
'decay_rate'
,
decay_rate
)
validator
.
check_
type
(
'is_stair'
,
is_stair
,
[
bool
]
)
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_float_positive
(
'learning_rate'
,
learning_rate
,
None
)
validator
.
check_float_positive
(
'decay_rate'
,
decay_rate
,
None
)
validator
.
check_
value_type
(
'is_stair'
,
is_stair
,
[
bool
],
None
)
def
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
=
False
):
...
...
@@ -228,11 +228,11 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
>>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
"""
validator
.
check_float_positive
(
'min_lr'
,
min_lr
)
validator
.
check_float_positive
(
'max_lr'
,
max_lr
)
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
)
validator
.
check_float_positive
(
'min_lr'
,
min_lr
,
None
)
validator
.
check_float_positive
(
'max_lr'
,
max_lr
,
None
)
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
,
None
)
delta
=
0.5
*
(
max_lr
-
min_lr
)
lr
=
[]
...
...
@@ -279,13 +279,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
>>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
"""
validator
.
check_float_positive
(
'learning_rate'
,
learning_rate
)
validator
.
check_float_positive
(
'end_learning_rate'
,
end_learning_rate
)
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
)
validator
.
check_
type
(
'power'
,
power
,
[
float
]
)
validator
.
check_
type
(
'update_decay_epoch'
,
update_decay_epoch
,
[
bool
]
)
validator
.
check_float_positive
(
'learning_rate'
,
learning_rate
,
None
)
validator
.
check_float_positive
(
'end_learning_rate'
,
end_learning_rate
,
None
)
validator
.
check_integer
(
'total_step'
,
total_step
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'step_per_epoch'
,
step_per_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_integer
(
'decay_epoch'
,
decay_epoch
,
0
,
Rel
.
GT
,
None
)
validator
.
check_
value_type
(
'power'
,
power
,
[
float
],
None
)
validator
.
check_
value_type
(
'update_decay_epoch'
,
update_decay_epoch
,
[
bool
],
None
)
function
=
lambda
x
,
y
:
(
x
,
min
(
x
,
y
))
if
update_decay_epoch
:
...
...
mindspore/nn/layer/basic.py
浏览文件 @
8cbbbd95
...
...
@@ -25,7 +25,7 @@ from mindspore.common.parameter import Parameter
from
mindspore._extends
import
cell_attr_register
from
..cell
import
Cell
from
.activation
import
get_activation
from
..._checkparam
import
Param
Validator
as
validator
from
..._checkparam
import
Validator
as
validator
class
Dropout
(
Cell
):
...
...
@@ -73,7 +73,7 @@ class Dropout(Cell):
super
(
Dropout
,
self
).
__init__
()
if
keep_prob
<=
0
or
keep_prob
>
1
:
raise
ValueError
(
"dropout probability should be a number in range (0, 1], but got {}"
.
format
(
keep_prob
))
validator
.
check_subclass
(
"dtype"
,
dtype
,
mstype
.
number_type
)
validator
.
check_subclass
(
"dtype"
,
dtype
,
mstype
.
number_type
,
self
.
cls_name
)
self
.
keep_prob
=
Tensor
(
keep_prob
)
self
.
seed0
=
seed0
self
.
seed1
=
seed1
...
...
@@ -421,7 +421,7 @@ class Pad(Cell):
super
(
Pad
,
self
).
__init__
()
self
.
mode
=
mode
self
.
paddings
=
paddings
validator
.
check_string
(
'mode'
,
self
.
mode
,
[
"CONSTANT"
,
"REFLECT"
,
"SYMMETRIC"
])
validator
.
check_string
(
'mode'
,
self
.
mode
,
[
"CONSTANT"
,
"REFLECT"
,
"SYMMETRIC"
]
,
self
.
cls_name
)
if
not
isinstance
(
paddings
,
tuple
):
raise
TypeError
(
'Paddings must be tuple type.'
)
for
item
in
paddings
:
...
...
mindspore/nn/layer/embedding.py
浏览文件 @
8cbbbd95
...
...
@@ -19,7 +19,7 @@ from mindspore.ops import operations as P
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
from
..cell
import
Cell
from
..._checkparam
import
Param
Validator
as
validator
from
..._checkparam
import
Validator
as
validator
class
Embedding
(
Cell
):
...
...
@@ -59,7 +59,7 @@ class Embedding(Cell):
"""
def
__init__
(
self
,
vocab_size
,
embedding_size
,
use_one_hot
=
False
,
embedding_table
=
'normal'
,
dtype
=
mstype
.
float32
):
super
(
Embedding
,
self
).
__init__
()
validator
.
check_subclass
(
"dtype"
,
dtype
,
mstype
.
number_type
)
validator
.
check_subclass
(
"dtype"
,
dtype
,
mstype
.
number_type
,
self
.
cls_name
)
self
.
vocab_size
=
vocab_size
self
.
embedding_size
=
embedding_size
self
.
use_one_hot
=
use_one_hot
...
...
mindspore/nn/layer/image.py
浏览文件 @
8cbbbd95
...
...
@@ -19,7 +19,7 @@ from mindspore.common.tensor import Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.ops.primitive
import
constexpr
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
..cell
import
Cell
...
...
@@ -134,15 +134,15 @@ class SSIM(Cell):
"""
def
__init__
(
self
,
max_val
=
1.0
,
filter_size
=
11
,
filter_sigma
=
1.5
,
k1
=
0.01
,
k2
=
0.03
):
super
(
SSIM
,
self
).
__init__
()
validator
.
check_
type
(
'max_val'
,
max_val
,
[
int
,
float
]
)
validator
.
check
(
'max_val'
,
max_val
,
''
,
0.0
,
Rel
.
GT
)
validator
.
check_
value_type
(
'max_val'
,
max_val
,
[
int
,
float
],
self
.
cls_name
)
validator
.
check
_number
(
'max_val'
,
max_val
,
0.0
,
Rel
.
GT
,
self
.
cls_name
)
self
.
max_val
=
max_val
self
.
filter_size
=
validator
.
check_integer
(
'filter_size'
,
filter_size
,
1
,
Rel
.
GE
)
self
.
filter_sigma
=
validator
.
check_float_positive
(
'filter_sigma'
,
filter_sigma
)
validator
.
check_
type
(
'k1'
,
k1
,
[
float
]
)
self
.
k1
=
validator
.
check_number_range
(
'k1'
,
k1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
)
validator
.
check_
type
(
'k2'
,
k2
,
[
float
]
)
self
.
k2
=
validator
.
check_number_range
(
'k2'
,
k2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
)
self
.
filter_size
=
validator
.
check_integer
(
'filter_size'
,
filter_size
,
1
,
Rel
.
GE
,
self
.
cls_name
)
self
.
filter_sigma
=
validator
.
check_float_positive
(
'filter_sigma'
,
filter_sigma
,
self
.
cls_name
)
validator
.
check_
value_type
(
'k1'
,
k1
,
[
float
],
self
.
cls_name
)
self
.
k1
=
validator
.
check_number_range
(
'k1'
,
k1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
self
.
cls_name
)
validator
.
check_
value_type
(
'k2'
,
k2
,
[
float
],
self
.
cls_name
)
self
.
k2
=
validator
.
check_number_range
(
'k2'
,
k2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
self
.
cls_name
)
self
.
mean
=
P
.
DepthwiseConv2dNative
(
channel_multiplier
=
1
,
kernel_size
=
filter_size
)
def
construct
(
self
,
img1
,
img2
):
...
...
@@ -231,8 +231,8 @@ class PSNR(Cell):
"""
def
__init__
(
self
,
max_val
=
1.0
):
super
(
PSNR
,
self
).
__init__
()
validator
.
check_
type
(
'max_val'
,
max_val
,
[
int
,
float
]
)
validator
.
check
(
'max_val'
,
max_val
,
''
,
0.0
,
Rel
.
GT
)
validator
.
check_
value_type
(
'max_val'
,
max_val
,
[
int
,
float
],
self
.
cls_name
)
validator
.
check
_number
(
'max_val'
,
max_val
,
0.0
,
Rel
.
GT
,
self
.
cls_name
)
self
.
max_val
=
max_val
def
construct
(
self
,
img1
,
img2
):
...
...
mindspore/nn/layer/lstm.py
浏览文件 @
8cbbbd95
...
...
@@ -17,7 +17,7 @@ from mindspore.ops import operations as P
from
mindspore.nn.cell
import
Cell
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
class
LSTM
(
Cell
):
...
...
@@ -114,7 +114,7 @@ class LSTM(Cell):
self
.
hidden_size
=
hidden_size
self
.
num_layers
=
num_layers
self
.
has_bias
=
has_bias
self
.
batch_first
=
validator
.
check_
type
(
"batch_first"
,
batch_first
,
[
bool
]
)
self
.
batch_first
=
validator
.
check_
value_type
(
"batch_first"
,
batch_first
,
[
bool
],
self
.
cls_name
)
self
.
dropout
=
float
(
dropout
)
self
.
bidirectional
=
bidirectional
...
...
mindspore/nn/layer/pooling.py
浏览文件 @
8cbbbd95
...
...
@@ -14,8 +14,7 @@
# ============================================================================
"""pooling"""
from
mindspore.ops
import
operations
as
P
from
mindspore._checkparam
import
ParamValidator
as
validator
from
mindspore._checkparam
import
Rel
from
mindspore._checkparam
import
Validator
as
validator
from
...
import
context
from
..cell
import
Cell
...
...
@@ -24,35 +23,27 @@ class _PoolNd(Cell):
"""N-D AvgPool"""
def
__init__
(
self
,
kernel_size
,
stride
,
pad_mode
):
name
=
self
.
__class__
.
__name__
super
(
_PoolNd
,
self
).
__init__
()
validator
.
check_type
(
'kernel_size'
,
kernel_size
,
[
int
,
tuple
])
validator
.
check_type
(
'stride'
,
stride
,
[
int
,
tuple
])
self
.
pad_mode
=
validator
.
check_string
(
'pad_mode'
,
pad_mode
.
upper
(),
[
'VALID'
,
'SAME'
])
if
isinstance
(
kernel_size
,
int
):
validator
.
check_integer
(
"kernel_size"
,
kernel_size
,
1
,
Rel
.
GE
)
else
:
if
(
len
(
kernel_size
)
!=
2
or
(
not
isinstance
(
kernel_size
[
0
],
int
))
or
(
not
isinstance
(
kernel_size
[
1
],
int
))
or
kernel_size
[
0
]
<=
0
or
kernel_size
[
1
]
<=
0
):
raise
ValueError
(
f
'The kernel_size passed to cell
{
name
}
should be an positive int number or'
f
'a tuple of two positive int numbers, but got
{
kernel_size
}
'
)
self
.
kernel_size
=
kernel_size
if
isinstance
(
stride
,
int
):
validator
.
check_integer
(
"stride"
,
stride
,
1
,
Rel
.
GE
)
else
:
if
(
len
(
stride
)
!=
2
or
(
not
isinstance
(
stride
[
0
],
int
))
or
(
not
isinstance
(
stride
[
1
],
int
))
or
stride
[
0
]
<=
0
or
stride
[
1
]
<=
0
):
raise
ValueError
(
f
'The stride passed to cell
{
name
}
should be an positive int number or'
f
'a tuple of two positive int numbers, but got
{
stride
}
'
)
self
.
stride
=
stride
self
.
pad_mode
=
validator
.
check_string
(
'pad_mode'
,
pad_mode
.
upper
(),
[
'VALID'
,
'SAME'
],
self
.
cls_name
)
def
_check_int_or_tuple
(
arg_name
,
arg_value
):
validator
.
check_value_type
(
arg_name
,
arg_value
,
[
int
,
tuple
],
self
.
cls_name
)
error_msg
=
f
'For
\'
{
self
.
cls_name
}
\'
the
{
arg_name
}
should be an positive int number or '
\
f
'a tuple of two positive int numbers, but got
{
arg_value
}
'
if
isinstance
(
arg_value
,
int
):
if
arg_value
<=
0
:
raise
ValueError
(
error_msg
)
elif
len
(
arg_value
)
==
2
:
for
item
in
arg_value
:
if
isinstance
(
item
,
int
)
and
item
>
0
:
continue
raise
ValueError
(
error_msg
)
else
:
raise
ValueError
(
error_msg
)
return
arg_value
self
.
kernel_size
=
_check_int_or_tuple
(
'kernel_size'
,
kernel_size
)
self
.
stride
=
_check_int_or_tuple
(
'stride'
,
stride
)
def
construct
(
self
,
*
inputs
):
pass
...
...
mindspore/nn/metrics/fbeta.py
浏览文件 @
8cbbbd95
...
...
@@ -15,7 +15,7 @@
"""Fbeta."""
import
sys
import
numpy
as
np
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
.metric
import
Metric
...
...
@@ -104,7 +104,7 @@ class Fbeta(Metric):
Returns:
Float, computed result.
"""
validator
.
check_
type
(
"average"
,
average
,
[
bool
]
)
validator
.
check_
value_type
(
"average"
,
average
,
[
bool
],
self
.
__class__
.
__name__
)
if
self
.
_class_num
==
0
:
raise
RuntimeError
(
'Input number of samples can not be 0.'
)
...
...
mindspore/nn/metrics/precision.py
浏览文件 @
8cbbbd95
...
...
@@ -17,7 +17,7 @@ import sys
import
numpy
as
np
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
.evaluation
import
EvaluationBase
...
...
@@ -136,7 +136,7 @@ class Precision(EvaluationBase):
if
self
.
_class_num
==
0
:
raise
RuntimeError
(
'Input number of samples can not be 0.'
)
validator
.
check_
type
(
"average"
,
average
,
[
bool
]
)
validator
.
check_
value_type
(
"average"
,
average
,
[
bool
],
self
.
__class__
.
__name__
)
result
=
self
.
_true_positives
/
(
self
.
_positives
+
self
.
eps
)
if
average
:
...
...
mindspore/nn/metrics/recall.py
浏览文件 @
8cbbbd95
...
...
@@ -17,7 +17,7 @@ import sys
import
numpy
as
np
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
.evaluation
import
EvaluationBase
...
...
@@ -136,7 +136,7 @@ class Recall(EvaluationBase):
if
self
.
_class_num
==
0
:
raise
RuntimeError
(
'Input number of samples can not be 0.'
)
validator
.
check_
type
(
"average"
,
average
,
[
bool
]
)
validator
.
check_
value_type
(
"average"
,
average
,
[
bool
],
self
.
__class__
.
__name__
)
result
=
self
.
_true_positives
/
(
self
.
_actual_positives
+
self
.
eps
)
if
average
:
...
...
mindspore/nn/optim/adam.py
浏览文件 @
8cbbbd95
...
...
@@ -22,7 +22,7 @@ from mindspore.ops import composite as C
from
mindspore.ops
import
functional
as
F
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.tensor
import
Tensor
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
.optimizer
import
Optimizer
...
...
@@ -78,16 +78,16 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
return
next_v
def
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
):
def
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
prim_name
):
"""Check the type of inputs."""
validator
.
check_
type
(
"beta1"
,
beta1
,
[
float
]
)
validator
.
check_
type
(
"beta2"
,
beta2
,
[
float
]
)
validator
.
check_
type
(
"eps"
,
eps
,
[
float
]
)
validator
.
check_
type
(
"weight_dacay"
,
weight_decay
,
[
float
]
)
validator
.
check_number_range
(
"beta1"
,
beta1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
)
validator
.
check_number_range
(
"beta2"
,
beta2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
)
validator
.
check_number_range
(
"eps"
,
eps
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
)
validator
.
check_number_range
(
"weight_decay"
,
weight_decay
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
)
validator
.
check_
value_type
(
"beta1"
,
beta1
,
[
float
],
prim_name
)
validator
.
check_
value_type
(
"beta2"
,
beta2
,
[
float
],
prim_name
)
validator
.
check_
value_type
(
"eps"
,
eps
,
[
float
],
prim_name
)
validator
.
check_
value_type
(
"weight_dacay"
,
weight_decay
,
[
float
],
prim_name
)
validator
.
check_number_range
(
"beta1"
,
beta1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"beta2"
,
beta2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"eps"
,
eps
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"weight_decay"
,
weight_decay
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
prim_name
)
@
adam_opt
.
register
(
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
...
...
@@ -168,11 +168,11 @@ class Adam(Optimizer):
use_nesterov
=
False
,
weight_decay
=
0.0
,
loss_scale
=
1.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
Adam
,
self
).
__init__
(
learning_rate
,
params
,
weight_decay
,
loss_scale
,
decay_filter
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
)
validator
.
check_
type
(
"use_locking"
,
use_locking
,
[
bool
]
)
validator
.
check_
type
(
"use_nesterov"
,
use_nesterov
,
[
bool
]
)
validator
.
check_
type
(
"loss_scale"
,
loss_scale
,
[
float
]
)
validator
.
check_number_range
(
"loss_scale"
,
loss_scale
,
1.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
validator
.
check_
value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
cls_name
)
validator
.
check_
value_type
(
"use_nesterov"
,
use_nesterov
,
[
bool
],
self
.
cls_name
)
validator
.
check_
value_type
(
"loss_scale"
,
loss_scale
,
[
float
],
self
.
cls_name
)
validator
.
check_number_range
(
"loss_scale"
,
loss_scale
,
1.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
self
.
cls_name
)
self
.
beta1
=
Tensor
(
beta1
,
mstype
.
float32
)
self
.
beta2
=
Tensor
(
beta2
,
mstype
.
float32
)
...
...
@@ -241,7 +241,7 @@ class AdamWeightDecay(Optimizer):
"""
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-6
,
weight_decay
=
0.0
):
super
(
AdamWeightDecay
,
self
).
__init__
(
learning_rate
,
params
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
self
.
lr
=
Tensor
(
np
.
array
([
learning_rate
]).
astype
(
np
.
float32
))
self
.
beta1
=
Tensor
(
np
.
array
([
beta1
]).
astype
(
np
.
float32
))
self
.
beta2
=
Tensor
(
np
.
array
([
beta2
]).
astype
(
np
.
float32
))
...
...
@@ -304,7 +304,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
eps
=
1e-6
,
weight_decay
=
0.0
):
super
(
AdamWeightDecayDynamicLR
,
self
).
__init__
(
learning_rate
,
params
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
# turn them to scalar when me support scalar/tensor mix operations
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
]),
name
=
"global_step"
)
...
...
mindspore/nn/optim/ftrl.py
浏览文件 @
8cbbbd95
...
...
@@ -18,7 +18,7 @@ from mindspore.common.initializer import initializer
from
mindspore.common.parameter
import
Parameter
from
mindspore.common
import
Tensor
import
mindspore.common.dtype
as
mstype
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
.optimizer
import
Optimizer
,
apply_decay
,
grad_scale
...
...
@@ -30,29 +30,30 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
linear
,
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
))
return
success
def
_check_param
(
initial_accum
,
learning_rate
,
lr_power
,
l1
,
l2
,
use_locking
,
loss_scale
=
1.0
,
weight_decay
=
0.0
):
validator
.
check_type
(
"initial_accum"
,
initial_accum
,
[
float
])
validator
.
check
(
"initial_accum"
,
initial_accum
,
""
,
0.0
,
Rel
.
GE
)
def
_check_param
(
initial_accum
,
learning_rate
,
lr_power
,
l1
,
l2
,
use_locking
,
loss_scale
=
1.0
,
weight_decay
=
0.0
,
prim_name
=
None
):
validator
.
check_value_type
(
"initial_accum"
,
initial_accum
,
[
float
],
prim_name
)
validator
.
check_number
(
"initial_accum"
,
initial_accum
,
0.0
,
Rel
.
GE
,
prim_name
)
validator
.
check_
type
(
"learning_rate"
,
learning_rate
,
[
float
]
)
validator
.
check
(
"learning_rate"
,
learning_rate
,
""
,
0.0
,
Rel
.
GT
)
validator
.
check_
value_type
(
"learning_rate"
,
learning_rate
,
[
float
],
prim_name
)
validator
.
check
_number
(
"learning_rate"
,
learning_rate
,
0.0
,
Rel
.
GT
,
prim_name
)
validator
.
check_
type
(
"lr_power"
,
lr_power
,
[
float
]
)
validator
.
check
(
"lr_power"
,
lr_power
,
""
,
0.0
,
Rel
.
LE
)
validator
.
check_
value_type
(
"lr_power"
,
lr_power
,
[
float
],
prim_name
)
validator
.
check
_number
(
"lr_power"
,
lr_power
,
0.0
,
Rel
.
LE
,
prim_name
)
validator
.
check_
type
(
"l1"
,
l1
,
[
float
]
)
validator
.
check
(
"l1"
,
l1
,
""
,
0.0
,
Rel
.
GE
)
validator
.
check_
value_type
(
"l1"
,
l1
,
[
float
],
prim_name
)
validator
.
check
_number
(
"l1"
,
l1
,
0.0
,
Rel
.
GE
,
prim_name
)
validator
.
check_
type
(
"l2"
,
l2
,
[
float
]
)
validator
.
check
(
"l2"
,
l2
,
""
,
0.0
,
Rel
.
GE
)
validator
.
check_
value_type
(
"l2"
,
l2
,
[
float
],
prim_name
)
validator
.
check
_number
(
"l2"
,
l2
,
0.0
,
Rel
.
GE
,
prim_name
)
validator
.
check_
type
(
"use_locking"
,
use_locking
,
[
bool
]
)
validator
.
check_
value_type
(
"use_locking"
,
use_locking
,
[
bool
],
prim_name
)
validator
.
check_
type
(
"loss_scale"
,
loss_scale
,
[
float
]
)
validator
.
check
(
"loss_scale"
,
loss_scale
,
""
,
1.0
,
Rel
.
GE
)
validator
.
check_
value_type
(
"loss_scale"
,
loss_scale
,
[
float
],
prim_name
)
validator
.
check
_number
(
"loss_scale"
,
loss_scale
,
1.0
,
Rel
.
GE
,
prim_name
)
validator
.
check_
type
(
"weight_decay"
,
weight_decay
,
[
float
]
)
validator
.
check
(
"weight_decay"
,
weight_decay
,
""
,
0.0
,
Rel
.
GE
)
validator
.
check_
value_type
(
"weight_decay"
,
weight_decay
,
[
float
],
prim_name
)
validator
.
check
_number
(
"weight_decay"
,
weight_decay
,
0.0
,
Rel
.
GE
,
prim_name
)
class
FTRL
(
Optimizer
):
...
...
@@ -94,7 +95,8 @@ class FTRL(Optimizer):
use_locking
=
False
,
loss_scale
=
1.0
,
weight_decay
=
0.0
):
super
(
FTRL
,
self
).
__init__
(
learning_rate
,
params
)
_check_param
(
initial_accum
,
learning_rate
,
lr_power
,
l1
,
l2
,
use_locking
,
loss_scale
,
weight_decay
)
_check_param
(
initial_accum
,
learning_rate
,
lr_power
,
l1
,
l2
,
use_locking
,
loss_scale
,
weight_decay
,
self
.
cls_name
)
self
.
moments
=
self
.
parameters
.
clone
(
prefix
=
"moments"
,
init
=
initial_accum
)
self
.
linear
=
self
.
parameters
.
clone
(
prefix
=
"linear"
,
init
=
'zeros'
)
self
.
l1
=
l1
...
...
mindspore/nn/optim/lamb.py
浏览文件 @
8cbbbd95
...
...
@@ -21,7 +21,7 @@ from mindspore.ops import composite as C
from
mindspore.ops
import
functional
as
F
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.tensor
import
Tensor
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
.optimizer
import
Optimizer
from
..
import
layer
...
...
@@ -109,23 +109,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
def
_check_param_value
(
decay_steps
,
warmup_steps
,
start_learning_rate
,
end_learning_rate
,
power
,
beta1
,
beta2
,
eps
,
weight_decay
):
end_learning_rate
,
power
,
beta1
,
beta2
,
eps
,
weight_decay
,
prim_name
):
"""Check the type of inputs."""
validator
.
check_
type
(
"decay_steps"
,
decay_steps
,
[
int
]
)
validator
.
check_
type
(
"warmup_steps"
,
warmup_steps
,
[
int
]
)
validator
.
check_
type
(
"start_learning_rate"
,
start_learning_rate
,
[
float
]
)
validator
.
check_
type
(
"end_learning_rate"
,
end_learning_rate
,
[
float
]
)
validator
.
check_
type
(
"power"
,
power
,
[
float
]
)
validator
.
check_
type
(
"beta1"
,
beta1
,
[
float
]
)
validator
.
check_
type
(
"beta2"
,
beta2
,
[
float
]
)
validator
.
check_
type
(
"eps"
,
eps
,
[
float
]
)
validator
.
check_
type
(
"weight_dacay"
,
weight_decay
,
[
float
]
)
validator
.
check_number_range
(
"decay_steps"
,
decay_steps
,
1
,
float
(
"inf"
),
Rel
.
INC_LEFT
)
validator
.
check_number_range
(
"beta1"
,
beta1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
)
validator
.
check_number_range
(
"beta2"
,
beta2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
)
validator
.
check_number_range
(
"eps"
,
eps
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
)
validator
.
check_number_range
(
"weight_decay"
,
weight_decay
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
)
validator
.
check_
value_type
(
"decay_steps"
,
decay_steps
,
[
int
],
prim_name
)
validator
.
check_
value_type
(
"warmup_steps"
,
warmup_steps
,
[
int
],
prim_name
)
validator
.
check_
value_type
(
"start_learning_rate"
,
start_learning_rate
,
[
float
],
prim_name
)
validator
.
check_
value_type
(
"end_learning_rate"
,
end_learning_rate
,
[
float
],
prim_name
)
validator
.
check_
value_type
(
"power"
,
power
,
[
float
],
prim_name
)
validator
.
check_
value_type
(
"beta1"
,
beta1
,
[
float
],
prim_name
)
validator
.
check_
value_type
(
"beta2"
,
beta2
,
[
float
],
prim_name
)
validator
.
check_
value_type
(
"eps"
,
eps
,
[
float
],
prim_name
)
validator
.
check_
value_type
(
"weight_dacay"
,
weight_decay
,
[
float
],
prim_name
)
validator
.
check_number_range
(
"decay_steps"
,
decay_steps
,
1
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
prim_name
)
validator
.
check_number_range
(
"beta1"
,
beta1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"beta2"
,
beta2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"eps"
,
eps
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"weight_decay"
,
weight_decay
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
prim_name
)
class
Lamb
(
Optimizer
):
...
...
@@ -182,7 +182,7 @@ class Lamb(Optimizer):
super
(
Lamb
,
self
).
__init__
(
start_learning_rate
,
params
)
_check_param_value
(
decay_steps
,
warmup_steps
,
start_learning_rate
,
end_learning_rate
,
power
,
beta1
,
beta2
,
eps
,
weight_decay
)
power
,
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
# turn them to scalar when me support scalar/tensor mix operations
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
]),
name
=
"global_step"
)
...
...
mindspore/nn/optim/optimizer.py
浏览文件 @
8cbbbd95
...
...
@@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from
mindspore.nn.cell
import
Cell
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
from
mindspore.common.initializer
import
initializer
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
mindspore.common.tensor
import
Tensor
from
mindspore
import
log
as
logger
...
...
@@ -63,7 +63,7 @@ class Optimizer(Cell):
self
.
gather
=
None
self
.
assignadd
=
None
self
.
global_step
=
None
validator
.
check_number_range
(
"learning rate"
,
learning_rate
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
)
validator
.
check_number_range
(
"learning rate"
,
learning_rate
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
self
.
cls_name
)
else
:
self
.
dynamic_lr
=
True
self
.
gather
=
P
.
GatherV2
()
...
...
mindspore/nn/optim/rmsprop.py
浏览文件 @
8cbbbd95
...
...
@@ -14,7 +14,7 @@
# ============================================================================
"""rmsprop"""
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
.optimizer
import
Optimizer
rmsprop_opt
=
C
.
MultitypeFuncGraph
(
"rmsprop_opt"
)
...
...
@@ -144,8 +144,8 @@ class RMSProp(Optimizer):
self
.
decay
=
decay
self
.
epsilon
=
epsilon
validator
.
check_
type
(
"use_locking"
,
use_locking
,
[
bool
]
)
validator
.
check_
type
(
"centered"
,
centered
,
[
bool
]
)
validator
.
check_
value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
cls_name
)
validator
.
check_
value_type
(
"centered"
,
centered
,
[
bool
],
self
.
cls_name
)
self
.
centered
=
centered
if
centered
:
self
.
opt
=
P
.
ApplyCenteredRMSProp
(
use_locking
)
...
...
mindspore/nn/optim/sgd.py
浏览文件 @
8cbbbd95
...
...
@@ -15,7 +15,7 @@
"""sgd"""
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.common.parameter
import
Parameter
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
.optimizer
import
Optimizer
sgd_opt
=
C
.
MultitypeFuncGraph
(
"sgd_opt"
)
...
...
@@ -100,7 +100,7 @@ class SGD(Optimizer):
raise
ValueError
(
"dampening should be at least 0.0, but got dampening {}"
.
format
(
dampening
))
self
.
dampening
=
dampening
validator
.
check_
type
(
"nesterov"
,
nesterov
,
[
bool
]
)
validator
.
check_
value_type
(
"nesterov"
,
nesterov
,
[
bool
],
self
.
cls_name
)
self
.
nesterov
=
nesterov
self
.
opt
=
P
.
SGD
(
dampening
,
weight_decay
,
nesterov
)
...
...
mindspore/ops/op_info_register.py
浏览文件 @
8cbbbd95
...
...
@@ -19,7 +19,7 @@ import os
import
json
import
inspect
from
mindspore._c_expression
import
Oplib
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
# path of built-in op info register.
BUILT_IN_OPS_REGISTER_PATH
=
"mindspore/ops/_op_impl"
...
...
@@ -43,7 +43,7 @@ def op_info_register(op_info):
op_info_real
=
json
.
dumps
(
op_info
)
else
:
op_info_real
=
op_info
validator
.
check_
type
(
"op_info"
,
op_info_real
,
[
str
]
)
validator
.
check_
value_type
(
"op_info"
,
op_info_real
,
[
str
],
None
)
op_lib
=
Oplib
()
file_path
=
os
.
path
.
realpath
(
inspect
.
getfile
(
func
))
# keep the path custom ops implementation.
...
...
mindspore/train/amp.py
浏览文件 @
8cbbbd95
...
...
@@ -16,7 +16,7 @@
from
easydict
import
EasyDict
as
edict
from
..
import
nn
from
.._checkparam
import
Param
Validator
as
validator
from
.._checkparam
import
Validator
as
validator
from
.._checkparam
import
Rel
from
..common
import
dtype
as
mstype
from
..nn.wrap.cell_wrapper
import
_VirtualDatasetCell
...
...
@@ -73,14 +73,14 @@ def _check_kwargs(key_words):
raise
ValueError
(
f
"Unsupported arg '
{
arg
}
'"
)
if
'cast_model_type'
in
key_words
:
validator
.
check
(
'cast_model_type'
,
key_words
[
'cast_model_type'
],
[
mstype
.
float16
,
mstype
.
float32
],
Rel
.
IN
)
validator
.
check
_type_name
(
'cast_model_type'
,
key_words
[
'cast_model_type'
],
[
mstype
.
float16
,
mstype
.
float32
],
None
)
if
'keep_batchnorm_fp32'
in
key_words
:
validator
.
check_
isinstance
(
'keep_batchnorm_fp32'
,
key_words
[
'keep_batchnorm_fp32'
],
bool
)
validator
.
check_
value_type
(
'keep_batchnorm_fp32'
,
key_words
[
'keep_batchnorm_fp32'
],
bool
,
None
)
if
'loss_scale_manager'
in
key_words
:
loss_scale_manager
=
key_words
[
'loss_scale_manager'
]
if
loss_scale_manager
:
validator
.
check_
isinstance
(
'loss_scale_manager'
,
loss_scale_manager
,
LossScaleManager
)
validator
.
check_
value_type
(
'loss_scale_manager'
,
loss_scale_manager
,
LossScaleManager
,
None
)
def
_add_loss_network
(
network
,
loss_fn
,
cast_model_type
):
...
...
@@ -97,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
label
=
_mp_cast_helper
(
mstype
.
float32
,
label
)
return
self
.
_loss_fn
(
F
.
cast
(
out
,
mstype
.
float32
),
label
)
validator
.
check_
isinstance
(
'loss_fn'
,
loss_fn
,
nn
.
Cell
)
validator
.
check_
value_type
(
'loss_fn'
,
loss_fn
,
nn
.
Cell
,
None
)
if
cast_model_type
==
mstype
.
float16
:
network
=
WithLossCell
(
network
,
loss_fn
)
else
:
...
...
@@ -126,9 +126,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If set, overwrite the level setting.
"""
validator
.
check_
isinstance
(
'network'
,
network
,
nn
.
Cell
)
validator
.
check_
isinstance
(
'optimizer'
,
optimizer
,
nn
.
Optimizer
)
validator
.
check
(
'level'
,
level
,
""
,
[
'O0'
,
'O2'
],
Rel
.
IN
)
validator
.
check_
value_type
(
'network'
,
network
,
nn
.
Cell
,
None
)
validator
.
check_
value_type
(
'optimizer'
,
optimizer
,
nn
.
Optimizer
,
None
)
validator
.
check
(
'level'
,
level
,
""
,
[
'O0'
,
'O2'
],
Rel
.
IN
,
None
)
_check_kwargs
(
kwargs
)
config
=
dict
(
_config_level
[
level
],
**
kwargs
)
config
=
edict
(
config
)
...
...
mindspore/train/loss_scale_manager.py
浏览文件 @
8cbbbd95
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Loss scale manager abstract class."""
from
.._checkparam
import
Param
Validator
as
validator
from
.._checkparam
import
Validator
as
validator
from
.._checkparam
import
Rel
from
..
import
nn
...
...
@@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager):
if
init_loss_scale
<
1.0
:
raise
ValueError
(
"Loss scale value should be > 1"
)
self
.
loss_scale
=
init_loss_scale
validator
.
check_integer
(
"scale_window"
,
scale_window
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"scale_window"
,
scale_window
,
0
,
Rel
.
GT
,
self
.
__class__
.
__name__
)
self
.
scale_window
=
scale_window
if
scale_factor
<=
0
:
raise
ValueError
(
"Scale factor should be > 1"
)
...
...
tests/ut/python/nn/test_dynamic_lr.py
浏览文件 @
8cbbbd95
...
...
@@ -32,7 +32,7 @@ power = 0.5
class
TestInputs
:
def
test_milestone1
(
self
):
milestone1
=
1
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
piecewise_constant_lr
(
milestone1
,
learning_rates
)
def
test_milestone2
(
self
):
...
...
@@ -46,12 +46,12 @@ class TestInputs:
def
test_learning_rates1
(
self
):
lr
=
True
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
piecewise_constant_lr
(
milestone
,
lr
)
def
test_learning_rates2
(
self
):
lr
=
[
1
,
2
,
1
]
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
piecewise_constant_lr
(
milestone
,
lr
)
def
test_learning_rate_type
(
self
):
...
...
@@ -158,7 +158,7 @@ class TestInputs:
def
test_is_stair
(
self
):
is_stair
=
1
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
is_stair
)
def
test_min_lr_type
(
self
):
...
...
@@ -183,12 +183,12 @@ class TestInputs:
def
test_power
(
self
):
power1
=
True
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
power1
)
def
test_update_decay_epoch
(
self
):
update_decay_epoch
=
1
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch
,
power
,
update_decay_epoch
)
...
...
tests/ut/python/nn/test_psnr.py
浏览文件 @
8cbbbd95
...
...
@@ -52,7 +52,7 @@ def test_psnr_max_val_negative():
def
test_psnr_max_val_bool
():
max_val
=
True
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
net
=
PSNRNet
(
max_val
)
def
test_psnr_max_val_zero
():
...
...
tests/ut/python/nn/test_ssim.py
浏览文件 @
8cbbbd95
...
...
@@ -51,7 +51,7 @@ def test_ssim_max_val_negative():
def
test_ssim_max_val_bool
():
max_val
=
True
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
net
=
SSIMNet
(
max_val
)
def
test_ssim_max_val_zero
():
...
...
@@ -92,4 +92,4 @@ def test_ssim_k1_k2_wrong_value():
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k2
=
0.0
)
with
pytest
.
raises
(
ValueError
):
net
=
SSIMNet
(
k2
=-
1.0
)
\ No newline at end of file
net
=
SSIMNet
(
k2
=-
1.0
)
tests/ut/python/ops/test_nn_ops.py
浏览文件 @
8cbbbd95
...
...
@@ -577,14 +577,14 @@ test_cases_for_verify_exception = [
(
'MaxPool2d_ValueError_2'
,
{
'block'
:
(
lambda
_
:
nn
.
MaxPool2d
(
kernel_size
=
120
,
stride
=
True
,
pad_mode
=
"valid"
),
{
'exception'
:
Valu
eError
},
{
'exception'
:
Typ
eError
},
),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randn
(
32
,
3
,
112
,
112
).
astype
(
np
.
float32
).
transpose
(
0
,
3
,
1
,
2
))],
}),
(
'MaxPool2d_ValueError_3'
,
{
'block'
:
(
lambda
_
:
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
True
,
pad_mode
=
"valid"
),
{
'exception'
:
Valu
eError
},
{
'exception'
:
Typ
eError
},
),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
randn
(
32
,
3
,
112
,
112
).
astype
(
np
.
float32
).
transpose
(
0
,
3
,
1
,
2
))],
}),
...
...
tests/ut/python/pynative_mode/nn/test_pooling.py
浏览文件 @
8cbbbd95
...
...
@@ -38,7 +38,7 @@ def test_avgpool2d_error_input():
""" test_avgpool2d_error_input """
kernel_size
=
5
stride
=
2.3
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
nn
.
AvgPool2d
(
kernel_size
,
stride
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录