Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
bccfa485
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bccfa485
编写于
7月 10, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2980 Prevent empty column names
Merge pull request !2980 from nhussain/empty_column_b
上级
9daeeb5a
2c7fd248
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
75 addition
and
36 deletion
+75
-36
mindspore/dataset/core/validator_helpers.py
mindspore/dataset/core/validator_helpers.py
+42
-28
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+1
-6
tests/ut/python/dataset/test_bucket_batch_by_length.py
tests/ut/python/dataset/test_bucket_batch_by_length.py
+1
-1
tests/ut/python/dataset/test_dataset_numpy_slices.py
tests/ut/python/dataset/test_dataset_numpy_slices.py
+31
-1
未找到文件。
mindspore/dataset/core/validator_helpers.py
浏览文件 @
bccfa485
...
@@ -123,25 +123,39 @@ def check_valid_detype(type_):
...
@@ -123,25 +123,39 @@ def check_valid_detype(type_):
def
check_columns
(
columns
,
name
):
def
check_columns
(
columns
,
name
):
"""
Validate strings in column_names.
Args:
columns (list): list of column_names.
name (str): name of columns.
Returns:
Exception: when the value is not correct, otherwise nothing.
"""
type_check
(
columns
,
(
list
,
str
),
name
)
type_check
(
columns
,
(
list
,
str
),
name
)
if
isinstance
(
columns
,
list
):
if
isinstance
(
columns
,
list
):
if
not
columns
:
if
not
columns
:
raise
ValueError
(
"Column names should not be empty"
)
raise
ValueError
(
"{0} should not be empty"
.
format
(
name
))
col_names
=
[
"col_{0}"
.
format
(
i
)
for
i
in
range
(
len
(
columns
))]
for
i
,
column_name
in
enumerate
(
columns
):
if
not
column_name
:
raise
ValueError
(
"{0}[{1}] should not be empty"
.
format
(
name
,
i
))
col_names
=
[
"{0}[{1}]"
.
format
(
name
,
i
)
for
i
in
range
(
len
(
columns
))]
type_check_list
(
columns
,
(
str
,),
col_names
)
type_check_list
(
columns
,
(
str
,),
col_names
)
def
parse_user_args
(
method
,
*
args
,
**
kwargs
):
def
parse_user_args
(
method
,
*
args
,
**
kwargs
):
"""
"""
Parse user arguments in a function
Parse user arguments in a function
.
Args:
Args:
method (method): a callable function
method (method): a callable function
.
*args: user passed args
*args: user passed args
.
**kwargs: user passed kwargs
**kwargs: user passed kwargs
.
Returns:
Returns:
user_filled_args (list): values of what the user passed in for the arguments
,
user_filled_args (list): values of what the user passed in for the arguments
.
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
"""
"""
sig
=
inspect
.
signature
(
method
)
sig
=
inspect
.
signature
(
method
)
...
@@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs):
...
@@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs):
def
type_check_list
(
args
,
types
,
arg_names
):
def
type_check_list
(
args
,
types
,
arg_names
):
"""
"""
Check the type of each parameter in the list
Check the type of each parameter in the list
.
Args:
Args:
args (list, tuple): a list or tuple of any variable
args (list, tuple): a list or tuple of any variable
.
types (tuple): tuple of all valid types for arg
types (tuple): tuple of all valid types for arg
.
arg_names (list, tuple of str): the names of args
arg_names (list, tuple of str): the names of args
.
Returns:
Returns:
Exception: when the type is not correct, otherwise nothing
Exception: when the type is not correct, otherwise nothing
.
"""
"""
type_check
(
args
,
(
list
,
tuple
,),
arg_names
)
type_check
(
args
,
(
list
,
tuple
,),
arg_names
)
if
len
(
args
)
!=
len
(
arg_names
):
if
len
(
args
)
!=
len
(
arg_names
):
...
@@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names):
...
@@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names):
def
type_check
(
arg
,
types
,
arg_name
):
def
type_check
(
arg
,
types
,
arg_name
):
"""
"""
Check the type of the parameter
Check the type of the parameter
.
Args:
Args:
arg : any variable
arg : any variable
.
types (tuple): tuple of all valid types for arg
types (tuple): tuple of all valid types for arg
.
arg_name (str): the name of arg
arg_name (str): the name of arg
.
Returns:
Returns:
Exception: when the type is not correct, otherwise nothing
Exception: when the type is not correct, otherwise nothing
.
"""
"""
# handle special case of booleans being a subclass of ints
# handle special case of booleans being a subclass of ints
print_value
=
'
\"\"
'
if
repr
(
arg
)
==
repr
(
''
)
else
arg
print_value
=
'
\"\"
'
if
repr
(
arg
)
==
repr
(
''
)
else
arg
...
@@ -201,13 +215,13 @@ def type_check(arg, types, arg_name):
...
@@ -201,13 +215,13 @@ def type_check(arg, types, arg_name):
def
check_filename
(
path
):
def
check_filename
(
path
):
"""
"""
check the filename in the path
check the filename in the path
.
Args:
Args:
path (str): the path
path (str): the path
.
Returns:
Returns:
Exception: when error
Exception: when error
.
"""
"""
if
not
isinstance
(
path
,
str
):
if
not
isinstance
(
path
,
str
):
raise
TypeError
(
"path: {} is not string"
.
format
(
path
))
raise
TypeError
(
"path: {} is not string"
.
format
(
path
))
...
@@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict):
...
@@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict):
"""
"""
Check for valid shuffle, sampler, num_shards, and shard_id inputs.
Check for valid shuffle, sampler, num_shards, and shard_id inputs.
Args:
Args:
param_dict (dict): param_dict
param_dict (dict): param_dict
.
Returns:
Returns:
Exception: ValueError or RuntimeError if error
Exception: ValueError or RuntimeError if error
.
"""
"""
shuffle
,
sampler
=
param_dict
.
get
(
'shuffle'
),
param_dict
.
get
(
'sampler'
)
shuffle
,
sampler
=
param_dict
.
get
(
'shuffle'
),
param_dict
.
get
(
'sampler'
)
num_shards
,
shard_id
=
param_dict
.
get
(
'num_shards'
),
param_dict
.
get
(
'shard_id'
)
num_shards
,
shard_id
=
param_dict
.
get
(
'num_shards'
),
param_dict
.
get
(
'shard_id'
)
...
@@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict):
...
@@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict):
def
check_padding_options
(
param_dict
):
def
check_padding_options
(
param_dict
):
"""
"""
Check for valid padded_sample and num_padded of padded samples
Check for valid padded_sample and num_padded of padded samples
.
Args:
Args:
param_dict (dict): param_dict
param_dict (dict): param_dict
.
Returns:
Returns:
Exception: ValueError or RuntimeError if error
Exception: ValueError or RuntimeError if error
.
"""
"""
columns_list
=
param_dict
.
get
(
'columns_list'
)
columns_list
=
param_dict
.
get
(
'columns_list'
)
...
@@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name):
...
@@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name):
Check if the input parameter is list or numpy.ndarray.
Check if the input parameter is list or numpy.ndarray.
Args:
Args:
param (list, nd.ndarray): param
param (list, nd.ndarray): param
.
param_name (str): param_name
param_name (str): param_name
.
Returns:
Returns:
Exception: TypeError if error
Exception: TypeError if error
.
"""
"""
type_check
(
param
,
(
list
,
np
.
ndarray
),
param_name
)
type_check
(
param
,
(
list
,
np
.
ndarray
),
param_name
)
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
bccfa485
...
@@ -380,12 +380,7 @@ def check_bucket_batch_by_length(method):
...
@@ -380,12 +380,7 @@ def check_bucket_batch_by_length(method):
type_check_list
([
pad_to_bucket_boundary
,
drop_remainder
],
(
bool
,),
nbool_param_list
)
type_check_list
([
pad_to_bucket_boundary
,
drop_remainder
],
(
bool
,),
nbool_param_list
)
# check column_names: must be list of string.
# check column_names: must be list of string.
if
not
column_names
:
check_columns
(
column_names
,
"column_names"
)
raise
ValueError
(
"column_names cannot be empty"
)
all_string
=
all
(
isinstance
(
item
,
str
)
for
item
in
column_names
)
if
not
all_string
:
raise
TypeError
(
"column_names should be a list of str."
)
if
element_length_function
is
None
and
len
(
column_names
)
!=
1
:
if
element_length_function
is
None
and
len
(
column_names
)
!=
1
:
raise
ValueError
(
"If element_length_function is not specified, exactly one column name should be passed."
)
raise
ValueError
(
"If element_length_function is not specified, exactly one column name should be passed."
)
...
...
tests/ut/python/dataset/test_bucket_batch_by_length.py
浏览文件 @
bccfa485
...
@@ -59,7 +59,7 @@ def test_bucket_batch_invalid_input():
...
@@ -59,7 +59,7 @@ def test_bucket_batch_invalid_input():
with
pytest
.
raises
(
TypeError
)
as
info
:
with
pytest
.
raises
(
TypeError
)
as
info
:
_
=
dataset
.
bucket_batch_by_length
(
invalid_column_names
,
bucket_boundaries
,
bucket_batch_sizes
)
_
=
dataset
.
bucket_batch_by_length
(
invalid_column_names
,
bucket_boundaries
,
bucket_batch_sizes
)
assert
"
column_names should be a list of str
"
in
str
(
info
.
value
)
assert
"
Argument column_names[0] with value 1 is not of type (<class 'str'>,).
"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
with
pytest
.
raises
(
ValueError
)
as
info
:
_
=
dataset
.
bucket_batch_by_length
(
column_names
,
empty_bucket_boundaries
,
bucket_batch_sizes
)
_
=
dataset
.
bucket_batch_by_length
(
column_names
,
empty_bucket_boundaries
,
bucket_batch_sizes
)
...
...
tests/ut/python/dataset/test_dataset_numpy_slices.py
浏览文件 @
bccfa485
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
numpy
as
np
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
de
import
mindspore.dataset
as
de
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
...
@@ -173,7 +174,6 @@ def test_numpy_slices_distributed_sampler():
...
@@ -173,7 +174,6 @@ def test_numpy_slices_distributed_sampler():
def
test_numpy_slices_sequential_sampler
():
def
test_numpy_slices_sequential_sampler
():
logger
.
info
(
"Test numpy_slices_dataset with SequentialSampler and repeat."
)
logger
.
info
(
"Test numpy_slices_dataset with SequentialSampler and repeat."
)
np_data
=
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
],
[
13
,
14
],
[
15
,
16
]]
np_data
=
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
],
[
13
,
14
],
[
15
,
16
]]
...
@@ -183,6 +183,33 @@ def test_numpy_slices_sequential_sampler():
...
@@ -183,6 +183,33 @@ def test_numpy_slices_sequential_sampler():
assert
np
.
equal
(
data
[
0
],
np_data
[
i
%
8
]).
all
()
assert
np
.
equal
(
data
[
0
],
np_data
[
i
%
8
]).
all
()
def
test_numpy_slices_invalid_column_names_type
():
logger
.
info
(
"Test incorrect column_names input"
)
np_data
=
[
1
,
2
,
3
]
with
pytest
.
raises
(
TypeError
)
as
err
:
de
.
NumpySlicesDataset
(
np_data
,
column_names
=
[
1
],
shuffle
=
False
)
assert
"Argument column_names[0] with value 1 is not of type (<class 'str'>,)."
in
str
(
err
.
value
)
def
test_numpy_slices_invalid_column_names_string
():
logger
.
info
(
"Test incorrect column_names input"
)
np_data
=
[
1
,
2
,
3
]
with
pytest
.
raises
(
ValueError
)
as
err
:
de
.
NumpySlicesDataset
(
np_data
,
column_names
=
[
""
],
shuffle
=
False
)
assert
"column_names[0] should not be empty"
in
str
(
err
.
value
)
def
test_numpy_slices_invalid_empty_column_names
():
logger
.
info
(
"Test incorrect column_names input"
)
np_data
=
[
1
,
2
,
3
]
with
pytest
.
raises
(
ValueError
)
as
err
:
de
.
NumpySlicesDataset
(
np_data
,
column_names
=
[],
shuffle
=
False
)
assert
"column_names should not be empty"
in
str
(
err
.
value
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_numpy_slices_list_1
()
test_numpy_slices_list_1
()
test_numpy_slices_list_2
()
test_numpy_slices_list_2
()
...
@@ -197,3 +224,6 @@ if __name__ == "__main__":
...
@@ -197,3 +224,6 @@ if __name__ == "__main__":
test_numpy_slices_num_samplers
()
test_numpy_slices_num_samplers
()
test_numpy_slices_distributed_sampler
()
test_numpy_slices_distributed_sampler
()
test_numpy_slices_sequential_sampler
()
test_numpy_slices_sequential_sampler
()
test_numpy_slices_invalid_column_names_type
()
test_numpy_slices_invalid_column_names_string
()
test_numpy_slices_invalid_empty_column_names
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录