Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
兔爷不爱我
mindspore
提交
6dd72f65
M
mindspore
项目概览
兔爷不爱我
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6dd72f65
编写于
4月 07, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add prim name to error message for nn_ops.py
上级
475f62f6
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
821 addition
and
498 deletion
+821
-498
mindspore/_checkparam.py
mindspore/_checkparam.py
+18
-28
mindspore/context.py
mindspore/context.py
+1
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+318
-448
tests/ut/python/nn/test_dynamic_lr.py
tests/ut/python/nn/test_dynamic_lr.py
+10
-10
tests/ut/python/nn/test_ssim.py
tests/ut/python/nn/test_ssim.py
+1
-1
tests/ut/python/ops/test_nn_ops.py
tests/ut/python/ops/test_nn_ops.py
+10
-10
tests/ut/python/ops/test_nn_ops_check.py
tests/ut/python/ops/test_nn_ops_check.py
+463
-0
未找到文件。
mindspore/_checkparam.py
浏览文件 @
6dd72f65
...
...
@@ -117,10 +117,12 @@ class Validator:
"""Integer value judgment."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
or
isinstance
(
arg_value
,
bool
)
excp_cls
=
TypeError
if
type_mismatch
else
ValueError
if
type_mismatch
or
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
msg_prefix
=
f
'For
\'
{
prim_name
}
\'
the'
if
prim_name
else
"The"
raise
ValueError
(
f
'
{
msg_prefix
}
`
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
raise
excp_cls
(
f
'
{
msg_prefix
}
`
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got `
{
arg_value
}
`'
f
' with type `
{
type
(
arg_value
).
__name__
}
`.'
)
return
arg_value
@
staticmethod
...
...
@@ -137,10 +139,11 @@ class Validator:
"""Method for checking whether an int value is in some range."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
excp_cls
=
TypeError
if
type_mismatch
else
ValueError
if
type_mismatch
or
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_name
}
` should be an int in range
{
rel_str
}
,'
f
' but got
{
arg_value
}
.'
)
raise
excp_cls
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_name
}
` should be an int in range
{
rel_str
}
,'
f
' but got `
{
arg_value
}
` with type `
{
type
(
arg_value
).
__name__
}
`
.'
)
return
arg_value
@
staticmethod
...
...
@@ -192,19 +195,23 @@ class Validator:
@
staticmethod
def
check_const_input
(
arg_name
,
arg_value
,
prim_name
):
"""Check valid value."""
"""Check
s
valid value."""
if
arg_value
is
None
:
raise
ValueError
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_name
}
` must be a const input, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_
scalar_
type_same
(
args
,
valid_values
,
prim_name
):
"""
check
whether the types of inputs are the same."""
def
check_type_same
(
args
,
valid_values
,
prim_name
):
"""
Checks
whether the types of inputs are the same."""
def
_check_tensor_type
(
arg
):
arg_key
,
arg_val
=
arg
elem_type
=
arg_val
type_names
=
[]
if
not
elem_type
in
valid_values
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg_key
}
` should be in
{
valid_values
}
,'
f
' but `
{
arg_key
}
` is
{
elem_type
}
.'
)
for
t
in
valid_values
:
type_names
.
append
(
str
(
t
))
types_info
=
'['
+
", "
.
join
(
type_names
)
+
']'
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg_key
}
` should be in
{
types_info
}
,'
f
' but got
{
elem_type
}
.'
)
return
(
arg_key
,
elem_type
)
def
_check_types_same
(
arg1
,
arg2
):
...
...
@@ -212,7 +219,7 @@ class Validator:
arg2_name
,
arg2_type
=
arg2
if
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg2_name
}
` should be same as `
{
arg1_name
}
`,'
f
' but `
{
arg1_name
}
`
is
{
arg1_type
}
and `
{
arg2_name
}
` is
{
arg2_type
}
.'
)
f
' but `
{
arg1_name
}
`
with type
{
arg1_type
}
and `
{
arg2_name
}
` with type
{
arg2_type
}
.'
)
return
arg1
elem_types
=
map
(
_check_tensor_type
,
args
.
items
())
...
...
@@ -221,25 +228,8 @@ class Validator:
@
staticmethod
def
check_tensor_type_same
(
args
,
valid_values
,
prim_name
):
"""Checks whether the element types of input tensors are the same."""
def
_check_tensor_type
(
arg
):
arg_key
,
arg_val
=
arg
Validator
.
check_subclass
(
arg_key
,
arg_val
,
mstype
.
tensor
,
prim_name
)
elem_type
=
arg_val
.
element_type
()
if
not
elem_type
in
valid_values
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
element type of `
{
arg_key
}
` should be in
{
valid_values
}
,'
f
' but element type of `
{
arg_key
}
` is
{
elem_type
}
.'
)
return
(
arg_key
,
elem_type
)
def
_check_types_same
(
arg1
,
arg2
):
arg1_name
,
arg1_type
=
arg1
arg2_name
,
arg2_type
=
arg2
if
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
element type of `
{
arg2_name
}
` should be same as `
{
arg1_name
}
`,'
f
' but `
{
arg1_name
}
` is
{
arg1_type
}
and `
{
arg2_name
}
` is
{
arg2_type
}
.'
)
return
arg1
elem_types
=
map
(
_check_tensor_type
,
args
.
items
())
reduce
(
_check_types_same
,
elem_types
)
tensor_types
=
[
mstype
.
tensor_type
(
t
)
for
t
in
valid_values
]
Validator
.
check_type_same
(
args
,
tensor_types
,
prim_name
)
@
staticmethod
def
check_scalar_or_tensor_type_same
(
args
,
valid_values
,
prim_name
,
allow_mix
=
False
):
...
...
mindspore/context.py
浏览文件 @
6dd72f65
...
...
@@ -34,7 +34,7 @@ GRAPH_MODE = 0
PYNATIVE_MODE
=
1
def
_make_directory
(
path
:
str
):
def
_make_directory
(
path
):
"""Make directory."""
real_path
=
None
if
path
is
None
or
not
isinstance
(
path
,
str
)
or
path
.
strip
()
==
""
:
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
6dd72f65
此差异已折叠。
点击以展开。
tests/ut/python/nn/test_dynamic_lr.py
浏览文件 @
6dd72f65
...
...
@@ -41,7 +41,7 @@ class TestInputs:
dr
.
piecewise_constant_lr
(
milestone1
,
learning_rates
)
milestone2
=
[
1.0
,
2.0
,
True
]
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
piecewise_constant_lr
(
milestone2
,
learning_rates
)
def
test_learning_rates1
(
self
):
...
...
@@ -92,13 +92,13 @@ class TestInputs:
def
test_total_step1
(
self
):
total_step1
=
2.0
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step1
,
step_per_epoch
,
decay_epoch
)
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step1
,
step_per_epoch
,
decay_epoch
)
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step1
,
step_per_epoch
,
decay_epoch
,
power
)
def
test_total_step2
(
self
):
...
...
@@ -114,13 +114,13 @@ class TestInputs:
def
test_step_per_epoch1
(
self
):
step_per_epoch1
=
True
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch1
,
decay_epoch
)
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step
,
step_per_epoch1
,
decay_epoch
)
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch1
,
decay_epoch
,
power
)
def
test_step_per_epoch2
(
self
):
...
...
@@ -136,13 +136,13 @@ class TestInputs:
def
test_decay_epoch1
(
self
):
decay_epoch1
=
'm'
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
exponential_decay_lr
(
learning_rate
,
decay_rate
,
total_step
,
step_per_epoch
,
decay_epoch1
)
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
cosine_decay_lr
(
min_lr
,
max_lr
,
total_step
,
step_per_epoch
,
decay_epoch1
)
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
dr
.
polynomial_decay_lr
(
learning_rate
,
end_learning_rate
,
total_step
,
step_per_epoch
,
decay_epoch1
,
power
)
def
test_decay_epoch2
(
self
):
...
...
tests/ut/python/nn/test_ssim.py
浏览文件 @
6dd72f65
...
...
@@ -60,7 +60,7 @@ def test_ssim_max_val_zero():
net
=
SSIMNet
(
max_val
)
def
test_ssim_filter_size_float
():
with
pytest
.
raises
(
Valu
eError
):
with
pytest
.
raises
(
Typ
eError
):
net
=
SSIMNet
(
filter_size
=
1.1
)
def
test_ssim_filter_size_zero
():
...
...
tests/ut/python/ops/test_nn_ops.py
浏览文件 @
6dd72f65
...
...
@@ -516,7 +516,7 @@ test_cases = [
test_cases_for_verify_exception
=
[
(
'Conv2d_ValueError_1'
,
{
'block'
:
(
lambda
_
:
P
.
Conv2D
(
3
,
4
,
mode
=-
2.0
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
Conv2D
(
3
,
4
,
mode
=-
2.0
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'Conv2d_ValueError_2'
,
{
...
...
@@ -528,7 +528,7 @@ test_cases_for_verify_exception = [
'desc_inputs'
:
[
0
],
}),
(
'MaxPoolWithArgmax_ValueError_2'
,
{
'block'
:
(
lambda
_
:
P
.
MaxPoolWithArgmax
(
ksize
=
'1'
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
MaxPoolWithArgmax
(
ksize
=
'1'
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'MaxPoolWithArgmax_ValueError_3'
,
{
...
...
@@ -540,7 +540,7 @@ test_cases_for_verify_exception = [
'desc_inputs'
:
[
0
],
}),
(
'FusedBatchNorm_ValueError_1'
,
{
'block'
:
(
lambda
_
:
P
.
FusedBatchNorm
(
mode
=
"1"
,
epsilon
=
1e-5
,
momentum
=
0.1
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
FusedBatchNorm
(
mode
=
"1"
,
epsilon
=
1e-5
,
momentum
=
0.1
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'FusedBatchNorm_ValueError_2'
,
{
...
...
@@ -560,31 +560,31 @@ test_cases_for_verify_exception = [
'desc_inputs'
:
[
0
],
}),
(
'Softmax_ValueError_1'
,
{
'block'
:
(
lambda
_
:
P
.
Softmax
(
"1"
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
Softmax
(
"1"
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'Softmax_ValueError_2'
,
{
'block'
:
(
lambda
_
:
P
.
Softmax
(
1.1
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
Softmax
(
1.1
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'Softmax_ValueError_3'
,
{
'block'
:
(
lambda
_
:
P
.
Softmax
(
axis
=
"1"
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
Softmax
(
axis
=
"1"
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'DropoutGenMask_ValueError_1'
,
{
'block'
:
(
lambda
_
:
P
.
DropoutGenMask
(
Seed0
=
"seed0"
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
DropoutGenMask
(
Seed0
=
"seed0"
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'DropoutGenMask_ValueError_2'
,
{
'block'
:
(
lambda
_
:
P
.
DropoutGenMask
(
Seed0
=
1.0
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
DropoutGenMask
(
Seed0
=
1.0
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'DropoutGenMask_ValueError_3'
,
{
'block'
:
(
lambda
_
:
P
.
DropoutGenMask
(
Seed1
=
"seed1"
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
DropoutGenMask
(
Seed1
=
"seed1"
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'DropoutGenMask_ValueError_4'
,
{
'block'
:
(
lambda
_
:
P
.
DropoutGenMask
(
Seed1
=
2.0
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
_
:
P
.
DropoutGenMask
(
Seed1
=
2.0
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
0
],
}),
(
'MaxPool2d_ValueError_1'
,
{
...
...
tests/ut/python/ops/test_nn_ops_check.py
0 → 100755
浏览文件 @
6dd72f65
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录