Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindspore
提交
69ed72f1
M
mindspore
项目概览
MindSpore
/
mindspore
通知
35
Star
15
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
69ed72f1
编写于
4月 02, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add primitive name to param error message for math_ops.py
上级
cc0ba93d
变更
11
展开全部
隐藏空白更改
内联
并排
Showing
11 changed file
with
1003 addition
and
129 deletion
+1003
-129
mindspore/_checkparam.py
mindspore/_checkparam.py
+125
-1
mindspore/ops/_utils/broadcast.py
mindspore/ops/_utils/broadcast.py
+4
-2
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+10
-3
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+92
-114
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+3
-0
tests/mindspore_test_framework/components/executor/check_exceptions.py
...re_test_framework/components/executor/check_exceptions.py
+8
-3
tests/mindspore_test_framework/utils/config_util.py
tests/mindspore_test_framework/utils/config_util.py
+3
-2
tests/mindspore_test_framework/utils/facade_util.py
tests/mindspore_test_framework/utils/facade_util.py
+5
-3
tests/mindspore_test_framework/utils/keyword.py
tests/mindspore_test_framework/utils/keyword.py
+1
-0
tests/ut/python/ops/test_array_ops.py
tests/ut/python/ops/test_array_ops.py
+1
-1
tests/ut/python/ops/test_math_ops_check.py
tests/ut/python/ops/test_math_ops_check.py
+751
-0
未找到文件。
mindspore/_checkparam.py
浏览文件 @
69ed72f1
...
...
@@ -15,6 +15,7 @@
"""Check parameters."""
import
re
from
enum
import
Enum
from
functools
import
reduce
from
itertools
import
repeat
from
collections
import
Iterable
...
...
@@ -93,8 +94,131 @@ rel_strs = {
}
class
Validator
:
"""validator for checking input parameters"""
@
staticmethod
def
check
(
arg_name
,
arg_value
,
value_name
,
value
,
rel
=
Rel
.
EQ
,
prim_name
=
None
):
"""
Method for judging relation between two int values or list/tuple made up of ints.
This method is not suitable for judging relation between floats, since it does not consider float error.
"""
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
f
'
{
value_name
}
:
{
value
}
'
)
msg_prefix
=
f
'For
{
prim_name
}
the'
if
prim_name
else
"The"
raise
ValueError
(
f
'
{
msg_prefix
}
`
{
arg_name
}
` should be
{
rel_str
}
, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_integer
(
arg_name
,
arg_value
,
value
,
rel
,
prim_name
):
"""Integer value judgment."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
or
isinstance
(
arg_value
,
bool
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'For
{
prim_name
}
the `
{
arg_name
}
` should be an int and must
{
rel_str
}
,'
f
' but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_int_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
,
prim_name
):
"""Method for checking whether an int value is in some range."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_name
}
` should be an int in range
{
rel_str
}
,'
f
' but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
prim_name
):
"""Check whether some type is sublcass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
template_type
=
(
template_type
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
type_str
=
(
type
(
type_
).
__name__
if
isinstance
(
type_
,
(
tuple
,
list
))
else
""
)
+
str
(
type_
)
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
the type of `
{
arg_name
}
` should be subclass'
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
))
}
, but got
{
type_str
}
.'
)
@
staticmethod
def
check_tensor_type_same
(
args
,
valid_values
,
prim_name
):
"""check whether the element types of input tensors are the same."""
def
_check_tensor_type
(
arg
):
arg_key
,
arg_val
=
arg
Validator
.
check_subclass
(
arg_key
,
arg_val
,
mstype
.
tensor
,
prim_name
)
elem_type
=
arg_val
.
element_type
()
if
not
elem_type
in
valid_values
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
element type of `
{
arg_key
}
` should be in
{
valid_values
}
,'
f
' but `
{
arg_key
}
` is
{
elem_type
}
.'
)
return
(
arg_key
,
elem_type
)
def
_check_types_same
(
arg1
,
arg2
):
arg1_name
,
arg1_type
=
arg1
arg2_name
,
arg2_type
=
arg2
if
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
element type of `
{
arg2_name
}
` should be same as `
{
arg1_name
}
`,'
f
' but `
{
arg1_name
}
` is
{
arg1_type
}
and `
{
arg2_name
}
` is
{
arg2_type
}
.'
)
return
arg1
elem_types
=
map
(
_check_tensor_type
,
args
.
items
())
reduce
(
_check_types_same
,
elem_types
)
@
staticmethod
def
check_scalar_or_tensor_type_same
(
args
,
valid_values
,
prim_name
):
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
def
_check_argument_type
(
arg
):
arg_key
,
arg_val
=
arg
if
isinstance
(
arg_val
,
type
(
mstype
.
tensor
)):
arg_val
=
arg_val
.
element_type
()
if
not
arg_val
in
valid_values
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
the `
{
arg_key
}
` should be in
{
valid_values
}
,'
f
' but `
{
arg_key
}
` is
{
arg_val
}
.'
)
return
arg
def
_check_types_same
(
arg1
,
arg2
):
arg1_name
,
arg1_type
=
arg1
arg2_name
,
arg2_type
=
arg2
excp_flag
=
False
if
isinstance
(
arg1_type
,
type
(
mstype
.
tensor
))
and
isinstance
(
arg2_type
,
type
(
mstype
.
tensor
)):
arg1_type
=
arg1_type
.
element_type
()
arg2_type
=
arg2_type
.
element_type
()
elif
not
(
isinstance
(
arg1_type
,
type
(
mstype
.
tensor
))
or
isinstance
(
arg2_type
,
type
(
mstype
.
tensor
))):
pass
else
:
excp_flag
=
True
if
excp_flag
or
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg2_name
}
` should be same as `
{
arg1_name
}
`,'
f
' but `
{
arg1_name
}
` is
{
arg1_type
}
and `
{
arg2_name
}
` is
{
arg2_type
}
.'
)
return
arg1
reduce
(
_check_types_same
,
map
(
_check_argument_type
,
args
.
items
()))
@
staticmethod
def
check_value_type
(
arg_name
,
arg_value
,
valid_types
,
prim_name
):
"""Check whether a values is instance of some types."""
def
raise_error_msg
():
"""func for raising error message when check failed"""
type_names
=
[
t
.
__name__
for
t
in
valid_types
]
num_types
=
len
(
valid_types
)
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
the type of `
{
arg_name
}
` should be '
f
'
{
"one of "
if
num_types
>
1
else
""
}
'
f
'
{
type_names
if
num_types
>
1
else
type_names
[
0
]
}
, but got
{
type
(
arg_value
).
__name__
}
.'
)
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass
if
isinstance
(
arg_value
,
bool
)
and
bool
not
in
tuple
(
valid_types
):
raise_error_msg
()
if
isinstance
(
arg_value
,
tuple
(
valid_types
)):
return
arg_value
raise_error_msg
()
class
ParamValidator
:
"""Parameter validator."""
"""Parameter validator.
NOTICE: this class will be replaced by `class Validator`
"""
@
staticmethod
def
equal
(
arg_name
,
arg_value
,
cond_str
,
cond
):
...
...
mindspore/ops/_utils/broadcast.py
浏览文件 @
69ed72f1
...
...
@@ -16,13 +16,14 @@
"""broadcast"""
def
_get_broadcast_shape
(
x_shape
,
y_shape
):
def
_get_broadcast_shape
(
x_shape
,
y_shape
,
prim_name
):
"""
Doing broadcast between tensor x and tensor y.
Args:
x_shape (list): The shape of tensor x.
y_shape (list): The shape of tensor y.
prim_name (str): Primitive name.
Returns:
List, the shape that broadcast between tensor x and tensor y.
...
...
@@ -50,7 +51,8 @@ def _get_broadcast_shape(x_shape, y_shape):
elif
x_shape
[
i
]
==
y_shape
[
i
]:
broadcast_shape_back
.
append
(
x_shape
[
i
])
else
:
raise
ValueError
(
"The x_shape {} and y_shape {} can not broadcast."
.
format
(
x_shape
,
y_shape
))
raise
ValueError
(
"For '{}' the x_shape {} and y_shape {} can not broadcast."
.
format
(
prim_name
,
x_shape
,
y_shape
))
broadcast_shape_front
=
y_shape
[
0
:
y_len
-
length
]
if
length
==
x_len
else
x_shape
[
0
:
x_len
-
length
]
broadcast_shape
=
broadcast_shape_front
+
broadcast_shape_back
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
69ed72f1
...
...
@@ -28,9 +28,16 @@ from ..._checkparam import ParamValidator as validator
from
..._checkparam
import
Rel
from
...common
import
dtype
as
mstype
from
...common.tensor
import
Tensor
from
..operations.math_ops
import
_
check_infer_attr_reduce
,
_
infer_shape_reduce
from
..operations.math_ops
import
_infer_shape_reduce
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
def
_check_infer_attr_reduce
(
axis
,
keep_dims
):
validator
.
check_type
(
'keep_dims'
,
keep_dims
,
[
bool
])
validator
.
check_type
(
'axis'
,
axis
,
[
int
,
tuple
])
if
isinstance
(
axis
,
tuple
):
for
index
,
value
in
enumerate
(
axis
):
validator
.
check_type
(
'axis[%d]'
%
index
,
value
,
[
int
])
class
ExpandDims
(
PrimitiveWithInfer
):
"""
...
...
@@ -1091,7 +1098,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis
=
self
.
axis
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
prim_name
()
)
return
ouput_shape
,
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
...
...
@@ -1137,7 +1144,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis
=
self
.
axis
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
prim_name
()
)
return
ouput_shape
,
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
69ed72f1
此差异已折叠。
点击以展开。
mindspore/ops/primitive.py
浏览文件 @
69ed72f1
...
...
@@ -194,6 +194,9 @@ class PrimitiveWithInfer(Primitive):
Primitive
.
__init__
(
self
,
name
)
self
.
set_prim_type
(
prim_type
.
py_infer_shape
)
def
prim_name
(
self
):
return
self
.
__class__
.
__name__
def
_clone
(
self
):
"""
Deeply clones the primitive object.
...
...
tests/mindspore_test_framework/components/executor/check_exceptions.py
浏览文件 @
69ed72f1
...
...
@@ -23,20 +23,25 @@ from ...utils import keyword
class
CheckExceptionsEC
(
IExectorComponent
):
"""
Check if the function raises the expected Exception.
Check if the function raises the expected Exception
and the error message contains specified keywords if not None
.
Examples:
{
'block': f,
'exception': Exception
'exception': Exception,
'error_keywords': ['TensorAdd', 'shape']
}
"""
def
run_function
(
self
,
function
,
inputs
,
verification_set
):
f
=
function
[
keyword
.
block
]
args
=
inputs
[
keyword
.
desc_inputs
]
e
=
function
.
get
(
keyword
.
exception
,
Exception
)
error_kws
=
function
.
get
(
keyword
.
error_keywords
,
None
)
try
:
with
pytest
.
raises
(
e
):
with
pytest
.
raises
(
e
)
as
exec_info
:
f
(
*
args
)
except
:
raise
Exception
(
f
"Expect
{
e
}
, but got
{
sys
.
exc_info
()[
0
]
}
"
)
if
error_kws
and
any
(
keyword
not
in
str
(
exec_info
.
value
)
for
keyword
in
error_kws
):
raise
ValueError
(
'Error message `{}` does not contain all keywords `{}`'
.
format
(
str
(
exec_info
.
value
),
error_kws
))
tests/mindspore_test_framework/utils/config_util.py
浏览文件 @
69ed72f1
...
...
@@ -87,8 +87,9 @@ def get_function_config(function):
init_param_with
=
function
.
get
(
keyword
.
init_param_with
,
None
)
split_outputs
=
function
.
get
(
keyword
.
split_outputs
,
True
)
exception
=
function
.
get
(
keyword
.
exception
,
Exception
)
error_keywords
=
function
.
get
(
keyword
.
error_keywords
,
None
)
return
delta
,
max_error
,
input_selector
,
output_selector
,
sampling_times
,
\
reduce_output
,
init_param_with
,
split_outputs
,
exception
reduce_output
,
init_param_with
,
split_outputs
,
exception
,
error_keywords
def
get_grad_checking_options
(
function
,
inputs
):
"""
...
...
@@ -104,6 +105,6 @@ def get_grad_checking_options(function, inputs):
"""
f
=
function
[
keyword
.
block
]
args
=
inputs
[
keyword
.
desc_inputs
]
delta
,
max_error
,
input_selector
,
output_selector
,
sampling_times
,
reduce_output
,
_
,
_
,
_
=
\
delta
,
max_error
,
input_selector
,
output_selector
,
sampling_times
,
reduce_output
,
_
,
_
,
_
,
_
=
\
get_function_config
(
function
)
return
f
,
args
,
delta
,
max_error
,
input_selector
,
output_selector
,
sampling_times
,
reduce_output
tests/mindspore_test_framework/utils/facade_util.py
浏览文件 @
69ed72f1
...
...
@@ -54,11 +54,12 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
block
=
block_config
delta
,
max_error
,
input_selector
,
output_selector
,
\
sampling_times
,
reduce_output
,
init_param_with
,
split_outputs
,
exception
=
get_function_config
({})
sampling_times
,
reduce_output
,
init_param_with
,
split_outputs
,
exception
,
error_keywords
=
get_function_config
({})
if
isinstance
(
block_config
,
tuple
)
and
isinstance
(
block_config
[
-
1
],
dict
):
block
=
block_config
[
0
]
delta
,
max_error
,
input_selector
,
output_selector
,
\
sampling_times
,
reduce_output
,
init_param_with
,
split_outputs
,
exception
=
get_function_config
(
block_config
[
-
1
])
sampling_times
,
reduce_output
,
init_param_with
,
\
split_outputs
,
exception
,
error_keywords
=
get_function_config
(
block_config
[
-
1
])
if
block
:
func_list
.
append
({
...
...
@@ -78,7 +79,8 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
keyword
.
const_first
:
const_first
,
keyword
.
add_fake_input
:
add_fake_input
,
keyword
.
split_outputs
:
split_outputs
,
keyword
.
exception
:
exception
keyword
.
exception
:
exception
,
keyword
.
error_keywords
:
error_keywords
})
if
desc_inputs
or
desc_const
:
...
...
tests/mindspore_test_framework/utils/keyword.py
浏览文件 @
69ed72f1
...
...
@@ -73,5 +73,6 @@ keyword.const_first = "const_first"
keyword
.
add_fake_input
=
"add_fake_input"
keyword
.
fake_input_type
=
"fake_input_type"
keyword
.
exception
=
"exception"
keyword
.
error_keywords
=
"error_keywords"
sys
.
modules
[
__name__
]
=
keyword
tests/ut/python/ops/test_array_ops.py
浏览文件 @
69ed72f1
...
...
@@ -234,7 +234,7 @@ raise_set = [
'block'
:
(
lambda
x
:
P
.
Squeeze
(
axis
=
((
1.2
,
1.3
))),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
(
shape
=
[
3
,
1
,
5
]))]}),
(
'ReduceSum_Error'
,
{
'block'
:
(
lambda
x
:
P
.
ReduceSum
(
keep_dims
=
1
),
{
'exception'
:
Valu
eError
}),
'block'
:
(
lambda
x
:
P
.
ReduceSum
(
keep_dims
=
1
),
{
'exception'
:
Typ
eError
}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
(
shape
=
[
3
,
1
,
5
]))]}),
]
...
...
tests/ut/python/ops/test_math_ops_check.py
0 → 100755
浏览文件 @
69ed72f1
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录