Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
bccfa485
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bccfa485
编写于
7月 10, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2980 Prevent empty column names
Merge pull request !2980 from nhussain/empty_column_b
上级
9daeeb5a
2c7fd248
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
75 addition
and
36 deletion
+75
-36
mindspore/dataset/core/validator_helpers.py
mindspore/dataset/core/validator_helpers.py
+42
-28
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+1
-6
tests/ut/python/dataset/test_bucket_batch_by_length.py
tests/ut/python/dataset/test_bucket_batch_by_length.py
+1
-1
tests/ut/python/dataset/test_dataset_numpy_slices.py
tests/ut/python/dataset/test_dataset_numpy_slices.py
+31
-1
未找到文件。
mindspore/dataset/core/validator_helpers.py
浏览文件 @
bccfa485
...
...
@@ -123,25 +123,39 @@ def check_valid_detype(type_):
def
check_columns
(
columns
,
name
):
"""
Validate strings in column_names.
Args:
columns (list): list of column_names.
name (str): name of columns.
Returns:
Exception: when the value is not correct, otherwise nothing.
"""
type_check
(
columns
,
(
list
,
str
),
name
)
if
isinstance
(
columns
,
list
):
if
not
columns
:
raise
ValueError
(
"Column names should not be empty"
)
col_names
=
[
"col_{0}"
.
format
(
i
)
for
i
in
range
(
len
(
columns
))]
raise
ValueError
(
"{0} should not be empty"
.
format
(
name
))
for
i
,
column_name
in
enumerate
(
columns
):
if
not
column_name
:
raise
ValueError
(
"{0}[{1}] should not be empty"
.
format
(
name
,
i
))
col_names
=
[
"{0}[{1}]"
.
format
(
name
,
i
)
for
i
in
range
(
len
(
columns
))]
type_check_list
(
columns
,
(
str
,),
col_names
)
def
parse_user_args
(
method
,
*
args
,
**
kwargs
):
"""
Parse user arguments in a function
Parse user arguments in a function
.
Args:
method (method): a callable function
*args: user passed args
**kwargs: user passed kwargs
method (method): a callable function
.
*args: user passed args
.
**kwargs: user passed kwargs
.
Returns:
user_filled_args (list): values of what the user passed in for the arguments
,
user_filled_args (list): values of what the user passed in for the arguments
.
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
"""
sig
=
inspect
.
signature
(
method
)
...
...
@@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs):
def
type_check_list
(
args
,
types
,
arg_names
):
"""
Check the type of each parameter in the list
Check the type of each parameter in the list
.
Args:
args (list, tuple): a list or tuple of any variable
types (tuple): tuple of all valid types for arg
arg_names (list, tuple of str): the names of args
args (list, tuple): a list or tuple of any variable
.
types (tuple): tuple of all valid types for arg
.
arg_names (list, tuple of str): the names of args
.
Returns:
Exception: when the type is not correct, otherwise nothing
Exception: when the type is not correct, otherwise nothing
.
"""
type_check
(
args
,
(
list
,
tuple
,),
arg_names
)
if
len
(
args
)
!=
len
(
arg_names
):
...
...
@@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names):
def
type_check
(
arg
,
types
,
arg_name
):
"""
Check the type of the parameter
Check the type of the parameter
.
Args:
arg : any variable
types (tuple): tuple of all valid types for arg
arg_name (str): the name of arg
arg : any variable
.
types (tuple): tuple of all valid types for arg
.
arg_name (str): the name of arg
.
Returns:
Exception: when the type is not correct, otherwise nothing
Exception: when the type is not correct, otherwise nothing
.
"""
# handle special case of booleans being a subclass of ints
print_value
=
'
\"\"
'
if
repr
(
arg
)
==
repr
(
''
)
else
arg
...
...
@@ -201,13 +215,13 @@ def type_check(arg, types, arg_name):
def
check_filename
(
path
):
"""
check the filename in the path
check the filename in the path
.
Args:
path (str): the path
path (str): the path
.
Returns:
Exception: when error
Exception: when error
.
"""
if
not
isinstance
(
path
,
str
):
raise
TypeError
(
"path: {} is not string"
.
format
(
path
))
...
...
@@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict):
"""
Check for valid shuffle, sampler, num_shards, and shard_id inputs.
Args:
param_dict (dict): param_dict
param_dict (dict): param_dict
.
Returns:
Exception: ValueError or RuntimeError if error
Exception: ValueError or RuntimeError if error
.
"""
shuffle
,
sampler
=
param_dict
.
get
(
'shuffle'
),
param_dict
.
get
(
'sampler'
)
num_shards
,
shard_id
=
param_dict
.
get
(
'num_shards'
),
param_dict
.
get
(
'shard_id'
)
...
...
@@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict):
def
check_padding_options
(
param_dict
):
"""
Check for valid padded_sample and num_padded of padded samples
Check for valid padded_sample and num_padded of padded samples
.
Args:
param_dict (dict): param_dict
param_dict (dict): param_dict
.
Returns:
Exception: ValueError or RuntimeError if error
Exception: ValueError or RuntimeError if error
.
"""
columns_list
=
param_dict
.
get
(
'columns_list'
)
...
...
@@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name):
Check if the input parameter is list or numpy.ndarray.
Args:
param (list, nd.ndarray): param
param_name (str): param_name
param (list, nd.ndarray): param
.
param_name (str): param_name
.
Returns:
Exception: TypeError if error
Exception: TypeError if error
.
"""
type_check
(
param
,
(
list
,
np
.
ndarray
),
param_name
)
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
bccfa485
...
...
@@ -380,12 +380,7 @@ def check_bucket_batch_by_length(method):
type_check_list
([
pad_to_bucket_boundary
,
drop_remainder
],
(
bool
,),
nbool_param_list
)
# check column_names: must be list of string.
if
not
column_names
:
raise
ValueError
(
"column_names cannot be empty"
)
all_string
=
all
(
isinstance
(
item
,
str
)
for
item
in
column_names
)
if
not
all_string
:
raise
TypeError
(
"column_names should be a list of str."
)
check_columns
(
column_names
,
"column_names"
)
if
element_length_function
is
None
and
len
(
column_names
)
!=
1
:
raise
ValueError
(
"If element_length_function is not specified, exactly one column name should be passed."
)
...
...
tests/ut/python/dataset/test_bucket_batch_by_length.py
浏览文件 @
bccfa485
...
...
@@ -59,7 +59,7 @@ def test_bucket_batch_invalid_input():
with
pytest
.
raises
(
TypeError
)
as
info
:
_
=
dataset
.
bucket_batch_by_length
(
invalid_column_names
,
bucket_boundaries
,
bucket_batch_sizes
)
assert
"
column_names should be a list of str
"
in
str
(
info
.
value
)
assert
"
Argument column_names[0] with value 1 is not of type (<class 'str'>,).
"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
_
=
dataset
.
bucket_batch_by_length
(
column_names
,
empty_bucket_boundaries
,
bucket_batch_sizes
)
...
...
tests/ut/python/dataset/test_dataset_numpy_slices.py
浏览文件 @
bccfa485
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
de
from
mindspore
import
log
as
logger
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
...
...
@@ -173,7 +174,6 @@ def test_numpy_slices_distributed_sampler():
def
test_numpy_slices_sequential_sampler
():
logger
.
info
(
"Test numpy_slices_dataset with SequentialSampler and repeat."
)
np_data
=
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
],
[
13
,
14
],
[
15
,
16
]]
...
...
@@ -183,6 +183,33 @@ def test_numpy_slices_sequential_sampler():
assert
np
.
equal
(
data
[
0
],
np_data
[
i
%
8
]).
all
()
def
test_numpy_slices_invalid_column_names_type
():
logger
.
info
(
"Test incorrect column_names input"
)
np_data
=
[
1
,
2
,
3
]
with
pytest
.
raises
(
TypeError
)
as
err
:
de
.
NumpySlicesDataset
(
np_data
,
column_names
=
[
1
],
shuffle
=
False
)
assert
"Argument column_names[0] with value 1 is not of type (<class 'str'>,)."
in
str
(
err
.
value
)
def
test_numpy_slices_invalid_column_names_string
():
logger
.
info
(
"Test incorrect column_names input"
)
np_data
=
[
1
,
2
,
3
]
with
pytest
.
raises
(
ValueError
)
as
err
:
de
.
NumpySlicesDataset
(
np_data
,
column_names
=
[
""
],
shuffle
=
False
)
assert
"column_names[0] should not be empty"
in
str
(
err
.
value
)
def
test_numpy_slices_invalid_empty_column_names
():
logger
.
info
(
"Test incorrect column_names input"
)
np_data
=
[
1
,
2
,
3
]
with
pytest
.
raises
(
ValueError
)
as
err
:
de
.
NumpySlicesDataset
(
np_data
,
column_names
=
[],
shuffle
=
False
)
assert
"column_names should not be empty"
in
str
(
err
.
value
)
if
__name__
==
"__main__"
:
test_numpy_slices_list_1
()
test_numpy_slices_list_2
()
...
...
@@ -197,3 +224,6 @@ if __name__ == "__main__":
test_numpy_slices_num_samplers
()
test_numpy_slices_distributed_sampler
()
test_numpy_slices_sequential_sampler
()
test_numpy_slices_invalid_column_names_type
()
test_numpy_slices_invalid_column_names_string
()
test_numpy_slices_invalid_empty_column_names
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录