Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
20782294
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
20782294
编写于
4月 14, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add prim name to error message for array_ops
上级
789edcb2
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
358 addition
and
472 deletion
+358
-472
mindspore/_checkparam.py
mindspore/_checkparam.py
+1
-243
mindspore/ccsrc/optimizer/ad/dfunctor.cc
mindspore/ccsrc/optimizer/ad/dfunctor.cc
+1
-1
mindspore/nn/layer/pooling.py
mindspore/nn/layer/pooling.py
+5
-6
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+184
-214
tests/ut/python/ops/test_array_ops_check.py
tests/ut/python/ops/test_array_ops_check.py
+159
-0
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+2
-2
tests/vm_impl/vm_me.py
tests/vm_impl/vm_me.py
+6
-6
未找到文件。
mindspore/_checkparam.py
浏览文件 @
20782294
...
...
@@ -210,7 +210,7 @@ class Validator:
type_names
=
[]
for
t
in
valid_values
:
type_names
.
append
(
str
(
t
))
types_info
=
'['
+
", "
.
join
(
type_names
)
+
']'
types_info
=
'['
+
', '
.
join
(
type_names
)
+
']'
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg_key
}
` should be in
{
types_info
}
,'
f
' but got
{
elem_type
}
.'
)
return
(
arg_key
,
elem_type
)
...
...
@@ -320,224 +320,6 @@ class Validator:
raise
TypeError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be float."
)
class
ParamValidator
:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@
staticmethod
def
equal
(
arg_name
,
arg_value
,
cond_str
,
cond
):
"""Judging valid value."""
if
not
cond
:
raise
ValueError
(
f
'The `
{
arg_name
}
` must be
{
cond_str
}
, but got
{
arg_value
}
.'
)
@
staticmethod
def
check
(
arg_name
,
arg_value
,
value_name
,
value
,
rel
=
Rel
.
EQ
):
"""This method is only used for check int values, since when compare float values,
we need consider float error."""
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
f
'
{
value_name
}
:
{
value
}
'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be
{
rel_str
}
, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_integer
(
arg_name
,
arg_value
,
value
,
rel
):
"""Integer value judgment."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
or
isinstance
(
arg_value
,
bool
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_shape_length
(
arg_name
,
arg_value
,
value
,
rel
):
"""Shape length judgment."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The length of `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
'
)
return
arg_value
@
staticmethod
def
check_int_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int in range
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_isinstance
(
arg_name
,
arg_value
,
classes
):
"""Check arg isinstance of classes"""
if
not
isinstance
(
arg_value
,
classes
):
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isinstance of
{
classes
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_number_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
):
"""Is it necessary to consider error when comparing float values."""
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be in range
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
with_type_of
=
True
):
"""Check whether some type is subclass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
template_type
=
(
template_type
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
type_str
=
(
type
(
type_
).
__name__
if
isinstance
(
type_
,
(
tuple
,
list
))
else
""
)
+
str
(
type_
)
raise
TypeError
(
f
'The
{
"type of"
if
with_type_of
else
""
}
`
{
arg_name
}
` should be subclass'
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
))
}
, but got
{
type_str
}
.'
)
@
staticmethod
def
check_args_tensor
(
args
):
"""Check whether args are all tensor."""
if
not
isinstance
(
args
,
dict
):
raise
TypeError
(
"The args should be a dict."
)
for
arg
,
value
in
args
.
items
():
ParamValidator
.
check_subclass
(
arg
,
value
,
mstype
.
tensor
)
@
staticmethod
def
check_bool
(
arg_name
,
arg_value
):
"""Check arg isinstance of bool"""
if
not
isinstance
(
arg_value
,
bool
):
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isinstance of bool, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_type
(
arg_name
,
arg_value
,
valid_types
):
"""Type checking."""
def
raise_error_msg
():
"""func for raising error message when check failed"""
type_names
=
[
t
.
__name__
for
t
in
valid_types
]
num_types
=
len
(
valid_types
)
raise
TypeError
(
f
'The type of `
{
arg_name
}
` should be
{
"one of "
if
num_types
>
1
else
""
}
'
f
'
{
type_names
if
num_types
>
1
else
type_names
[
0
]
}
, but got
{
type
(
arg_value
).
__name__
}
.'
)
if
isinstance
(
arg_value
,
type
(
mstype
.
tensor
)):
arg_value
=
arg_value
.
element_type
()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if
isinstance
(
arg_value
,
bool
)
and
bool
not
in
tuple
(
valid_types
):
raise_error_msg
()
if
isinstance
(
arg_value
,
tuple
(
valid_types
)):
return
arg_value
raise_error_msg
()
@
staticmethod
def
check_typename
(
arg_name
,
arg_type
,
valid_types
):
"""Does it contain the _name_ attribute."""
def
get_typename
(
t
):
return
t
.
__name__
if
hasattr
(
t
,
'__name__'
)
else
str
(
t
)
if
isinstance
(
arg_type
,
type
(
mstype
.
tensor
)):
arg_type
=
arg_type
.
element_type
()
if
arg_type
in
valid_types
:
return
arg_type
type_names
=
[
get_typename
(
t
)
for
t
in
valid_types
]
if
len
(
valid_types
)
==
1
:
raise
ValueError
(
f
'The type of `
{
arg_name
}
` should be
{
type_names
[
0
]
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
raise
ValueError
(
f
'The type of `
{
arg_name
}
` should be one of
{
type_names
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
@
staticmethod
def
check_string
(
arg_name
,
arg_value
,
valid_values
):
"""String type judgment."""
if
isinstance
(
arg_value
,
str
)
and
arg_value
in
valid_values
:
return
arg_value
if
len
(
valid_values
)
==
1
:
raise
ValueError
(
f
'The `
{
arg_name
}
` should be str and must be
{
valid_values
[
0
]
}
,'
f
' but got
{
arg_value
}
.'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be str and must be one of
{
valid_values
}
,'
f
' but got
{
arg_value
}
.'
)
@
staticmethod
def
check_type_same
(
args
,
valid_values
):
"""Determine whether the types are the same."""
name
=
list
(
args
.
keys
())[
0
]
value
=
list
(
args
.
values
())[
0
]
if
isinstance
(
value
,
type
(
mstype
.
tensor
)):
value
=
value
.
element_type
()
for
arg_name
,
arg_value
in
args
.
items
():
if
isinstance
(
arg_value
,
type
(
mstype
.
tensor
)):
arg_value
=
arg_value
.
element_type
()
if
arg_value
not
in
valid_values
:
raise
TypeError
(
f
'The `
{
arg_name
}
` should be in
{
valid_values
}
,'
f
' but `
{
arg_name
}
` is
{
arg_value
}
.'
)
if
arg_value
!=
value
:
raise
TypeError
(
f
'`
{
arg_name
}
` should be same as `
{
name
}
`,'
f
' but `
{
arg_name
}
` is
{
arg_value
}
, `
{
name
}
` is
{
value
}
.'
)
@
staticmethod
def
check_two_types_same
(
arg1_name
,
arg1_type
,
arg2_name
,
arg2_type
):
"""Determine whether the types of two variables are the same."""
if
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'The type of `
{
arg1_name
}
` and `
{
arg2_name
}
` should be same.'
)
@
staticmethod
def
check_value_on_integer
(
arg_name
,
arg_value
,
value
,
rel
):
"""Judging integer type."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_match
=
isinstance
(
arg_value
,
int
)
if
type_match
and
(
not
rel_fn
(
arg_value
,
value
)):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_param_equal
(
param1_name
,
param1_value
,
param2_name
,
param2_value
):
"""Judging the equality of parameters."""
if
param1_value
!=
param2_value
:
raise
ValueError
(
f
"`
{
param1_name
}
` must equal `
{
param2_name
}
`,"
f
" but got `
{
param1_name
}
` =
{
param1_value
}
,"
f
" `
{
param2_name
}
` =
{
param2_value
}
."
)
@
staticmethod
def
check_const_input
(
arg_name
,
arg_value
):
"""Check valid value."""
if
arg_value
is
None
:
raise
ValueError
(
f
'The `
{
arg_name
}
` must be a const input, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_float_positive
(
arg_name
,
arg_value
):
"""Float type judgment."""
if
isinstance
(
arg_value
,
float
):
if
arg_value
>
0
:
return
arg_value
raise
ValueError
(
f
"The `
{
arg_name
}
` must be positive, but got
{
arg_value
}
."
)
raise
TypeError
(
f
"`
{
arg_name
}
` must be float!"
)
@
staticmethod
def
check_pad_value_by_mode
(
op_name
,
pad_mode
,
padding
):
"""Validate value of padding according to pad_mode"""
if
pad_mode
!=
'pad'
and
padding
!=
0
:
raise
ValueError
(
f
"For op '
{
op_name
}
', padding must be zero when pad_mode is '
{
pad_mode
}
'."
)
return
padding
@
staticmethod
def
check_empty_shape_input
(
arg_name
,
arg_value
):
"""Check zeros value."""
if
0
in
arg_value
:
raise
ValueError
(
f
"Input `
{
arg_name
}
` cannot be empty."
)
@
staticmethod
def
check_scalar_shape_input
(
arg_name
,
arg_value
):
"""Check scalar shape input."""
if
arg_value
!=
[]:
raise
ValueError
(
f
"Input `
{
arg_name
}
` shape should be (). got
{
arg_value
}
"
)
def
check_int
(
input_param
):
"""Int type judgment."""
if
isinstance
(
input_param
,
int
)
and
not
isinstance
(
input_param
,
bool
):
...
...
@@ -653,30 +435,6 @@ def check_output_data(data):
raise
RuntimeError
(
'Executor return data '
+
str
(
data
)
+
', please check your net or input data.'
)
def
check_axis_type_int
(
axis
):
"""Check axis type."""
if
not
isinstance
(
axis
,
int
):
raise
TypeError
(
'Wrong type for axis, should be int.'
)
def
check_axis_range
(
axis
,
rank
):
"""Check axis range."""
if
not
-
rank
<=
axis
<
rank
:
raise
ValueError
(
'The axis should be in range [{}, {}),'' but got {}.'
.
format
(
-
rank
,
rank
,
axis
))
def
check_attr_int
(
attr_name
,
attr
):
"""Check int type."""
if
not
isinstance
(
attr
,
int
):
raise
TypeError
(
"The attr {} should be int, but got {}."
.
format
(
attr_name
,
type
(
attr
)))
def
check_t_in_range
(
t
):
"""Check input range."""
if
t
not
in
(
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
,
mstype
.
int32
,
mstype
.
int64
):
raise
ValueError
(
"The param T should be (float16, float32, float64, int32, int64)."
)
once
=
_expand_tuple
(
1
)
twice
=
_expand_tuple
(
2
)
triple
=
_expand_tuple
(
3
)
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.cc
浏览文件 @
20782294
...
...
@@ -175,7 +175,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
UpdateAdjoint
(
node_adjoint
);
anfnode_to_adjoin_
[
morph
]
=
node_adjoint
;
if
(
cnode_morph
->
stop_gradient
())
{
MS_LOG
(
WARNIN
G
)
<<
"MapMorphism node "
<<
morph
->
ToString
()
<<
" is stopped."
;
MS_LOG
(
DEBU
G
)
<<
"MapMorphism node "
<<
morph
->
ToString
()
<<
" is stopped."
;
return
node_adjoint
;
}
...
...
mindspore/nn/layer/pooling.py
浏览文件 @
20782294
...
...
@@ -19,7 +19,6 @@ from mindspore._checkparam import Validator as validator
from
...
import
context
from
..cell
import
Cell
from
..._checkparam
import
Rel
from
..._checkparam
import
ParamValidator
class
_PoolNd
(
Cell
):
...
...
@@ -265,11 +264,11 @@ class AvgPool1d(_PoolNd):
stride
=
1
,
pad_mode
=
"valid"
):
super
(
AvgPool1d
,
self
).
__init__
(
kernel_size
,
stride
,
pad_mode
)
ParamValidator
.
check_type
(
'kernel_size'
,
kernel_size
,
[
int
,]
)
ParamValidator
.
check_type
(
'stride'
,
stride
,
[
int
,]
)
self
.
pad_mode
=
ParamValidator
.
check_string
(
'pad_mode'
,
pad_mode
.
upper
(),
[
'VALID'
,
'SAME'
]
)
ParamValidator
.
check_integer
(
"kernel_size"
,
kernel_size
,
1
,
Rel
.
GE
)
ParamValidator
.
check_integer
(
"stride"
,
stride
,
1
,
Rel
.
GE
)
validator
.
check_value_type
(
'kernel_size'
,
kernel_size
,
[
int
],
self
.
cls_name
)
validator
.
check_value_type
(
'stride'
,
stride
,
[
int
],
self
.
cls_name
)
self
.
pad_mode
=
validator
.
check_string
(
'pad_mode'
,
pad_mode
.
upper
(),
[
'VALID'
,
'SAME'
],
self
.
cls_name
)
validator
.
check_integer
(
"kernel_size"
,
kernel_size
,
1
,
Rel
.
GE
,
self
.
cls_name
)
validator
.
check_integer
(
"stride"
,
stride
,
1
,
Rel
.
GE
,
self
.
cls_name
)
self
.
kernel_size
=
(
1
,
kernel_size
)
self
.
stride
=
(
1
,
stride
)
self
.
avg_pool
=
P
.
AvgPool
(
ksize
=
self
.
kernel_size
,
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
20782294
此差异已折叠。
点击以展开。
tests/ut/python/ops/test_array_ops_check.py
0 → 100755
浏览文件 @
20782294
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test ops """
import
functools
import
numpy
as
np
from
mindspore
import
ops
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.operations
import
_grad_ops
as
G
import
mindspore.ops.composite
as
C
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common.parameter
import
Parameter
from
..ut_filter
import
non_graph_engine
from
mindspore.common.api
import
_executor
from
....mindspore_test_framework.mindspore_test
import
mindspore_test
from
....mindspore_test_framework.pipeline.forward.compile_forward
\
import
(
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
,
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
)
from
....mindspore_test_framework.pipeline.gradient.compile_gradient
\
import
pipeline_for_compile_grad_ge_graph_for_case_by_case_config
class
ExpandDimsNet
(
nn
.
Cell
):
def
__init__
(
self
,
axis
):
super
(
ExpandDimsNet
,
self
).
__init__
()
self
.
axis
=
axis
self
.
op
=
P
.
ExpandDims
()
def
construct
(
self
,
x
):
return
self
.
op
(
x
,
self
.
axis
)
class
IsInstanceNet
(
nn
.
Cell
):
def
__init__
(
self
,
inst
):
super
(
IsInstanceNet
,
self
).
__init__
()
self
.
inst
=
inst
self
.
op
=
P
.
IsInstance
()
def
construct
(
self
,
t
):
return
self
.
op
(
self
.
inst
,
t
)
class
ReshapeNet
(
nn
.
Cell
):
def
__init__
(
self
,
shape
):
super
(
ReshapeNet
,
self
).
__init__
()
self
.
shape
=
shape
self
.
op
=
P
.
Reshape
()
def
construct
(
self
,
x
):
return
self
.
op
(
x
,
self
.
shape
)
raise_set
=
[
# input is scala, not Tensor
(
'ExpandDims0'
,
{
'block'
:
(
P
.
ExpandDims
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'ExpandDims'
]}),
'desc_inputs'
:
[
5.0
,
1
],
'skip'
:
[
'backward'
]}),
# axis is as a parameter
(
'ExpandDims1'
,
{
'block'
:
(
P
.
ExpandDims
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'ExpandDims'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
1
],
'skip'
:
[
'backward'
]}),
# axis as an attribute, but less then lower limit
(
'ExpandDims2'
,
{
'block'
:
(
ExpandDimsNet
(
-
4
),
{
'exception'
:
ValueError
,
'error_keywords'
:
[
'ExpandDims'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
# axis as an attribute, but greater then upper limit
(
'ExpandDims3'
,
{
'block'
:
(
ExpandDimsNet
(
3
),
{
'exception'
:
ValueError
,
'error_keywords'
:
[
'ExpandDims'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
# input is scala, not Tensor
(
'DType0'
,
{
'block'
:
(
P
.
DType
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'DType'
]}),
'desc_inputs'
:
[
5.0
],
'skip'
:
[
'backward'
]}),
# input x scala, not Tensor
(
'SameTypeShape0'
,
{
'block'
:
(
P
.
SameTypeShape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'SameTypeShape'
]}),
'desc_inputs'
:
[
5.0
,
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
# input y scala, not Tensor
(
'SameTypeShape1'
,
{
'block'
:
(
P
.
SameTypeShape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'SameTypeShape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
5.0
],
'skip'
:
[
'backward'
]}),
# type of x and y not match
(
'SameTypeShape2'
,
{
'block'
:
(
P
.
SameTypeShape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'SameTypeShape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
int32
))],
'skip'
:
[
'backward'
]}),
# shape of x and y not match
(
'SameTypeShape3'
,
{
'block'
:
(
P
.
SameTypeShape
(),
{
'exception'
:
ValueError
,
'error_keywords'
:
[
'SameTypeShape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
Tensor
(
np
.
ones
([
3
,
3
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
# sub_type is None
(
'IsSubClass0'
,
{
'block'
:
(
P
.
IsSubClass
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'IsSubClass'
]}),
'desc_inputs'
:
[
None
,
mstype
.
number
],
'skip'
:
[
'backward'
]}),
# type_ is None
(
'IsSubClass1'
,
{
'block'
:
(
P
.
IsSubClass
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'IsSubClass'
]}),
'desc_inputs'
:
[
mstype
.
number
,
None
],
'skip'
:
[
'backward'
]}),
# inst is var
(
'IsInstance0'
,
{
'block'
:
(
P
.
IsInstance
(),
{
'exception'
:
ValueError
,
'error_keywords'
:
[
'IsInstance'
]}),
'desc_inputs'
:
[
5.0
,
mstype
.
number
],
'skip'
:
[
'backward'
]}),
# t is not mstype.Type
(
'IsInstance1'
,
{
'block'
:
(
IsInstanceNet
(
5.0
),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'IsInstance'
]}),
'desc_inputs'
:
[
None
],
'skip'
:
[
'backward'
]}),
# input x is scalar, not Tensor
(
'Reshape0'
,
{
'block'
:
(
P
.
Reshape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'Reshape'
]}),
'desc_inputs'
:
[
5.0
,
(
1
,
2
)],
'skip'
:
[
'backward'
]}),
# input shape is var
(
'Reshape1'
,
{
'block'
:
(
P
.
Reshape
(),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'Reshape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
)),
(
2
,
3
,
2
)],
'skip'
:
[
'backward'
]}),
# element of shape is not int
(
'Reshape3'
,
{
'block'
:
(
ReshapeNet
((
2
,
3.0
,
2
)),
{
'exception'
:
TypeError
,
'error_keywords'
:
[
'Reshape'
]}),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
3
,
4
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
]
@
mindspore_test
(
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
)
def
test_check_exception
():
return
raise_set
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
20782294
...
...
@@ -383,7 +383,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
net
=
NetWork
()
with
pytest
.
raises
(
ValueError
)
as
ex
:
net
(
input_tensor
)
assert
"
The `begin[0]` should be an int and must greater or equal to -6, but got -7
"
in
str
(
ex
.
value
)
assert
"
For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`
"
in
str
(
ex
.
value
)
def
test_tensor_slice_reduce_out_of_bounds_positive
():
...
...
@@ -400,4 +400,4 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
net
=
NetWork
()
with
pytest
.
raises
(
ValueError
)
as
ex
:
net
(
input_tensor
)
assert
"
The `begin[0]` should be an int and must less than 6, but got 6
"
in
str
(
ex
.
value
)
assert
"
For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`
"
in
str
(
ex
.
value
)
tests/vm_impl/vm_me.py
浏览文件 @
20782294
...
...
@@ -16,7 +16,7 @@
import
numpy
as
np
from
mindspore._checkparam
import
Rel
from
mindspore._checkparam
import
Param
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
def
avg_pooling
(
x
,
pool_h
,
pool_w
,
stride
):
...
...
@@ -32,7 +32,7 @@ def avg_pooling(x, pool_h, pool_w, stride):
Returns:
numpy.ndarray, an output array after applying average pooling on input array.
"""
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
,
None
)
num
,
channel
,
height
,
width
=
x
.
shape
out_h
=
(
height
-
pool_h
)
//
stride
+
1
out_w
=
(
width
-
pool_w
)
//
stride
+
1
...
...
@@ -217,7 +217,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
dilation
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
):
"""Convolution 2D."""
# pylint: disable=unused-argument
validator
.
check_
type
(
'stride'
,
stride
,
(
int
,
tuple
)
)
validator
.
check_
value_type
(
'stride'
,
stride
,
(
int
,
tuple
),
None
)
if
isinstance
(
stride
,
int
):
stride
=
(
stride
,
stride
)
elif
len
(
stride
)
==
4
:
...
...
@@ -229,7 +229,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
f
"a tuple of two positive int numbers, but got
{
stride
}
"
)
stride_h
=
stride
[
0
]
stride_w
=
stride
[
1
]
validator
.
check_
type
(
'dilation'
,
dilation
,
(
int
,
tuple
)
)
validator
.
check_
value_type
(
'dilation'
,
dilation
,
(
int
,
tuple
),
None
)
if
isinstance
(
dilation
,
int
):
dilation
=
(
dilation
,
dilation
)
elif
len
(
dilation
)
==
4
:
...
...
@@ -384,7 +384,7 @@ def matmul(x, w, b=None):
def
max_pooling
(
x
,
pool_h
,
pool_w
,
stride
):
"""Max pooling."""
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
,
None
)
num
,
channel
,
height
,
width
=
x
.
shape
out_h
=
(
height
-
pool_h
)
//
stride
+
1
out_w
=
(
width
-
pool_w
)
//
stride
+
1
...
...
@@ -427,7 +427,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
def
max_pool_with_argmax
(
x
,
pool_h
,
pool_w
,
stride
):
"""Max pooling with argmax."""
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
)
validator
.
check_integer
(
"stride"
,
stride
,
0
,
Rel
.
GT
,
None
)
num
,
channel
,
height
,
width
=
x
.
shape
out_h
=
(
height
-
pool_h
)
//
stride
+
1
out_w
=
(
width
-
pool_w
)
//
stride
+
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录