Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
20782294
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
20782294
编写于
4月 14, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add prim name to error message for array_ops
上级
789edcb2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
358 addition
and
472 deletion
+358
-472
mindspore/_checkparam.py
mindspore/_checkparam.py
+1
-243
mindspore/ccsrc/optimizer/ad/dfunctor.cc
mindspore/ccsrc/optimizer/ad/dfunctor.cc
+1
-1
mindspore/nn/layer/pooling.py
mindspore/nn/layer/pooling.py
+5
-6
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+184
-214
tests/ut/python/ops/test_array_ops_check.py
tests/ut/python/ops/test_array_ops_check.py
+159
-0
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+2
-2
tests/vm_impl/vm_me.py
tests/vm_impl/vm_me.py
+6
-6
未找到文件。
mindspore/_checkparam.py
浏览文件 @
20782294
...
@@ -210,7 +210,7 @@ class Validator:
...
@@ -210,7 +210,7 @@ class Validator:
type_names
=
[]
type_names
=
[]
for
t
in
valid_values
:
for
t
in
valid_values
:
type_names
.
append
(
str
(
t
))
type_names
.
append
(
str
(
t
))
types_info
=
'['
+
", "
.
join
(
type_names
)
+
']'
types_info
=
'['
+
', '
.
join
(
type_names
)
+
']'
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg_key
}
` should be in
{
types_info
}
,'
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg_key
}
` should be in
{
types_info
}
,'
f
' but got
{
elem_type
}
.'
)
f
' but got
{
elem_type
}
.'
)
return
(
arg_key
,
elem_type
)
return
(
arg_key
,
elem_type
)
...
@@ -320,224 +320,6 @@ class Validator:
...
@@ -320,224 +320,6 @@ class Validator:
raise
TypeError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be float."
)
raise
TypeError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be float."
)
class
ParamValidator
:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@
staticmethod
def
equal
(
arg_name
,
arg_value
,
cond_str
,
cond
):
"""Judging valid value."""
if
not
cond
:
raise
ValueError
(
f
'The `
{
arg_name
}
` must be
{
cond_str
}
, but got
{
arg_value
}
.'
)
@
staticmethod
def
check
(
arg_name
,
arg_value
,
value_name
,
value
,
rel
=
Rel
.
EQ
):
"""This method is only used for check int values, since when compare float values,
we need consider float error."""
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
f
'
{
value_name
}
:
{
value
}
'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be
{
rel_str
}
, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_integer
(
arg_name
,
arg_value
,
value
,
rel
):
"""Integer value judgment."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
or
isinstance
(
arg_value
,
bool
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_shape_length
(
arg_name
,
arg_value
,
value
,
rel
):
"""Shape length judgment."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The length of `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
'
)
return
arg_value
@
staticmethod
def
check_int_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int in range
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_isinstance
(
arg_name
,
arg_value
,
classes
):
"""Check arg isinstance of classes"""
if
not
isinstance
(
arg_value
,
classes
):
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isinstance of
{
classes
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_number_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
):
"""Is it necessary to consider error when comparing float values."""
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be in range
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
with_type_of
=
True
):
"""Check whether some type is subclass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
template_type
=
(
template_type
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
type_str
=
(
type
(
type_
).
__name__
if
isinstance
(
type_
,
(
tuple
,
list
))
else
""
)
+
str
(
type_
)
raise
TypeError
(
f
'The
{
"type of"
if
with_type_of
else
""
}
`
{
arg_name
}
` should be subclass'
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
))
}
, but got
{
type_str
}
.'
)
@
staticmethod
def
check_args_tensor
(
args
):
"""Check whether args are all tensor."""
if
not
isinstance
(
args
,
dict
):
raise
TypeError
(
"The args should be a dict."
)
for
arg
,
value
in
args
.
items
():
ParamValidator
.
check_subclass
(
arg
,
value
,
mstype
.
tensor
)
@
staticmethod
def
check_bool
(
arg_name
,
arg_value
):
"""Check arg isinstance of bool"""
if
not
isinstance
(
arg_value
,
bool
):
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isinstance of bool, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_type
(
arg_name
,
arg_value
,
valid_types
):
"""Type checking."""
def
raise_error_msg
():
"""func for raising error message when check failed"""
type_names
=
[
t
.
__name__
for
t
in
valid_types
]
num_types
=
len
(
valid_types
)
raise
TypeError
(
f
'The type of `
{
arg_name
}
` should be
{
"one of "
if
num_types
>
1
else
""
}
'
f
'
{
type_names
if
num_types
>
1
else
type_names
[
0
]
}
, but got
{
type
(
arg_value
).
__name__
}
.'
)
if
isinstance
(
arg_value
,
type
(
mstype
.
tensor
)):
arg_value
=
arg_value
.
element_type
()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if
isinstance
(
arg_value
,
bool
)
and
bool
not
in
tuple
(
valid_types
):
raise_error_msg
()
if
isinstance
(
arg_value
,
tuple
(
valid_types
)):
return
arg_value
raise_error_msg
()
@
staticmethod
def
check_typename
(
arg_name
,
arg_type
,
valid_types
):
"""Does it contain the _name_ attribute."""
def
get_typename
(
t
):
return
t
.
__name__
if
hasattr
(
t
,
'__name__'
)
else
str
(
t
)
if
isinstance
(
arg_type
,
type
(
mstype
.
tensor
)):
arg_type
=
arg_type
.
element_type
()
if
arg_type
in
valid_types
:
return
arg_type
type_names
=
[
get_typename
(
t
)
for
t
in
valid_types
]
if
len
(
valid_types
)
==
1
:
raise
ValueError
(
f
'The type of `
{
arg_name
}
` should be
{
type_names
[
0
]
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
raise
ValueError
(
f
'The type of `
{
arg_name
}
` should be one of
{
type_names
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
@
staticmethod
def
check_string
(
arg_name
,
arg_value
,
valid_values
):
"""String type judgment."""
if
isinstance
(
arg_value
,
str
)
and
arg_value
in
valid_values
:
return
arg_value
if
len
(
valid_values
)
==
1
:
raise
ValueError
(
f
'The `
{
arg_name
}
` should be str and must be
{
valid_values
[
0
]
}
,'
f
' but got
{
arg_value
}
.'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be str and must be one of
{
valid_values
}
,'
f
' but got
{
arg_value
}
.'
)
@
staticmethod
def
check_type_same
(
args
,
valid_values
):
"""Determine whether the types are the same."""
name
=
list
(
args
.
keys
())[
0
]
value
=
list
(
args
.
values
())[
0
]
if
isinstance
(
value
,
type
(
mstype
.
tensor
)):
value
=
value
.
element_type
()
for
arg_name
,
arg_value
in
args
.
items
():
if
isinstance
(
arg_value
,
type
(
mstype
.
tensor
)):
arg_value
=
arg_value
.
element_type
()
if
arg_value
not
in
valid_values
:
raise
TypeError
(
f
'The `
{
arg_name
}
` should be in
{
valid_values
}
,'
f
' but `
{
arg_name
}
` is
{
arg_value
}
.'
)
if
arg_value
!=
value
:
raise
TypeError
(
f
'`
{
arg_name
}
` should be same as `
{
name
}
`,'
f
' but `
{
arg_name
}
` is
{
arg_value
}
, `
{
name
}
` is
{
value
}
.'
)
@
staticmethod
def
check_two_types_same
(
arg1_name
,
arg1_type
,
arg2_name
,
arg2_type
):
"""Determine whether the types of two variables are the same."""
if
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'The type of `
{
arg1_name
}
` and `
{
arg2_name
}
` should be same.'
)
@
staticmethod
def
check_value_on_integer
(
arg_name
,
arg_value
,
value
,
rel
):
"""Judging integer type."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_match
=
isinstance
(
arg_value
,
int
)
if
type_match
and
(
not
rel_fn
(
arg_value
,
value
)):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_param_equal
(
param1_name
,
param1_value
,
param2_name
,
param2_value
):
"""Judging the equality of parameters."""
if
param1_value
!=
param2_value
:
raise
ValueError
(
f
"`
{
param1_name
}
` must equal `
{
param2_name
}
`,"
f
" but got `
{
param1_name
}
` =
{
param1_value
}
,"
f
" `
{
param2_name
}
` =
{
param2_value
}
."
)
@
staticmethod
def
check_const_input
(
arg_name
,
arg_value
):
"""Check valid value."""
if
arg_value
is
None
:
raise
ValueError
(
f
'The `
{
arg_name
}
` must be a const input, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_float_positive
(
arg_name
,
arg_value
):
"""Float type judgment."""
if
isinstance
(
arg_value
,
float
):
if
arg_value
>
0
:
return
arg_value
raise
ValueError
(
f
"The `
{
arg_name
}
` must be positive, but got
{
arg_value
}
."
)
raise
TypeError
(
f
"`
{
arg_name
}
` must be float!"
)
@
staticmethod
def
check_pad_value_by_mode
(
op_name
,
pad_mode
,
padding
):
"""Validate value of padding according to pad_mode"""
if
pad_mode
!=
'pad'
and
padding
!=
0
:
raise
ValueError
(
f
"For op '
{
op_name
}
', padding must be zero when pad_mode is '
{
pad_mode
}
'."
)
return
padding
@
staticmethod
def
check_empty_shape_input
(
arg_name
,
arg_value
):
"""Check zeros value."""
if
0
in
arg_value
:
raise
ValueError
(
f
"Input `
{
arg_name
}
` cannot be empty."
)
@
staticmethod
def
check_scalar_shape_input
(
arg_name
,
arg_value
):
"""Check scalar shape input."""
if
arg_value
!=
[]:
raise
ValueError
(
f
"Input `
{
arg_name
}
` shape should be (). got
{
arg_value
}
"
)
def
check_int
(
input_param
):
def
check_int
(
input_param
):
"""Int type judgment."""
"""Int type judgment."""
if
isinstance
(
input_param
,
int
)
and
not
isinstance
(
input_param
,
bool
):
if
isinstance
(
input_param
,
int
)
and
not
isinstance
(
input_param
,
bool
):
...
@@ -653,30 +435,6 @@ def check_output_data(data):
...
@@ -653,30 +435,6 @@ def check_output_data(data):
raise
RuntimeError
(
'Executor return data '
+
str
(
data
)
+
', please check your net or input data.'
)
raise
RuntimeError
(
'Executor return data '
+
str
(
data
)
+
', please check your net or input data.'
)
def
check_axis_type_int
(
axis
):
"""Check axis type."""
if
not
isinstance
(
axis
,
int
):
raise
TypeError
(
'Wrong type for axis, should be int.'
)
def
check_axis_range
(
axis
,
rank
):
"""Check axis range."""
if
not
-
rank
<=
axis
<
rank
:
raise
ValueError
(
'The axis should be in range [{}, {}),'' but got {}.'
.
format
(
-
rank
,
rank
,
axis
))
def
check_attr_int
(
attr_name
,
attr
):
"""Check int type."""
if
not
isinstance
(
attr
,
int
):
raise
TypeError
(
"The attr {} should be int, but got {}."
.
format
(
attr_name
,
type
(
attr
)))
def
check_t_in_range
(
t
):
"""Check input range."""
if
t
not
in
(
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
,
mstype
.
int32
,
mstype
.
int64
):
raise
ValueError
(
"The param T should be (float16, float32, float64, int32, int64)."
)
once
=
_expand_tuple
(
1
)
once
=
_expand_tuple
(
1
)
twice
=
_expand_tuple
(
2
)
twice
=
_expand_tuple
(
2
)
triple
=
_expand_tuple
(
3
)
triple
=
_expand_tuple
(
3
)
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.cc
浏览文件 @
20782294
...
@@ -175,7 +175,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
...
@@ -175,7 +175,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
UpdateAdjoint
(
node_adjoint
);
UpdateAdjoint
(
node_adjoint
);
anfnode_to_adjoin_
[
morph
]
=
node_adjoint
;
anfnode_to_adjoin_
[
morph
]
=
node_adjoint
;
if
(
cnode_morph
->
stop_gradient
())
{
if
(
cnode_morph
->
stop_gradient
())
{
MS_LOG
(
WARNIN
G
)
<<
"MapMorphism node "
<<
morph
->
ToString
()
<<
" is stopped."
;
MS_LOG
(
DEBU
G
)
<<
"MapMorphism node "
<<
morph
->
ToString
()
<<
" is stopped."
;
return
node_adjoint
;
return
node_adjoint
;
}
}
...
...
mindspore/nn/layer/pooling.py
浏览文件 @
20782294
...
@@ -19,7 +19,6 @@ from mindspore._checkparam import Validator as validator
...
@@ -19,7 +19,6 @@ from mindspore._checkparam import Validator as validator
from
...
import
context
from
...
import
context
from
..cell
import
Cell
from
..cell
import
Cell
from
..._checkparam
import
Rel
from
..._checkparam
import
Rel
from
..._checkparam
import
ParamValidator
class
_PoolNd
(
Cell
):
class
_PoolNd
(
Cell
):
...
@@ -265,11 +264,11 @@ class AvgPool1d(_PoolNd):
...
@@ -265,11 +264,11 @@ class AvgPool1d(_PoolNd):
stride
=
1
,
stride
=
1
,
pad_mode
=
"valid"
):
pad_mode
=
"valid"
):
super
(
AvgPool1d
,
self
).
__init__
(
kernel_size
,
stride
,
pad_mode
)
super
(
AvgPool1d
,
self
).
__init__
(
kernel_size
,
stride
,
pad_mode
)
ParamValidator
.
check_type
(
'kernel_size'
,
kernel_size
,
[
int
,]
)
validator
.
check_value_type
(
'kernel_size'
,
kernel_size
,
[
int
],
self
.
cls_name
)
ParamValidator
.
check_type
(
'stride'
,
stride
,
[
int
,]
)
validator
.
check_value_type
(
'stride'
,
stride
,
[
int
],
self
.
cls_name
)
self
.
pad_mode
=
ParamValidator
.
check_string
(
'pad_mode'
,
pad_mode
.
upper
(),
[
'VALID'
,
'SAME'
]
)
self
.
pad_mode
=
validator
.
check_string
(
'pad_mode'
,
pad_mode
.
upper
(),
[
'VALID'
,
'SAME'
],
self
.
cls_name
)
ParamValidator
.
check_integer
(
"kernel_size"
,
kernel_size
,
1
,
Rel
.
GE
)
validator
.
check_integer
(
"kernel_size"
,
kernel_size
,
1
,
Rel
.
GE
,
self
.
cls_name
)
ParamValidator
.
check_integer
(
"stride"
,
stride
,
1
,
Rel
.
GE
)
validator
.
check_integer
(
"stride"
,
stride
,
1
,
Rel
.
GE
,
self
.
cls_name
)
self
.
kernel_size
=
(
1
,
kernel_size
)
self
.
kernel_size
=
(
1
,
kernel_size
)
self
.
stride
=
(
1
,
stride
)
self
.
stride
=
(
1
,
stride
)
self
.
avg_pool
=
P
.
AvgPool
(
ksize
=
self
.
kernel_size
,
self
.
avg_pool
=
P
.
AvgPool
(
ksize
=
self
.
kernel_size
,
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
20782294
...
@@ -24,7 +24,7 @@ import itertools
...
@@ -24,7 +24,7 @@ import itertools
import
numbers
import
numbers
import
numpy
as
np
import
numpy
as
np
from
..._checkparam
import
Param
Validator
as
validator
from
..._checkparam
import
Validator
as
validator
from
..._checkparam
import
Rel
from
..._checkparam
import
Rel
from
...common
import
dtype
as
mstype
from
...common
import
dtype
as
mstype
from
...common.tensor
import
Tensor
from
...common.tensor
import
Tensor
...
@@ -32,12 +32,12 @@ from ..operations.math_ops import _infer_shape_reduce
...
@@ -32,12 +32,12 @@ from ..operations.math_ops import _infer_shape_reduce
from
.._utils
import
_get_concat_offset
from
.._utils
import
_get_concat_offset
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
def
_check_infer_attr_reduce
(
axis
,
keep_dims
):
def
_check_infer_attr_reduce
(
axis
,
keep_dims
,
prim_name
):
validator
.
check_
type
(
'keep_dims'
,
keep_dims
,
[
bool
]
)
validator
.
check_
value_type
(
'keep_dims'
,
keep_dims
,
[
bool
],
prim_name
)
validator
.
check_
type
(
'axis'
,
axis
,
[
int
,
tuple
]
)
validator
.
check_
value_type
(
'axis'
,
axis
,
[
int
,
tuple
],
prim_name
)
if
isinstance
(
axis
,
tuple
):
if
isinstance
(
axis
,
tuple
):
for
index
,
value
in
enumerate
(
axis
):
for
index
,
value
in
enumerate
(
axis
):
validator
.
check_
type
(
'axis[%d]'
%
index
,
value
,
[
int
]
)
validator
.
check_
value_type
(
'axis[%d]'
%
index
,
value
,
[
int
],
prim_name
)
class
ExpandDims
(
PrimitiveWithInfer
):
class
ExpandDims
(
PrimitiveWithInfer
):
...
@@ -74,13 +74,11 @@ class ExpandDims(PrimitiveWithInfer):
...
@@ -74,13 +74,11 @@ class ExpandDims(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'axis'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'axis'
],
outputs
=
[
'output'
])
def
__infer__
(
self
,
x
,
axis
):
def
__infer__
(
self
,
x
,
axis
):
validator
.
check_subclass
(
"input_x"
,
x
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"input_x"
,
x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
x_shape
=
list
(
x
[
'shape'
])
x_shape
=
list
(
x
[
'shape'
])
axis_v
=
axis
[
'value'
]
axis_v
=
axis
[
'value'
]
rank
=
len
(
x_shape
)
rank
=
len
(
x_shape
)
validator
.
check_const_input
(
'axis'
,
axis_v
)
validator
.
check_int_range
(
'axis'
,
axis_v
,
-
rank
-
1
,
rank
,
Rel
.
INC_BOTH
,
self
.
name
)
validator
.
check_type
(
"axis"
,
axis_v
,
[
int
])
validator
.
check_int_range
(
'axis'
,
axis_v
,
-
rank
-
1
,
rank
,
Rel
.
INC_BOTH
)
if
axis_v
<
0
:
if
axis_v
<
0
:
axis_v
=
rank
+
1
+
axis_v
axis_v
=
rank
+
1
+
axis_v
x_shape
.
insert
(
axis_v
,
1
)
x_shape
.
insert
(
axis_v
,
1
)
...
@@ -110,7 +108,7 @@ class DType(PrimitiveWithInfer):
...
@@ -110,7 +108,7 @@ class DType(PrimitiveWithInfer):
"""init DType"""
"""init DType"""
def
__infer__
(
self
,
x
):
def
__infer__
(
self
,
x
):
validator
.
check_subclass
(
"input_x"
,
x
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"input_x"
,
x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
out
=
{
'shape'
:
(),
out
=
{
'shape'
:
(),
'dtype'
:
mstype
.
type_type
,
'dtype'
:
mstype
.
type_type
,
'value'
:
x
[
'dtype'
].
element_type
()}
'value'
:
x
[
'dtype'
].
element_type
()}
...
@@ -144,19 +142,17 @@ class SameTypeShape(PrimitiveWithInfer):
...
@@ -144,19 +142,17 @@ class SameTypeShape(PrimitiveWithInfer):
def
__call__
(
self
,
x
,
y
):
def
__call__
(
self
,
x
,
y
):
"""run in PyNative mode"""
"""run in PyNative mode"""
if
x
.
dtype
()
!=
y
.
dtype
():
validator
.
check_subclass
(
'x'
,
x
.
dtype
(),
mstype
.
tensor
,
self
.
name
)
raise
TypeError
(
f
"The
{
x
}
and
{
y
}
should be same dtype."
)
validator
.
check_subclass
(
'y'
,
y
.
dtype
(),
mstype
.
tensor
,
self
.
name
)
if
x
.
shape
()
!=
y
.
shape
():
validator
.
check
(
'x dtype'
,
x
.
dtype
(),
'y dtype'
,
y
.
dtype
(),
Rel
.
EQ
,
self
.
name
,
TypeError
)
raise
TypeError
(
f
"The
{
x
}
and
{
y
}
should have same shape."
)
validator
.
check
(
'x shape'
,
x
.
shape
(),
'y shape'
,
y
.
shape
(),
Rel
.
EQ
,
self
.
name
)
return
x
return
x
def
__infer__
(
self
,
x
,
y
):
def
__infer__
(
self
,
x
,
y
):
if
x
[
'dtype'
]
!=
y
[
'dtype'
]:
validator
.
check_subclass
(
'x'
,
x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
raise
TypeError
(
f
"The
{
x
}
and
{
y
}
should be same dtype,"
validator
.
check_subclass
(
'y'
,
y
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
f
" but got
{
x
[
'dtype'
]
}
{
y
[
'dtype'
]
}
."
)
validator
.
check
(
'x dtype'
,
x
[
'dtype'
],
'y dtype'
,
y
[
'dtype'
],
Rel
.
EQ
,
self
.
name
,
TypeError
)
if
x
[
'shape'
]
!=
y
[
'shape'
]:
validator
.
check
(
'x shape'
,
x
[
'shape'
],
'y shape'
,
y
[
'shape'
],
Rel
.
EQ
,
self
.
name
)
raise
ValueError
(
f
"The
{
x
}
and
{
y
}
should be same shape,"
f
" but got
{
x
[
'shape'
]
}
{
y
[
'shape'
]
}
."
)
return
x
return
x
...
@@ -191,8 +187,8 @@ class Cast(PrimitiveWithInfer):
...
@@ -191,8 +187,8 @@ class Cast(PrimitiveWithInfer):
src_type
=
x
[
'dtype'
]
src_type
=
x
[
'dtype'
]
dst_type
=
t
[
'value'
]
dst_type
=
t
[
'value'
]
validator
.
check_subclass
(
"input_x"
,
src_type
,
[
mstype
.
tensor
,
mstype
.
number
])
validator
.
check_subclass
(
"input_x"
,
src_type
,
[
mstype
.
tensor
,
mstype
.
number
]
,
self
.
name
)
validator
.
check_subclass
(
"type"
,
dst_type
,
mstype
.
number
,
with_type_of
=
Fals
e
)
validator
.
check_subclass
(
"type"
,
dst_type
,
mstype
.
number
,
self
.
nam
e
)
if
isinstance
(
src_type
,
type
(
mstype
.
tensor
)):
if
isinstance
(
src_type
,
type
(
mstype
.
tensor
)):
src_type
=
x
[
'dtype'
].
element_type
()
src_type
=
x
[
'dtype'
].
element_type
()
...
@@ -238,8 +234,8 @@ class IsSubClass(PrimitiveWithInfer):
...
@@ -238,8 +234,8 @@ class IsSubClass(PrimitiveWithInfer):
sub_type_t
=
sub_type
[
'value'
]
sub_type_t
=
sub_type
[
'value'
]
type_v
=
type_
[
'value'
]
type_v
=
type_
[
'value'
]
validator
.
check_
type
(
"sub_type"
,
sub_type_t
,
[
mstype
.
Type
]
)
validator
.
check_
value_type
(
"sub_type"
,
sub_type_t
,
[
mstype
.
Type
],
self
.
name
)
validator
.
check_
type
(
"type_"
,
type_v
,
[
mstype
.
Type
]
)
validator
.
check_
value_type
(
"type_"
,
type_v
,
[
mstype
.
Type
],
self
.
name
)
value
=
mstype
.
issubclass_
(
sub_type_t
,
type_v
)
value
=
mstype
.
issubclass_
(
sub_type_t
,
type_v
)
...
@@ -273,8 +269,8 @@ class IsInstance(PrimitiveWithInfer):
...
@@ -273,8 +269,8 @@ class IsInstance(PrimitiveWithInfer):
sub_type_t
=
inst
[
'dtype'
]
sub_type_t
=
inst
[
'dtype'
]
type_v
=
type_
[
'value'
]
type_v
=
type_
[
'value'
]
validator
.
check_const_input
(
"inst"
,
inst
[
'value'
])
validator
.
check_const_input
(
"inst"
,
inst
[
'value'
]
,
self
.
name
)
validator
.
check_
type
(
"type_"
,
type_v
,
[
mstype
.
Type
]
)
validator
.
check_
value_type
(
"type_"
,
type_v
,
[
mstype
.
Type
],
self
.
name
)
value
=
mstype
.
issubclass_
(
sub_type_t
,
type_v
)
value
=
mstype
.
issubclass_
(
sub_type_t
,
type_v
)
...
@@ -316,14 +312,13 @@ class Reshape(PrimitiveWithInfer):
...
@@ -316,14 +312,13 @@ class Reshape(PrimitiveWithInfer):
def
__infer__
(
self
,
x
,
shape
):
def
__infer__
(
self
,
x
,
shape
):
shape_v
=
shape
[
'value'
]
shape_v
=
shape
[
'value'
]
x_shp
=
x
[
'shape'
]
x_shp
=
x
[
'shape'
]
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_const_input
(
"shape"
,
shape_v
)
validator
.
check_value_type
(
"shape"
,
shape_v
,
[
tuple
],
self
.
name
)
validator
.
check_type
(
"shape"
,
shape_v
,
[
tuple
])
shape_v
=
list
(
shape_v
)
shape_v
=
list
(
shape_v
)
neg_index
=
-
1
neg_index
=
-
1
dim_prod
=
1
dim_prod
=
1
for
i
,
shp_i
in
enumerate
(
shape_v
):
for
i
,
shp_i
in
enumerate
(
shape_v
):
validator
.
check_
type
(
"shape[%d]"
%
i
,
shp_i
,
[
int
]
)
validator
.
check_
value_type
(
"shape[%d]"
%
i
,
shp_i
,
[
int
],
self
.
name
)
if
shp_i
==
-
1
:
if
shp_i
==
-
1
:
if
neg_index
!=
-
1
:
if
neg_index
!=
-
1
:
raise
ValueError
(
f
'The shape can only has one -1 at most, but
{
shape_v
}
.'
)
raise
ValueError
(
f
'The shape can only has one -1 at most, but
{
shape_v
}
.'
)
...
@@ -332,7 +327,7 @@ class Reshape(PrimitiveWithInfer):
...
@@ -332,7 +327,7 @@ class Reshape(PrimitiveWithInfer):
dim_prod
*=
shp_i
dim_prod
*=
shp_i
arr_prod
=
np
.
prod
(
x_shp
)
arr_prod
=
np
.
prod
(
x_shp
)
if
dim_prod
<=
0
or
arr_prod
%
dim_prod
!=
0
:
if
dim_prod
<=
0
or
arr_prod
%
dim_prod
!=
0
:
raise
ValueError
(
f
'
T
he product of shape should > 0 and'
raise
ValueError
(
f
'
For
\'
{
self
.
name
}
\'
t
he product of shape should > 0 and'
f
' can be divided by prod of input
{
arr_prod
}
,'
f
' can be divided by prod of input
{
arr_prod
}
,'
f
' but shape
{
shape
}
, product of shape
{
dim_prod
}
.'
)
f
' but shape
{
shape
}
, product of shape
{
dim_prod
}
.'
)
...
@@ -340,7 +335,7 @@ class Reshape(PrimitiveWithInfer):
...
@@ -340,7 +335,7 @@ class Reshape(PrimitiveWithInfer):
shape_v
[
neg_index
]
=
int
(
arr_prod
/
dim_prod
)
shape_v
[
neg_index
]
=
int
(
arr_prod
/
dim_prod
)
dim_prod
*=
shape_v
[
neg_index
]
dim_prod
*=
shape_v
[
neg_index
]
if
dim_prod
!=
arr_prod
:
if
dim_prod
!=
arr_prod
:
raise
ValueError
(
f
'The shape arg for reshape must match array''s size'
raise
ValueError
(
f
'
For
\'
{
self
.
name
}
\'
The shape arg for reshape must match array''s size'
f
' input shape
{
arr_prod
}
, shape
{
dim_prod
}
.'
)
f
' input shape
{
arr_prod
}
, shape
{
dim_prod
}
.'
)
value
=
None
value
=
None
...
@@ -406,10 +401,10 @@ class Squeeze(PrimitiveWithInfer):
...
@@ -406,10 +401,10 @@ class Squeeze(PrimitiveWithInfer):
def
__init__
(
self
,
axis
=
()):
def
__init__
(
self
,
axis
=
()):
"""init Squeeze"""
"""init Squeeze"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
])
validator
.
check_
type
(
'axis'
,
axis
,
[
int
,
tuple
]
)
validator
.
check_
value_type
(
'axis'
,
axis
,
[
int
,
tuple
],
self
.
name
)
if
isinstance
(
axis
,
tuple
):
if
isinstance
(
axis
,
tuple
):
for
i
tem
in
axis
:
for
i
dx
,
item
in
enumerate
(
axis
)
:
validator
.
check_
type
(
"item"
,
item
,
[
int
]
)
validator
.
check_
value_type
(
"axis[%d]"
%
idx
,
item
,
[
int
],
self
.
name
)
else
:
else
:
self
.
axis
=
(
axis
,)
self
.
axis
=
(
axis
,)
self
.
add_prim_attr
(
"axis"
,
(
axis
,))
self
.
add_prim_attr
(
"axis"
,
(
axis
,))
...
@@ -422,14 +417,14 @@ class Squeeze(PrimitiveWithInfer):
...
@@ -422,14 +417,14 @@ class Squeeze(PrimitiveWithInfer):
ret
=
[
d
for
d
in
x_shape
if
d
!=
1
]
ret
=
[
d
for
d
in
x_shape
if
d
!=
1
]
else
:
else
:
for
a
in
axis
:
for
a
in
axis
:
validator
.
check_int_range
(
'axis or its elements'
,
a
,
-
ndim
,
ndim
-
1
,
Rel
.
INC_BOTH
)
validator
.
check_int_range
(
'axis or its elements'
,
a
,
-
ndim
,
ndim
-
1
,
Rel
.
INC_BOTH
,
self
.
name
)
if
x_shape
[
a
]
!=
1
:
if
x_shape
[
a
]
!=
1
:
raise
ValueError
(
'Cannot select an axis to squeeze out which has size not equal to one.'
)
raise
ValueError
(
'Cannot select an axis to squeeze out which has size not equal to one.'
)
ret
=
[
x_shape
[
i
]
for
i
in
range
(
ndim
)
if
not
(
i
in
axis
or
(
i
-
ndim
)
in
axis
)]
ret
=
[
x_shape
[
i
]
for
i
in
range
(
ndim
)
if
not
(
i
in
axis
or
(
i
-
ndim
)
in
axis
)]
return
ret
return
ret
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"x"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"x"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
return
x_dtype
return
x_dtype
...
@@ -467,14 +462,13 @@ class Transpose(PrimitiveWithInfer):
...
@@ -467,14 +462,13 @@ class Transpose(PrimitiveWithInfer):
if
len
(
x_shape
)
!=
len
(
p_value
):
if
len
(
x_shape
)
!=
len
(
p_value
):
raise
ValueError
(
'The dimension of x and perm must be equal.'
)
raise
ValueError
(
'The dimension of x and perm must be equal.'
)
validator
.
check_const_input
(
"perm"
,
p_value
)
validator
.
check_value_type
(
"p_value"
,
p_value
,
[
tuple
],
self
.
name
)
validator
.
check_type
(
"p_value"
,
p_value
,
[
tuple
])
validator
.
check_subclass
(
"x_type"
,
x_type
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"x_type"
,
x_type
,
mstype
.
tensor
)
tmp
=
list
(
p_value
)
tmp
=
list
(
p_value
)
for
i
,
dim
in
enumerate
(
p_value
):
for
i
,
dim
in
enumerate
(
p_value
):
validator
.
check_integer
(
"perm[%d]"
%
i
,
dim
,
0
,
Rel
.
GE
)
validator
.
check_integer
(
"perm[%d]"
%
i
,
dim
,
0
,
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
"perm[%d]"
%
i
,
dim
,
len
(
p_value
),
Rel
.
LT
)
validator
.
check_integer
(
"perm[%d]"
%
i
,
dim
,
len
(
p_value
),
Rel
.
LT
,
self
.
name
)
tmp
.
remove
(
dim
)
tmp
.
remove
(
dim
)
if
dim
in
tmp
:
if
dim
in
tmp
:
raise
ValueError
(
'The value of perm is wrong.'
)
raise
ValueError
(
'The value of perm is wrong.'
)
...
@@ -517,15 +511,13 @@ class GatherV2(PrimitiveWithInfer):
...
@@ -517,15 +511,13 @@ class GatherV2(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'params'
,
'indices'
,
'axis'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'params'
,
'indices'
,
'axis'
],
outputs
=
[
'output'
])
def
__infer__
(
self
,
params
,
indices
,
axis
):
def
__infer__
(
self
,
params
,
indices
,
axis
):
validator
.
check_subclass
(
"params"
,
params
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"params"
,
params
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"indices"
,
indices
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_tensor_type_same
({
"indices"
:
indices
[
'dtype'
]},
mstype
.
int_type
,
self
.
name
)
validator
.
check_subclass
(
"axis"
,
axis
[
'dtype'
],
mstype
.
int_
)
validator
.
check_subclass
(
"axis"
,
axis
[
'dtype'
],
mstype
.
int_
,
self
.
name
)
validator
.
check_typename
(
"element of indices"
,
indices
[
'dtype'
],
mstype
.
int_type
)
validator
.
check_const_input
(
"axis"
,
axis
[
'value'
])
axis_v
=
axis
[
'value'
]
axis_v
=
axis
[
'value'
]
params_shp
=
params
[
'shape'
]
params_shp
=
params
[
'shape'
]
rank
=
len
(
params_shp
)
rank
=
len
(
params_shp
)
validator
.
check_int_range
(
"axis"
,
axis_v
,
-
rank
,
rank
,
Rel
.
INC_LEFT
)
validator
.
check_int_range
(
"axis"
,
axis_v
,
-
rank
,
rank
,
Rel
.
INC_LEFT
,
self
.
name
)
if
axis_v
<
0
:
if
axis_v
<
0
:
axis_v
+=
rank
axis_v
+=
rank
out_shape
=
params_shp
[:
axis_v
]
+
indices
[
'shape'
]
+
params_shp
[
axis_v
+
1
:]
out_shape
=
params_shp
[:
axis_v
]
+
indices
[
'shape'
]
+
params_shp
[
axis_v
+
1
:]
...
@@ -564,19 +556,20 @@ class Split(PrimitiveWithInfer):
...
@@ -564,19 +556,20 @@ class Split(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
axis
=
0
,
output_num
=
1
):
def
__init__
(
self
,
axis
=
0
,
output_num
=
1
):
"""init Split"""
"""init Split"""
validator
.
check_
type
(
"axis"
,
axis
,
[
int
]
)
validator
.
check_
value_type
(
"axis"
,
axis
,
[
int
],
self
.
name
)
validator
.
check_
type
(
"output_num"
,
output_num
,
[
int
]
)
validator
.
check_
value_type
(
"output_num"
,
output_num
,
[
int
],
self
.
name
)
self
.
axis
=
axis
self
.
axis
=
axis
self
.
output_num
=
output_num
self
.
output_num
=
output_num
def
__infer__
(
self
,
x
):
def
__infer__
(
self
,
x
):
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
x_shape
=
list
(
x
[
'shape'
])
x_shape
=
list
(
x
[
'shape'
])
dim
=
len
(
x_shape
)
dim
=
len
(
x_shape
)
validator
.
check_int_range
(
'axis value'
,
self
.
axis
,
-
dim
,
dim
,
Rel
.
INC_LEFT
)
validator
.
check_int_range
(
'axis value'
,
self
.
axis
,
-
dim
,
dim
,
Rel
.
INC_LEFT
,
self
.
name
)
validator
.
check_integer
(
"output_num"
,
self
.
output_num
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"output_num"
,
self
.
output_num
,
0
,
Rel
.
GT
,
self
.
name
)
output_valid_check
=
x_shape
[
self
.
axis
]
%
self
.
output_num
output_valid_check
=
x_shape
[
self
.
axis
]
%
self
.
output_num
validator
.
check_integer
(
"the dimension which to split divides output_num"
,
output_valid_check
,
0
,
Rel
.
EQ
)
validator
.
check_integer
(
"the dimension which to split divides output_num"
,
output_valid_check
,
0
,
Rel
.
EQ
,
self
.
name
)
x_shape
[
self
.
axis
]
=
int
(
x_shape
[
self
.
axis
]
/
self
.
output_num
)
x_shape
[
self
.
axis
]
=
int
(
x_shape
[
self
.
axis
]
/
self
.
output_num
)
out_shapes
=
[]
out_shapes
=
[]
out_dtypes
=
[]
out_dtypes
=
[]
...
@@ -615,7 +608,7 @@ class Rank(PrimitiveWithInfer):
...
@@ -615,7 +608,7 @@ class Rank(PrimitiveWithInfer):
"""init Rank"""
"""init Rank"""
def
__infer__
(
self
,
x
):
def
__infer__
(
self
,
x
):
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
out
=
{
'shape'
:
None
,
out
=
{
'shape'
:
None
,
'dtype'
:
None
,
'dtype'
:
None
,
'value'
:
len
(
x
[
'shape'
])}
'value'
:
len
(
x
[
'shape'
])}
...
@@ -647,15 +640,14 @@ class TruncatedNormal(PrimitiveWithInfer):
...
@@ -647,15 +640,14 @@ class TruncatedNormal(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
seed
=
0
,
dtype
=
mstype
.
float32
):
def
__init__
(
self
,
seed
=
0
,
dtype
=
mstype
.
float32
):
"""init TruncatedNormal"""
"""init TruncatedNormal"""
validator
.
check_
type
(
'seed'
,
seed
,
[
int
]
)
validator
.
check_
value_type
(
'seed'
,
seed
,
[
int
],
self
.
name
)
validator
.
check_type
name
(
'dtype'
,
dtype
,
mstype
.
number_typ
e
)
validator
.
check_type
_same
({
'dtype'
:
dtype
},
mstype
.
number_type
,
self
.
nam
e
)
def
__infer__
(
self
,
shape
):
def
__infer__
(
self
,
shape
):
shape_value
=
shape
[
'value'
]
shape_value
=
shape
[
'value'
]
validator
.
check_const_input
(
"shape"
,
shape_value
)
validator
.
check_value_type
(
"shape"
,
shape_value
,
[
tuple
],
self
.
name
)
validator
.
check_type
(
"shape"
,
shape_value
,
[
tuple
])
for
i
,
value
in
enumerate
(
shape_value
):
for
i
,
value
in
enumerate
(
shape_value
):
validator
.
check_integer
(
f
'
{
i
}
th value of shape'
,
value
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
f
'
{
i
}
th value of shape'
,
value
,
0
,
Rel
.
GT
,
self
.
name
)
out
=
{
'shape'
:
shape_value
,
out
=
{
'shape'
:
shape_value
,
'dtype'
:
mstype
.
tensor_type
(
self
.
dtype
),
'dtype'
:
mstype
.
tensor_type
(
self
.
dtype
),
'value'
:
None
}
'value'
:
None
}
...
@@ -687,7 +679,7 @@ class Size(PrimitiveWithInfer):
...
@@ -687,7 +679,7 @@ class Size(PrimitiveWithInfer):
def
__infer__
(
self
,
x
):
def
__infer__
(
self
,
x
):
size
=
1
size
=
1
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
shp
=
x
[
'shape'
]
shp
=
x
[
'shape'
]
if
not
shp
:
if
not
shp
:
size
=
0
size
=
0
...
@@ -723,25 +715,20 @@ class Fill(PrimitiveWithInfer):
...
@@ -723,25 +715,20 @@ class Fill(PrimitiveWithInfer):
"""init Fill"""
"""init Fill"""
def
__infer__
(
self
,
dtype
,
dims
,
x
):
def
__infer__
(
self
,
dtype
,
dims
,
x
):
validator
.
check_const_input
(
"type"
,
dtype
[
'value'
])
validator
.
check_value_type
(
"shape"
,
dims
[
'value'
],
[
tuple
],
self
.
name
)
validator
.
check_const_input
(
"shape"
,
dims
[
'value'
])
validator
.
check_value_type
(
"value"
,
x
[
'value'
],
[
numbers
.
Number
,
bool
],
self
.
name
)
validator
.
check_const_input
(
"value"
,
x
[
'value'
])
for
idx
,
item
in
enumerate
(
dims
[
'value'
]):
validator
.
check_type
(
"shape"
,
dims
[
'value'
],
[
tuple
])
validator
.
check_integer
(
"dims[%d]"
%
idx
,
item
,
0
,
Rel
.
GT
,
self
.
name
)
validator
.
check_type
(
"value"
,
x
[
'value'
],
[
numbers
.
Number
,
bool
])
for
item
in
dims
[
'value'
]:
validator
.
check_type
(
"item"
,
item
,
[
int
])
validator
.
check_integer
(
"item"
,
item
,
0
,
Rel
.
GT
)
x_dtype
=
dtype
[
'value'
]
valid_types
=
[
mstype
.
bool_
,
mstype
.
int8
,
mstype
.
int32
,
mstype
.
int64
,
valid_types
=
[
mstype
.
bool_
,
mstype
.
int8
,
mstype
.
int32
,
mstype
.
int64
,
mstype
.
uint8
,
mstype
.
uint32
,
mstype
.
uint64
,
mstype
.
uint8
,
mstype
.
uint32
,
mstype
.
uint64
,
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
]
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
]
validator
.
check_type
name
(
"value"
,
x_dtype
,
valid_types
)
validator
.
check_type
_same
({
"value"
:
dtype
[
'value'
]},
valid_types
,
self
.
name
)
x_nptype
=
mstype
.
dtype_to_nptype
(
x_dtype
)
x_nptype
=
mstype
.
dtype_to_nptype
(
dtype
[
'value'
]
)
ret
=
np
.
full
(
dims
[
'value'
],
x
[
'value'
],
x_nptype
)
ret
=
np
.
full
(
dims
[
'value'
],
x
[
'value'
],
x_nptype
)
out
=
{
out
=
{
'value'
:
Tensor
(
ret
),
'value'
:
Tensor
(
ret
),
'shape'
:
dims
[
'value'
],
'shape'
:
dims
[
'value'
],
'dtype'
:
x
_dtype
,
'dtype'
:
x
[
'dtype'
]
,
}
}
return
out
return
out
...
@@ -772,8 +759,7 @@ class OnesLike(PrimitiveWithInfer):
...
@@ -772,8 +759,7 @@ class OnesLike(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"x"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
name
)
validator
.
check_typename
(
'x_dtype'
,
x_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,))
return
x_dtype
return
x_dtype
...
@@ -804,8 +790,7 @@ class ZerosLike(PrimitiveWithInfer):
...
@@ -804,8 +790,7 @@ class ZerosLike(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"x"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
name
)
validator
.
check_typename
(
'x_dtype'
,
x_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,))
return
x_dtype
return
x_dtype
...
@@ -830,14 +815,13 @@ class TupleToArray(PrimitiveWithInfer):
...
@@ -830,14 +815,13 @@ class TupleToArray(PrimitiveWithInfer):
"""init TupleToArray"""
"""init TupleToArray"""
def
infer_value
(
self
,
x
):
def
infer_value
(
self
,
x
):
validator
.
check_const_input
(
"x"
,
x
)
validator
.
check_value_type
(
"x"
,
x
,
[
tuple
],
self
.
name
)
validator
.
check_type
(
"x"
,
x
,
[
tuple
])
validator
.
check
(
"size of x"
,
len
(
x
),
''
,
0
,
Rel
.
GT
,
self
.
name
)
validator
.
check
(
"size of x"
,
len
(
x
),
''
,
0
,
Rel
.
GT
)
dtype
=
type
(
x
[
0
])
dtype
=
type
(
x
[
0
])
for
i
,
item
in
enumerate
(
x
):
for
i
,
item
in
enumerate
(
x
):
validator
.
check_
type
(
f
"x[
{
i
}
]"
,
item
,
[
numbers
.
Number
]
)
validator
.
check_
value_type
(
f
"x[
{
i
}
]"
,
item
,
[
numbers
.
Number
],
self
.
name
)
if
not
all
(
isinstance
(
item
,
dtype
)
for
item
in
x
):
if
not
all
(
isinstance
(
item
,
dtype
)
for
item
in
x
):
raise
TypeError
(
"
A
ll elements of input x must be have same type."
)
raise
TypeError
(
"
For
\'
{self.name}
\'
a
ll elements of input x must be have same type."
)
if
isinstance
(
x
[
0
],
int
):
if
isinstance
(
x
[
0
],
int
):
ret
=
np
.
array
(
x
,
np
.
int32
)
ret
=
np
.
array
(
x
,
np
.
int32
)
else
:
else
:
...
@@ -867,8 +851,7 @@ class ScalarToArray(PrimitiveWithInfer):
...
@@ -867,8 +851,7 @@ class ScalarToArray(PrimitiveWithInfer):
pass
pass
def
infer_value
(
self
,
x
):
def
infer_value
(
self
,
x
):
validator
.
check_const_input
(
"x"
,
x
)
validator
.
check_value_type
(
"x"
,
x
,
[
int
,
float
],
self
.
name
)
validator
.
check_type
(
"x"
,
x
,
[
int
,
float
])
if
isinstance
(
x
,
int
):
if
isinstance
(
x
,
int
):
ret
=
np
.
array
(
x
,
np
.
int32
)
ret
=
np
.
array
(
x
,
np
.
int32
)
else
:
else
:
...
@@ -899,9 +882,8 @@ class ScalarToTensor(PrimitiveWithInfer):
...
@@ -899,9 +882,8 @@ class ScalarToTensor(PrimitiveWithInfer):
pass
pass
def
infer_value
(
self
,
x
,
dtype
=
mstype
.
float32
):
def
infer_value
(
self
,
x
,
dtype
=
mstype
.
float32
):
validator
.
check_const_input
(
"x"
,
x
)
validator
.
check_value_type
(
"x"
,
x
,
[
int
,
float
],
self
.
name
)
validator
.
check_type
(
"x"
,
x
,
[
int
,
float
])
validator
.
check_subclass
(
"dtype"
,
dtype
,
mstype
.
number
,
self
.
name
)
validator
.
check_subclass
(
"dtype"
,
dtype
,
mstype
.
number
,
with_type_of
=
False
)
data_type
=
mstype
.
dtype_to_nptype
(
dtype
)
data_type
=
mstype
.
dtype_to_nptype
(
dtype
)
return
Tensor
(
np
.
array
(
x
,
data_type
))
return
Tensor
(
np
.
array
(
x
,
data_type
))
...
@@ -943,15 +925,14 @@ class InvertPermutation(PrimitiveWithInfer):
...
@@ -943,15 +925,14 @@ class InvertPermutation(PrimitiveWithInfer):
def
__infer__
(
self
,
x
):
def
__infer__
(
self
,
x
):
x_shp
=
x
[
'shape'
]
x_shp
=
x
[
'shape'
]
x_value
=
x
[
'value'
]
x_value
=
x
[
'value'
]
validator
.
check_const_input
(
"shape"
,
x_shp
)
validator
.
check_value_type
(
"shape"
,
x_shp
,
[
tuple
],
self
.
name
)
validator
.
check_type
(
"shape"
,
x_shp
,
[
tuple
])
z
=
[
x_value
[
i
]
for
i
in
range
(
len
(
x_value
))]
z
=
[
x_value
[
i
]
for
i
in
range
(
len
(
x_value
))]
z
.
sort
()
z
.
sort
()
y
=
[
None
]
*
len
(
x_value
)
y
=
[
None
]
*
len
(
x_value
)
for
i
,
value
in
enumerate
(
x_value
):
for
i
,
value
in
enumerate
(
x_value
):
validator
.
check_
type
(
"input[%d]"
%
i
,
value
,
[
int
]
)
validator
.
check_
value_type
(
"input[%d]"
%
i
,
value
,
[
int
],
self
.
name
)
validator
.
check
(
f
'value'
,
z
[
i
],
f
'index'
,
i
)
validator
.
check
(
f
'value'
,
z
[
i
],
f
'index'
,
i
,
Rel
.
EQ
,
self
.
name
)
y
[
value
]
=
i
y
[
value
]
=
i
z
.
append
(
value
)
z
.
append
(
value
)
return
{
'shape'
:
x_shp
,
return
{
'shape'
:
x_shp
,
...
@@ -986,8 +967,8 @@ class Argmax(PrimitiveWithInfer):
...
@@ -986,8 +967,8 @@ class Argmax(PrimitiveWithInfer):
def
__init__
(
self
,
axis
=-
1
,
output_type
=
mstype
.
int64
):
def
__init__
(
self
,
axis
=-
1
,
output_type
=
mstype
.
int64
):
"""init Argmax"""
"""init Argmax"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
])
validator
.
check_
type
(
"axis"
,
axis
,
[
int
]
)
validator
.
check_
value_type
(
"axis"
,
axis
,
[
int
],
self
.
name
)
validator
.
check_type
name
(
'output_type'
,
output_type
,
[
mstype
.
int32
,
mstype
.
int64
]
)
validator
.
check_type
_same
({
'output'
:
output_type
},
[
mstype
.
int32
,
mstype
.
int64
],
self
.
name
)
self
.
axis
=
axis
self
.
axis
=
axis
self
.
add_prim_attr
(
'output_type'
,
output_type
)
self
.
add_prim_attr
(
'output_type'
,
output_type
)
...
@@ -996,14 +977,13 @@ class Argmax(PrimitiveWithInfer):
...
@@ -996,14 +977,13 @@ class Argmax(PrimitiveWithInfer):
if
axis
is
None
:
if
axis
is
None
:
axis
=
0
axis
=
0
x_rank
=
len
(
x_shape
)
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
,
self
.
name
)
axis
=
axis
+
x_rank
if
axis
<
0
else
axis
axis
=
axis
+
x_rank
if
axis
<
0
else
axis
ouput_shape
=
[
x_shape
[
i
]
for
i
in
range
(
x_rank
)
if
i
!=
axis
]
ouput_shape
=
[
x_shape
[
i
]
for
i
in
range
(
x_rank
)
if
i
!=
axis
]
return
ouput_shape
return
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_typename
(
'input_x'
,
x_dtype
,
[
mstype
.
float32
,
mstype
.
float16
])
return
mstype
.
tensor_type
(
self
.
output_type
)
return
mstype
.
tensor_type
(
self
.
output_type
)
...
@@ -1035,7 +1015,7 @@ class Argmin(PrimitiveWithInfer):
...
@@ -1035,7 +1015,7 @@ class Argmin(PrimitiveWithInfer):
def
__init__
(
self
,
axis
=-
1
,
output_type
=
mstype
.
int64
):
def
__init__
(
self
,
axis
=-
1
,
output_type
=
mstype
.
int64
):
"""init Argmin"""
"""init Argmin"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
])
validator
.
check_
type
(
"axis"
,
axis
,
[
int
]
)
validator
.
check_
value_type
(
"axis"
,
axis
,
[
int
],
self
.
name
)
self
.
axis
=
axis
self
.
axis
=
axis
self
.
add_prim_attr
(
'output_type'
,
output_type
)
self
.
add_prim_attr
(
'output_type'
,
output_type
)
...
@@ -1044,13 +1024,13 @@ class Argmin(PrimitiveWithInfer):
...
@@ -1044,13 +1024,13 @@ class Argmin(PrimitiveWithInfer):
if
axis
is
None
:
if
axis
is
None
:
axis
=
0
axis
=
0
x_rank
=
len
(
x_shape
)
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
,
self
.
name
)
axis
=
axis
+
x_rank
if
axis
<
0
else
axis
axis
=
axis
+
x_rank
if
axis
<
0
else
axis
ouput_shape
=
[
x_shape
[
i
]
for
i
in
range
(
x_rank
)
if
i
!=
axis
]
ouput_shape
=
[
x_shape
[
i
]
for
i
in
range
(
x_rank
)
if
i
!=
axis
]
return
ouput_shape
return
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
return
mstype
.
tensor_type
(
self
.
output_type
)
return
mstype
.
tensor_type
(
self
.
output_type
)
...
@@ -1087,17 +1067,17 @@ class ArgMaxWithValue(PrimitiveWithInfer):
...
@@ -1087,17 +1067,17 @@ class ArgMaxWithValue(PrimitiveWithInfer):
"""init ArgMaxWithValue"""
"""init ArgMaxWithValue"""
self
.
axis
=
axis
self
.
axis
=
axis
self
.
keep_dims
=
keep_dims
self
.
keep_dims
=
keep_dims
_check_infer_attr_reduce
(
axis
,
keep_dims
)
_check_infer_attr_reduce
(
axis
,
keep_dims
,
self
.
name
)
def
infer_shape
(
self
,
x_shape
):
def
infer_shape
(
self
,
x_shape
):
axis
=
self
.
axis
axis
=
self
.
axis
x_rank
=
len
(
x_shape
)
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
,
self
.
name
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
name
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
name
)
return
ouput_shape
,
ouput_shape
return
ouput_shape
,
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
return
mstype
.
tensor_type
(
mstype
.
int32
),
x_dtype
return
mstype
.
tensor_type
(
mstype
.
int32
),
x_dtype
...
@@ -1133,17 +1113,17 @@ class ArgMinWithValue(PrimitiveWithInfer):
...
@@ -1133,17 +1113,17 @@ class ArgMinWithValue(PrimitiveWithInfer):
"""init ArgMinWithValue"""
"""init ArgMinWithValue"""
self
.
axis
=
axis
self
.
axis
=
axis
self
.
keep_dims
=
keep_dims
self
.
keep_dims
=
keep_dims
_check_infer_attr_reduce
(
axis
,
keep_dims
)
_check_infer_attr_reduce
(
axis
,
keep_dims
,
self
.
name
)
def
infer_shape
(
self
,
x_shape
):
def
infer_shape
(
self
,
x_shape
):
axis
=
self
.
axis
axis
=
self
.
axis
x_rank
=
len
(
x_shape
)
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
,
self
.
name
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
name
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
name
)
return
ouput_shape
,
ouput_shape
return
ouput_shape
,
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
return
mstype
.
tensor_type
(
mstype
.
int32
),
x_dtype
return
mstype
.
tensor_type
(
mstype
.
int32
),
x_dtype
...
@@ -1183,13 +1163,11 @@ class Tile(PrimitiveWithInfer):
...
@@ -1183,13 +1163,11 @@ class Tile(PrimitiveWithInfer):
def
__infer__
(
self
,
x
,
multiples
):
def
__infer__
(
self
,
x
,
multiples
):
multiples_v
=
multiples
[
'value'
]
multiples_v
=
multiples
[
'value'
]
x_shp
=
x
[
'shape'
]
x_shp
=
x
[
'shape'
]
validator
.
check_const_input
(
"shape"
,
multiples_v
)
validator
.
check_value_type
(
"shape"
,
multiples_v
,
[
tuple
],
self
.
name
)
validator
.
check_type
(
"shape"
,
multiples_v
,
[
tuple
])
for
i
,
multiple
in
enumerate
(
multiples_v
):
for
i
,
multiple
in
enumerate
(
multiples_v
):
validator
.
check_type
(
"multiples[%d]"
%
i
,
multiple
,
[
int
])
validator
.
check_value_type
(
"multiples[%d]"
%
i
,
multiple
,
[
int
],
self
.
name
)
validator
.
check_typename
(
'x'
,
x
[
'dtype'
],
valid_types
=
[
mstype
.
int16
,
mstype
.
int32
,
mstype
.
bool_
,
mstype
.
float16
,
mstype
.
float32
]
[
mstype
.
int16
,
mstype
.
int32
,
mstype
.
bool_
,
validator
.
check_tensor_type_same
({
'x'
:
x
[
'dtype'
]},
valid_types
,
self
.
name
)
mstype
.
float16
,
mstype
.
float32
])
len_sub
=
len
(
multiples_v
)
-
len
(
x_shp
)
len_sub
=
len
(
multiples_v
)
-
len
(
x_shp
)
multiples_w
=
None
multiples_w
=
None
if
len_sub
==
0
:
if
len_sub
==
0
:
...
@@ -1199,7 +1177,8 @@ class Tile(PrimitiveWithInfer):
...
@@ -1199,7 +1177,8 @@ class Tile(PrimitiveWithInfer):
x_shp
.
insert
(
0
,
1
)
x_shp
.
insert
(
0
,
1
)
multiples_w
=
multiples_v
multiples_w
=
multiples_v
elif
len_sub
<
0
:
elif
len_sub
<
0
:
raise
ValueError
(
"The length of multiples can not be smaller than the length of dimension in input_x."
)
raise
ValueError
(
f
'For
\'
{
self
.
name
}
\'
the length of multiples can not be smaller than '
f
'the length of dimension in input_x.'
)
for
i
,
a
in
enumerate
(
multiples_w
):
for
i
,
a
in
enumerate
(
multiples_w
):
x_shp
[
i
]
*=
a
x_shp
[
i
]
*=
a
value
=
None
value
=
None
...
@@ -1246,23 +1225,23 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
...
@@ -1246,23 +1225,23 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
def
__infer__
(
self
,
x
,
segment_ids
,
num_segments
):
def
__infer__
(
self
,
x
,
segment_ids
,
num_segments
):
x_type
=
x
[
'dtype'
]
x_type
=
x
[
'dtype'
]
x_shp
=
x
[
'shape'
]
x_shp
=
x
[
'shape'
]
validator
.
check_subclass
(
"input_x"
,
x_type
,
mstype
.
tensor
)
validator
.
check_subclass
(
"input_x"
,
x_type
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_
type
(
"x_shape"
,
x_shp
,
[
list
]
)
validator
.
check_
value_type
(
"x_shape"
,
x_shp
,
[
list
],
self
.
name
)
x_shp_len
=
len
(
x_shp
)
x_shp_len
=
len
(
x_shp
)
validator
.
check_integer
(
"rank of input_x"
,
x_shp_len
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"rank of input_x"
,
x_shp_len
,
0
,
Rel
.
GT
,
self
.
name
)
segment_ids_shp
=
segment_ids
[
'shape'
]
segment_ids_shp
=
segment_ids
[
'shape'
]
segment_ids_type
=
segment_ids
[
'dtype'
]
segment_ids_type
=
segment_ids
[
'dtype'
]
validator
.
check_subclass
(
"segment_ids"
,
segment_ids_type
,
mstype
.
tensor
)
validator
.
check_subclass
(
"segment_ids"
,
segment_ids_type
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_
type
(
"segment_ids"
,
segment_ids_shp
,
[
list
]
)
validator
.
check_
value_type
(
"segment_ids"
,
segment_ids_shp
,
[
list
],
self
.
name
)
segment_ids_shp_len
=
len
(
segment_ids_shp
)
segment_ids_shp_len
=
len
(
segment_ids_shp
)
validator
.
check_integer
(
"rank of segment_ids"
,
segment_ids_shp_len
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"rank of segment_ids"
,
segment_ids_shp_len
,
0
,
Rel
.
GT
,
self
.
name
)
validator
.
check
(
f
'rank of input_x'
,
len
(
x_shp
),
validator
.
check
(
f
'rank of input_x'
,
len
(
x_shp
),
'rank of segments_id'
,
len
(
segment_ids_shp
),
Rel
.
GE
)
'rank of segments_id'
,
len
(
segment_ids_shp
),
Rel
.
GE
,
self
.
name
)
for
i
,
value
in
enumerate
(
segment_ids_shp
):
for
i
,
value
in
enumerate
(
segment_ids_shp
):
validator
.
check
(
"ids[%d]"
%
i
,
value
,
'input[%d]'
%
i
,
x_shp
[
i
])
validator
.
check
(
"ids[%d]"
%
i
,
value
,
'input[%d]'
%
i
,
x_shp
[
i
]
,
Rel
.
EQ
,
self
.
name
)
num_segments_v
=
num_segments
[
'value'
]
num_segments_v
=
num_segments
[
'value'
]
validator
.
check_
type
(
'num_segments'
,
num_segments_v
,
[
int
]
)
validator
.
check_
value_type
(
'num_segments'
,
num_segments_v
,
[
int
],
self
.
name
)
validator
.
check_integer
(
"num_segments"
,
num_segments_v
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"num_segments"
,
num_segments_v
,
0
,
Rel
.
GT
,
self
.
name
)
shp
=
[
num_segments_v
]
shp
=
[
num_segments_v
]
shp
+=
x_shp
[
segment_ids_shp_len
:]
shp
+=
x_shp
[
segment_ids_shp_len
:]
out
=
{
'shape'
:
shp
,
out
=
{
'shape'
:
shp
,
...
@@ -1306,7 +1285,7 @@ class Concat(PrimitiveWithInfer):
...
@@ -1306,7 +1285,7 @@ class Concat(PrimitiveWithInfer):
def
__init__
(
self
,
axis
=
0
):
def
__init__
(
self
,
axis
=
0
):
"""init Tile"""
"""init Tile"""
self
.
__setattr_flag__
=
True
self
.
__setattr_flag__
=
True
validator
.
check_
type
(
"axis"
,
axis
,
[
int
]
)
validator
.
check_
value_type
(
"axis"
,
axis
,
[
int
],
self
.
name
)
def
__infer__
(
self
,
input_x
):
def
__infer__
(
self
,
input_x
):
axis
=
self
.
axis
axis
=
self
.
axis
...
@@ -1323,25 +1302,25 @@ class Concat(PrimitiveWithInfer):
...
@@ -1323,25 +1302,25 @@ class Concat(PrimitiveWithInfer):
return
out
return
out
def
_get_pack_shape
(
x_shape
,
x_type
,
axis
):
def
_get_pack_shape
(
x_shape
,
x_type
,
axis
,
prim_name
):
"""for pack output shape"""
"""for pack output shape"""
validator
.
check_
type
(
"shape"
,
x_shape
,
[
tuple
,
list
]
)
validator
.
check_
value_type
(
"shape"
,
x_shape
,
[
tuple
,
list
],
prim_name
)
validator
.
check_integer
(
"len of input_x shape"
,
len
(
x_shape
),
0
,
Rel
.
GT
)
validator
.
check_integer
(
"len of input_x shape"
,
len
(
x_shape
),
0
,
Rel
.
GT
,
prim_name
)
validator
.
check_subclass
(
"shape0"
,
x_type
[
0
],
mstype
.
tensor
)
validator
.
check_subclass
(
"shape0"
,
x_type
[
0
],
mstype
.
tensor
,
prim_name
)
validator
.
check_integer
(
"len of input_x0 shape"
,
len
(
x_shape
[
0
]),
0
,
Rel
.
GT
)
validator
.
check_integer
(
"len of input_x0 shape"
,
len
(
x_shape
[
0
]),
0
,
Rel
.
GT
,
prim_name
)
rank_base
=
len
(
x_shape
[
0
])
rank_base
=
len
(
x_shape
[
0
])
N
=
len
(
x_shape
)
N
=
len
(
x_shape
)
out_shape
=
x_shape
[
0
]
out_shape
=
x_shape
[
0
]
validator
.
check_int_range
(
'axis'
,
axis
,
-
rank_base
-
1
,
rank_base
,
Rel
.
INC_BOTH
)
validator
.
check_int_range
(
'axis'
,
axis
,
-
rank_base
-
1
,
rank_base
,
Rel
.
INC_BOTH
,
prim_name
)
if
axis
<
0
:
if
axis
<
0
:
axis
=
axis
+
rank_base
+
1
axis
=
axis
+
rank_base
+
1
for
i
in
range
(
1
,
N
):
for
i
in
range
(
1
,
N
):
v
=
x_shape
[
i
]
v
=
x_shape
[
i
]
validator
.
check
(
'len of x_shape[%d]'
%
i
,
len
(
v
),
'len of rank_base'
,
rank_base
)
validator
.
check
(
'len of x_shape[%d]'
%
i
,
len
(
v
),
'len of rank_base'
,
rank_base
,
Rel
.
EQ
,
prim_name
)
validator
.
check
(
'x_type[%d]'
%
i
,
x_type
[
i
],
'base'
,
x_type
[
0
])
validator
.
check
(
'x_type[%d]'
%
i
,
x_type
[
i
],
'base'
,
x_type
[
0
]
,
Rel
.
EQ
,
prim_name
)
for
j
in
range
(
rank_base
):
for
j
in
range
(
rank_base
):
if
v
[
j
]
!=
x_shape
[
0
][
j
]:
if
v
[
j
]
!=
x_shape
[
0
][
j
]:
raise
ValueError
(
"Pack evaluator element %d shape in input can not pack with first element"
%
i
)
raise
ValueError
(
f
"For
\'
{
prim_name
}
\'
element
{
i
}
shape in input can not pack with first element"
)
out_shape
.
insert
(
axis
,
N
)
out_shape
.
insert
(
axis
,
N
)
return
out_shape
return
out_shape
...
@@ -1376,14 +1355,14 @@ class Pack(PrimitiveWithInfer):
...
@@ -1376,14 +1355,14 @@ class Pack(PrimitiveWithInfer):
def
__init__
(
self
,
axis
=
0
):
def
__init__
(
self
,
axis
=
0
):
"""init Pack"""
"""init Pack"""
self
.
__setattr_flag__
=
True
self
.
__setattr_flag__
=
True
validator
.
check_
type
(
"axis"
,
axis
,
[
int
]
)
validator
.
check_
value_type
(
"axis"
,
axis
,
[
int
],
self
.
name
)
self
.
axis
=
axis
self
.
axis
=
axis
def
__infer__
(
self
,
value
):
def
__infer__
(
self
,
value
):
x_shape
=
value
[
'shape'
]
x_shape
=
value
[
'shape'
]
x_type
=
value
[
'dtype'
]
x_type
=
value
[
'dtype'
]
self
.
add_prim_attr
(
'num'
,
len
(
x_shape
))
self
.
add_prim_attr
(
'num'
,
len
(
x_shape
))
all_shape
=
_get_pack_shape
(
x_shape
,
x_type
,
self
.
axis
)
all_shape
=
_get_pack_shape
(
x_shape
,
x_type
,
self
.
axis
,
self
.
name
)
out
=
{
'shape'
:
all_shape
,
out
=
{
'shape'
:
all_shape
,
'dtype'
:
x_type
[
0
],
'dtype'
:
x_type
[
0
],
'value'
:
None
}
'value'
:
None
}
...
@@ -1429,22 +1408,23 @@ class Unpack(PrimitiveWithInfer):
...
@@ -1429,22 +1408,23 @@ class Unpack(PrimitiveWithInfer):
def
__init__
(
self
,
axis
=
0
):
def
__init__
(
self
,
axis
=
0
):
"""init Unpack"""
"""init Unpack"""
self
.
__setattr_flag__
=
True
self
.
__setattr_flag__
=
True
validator
.
check_
type
(
"axis"
,
axis
,
[
int
]
)
validator
.
check_
value_type
(
"axis"
,
axis
,
[
int
],
self
.
name
)
self
.
axis
=
axis
self
.
axis
=
axis
def
__infer__
(
self
,
x
):
def
__infer__
(
self
,
x
):
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"x"
,
x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
x_shape
=
list
(
x
[
'shape'
])
x_shape
=
list
(
x
[
'shape'
])
dim
=
len
(
x_shape
)
dim
=
len
(
x_shape
)
validator
.
check_int_range
(
'axis value'
,
self
.
axis
,
-
dim
,
dim
,
Rel
.
INC_LEFT
)
validator
.
check_int_range
(
'axis value'
,
self
.
axis
,
-
dim
,
dim
,
Rel
.
INC_LEFT
,
self
.
name
)
if
self
.
axis
<
0
:
if
self
.
axis
<
0
:
self
.
axis
=
self
.
axis
+
dim
self
.
axis
=
self
.
axis
+
dim
output_num
=
x_shape
[
self
.
axis
]
output_num
=
x_shape
[
self
.
axis
]
validator
.
check_
type
(
"num"
,
output_num
,
[
int
]
)
validator
.
check_
value_type
(
"num"
,
output_num
,
[
int
],
self
.
name
)
validator
.
check_integer
(
"output_num"
,
output_num
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"output_num"
,
output_num
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
add_prim_attr
(
'num'
,
output_num
)
self
.
add_prim_attr
(
'num'
,
output_num
)
output_valid_check
=
x_shape
[
self
.
axis
]
-
output_num
output_valid_check
=
x_shape
[
self
.
axis
]
-
output_num
validator
.
check_integer
(
"The dimension which to unpack divides output_num"
,
output_valid_check
,
0
,
Rel
.
EQ
)
validator
.
check_integer
(
"The dimension which to unpack divides output_num"
,
output_valid_check
,
0
,
Rel
.
EQ
,
self
.
name
)
out_shapes
=
[]
out_shapes
=
[]
out_dtypes
=
[]
out_dtypes
=
[]
out_shape
=
x_shape
[:
self
.
axis
]
+
x_shape
[
self
.
axis
+
1
:]
out_shape
=
x_shape
[:
self
.
axis
]
+
x_shape
[
self
.
axis
+
1
:]
...
@@ -1486,8 +1466,8 @@ class Slice(PrimitiveWithInfer):
...
@@ -1486,8 +1466,8 @@ class Slice(PrimitiveWithInfer):
def
__infer__
(
self
,
x
,
begin
,
size
):
def
__infer__
(
self
,
x
,
begin
,
size
):
x_shape
=
x
[
'shape'
]
x_shape
=
x
[
'shape'
]
x_shp_len
=
len
(
x_shape
)
x_shp_len
=
len
(
x_shape
)
validator
.
check_const_input
(
'begin'
,
begin
[
'value'
])
validator
.
check_const_input
(
'begin'
,
begin
[
'value'
]
,
self
.
name
)
validator
.
check_const_input
(
'size'
,
size
[
'value'
])
validator
.
check_const_input
(
'size'
,
size
[
'value'
]
,
self
.
name
)
begin_v
,
size_v
=
begin
[
'value'
],
size
[
'value'
]
begin_v
,
size_v
=
begin
[
'value'
],
size
[
'value'
]
if
begin_v
is
None
or
size_v
is
None
:
if
begin_v
is
None
or
size_v
is
None
:
return
{
'shape'
:
None
,
return
{
'shape'
:
None
,
...
@@ -1499,7 +1479,8 @@ class Slice(PrimitiveWithInfer):
...
@@ -1499,7 +1479,8 @@ class Slice(PrimitiveWithInfer):
for
i
in
range
(
x_shp_len
):
for
i
in
range
(
x_shp_len
):
if
x_shape
[
i
]
<
begin_v
[
i
]
+
size_v
[
i
]:
if
x_shape
[
i
]
<
begin_v
[
i
]
+
size_v
[
i
]:
y
=
begin_v
[
i
]
+
size_v
[
i
]
y
=
begin_v
[
i
]
+
size_v
[
i
]
raise
ValueError
(
"Slice shape can not bigger than orign shape %d, %d."
%
(
x_shape
[
i
],
y
))
raise
ValueError
(
"For '%s' slice shape can not bigger than orign shape %d, %d."
%
(
self
.
name
,
x_shape
[
i
],
y
))
return
{
'shape'
:
size_v
,
return
{
'shape'
:
size_v
,
'dtype'
:
x
[
'dtype'
],
'dtype'
:
x
[
'dtype'
],
'value'
:
None
}
'value'
:
None
}
...
@@ -1565,11 +1546,11 @@ class Select(PrimitiveWithInfer):
...
@@ -1565,11 +1546,11 @@ class Select(PrimitiveWithInfer):
def
infer_dtype
(
self
,
cond_type
,
x_type
,
y_type
):
def
infer_dtype
(
self
,
cond_type
,
x_type
,
y_type
):
self
.
add_prim_attr
(
'T'
,
x_type
)
self
.
add_prim_attr
(
'T'
,
x_type
)
validator
.
check_subclass
(
"x_type"
,
x_type
,
mstype
.
tensor
)
validator
.
check_subclass
(
"x_type"
,
x_type
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"y_type"
,
y_type
,
mstype
.
tensor
)
validator
.
check_subclass
(
"y_type"
,
y_type
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_t
ypename
(
"cond_type"
,
cond_type
,
[
mstype
.
bool_
]
)
validator
.
check_t
ensor_type_same
({
"cond"
:
cond_type
},
[
mstype
.
bool_
],
self
.
name
)
if
x_type
!=
y_type
:
if
x_type
!=
y_type
:
raise
TypeError
(
'
The x_type %s must be the same as y_type %s.'
%
(
x_type
,
y_type
))
raise
TypeError
(
'
\'
%s
\'
the x_type %s must be the same as y_type %s.'
%
(
self
.
name
,
x_type
,
y_type
))
return
x_type
return
x_type
...
@@ -1637,26 +1618,23 @@ class StridedSlice(PrimitiveWithInfer):
...
@@ -1637,26 +1618,23 @@ class StridedSlice(PrimitiveWithInfer):
shrink_axis_mask
=
0
):
shrink_axis_mask
=
0
):
"""init StrideSlice"""
"""init StrideSlice"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'begin'
,
'end'
,
'strides'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'begin'
,
'end'
,
'strides'
],
outputs
=
[
'output'
])
validator
.
check_
type
(
'begin_mask'
,
begin_mask
,
[
int
]
)
validator
.
check_
value_type
(
'begin_mask'
,
begin_mask
,
[
int
],
self
.
name
)
validator
.
check_
type
(
'end_mask'
,
end_mask
,
[
int
]
)
validator
.
check_
value_type
(
'end_mask'
,
end_mask
,
[
int
],
self
.
name
)
validator
.
check_
type
(
'ellipsis_mask'
,
ellipsis_mask
,
[
int
]
)
validator
.
check_
value_type
(
'ellipsis_mask'
,
ellipsis_mask
,
[
int
],
self
.
name
)
validator
.
check_
type
(
'new_axis_mask'
,
new_axis_mask
,
[
int
]
)
validator
.
check_
value_type
(
'new_axis_mask'
,
new_axis_mask
,
[
int
],
self
.
name
)
validator
.
check_
type
(
'shrink_axis_mask'
,
shrink_axis_mask
,
[
int
]
)
validator
.
check_
value_type
(
'shrink_axis_mask'
,
shrink_axis_mask
,
[
int
],
self
.
name
)
def
__infer__
(
self
,
x
,
begin
,
end
,
strides
):
def
__infer__
(
self
,
x
,
begin
,
end
,
strides
):
begin_v
,
end_v
,
strides_v
=
begin
[
'value'
],
end
[
'value'
],
strides
[
'value'
]
begin_v
,
end_v
,
strides_v
=
begin
[
'value'
],
end
[
'value'
],
strides
[
'value'
]
validator
.
check_const_input
(
"begin"
,
begin_v
)
validator
.
check_value_type
(
"begin"
,
begin_v
,
[
tuple
],
self
.
name
)
validator
.
check_const_input
(
"end"
,
end_v
)
validator
.
check_value_type
(
"end"
,
end_v
,
[
tuple
],
self
.
name
)
validator
.
check_const_input
(
"strides"
,
strides_v
)
validator
.
check_value_type
(
"strides"
,
strides_v
,
[
tuple
],
self
.
name
)
validator
.
check_type
(
"begin"
,
begin_v
,
[
tuple
])
validator
.
check_type
(
"end"
,
end_v
,
[
tuple
])
validator
.
check_type
(
"strides"
,
strides_v
,
[
tuple
])
x_shape
=
x
[
'shape'
]
x_shape
=
x
[
'shape'
]
x_shp_len
=
len
(
x_shape
)
x_shp_len
=
len
(
x_shape
)
if
len
(
begin_v
)
!=
x_shp_len
or
len
(
end_v
)
!=
x_shp_len
or
len
(
strides_v
)
!=
x_shp_len
:
if
len
(
begin_v
)
!=
x_shp_len
or
len
(
end_v
)
!=
x_shp_len
or
len
(
strides_v
)
!=
x_shp_len
:
raise
ValueError
(
f
"
The length of begin index
{
begin_v
}
, end index
{
end_v
}
and strides
{
strides_v
}
"
raise
ValueError
(
f
"
For
\'
{
self
.
name
}
\'
the length of begin index
{
begin_v
}
, end index
{
end_v
}
and
"
f
"must be equal to the dims(
{
x_shp_len
}
) of input."
)
f
"
strides
{
strides_v
}
must be equal to the dims(
{
x_shp_len
}
) of input."
)
ret_shape
=
[]
ret_shape
=
[]
append_dimensions
=
[]
append_dimensions
=
[]
...
@@ -1669,8 +1647,8 @@ class StridedSlice(PrimitiveWithInfer):
...
@@ -1669,8 +1647,8 @@ class StridedSlice(PrimitiveWithInfer):
append_dimensions
.
append
(
x_shape
[
x_shp_len
-
1
-
len
(
append_dimensions
)])
append_dimensions
.
append
(
x_shape
[
x_shp_len
-
1
-
len
(
append_dimensions
)])
continue
continue
if
i
<
(
len
(
shrink_pos
)
-
2
)
and
shrink_pos
[
i
]
==
'1'
:
if
i
<
(
len
(
shrink_pos
)
-
2
)
and
shrink_pos
[
i
]
==
'1'
:
validator
.
check_integer
(
f
'begin[
{
i
}
]'
,
begin_v
[
i
],
-
x_shape
[
i
],
Rel
.
GE
)
validator
.
check_integer
(
f
'begin[
{
i
}
]'
,
begin_v
[
i
],
-
x_shape
[
i
],
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
f
'begin[
{
i
}
]'
,
begin_v
[
i
],
x_shape
[
i
],
Rel
.
LT
)
validator
.
check_integer
(
f
'begin[
{
i
}
]'
,
begin_v
[
i
],
x_shape
[
i
],
Rel
.
LT
,
self
.
name
)
continue
continue
begin_idx
=
begin_v
[
i
]
begin_idx
=
begin_v
[
i
]
...
@@ -1680,9 +1658,9 @@ class StridedSlice(PrimitiveWithInfer):
...
@@ -1680,9 +1658,9 @@ class StridedSlice(PrimitiveWithInfer):
begin_idx
=
0
begin_idx
=
0
if
self
.
end_mask
:
if
self
.
end_mask
:
end_idx
=
x_shape
[
i
]
end_idx
=
x_shape
[
i
]
validator
.
check_integer
(
f
'begin[
{
i
}
]'
,
begin_idx
,
x_shape
[
i
],
Rel
.
LE
)
validator
.
check_integer
(
f
'begin[
{
i
}
]'
,
begin_idx
,
x_shape
[
i
],
Rel
.
LE
,
self
.
name
)
validator
.
check_integer
(
f
'end[
{
i
}
]'
,
end_idx
,
x_shape
[
i
],
Rel
.
LE
)
validator
.
check_integer
(
f
'end[
{
i
}
]'
,
end_idx
,
x_shape
[
i
],
Rel
.
LE
,
self
.
name
)
validator
.
check_integer
(
f
'strides[
{
i
}
]'
,
strides_idx
,
0
,
Rel
.
NE
)
validator
.
check_integer
(
f
'strides[
{
i
}
]'
,
strides_idx
,
0
,
Rel
.
NE
,
self
.
name
)
if
strides_idx
>
0
:
if
strides_idx
>
0
:
# If sliced forward , end_idx >= begin_idx
# If sliced forward , end_idx >= begin_idx
validator
.
check
(
f
'begin[
{
i
}
]'
,
begin_idx
,
f
'end[
{
i
}
]'
,
end_idx
,
Rel
.
LE
)
validator
.
check
(
f
'begin[
{
i
}
]'
,
begin_idx
,
f
'end[
{
i
}
]'
,
end_idx
,
Rel
.
LE
)
...
@@ -1736,7 +1714,7 @@ class Diag(PrimitiveWithInfer):
...
@@ -1736,7 +1714,7 @@ class Diag(PrimitiveWithInfer):
"""init Diag"""
"""init Diag"""
def
infer_dtype
(
self
,
x_type
):
def
infer_dtype
(
self
,
x_type
):
validator
.
check_subclass
(
'input_x'
,
x_type
,
mstype
.
tensor
)
validator
.
check_subclass
(
'input_x'
,
x_type
,
mstype
.
tensor
,
self
.
name
)
return
x_type
return
x_type
def
infer_shape
(
self
,
x_shape
):
def
infer_shape
(
self
,
x_shape
):
...
@@ -1748,7 +1726,7 @@ class Diag(PrimitiveWithInfer):
...
@@ -1748,7 +1726,7 @@ class Diag(PrimitiveWithInfer):
def
infer_value
(
self
,
x
):
def
infer_value
(
self
,
x
):
if
x
is
None
:
if
x
is
None
:
return
None
return
None
validator
.
check
(
"input x rank"
,
len
(
x
.
shape
()),
""
,
1
)
validator
.
check
_integer
(
"input x rank"
,
len
(
x
.
shape
()),
1
,
Rel
.
EQ
,
self
.
name
)
ret
=
np
.
diag
(
x
.
asnumpy
())
ret
=
np
.
diag
(
x
.
asnumpy
())
return
Tensor
(
ret
)
return
Tensor
(
ret
)
...
@@ -1783,13 +1761,13 @@ class DiagPart(PrimitiveWithInfer):
...
@@ -1783,13 +1761,13 @@ class DiagPart(PrimitiveWithInfer):
"""init DiagPart"""
"""init DiagPart"""
def
infer_dtype
(
self
,
x_type
):
def
infer_dtype
(
self
,
x_type
):
validator
.
check_subclass
(
'input_x'
,
x_type
,
mstype
.
tensor
)
validator
.
check_subclass
(
'input_x'
,
x_type
,
mstype
.
tensor
,
self
.
name
)
return
x_type
return
x_type
def
infer_shape
(
self
,
x_shape
):
def
infer_shape
(
self
,
x_shape
):
if
len
(
x_shape
)
%
2
!=
0
or
\
if
len
(
x_shape
)
%
2
!=
0
or
\
not
x_shape
:
not
x_shape
:
raise
ValueError
(
f
"
DiagPart
input rank must be non-zero and even, but got rank
{
len
(
x_shape
)
}
, "
raise
ValueError
(
f
"
For
\'
{
self
.
name
}
\'
input rank must be non-zero and even, but got rank
{
len
(
x_shape
)
}
, "
f
"with shapes
{
x_shape
}
"
)
f
"with shapes
{
x_shape
}
"
)
length
=
len
(
x_shape
)
//
2
length
=
len
(
x_shape
)
//
2
ret_shape
=
x_shape
[
0
:
length
]
ret_shape
=
x_shape
[
0
:
length
]
...
@@ -1798,7 +1776,7 @@ class DiagPart(PrimitiveWithInfer):
...
@@ -1798,7 +1776,7 @@ class DiagPart(PrimitiveWithInfer):
def
infer_value
(
self
,
x
):
def
infer_value
(
self
,
x
):
if
x
is
None
:
if
x
is
None
:
return
None
return
None
validator
.
check
(
"x rank"
,
len
(
x
.
shape
()),
""
,
2
)
validator
.
check
(
"x rank"
,
len
(
x
.
shape
()),
""
,
2
,
Rel
.
EQ
,
self
.
name
)
ret
=
np
.
diag
(
x
.
asnumpy
())
ret
=
np
.
diag
(
x
.
asnumpy
())
return
Tensor
(
ret
)
return
Tensor
(
ret
)
...
@@ -1826,12 +1804,10 @@ class Eye(PrimitiveWithInfer):
...
@@ -1826,12 +1804,10 @@ class Eye(PrimitiveWithInfer):
"""init Eye"""
"""init Eye"""
def
infer_value
(
self
,
n
,
m
,
t
):
def
infer_value
(
self
,
n
,
m
,
t
):
validator
.
check_type
(
"n"
,
n
,
[
int
])
validator
.
check_integer
(
"n"
,
n
,
0
,
Rel
.
GT
,
self
.
name
)
validator
.
check_integer
(
"n"
,
n
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"m"
,
m
,
0
,
Rel
.
GT
,
self
.
name
)
validator
.
check_type
(
"m"
,
m
,
[
int
])
validator
.
check_integer
(
"m"
,
m
,
0
,
Rel
.
GT
)
args
=
{
"dtype"
:
t
}
args
=
{
"dtype"
:
t
}
validator
.
check_type_same
(
args
,
mstype
.
number_type
+
(
mstype
.
bool_
,))
validator
.
check_type_same
(
args
,
mstype
.
number_type
+
(
mstype
.
bool_
,)
,
self
.
name
)
np_type
=
mstype
.
dtype_to_nptype
(
t
)
np_type
=
mstype
.
dtype_to_nptype
(
t
)
ret
=
np
.
eye
(
n
,
m
,
dtype
=
np_type
)
ret
=
np
.
eye
(
n
,
m
,
dtype
=
np_type
)
return
Tensor
(
ret
)
return
Tensor
(
ret
)
...
@@ -1866,16 +1842,15 @@ class ScatterNd(PrimitiveWithInfer):
...
@@ -1866,16 +1842,15 @@ class ScatterNd(PrimitiveWithInfer):
def
__infer__
(
self
,
indices
,
update
,
shape
):
def
__infer__
(
self
,
indices
,
update
,
shape
):
shp
=
shape
[
'value'
]
shp
=
shape
[
'value'
]
validator
.
check_subclass
(
"indices_dtype"
,
indices
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_subclass
(
"update_dtype"
,
update
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"update_dtype"
,
update
[
'dtype'
],
mstype
.
tensor
)
validator
.
check_tensor_type_same
({
"indices"
:
indices
[
'dtype'
]},
mstype
.
int_type
,
self
.
name
)
validator
.
check_typename
(
"indices_dtype"
,
indices
[
'dtype'
],
mstype
.
int_type
)
validator
.
check_value_type
(
"shape"
,
shp
,
[
tuple
],
self
.
name
)
validator
.
check_type
(
"shape"
,
shp
,
[
tuple
])
for
i
,
x
in
enumerate
(
shp
):
for
i
,
x
in
enumerate
(
shp
):
validator
.
check_integer
(
"shape[%d]"
%
i
,
x
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"shape[%d]"
%
i
,
x
,
0
,
Rel
.
GT
,
self
.
name
)
indices_shape
,
update_shape
=
indices
[
"shape"
],
update
[
"shape"
]
indices_shape
,
update_shape
=
indices
[
"shape"
],
update
[
"shape"
]
if
indices_shape
[
0
]
!=
update_shape
[
0
]:
if
indices_shape
[
0
]
!=
update_shape
[
0
]:
raise
ValueError
(
'
The indices_shape[0] and update_shape[0] must be equal.'
)
raise
ValueError
(
f
'For
\'
{
self
.
name
}
\'
The indices_shape[0] and update_shape[0] must be equal.'
)
return
{
'shape'
:
shp
,
return
{
'shape'
:
shp
,
'dtype'
:
update
[
'dtype'
],
'dtype'
:
update
[
'dtype'
],
...
@@ -1913,7 +1888,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
...
@@ -1913,7 +1888,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'image_in'
],
outputs
=
[
'image_out'
])
self
.
init_prim_io_names
(
inputs
=
[
'image_in'
],
outputs
=
[
'image_out'
])
def
infer_shape
(
self
,
x
):
def
infer_shape
(
self
,
x
):
validator
.
check
(
'the dimension of input_x'
,
len
(
x
),
''
,
2
,
Rel
.
GE
)
validator
.
check
(
'the dimension of input_x'
,
len
(
x
),
''
,
2
,
Rel
.
GE
,
self
.
name
)
return
tuple
(
x
)[:
-
2
]
+
tuple
(
self
.
size
)
return
tuple
(
x
)[:
-
2
]
+
tuple
(
self
.
size
)
def
infer_dtype
(
self
,
x
):
def
infer_dtype
(
self
,
x
):
...
@@ -1947,13 +1922,12 @@ class GatherNd(PrimitiveWithInfer):
...
@@ -1947,13 +1922,12 @@ class GatherNd(PrimitiveWithInfer):
def
infer_shape
(
self
,
x_shape
,
indices_shape
):
def
infer_shape
(
self
,
x_shape
,
indices_shape
):
validator
.
check
(
'the dimension of x'
,
len
(
x_shape
),
validator
.
check
(
'the dimension of x'
,
len
(
x_shape
),
'the dimension of indices'
,
indices_shape
[
-
1
],
Rel
.
GE
)
'the dimension of indices'
,
indices_shape
[
-
1
],
Rel
.
GE
,
self
.
name
)
return
indices_shape
[:
-
1
]
+
x_shape
[
indices_shape
[
-
1
]:]
return
indices_shape
[:
-
1
]
+
x_shape
[
indices_shape
[
-
1
]:]
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
):
validator
.
check_subclass
(
"x_dtype"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"x_dtype"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"indices_dtype"
,
indices_dtype
,
mstype
.
tensor
)
validator
.
check_tensor_type_same
({
"indices"
:
indices_dtype
},
mstype
.
int_type
,
self
.
name
)
validator
.
check_typename
(
"indices_dtype"
,
indices_dtype
,
mstype
.
int_type
)
return
x_dtype
return
x_dtype
...
@@ -1995,12 +1969,9 @@ class ScatterNdUpdate(PrimitiveWithInfer):
...
@@ -1995,12 +1969,9 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
,
value_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
,
value_dtype
):
validator
.
check_subclass
(
"x_dtype"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
mstype
.
int_type
,
self
.
name
)
validator
.
check_subclass
(
"indices_dtype"
,
indices_dtype
,
mstype
.
tensor
)
args
=
{
"x"
:
x_dtype
,
"value"
:
value_dtype
}
validator
.
check_subclass
(
"value_dtype"
,
value_dtype
,
mstype
.
tensor
)
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
validator
.
check_typename
(
'indices_dtype'
,
indices_dtype
,
mstype
.
int_type
)
args
=
{
"x_dtype"
:
x_dtype
,
"value_dtype"
:
value_dtype
}
validator
.
check_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
)
return
x_dtype
return
x_dtype
...
@@ -2038,7 +2009,7 @@ class SpaceToDepth(PrimitiveWithInfer):
...
@@ -2038,7 +2009,7 @@ class SpaceToDepth(PrimitiveWithInfer):
def
__init__
(
self
,
block_size
):
def
__init__
(
self
,
block_size
):
"""Init SpaceToDepth"""
"""Init SpaceToDepth"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'y'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'y'
])
validator
.
check_
type
(
'block_size'
,
block_size
,
[
int
]
)
validator
.
check_
value_type
(
'block_size'
,
block_size
,
[
int
],
self
.
name
)
validator
.
check
(
'block_size'
,
block_size
,
''
,
2
,
Rel
.
GE
)
validator
.
check
(
'block_size'
,
block_size
,
''
,
2
,
Rel
.
GE
)
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
add_prim_attr
(
"data_format"
,
"NCHW"
)
self
.
add_prim_attr
(
"data_format"
,
"NCHW"
)
...
@@ -2048,7 +2019,7 @@ class SpaceToDepth(PrimitiveWithInfer):
...
@@ -2048,7 +2019,7 @@ class SpaceToDepth(PrimitiveWithInfer):
out_shape
=
copy
.
deepcopy
(
x_shape
)
out_shape
=
copy
.
deepcopy
(
x_shape
)
for
i
in
range
(
2
):
for
i
in
range
(
2
):
if
out_shape
[
i
+
2
]
%
self
.
block_size
!=
0
:
if
out_shape
[
i
+
2
]
%
self
.
block_size
!=
0
:
raise
ValueError
(
f
'
SpaceToDepth
input shape[
{
i
+
2
}
]
{
out_shape
[
i
+
2
]
}
should be '
raise
ValueError
(
f
'
For
\'
{
self
.
name
}
\'
input shape[
{
i
+
2
}
]
{
out_shape
[
i
+
2
]
}
should be '
f
'fully divided by block_size
{
self
.
block_size
}
'
)
f
'fully divided by block_size
{
self
.
block_size
}
'
)
out_shape
[
i
+
2
]
//=
self
.
block_size
out_shape
[
i
+
2
]
//=
self
.
block_size
...
@@ -2056,7 +2027,7 @@ class SpaceToDepth(PrimitiveWithInfer):
...
@@ -2056,7 +2027,7 @@ class SpaceToDepth(PrimitiveWithInfer):
return
out_shape
return
out_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"x_dtype"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"x_dtype"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
return
x_dtype
return
x_dtype
...
@@ -2096,8 +2067,8 @@ class DepthToSpace(PrimitiveWithInfer):
...
@@ -2096,8 +2067,8 @@ class DepthToSpace(PrimitiveWithInfer):
def
__init__
(
self
,
block_size
):
def
__init__
(
self
,
block_size
):
"""Init DepthToSpace"""
"""Init DepthToSpace"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'y'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'y'
])
validator
.
check_
type
(
'block_size'
,
block_size
,
[
int
]
)
validator
.
check_
value_type
(
'block_size'
,
block_size
,
[
int
],
self
.
name
)
validator
.
check
(
'block_size'
,
block_size
,
''
,
2
,
Rel
.
GE
)
validator
.
check
(
'block_size'
,
block_size
,
''
,
2
,
Rel
.
GE
,
self
.
name
)
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
add_prim_attr
(
"data_format"
,
"NCHW"
)
self
.
add_prim_attr
(
"data_format"
,
"NCHW"
)
...
@@ -2107,12 +2078,13 @@ class DepthToSpace(PrimitiveWithInfer):
...
@@ -2107,12 +2078,13 @@ class DepthToSpace(PrimitiveWithInfer):
for
i
in
range
(
2
):
for
i
in
range
(
2
):
out_shape
[
i
+
2
]
*=
self
.
block_size
out_shape
[
i
+
2
]
*=
self
.
block_size
validator
.
check
(
'x_shape[1] % (block_size*block_size)'
,
x_shape
[
1
]
%
(
self
.
block_size
*
self
.
block_size
),
''
,
0
)
validator
.
check_integer
(
'x_shape[1] % (block_size*block_size)'
,
x_shape
[
1
]
%
(
self
.
block_size
*
self
.
block_size
),
0
,
Rel
.
EQ
,
self
.
name
)
out_shape
[
1
]
//=
self
.
block_size
*
self
.
block_size
out_shape
[
1
]
//=
self
.
block_size
*
self
.
block_size
return
out_shape
return
out_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"x_dtype"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"x_dtype"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
return
x_dtype
return
x_dtype
...
@@ -2159,27 +2131,26 @@ class SpaceToBatch(PrimitiveWithInfer):
...
@@ -2159,27 +2131,26 @@ class SpaceToBatch(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
block_size
,
paddings
):
def
__init__
(
self
,
block_size
,
paddings
):
"""Init SpaceToBatch"""
"""Init SpaceToBatch"""
validator
.
check_
type
(
'block_size'
,
block_size
,
[
int
]
)
validator
.
check_
value_type
(
'block_size'
,
block_size
,
[
int
],
self
.
name
)
validator
.
check
(
'block_size'
,
block_size
,
''
,
1
,
Rel
.
GT
)
validator
.
check
(
'block_size'
,
block_size
,
''
,
1
,
Rel
.
GT
,
self
.
name
)
self
.
block_size
=
block_size
self
.
block_size
=
block_size
validator
.
check
(
'paddings shape'
,
np
.
array
(
paddings
).
shape
,
''
,
(
2
,
2
))
validator
.
check
(
'paddings shape'
,
np
.
array
(
paddings
).
shape
,
''
,
(
2
,
2
)
,
Rel
.
EQ
,
self
.
name
)
for
elem
in
itertools
.
chain
(
*
paddings
):
for
elem
in
itertools
.
chain
(
*
paddings
):
validator
.
check_
type
(
'paddings element'
,
elem
,
[
int
]
)
validator
.
check_
value_type
(
'paddings element'
,
elem
,
[
int
],
self
.
name
)
self
.
paddings
=
paddings
self
.
paddings
=
paddings
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_tensor_type_same
({
'input_x'
:
x_dtype
},
mstype
.
number_type
,
self
.
name
)
validator
.
check_typename
(
'input_x'
,
x_dtype
,
mstype
.
number_type
)
return
x_dtype
return
x_dtype
def
infer_shape
(
self
,
x_shape
):
def
infer_shape
(
self
,
x_shape
):
validator
.
check
(
'rank of input_x'
,
len
(
x_shape
),
''
,
4
)
validator
.
check
_integer
(
'rank of input_x'
,
len
(
x_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
out_shape
=
copy
.
deepcopy
(
x_shape
)
out_shape
=
copy
.
deepcopy
(
x_shape
)
for
i
in
range
(
2
):
for
i
in
range
(
2
):
padded
=
out_shape
[
i
+
2
]
+
self
.
paddings
[
i
][
0
]
+
\
padded
=
out_shape
[
i
+
2
]
+
self
.
paddings
[
i
][
0
]
+
\
self
.
paddings
[
i
][
1
]
self
.
paddings
[
i
][
1
]
if
padded
%
self
.
block_size
!=
0
:
if
padded
%
self
.
block_size
!=
0
:
raise
ValueError
(
f
'padded[
{
i
}
]
{
padded
}
should be divisible by '
raise
ValueError
(
f
'
For
\'
{
self
.
name
}
\'
padded[
{
i
}
]
{
padded
}
should be divisible by '
f
'block_size
{
self
.
block_size
}
'
)
f
'block_size
{
self
.
block_size
}
'
)
out_shape
[
i
+
2
]
=
padded
//
self
.
block_size
out_shape
[
i
+
2
]
=
padded
//
self
.
block_size
out_shape
[
0
]
*=
self
.
block_size
*
self
.
block_size
out_shape
[
0
]
*=
self
.
block_size
*
self
.
block_size
...
@@ -2227,17 +2198,16 @@ class BatchToSpace(PrimitiveWithInfer):
...
@@ -2227,17 +2198,16 @@ class BatchToSpace(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
block_size
,
crops
):
def
__init__
(
self
,
block_size
,
crops
):
"""Init BatchToSpace"""
"""Init BatchToSpace"""
validator
.
check_
type
(
'block_size'
,
block_size
,
[
int
]
)
validator
.
check_
value_type
(
'block_size'
,
block_size
,
[
int
],
self
.
name
)
validator
.
check
(
'block_size'
,
block_size
,
''
,
1
,
Rel
.
GT
)
validator
.
check
(
'block_size'
,
block_size
,
''
,
1
,
Rel
.
GT
,
self
.
name
)
self
.
block_size
=
block_size
self
.
block_size
=
block_size
validator
.
check
(
'crops shape'
,
np
.
array
(
crops
).
shape
,
''
,
(
2
,
2
))
validator
.
check
(
'crops shape'
,
np
.
array
(
crops
).
shape
,
''
,
(
2
,
2
))
for
elem
in
itertools
.
chain
(
*
crops
):
for
elem
in
itertools
.
chain
(
*
crops
):
validator
.
check_
type
(
'crops element'
,
elem
,
[
int
]
)
validator
.
check_
value_type
(
'crops element'
,
elem
,
[
int
],
self
.
name
)
self
.
crops
=
crops
self
.
crops
=
crops
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_subclass
(
"input_x"
,
x_dtype
,
mstype
.
tensor
)
validator
.
check_tensor_type_same
({
'input_x'
:
x_dtype
},
mstype
.
number_type
,
self
.
name
)
validator
.
check_typename
(
'input_x'
,
x_dtype
,
mstype
.
number_type
)
return
x_dtype
return
x_dtype
def
infer_shape
(
self
,
x_shape
):
def
infer_shape
(
self
,
x_shape
):
...
@@ -2246,11 +2216,11 @@ class BatchToSpace(PrimitiveWithInfer):
...
@@ -2246,11 +2216,11 @@ class BatchToSpace(PrimitiveWithInfer):
for
i
in
range
(
2
):
for
i
in
range
(
2
):
x_block_prod
=
out_shape
[
i
+
2
]
*
self
.
block_size
x_block_prod
=
out_shape
[
i
+
2
]
*
self
.
block_size
crops_sum
=
self
.
crops
[
i
][
0
]
+
self
.
crops
[
i
][
1
]
crops_sum
=
self
.
crops
[
i
][
0
]
+
self
.
crops
[
i
][
1
]
validator
.
check
(
"x block shape prod"
,
x_block_prod
,
'crops sum'
,
crops_sum
,
Rel
.
GT
)
validator
.
check
(
"x block shape prod"
,
x_block_prod
,
'crops sum'
,
crops_sum
,
Rel
.
GT
,
self
.
name
)
out_shape
[
i
+
2
]
=
x_block_prod
-
crops_sum
out_shape
[
i
+
2
]
=
x_block_prod
-
crops_sum
block_size_prod
=
self
.
block_size
*
self
.
block_size
block_size_prod
=
self
.
block_size
*
self
.
block_size
if
out_shape
[
0
]
%
block_size_prod
!=
0
:
if
out_shape
[
0
]
%
block_size_prod
!=
0
:
raise
ValueError
(
f
'input_x dimension 0
{
out_shape
[
0
]
}
should be divisible by '
raise
ValueError
(
f
'
For
\'
{
self
.
name
}
\'
input_x dimension 0
{
out_shape
[
0
]
}
should be divisible by '
f
'block_size_prod
{
block_size_prod
}
'
)
f
'block_size_prod
{
block_size_prod
}
'
)
out_shape
[
0
]
=
out_shape
[
0
]
//
block_size_prod
out_shape
[
0
]
=
out_shape
[
0
]
//
block_size_prod
return
out_shape
return
out_shape
tests/ut/python/ops/test_array_ops_check.py
0 → 100755
浏览文件 @
20782294
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test ops """
import
functools
import
numpy
as
np
from
mindspore
import
ops
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.operations
import
_grad_ops
as
G
import
mindspore.ops.composite
as
C
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common.parameter
import
Parameter
from
..ut_filter
import
non_graph_engine
from
mindspore.common.api
import
_executor
from
....mindspore_test_framework.mindspore_test
import
mindspore_test
from
....mindspore_test_framework.pipeline.forward.compile_forward
\
import
(
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
,
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
)
from
....mindspore_test_framework.pipeline.gradient.compile_gradient
\
import
pipeline_for_compile_grad_ge_graph_for_case_by_case_config
class
ExpandDimsNet
(
nn
.
Cell
):
def
__init__
(
self
,
axis
):
super
(
ExpandDimsNet
,
self
).
__init__
()
self
.
axis
=
axis
self
.
op
=
P
.
ExpandDims
()
def
construct
(
self
,
x
):
return
self
.
op
(
x
,
self
.
axis
)
class
IsInstanceNet
(
nn
.
Cell
):
def
__init__
(
self
,
inst
):
super
(
IsInstanceNet
,
self
).
__init__
()
self
.
inst
=
inst
self
.
op
=
P
.
IsInstance
()
def
construct
(
self
,
t
):
return
self
.
op
(
self
.
inst
,
t
)
class
ReshapeNet
(
nn
.
Cell
):
def
__init__
(
self
,
shape
):
super
(
ReshapeNet
,
self
).
__init__
()
self
.
shape
=
shape
self
.
op
=
P
.
Reshape
()
def
construct
(
self
,
x
):
return
self
.
op
(
x
,
self
.
shape
)
raise_set
=
[
# input is scala, not Tensor
(
'ExpandDims0'
,
{
'block'
:
(
P
.
ExpandDims
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'ExpandDims'
]}),
'desc_inputs'
:
[
5.0
,
1
],
'skip'
:
[
'backward'
]}),
# axis is as a parameter
(
'ExpandDims1'
,
{
'block'
:
(
P
.
ExpandDims
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'ExpandDims'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
1
],
'skip'
:
[
'backward'
]}),
# axis as an attribute, but less then lower limit
(
'ExpandDims2'
,
{
'block'
:
(
ExpandDimsNet
(
-
4
),
{
'exception'
:
ValueError
,
'error_keywords'
:
[
'ExpandDims'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
# axis as an attribute, but greater then upper limit
(
'ExpandDims3'
,
{
'block'
:
(
ExpandDimsNet
(
3
),
{
'exception'
:
ValueError
,
'error_keywords'
:
[
'ExpandDims'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
# input is scala, not Tensor
(
'DType0'
,
{
'block'
:
(
P
.
DType
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'DType'
]}),
'desc_inputs'
:
[
5.0
],
'skip'
:
[
'backward'
]}),
# input x scala, not Tensor
(
'SameTypeShape0'
,
{
'block'
:
(
P
.
SameTypeShape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'SameTypeShape'
]}),
'desc_inputs'
:
[
5.0
,
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
# input y scala, not Tensor
(
'SameTypeShape1'
,
{
'block'
:
(
P
.
SameTypeShape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'SameTypeShape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
5.0
],
'skip'
:
[
'backward'
]}),
# type of x and y not match
(
'SameTypeShape2'
,
{
'block'
:
(
P
.
SameTypeShape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'SameTypeShape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
int32
))],
'skip'
:
[
'backward'
]}),
# shape of x and y not match
(
'SameTypeShape3'
,
{
'block'
:
(
P
.
SameTypeShape
(),
{
'exception'
:
ValueError
,
'error_keywords'
:
[
'SameTypeShape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
Tensor
(
np
.
ones
([
3
,
3
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
# sub_type is None
(
'IsSubClass0'
,
{
'block'
:
(
P
.
IsSubClass
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'IsSubClass'
]}),
'desc_inputs'
:
[
None
,
mstype
.
number
],
'skip'
:
[
'backward'
]}),
# type_ is None
(
'IsSubClass1'
,
{
'block'
:
(
P
.
IsSubClass
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'IsSubClass'
]}),
'desc_inputs'
:
[
mstype
.
number
,
None
],
'skip'
:
[
'backward'
]}),
# inst is var
(
'IsInstance0'
,
{
'block'
:
(
P
.
IsInstance
(),
{
'exception'
:
ValueError
,
'error_keywords'
:
[
'IsInstance'
]}),
'desc_inputs'
:
[
5.0
,
mstype
.
number
],
'skip'
:
[
'backward'
]}),
# t is not mstype.Type
(
'IsInstance1'
,
{
'block'
:
(
IsInstanceNet
(
5.0
),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'IsInstance'
]}),
'desc_inputs'
:
[
None
],
'skip'
:
[
'backward'
]}),
# input x is scalar, not Tensor
(
'Reshape0'
,
{
'block'
:
(
P
.
Reshape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'Reshape'
]}),
'desc_inputs'
:
[
5.0
,
(
1
,
2
)],
'skip'
:
[
'backward'
]}),
# input shape is var
(
'Reshape1'
,
{
'block'
:
(
P
.
Reshape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'Reshape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
(
2
,
3
,
2
)],
'skip'
:
[
'backward'
]}),
# element of shape is not int
(
'Reshape3'
,
{
'block'
:
(
ReshapeNet
((
2
,
3.0
,
2
)),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'Reshape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
]
@
mindspore_test
(
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
)
def
test_check_exception
():
return
raise_set
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
20782294
...
@@ -383,7 +383,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
...
@@ -383,7 +383,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
net
=
NetWork
()
net
=
NetWork
()
with
pytest
.
raises
(
ValueError
)
as
ex
:
with
pytest
.
raises
(
ValueError
)
as
ex
:
net
(
input_tensor
)
net
(
input_tensor
)
assert
"
The `begin[0]` should be an int and must greater or equal to -6, but got -7
"
in
str
(
ex
.
value
)
assert
"
For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`
"
in
str
(
ex
.
value
)
def
test_tensor_slice_reduce_out_of_bounds_positive
():
def
test_tensor_slice_reduce_out_of_bounds_positive
():
...
@@ -400,4 +400,4 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
...
@@ -400,4 +400,4 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
net
=
NetWork
()
net
=
NetWork
()
with
pytest
.
raises
(
ValueError
)
as
ex
:
with
pytest
.
raises
(
ValueError
)
as
ex
:
net
(
input_tensor
)
net
(
input_tensor
)
assert
"
The `begin[0]` should be an int and must less than 6, but got 6
"
in
str
(
ex
.
value
)
assert
"
For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`
"
in
str
(
ex
.
value
)
tests/vm_impl/vm_me.py
浏览文件 @
20782294
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
numpy
as
np
import
numpy
as
np
from
mindspore._checkparam
import
Rel
from
mindspore._checkparam
import
Rel
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
def
avg_pooling
(
x
,
pool_h
,
pool_w
,
stride
):
def
avg_pooling
(
x
,
pool_h
,
pool_w
,
stride
):
...
@@ -32,7 +32,7 @@ def avg_pooling(x, pool_h, pool_w, stride):
...
@@ -32,7 +32,7 @@ def avg_pooling(x, pool_h, pool_w, stride):
Returns:
Returns:
numpy.ndarray, an output array after applying average pooling on input array.
numpy.ndarray, an output array after applying average pooling on input array.
"""
"""
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
,
None
)
num
,
channel
,
height
,
width
=
x
.
shape
num
,
channel
,
height
,
width
=
x
.
shape
out_h
=
(
height
-
pool_h
)
//
stride
+
1
out_h
=
(
height
-
pool_h
)
//
stride
+
1
out_w
=
(
width
-
pool_w
)
//
stride
+
1
out_w
=
(
width
-
pool_w
)
//
stride
+
1
...
@@ -217,7 +217,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
...
@@ -217,7 +217,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
dilation
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
):
dilation
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
):
"""Convolution 2D."""
"""Convolution 2D."""
# pylint: disable=unused-argument
# pylint: disable=unused-argument
validator
.
check_
type
(
'stride'
,
stride
,
(
int
,
tuple
)
)
validator
.
check_
value_type
(
'stride'
,
stride
,
(
int
,
tuple
),
None
)
if
isinstance
(
stride
,
int
):
if
isinstance
(
stride
,
int
):
stride
=
(
stride
,
stride
)
stride
=
(
stride
,
stride
)
elif
len
(
stride
)
==
4
:
elif
len
(
stride
)
==
4
:
...
@@ -229,7 +229,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
...
@@ -229,7 +229,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
f
"a tuple of two positive int numbers, but got
{
stride
}
"
)
f
"a tuple of two positive int numbers, but got
{
stride
}
"
)
stride_h
=
stride
[
0
]
stride_h
=
stride
[
0
]
stride_w
=
stride
[
1
]
stride_w
=
stride
[
1
]
validator
.
check_
type
(
'dilation'
,
dilation
,
(
int
,
tuple
)
)
validator
.
check_
value_type
(
'dilation'
,
dilation
,
(
int
,
tuple
),
None
)
if
isinstance
(
dilation
,
int
):
if
isinstance
(
dilation
,
int
):
dilation
=
(
dilation
,
dilation
)
dilation
=
(
dilation
,
dilation
)
elif
len
(
dilation
)
==
4
:
elif
len
(
dilation
)
==
4
:
...
@@ -384,7 +384,7 @@ def matmul(x, w, b=None):
...
@@ -384,7 +384,7 @@ def matmul(x, w, b=None):
def
max_pooling
(
x
,
pool_h
,
pool_w
,
stride
):
def
max_pooling
(
x
,
pool_h
,
pool_w
,
stride
):
"""Max pooling."""
"""Max pooling."""
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
,
None
)
num
,
channel
,
height
,
width
=
x
.
shape
num
,
channel
,
height
,
width
=
x
.
shape
out_h
=
(
height
-
pool_h
)
//
stride
+
1
out_h
=
(
height
-
pool_h
)
//
stride
+
1
out_w
=
(
width
-
pool_w
)
//
stride
+
1
out_w
=
(
width
-
pool_w
)
//
stride
+
1
...
@@ -427,7 +427,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
...
@@ -427,7 +427,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
def
max_pool_with_argmax
(
x
,
pool_h
,
pool_w
,
stride
):
def
max_pool_with_argmax
(
x
,
pool_h
,
pool_w
,
stride
):
"""Max pooling with argmax."""
"""Max pooling with argmax."""
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
,
None
)
num
,
channel
,
height
,
width
=
x
.
shape
num
,
channel
,
height
,
width
=
x
.
shape
out_h
=
(
height
-
pool_h
)
//
stride
+
1
out_h
=
(
height
-
pool_h
)
//
stride
+
1
out_w
=
(
width
-
pool_w
)
//
stride
+
1
out_w
=
(
width
-
pool_w
)
//
stride
+
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录