Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
05b2a57d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
05b2a57d
编写于
7月 10, 2020
作者:
N
nhussain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix validation errors, and fix try catch error tests
上级
089623ad
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
102 addition
and
55 deletion
+102
-55
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+6
-6
mindspore/dataset/text/validators.py
mindspore/dataset/text/validators.py
+2
-1
tests/ut/python/dataset/test_bucket_batch_by_length.py
tests/ut/python/dataset/test_bucket_batch_by_length.py
+6
-1
tests/ut/python/dataset/test_concatenate_op.py
tests/ut/python/dataset/test_concatenate_op.py
+5
-5
tests/ut/python/dataset/test_dataset_numpy_slices.py
tests/ut/python/dataset/test_dataset_numpy_slices.py
+34
-2
tests/ut/python/dataset/test_fill_op.py
tests/ut/python/dataset/test_fill_op.py
+3
-3
tests/ut/python/dataset/test_minddataset_exception.py
tests/ut/python/dataset/test_minddataset_exception.py
+4
-4
tests/ut/python/dataset/test_nlp.py
tests/ut/python/dataset/test_nlp.py
+20
-1
tests/ut/python/dataset/test_sync_wait.py
tests/ut/python/dataset/test_sync_wait.py
+16
-26
tests/ut/python/dataset/test_uniform_augment.py
tests/ut/python/dataset/test_uniform_augment.py
+6
-6
未找到文件。
mindspore/dataset/engine/validators.py
浏览文件 @
05b2a57d
...
...
@@ -25,7 +25,7 @@ from mindspore._c_expression import typing
from
..core.validator_helpers
import
parse_user_args
,
type_check
,
type_check_list
,
check_value
,
\
INT32_MAX
,
check_valid_detype
,
check_dir
,
check_file
,
check_sampler_shuffle_shard_options
,
\
validate_dataset_param_value
,
check_padding_options
,
check_gnn_list_or_ndarray
,
check_num_parallel_workers
,
\
check_columns
,
check_pos
itive
,
check_pos
_int32
check_columns
,
check_pos_int32
from
.
import
datasets
from
.
import
samplers
...
...
@@ -319,10 +319,9 @@ def check_generatordataset(method):
# These two parameters appear together.
raise
ValueError
(
"num_shards and shard_id need to be passed in together"
)
if
num_shards
is
not
None
:
type_check
(
num_shards
,
(
int
,),
"num_shards"
)
check_positive
(
num_shards
,
"num_shards"
)
check_pos_int32
(
num_shards
,
"num_shards"
)
if
shard_id
>=
num_shards
:
raise
ValueError
(
"shard_id should be less than num_shards"
)
raise
ValueError
(
"shard_id should be less than num_shards
.
"
)
sampler
=
param_dict
.
get
(
"sampler"
)
if
sampler
is
not
None
:
...
...
@@ -417,7 +416,7 @@ def check_bucket_batch_by_length(method):
all_non_negative
=
all
(
item
>
0
for
item
in
bucket_boundaries
)
if
not
all_non_negative
:
raise
ValueError
(
"bucket_boundaries
cannot contain any nega
tive numbers."
)
raise
ValueError
(
"bucket_boundaries
must only contain posi
tive numbers."
)
for
i
in
range
(
len
(
bucket_boundaries
)
-
1
):
if
not
bucket_boundaries
[
i
+
1
]
>
bucket_boundaries
[
i
]:
...
...
@@ -1044,7 +1043,8 @@ def check_numpyslicesdataset(method):
data
=
param_dict
.
get
(
"data"
)
column_names
=
param_dict
.
get
(
"column_names"
)
if
not
data
:
raise
ValueError
(
"Argument data cannot be empty"
)
type_check
(
data
,
(
list
,
tuple
,
dict
,
np
.
ndarray
),
"data"
)
if
isinstance
(
data
,
tuple
):
type_check
(
data
[
0
],
(
list
,
np
.
ndarray
),
"data[0]"
)
...
...
mindspore/dataset/text/validators.py
浏览文件 @
05b2a57d
...
...
@@ -62,6 +62,7 @@ def check_from_file(method):
def
new_method
(
self
,
*
args
,
**
kwargs
):
[
file_path
,
delimiter
,
vocab_size
,
special_tokens
,
special_first
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
if
special_tokens
is
not
None
:
check_unique_list_of_words
(
special_tokens
,
"special_tokens"
)
type_check_list
([
file_path
,
delimiter
],
(
str
,),
[
"file_path"
,
"delimiter"
])
if
vocab_size
is
not
None
:
...
...
tests/ut/python/dataset/test_bucket_batch_by_length.py
浏览文件 @
05b2a57d
...
...
@@ -45,6 +45,7 @@ def test_bucket_batch_invalid_input():
bucket_boundaries
=
[
1
,
2
,
3
]
empty_bucket_boundaries
=
[]
invalid_bucket_boundaries
=
[
"1"
,
"2"
,
"3"
]
zero_start_bucket_boundaries
=
[
0
,
2
,
3
]
negative_bucket_boundaries
=
[
1
,
2
,
-
3
]
decreasing_bucket_boundaries
=
[
3
,
2
,
1
]
non_increasing_bucket_boundaries
=
[
1
,
2
,
2
]
...
...
@@ -69,9 +70,13 @@ def test_bucket_batch_invalid_input():
_
=
dataset
.
bucket_batch_by_length
(
column_names
,
invalid_bucket_boundaries
,
bucket_batch_sizes
)
assert
"bucket_boundaries should be a list of int"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
_
=
dataset
.
bucket_batch_by_length
(
column_names
,
zero_start_bucket_boundaries
,
bucket_batch_sizes
)
assert
"bucket_boundaries must only contain positive numbers."
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
_
=
dataset
.
bucket_batch_by_length
(
column_names
,
negative_bucket_boundaries
,
bucket_batch_sizes
)
assert
"bucket_boundaries
cannot contain any negative numbers
"
in
str
(
info
.
value
)
assert
"bucket_boundaries
must only contain positive numbers.
"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
_
=
dataset
.
bucket_batch_by_length
(
column_names
,
decreasing_bucket_boundaries
,
bucket_batch_sizes
)
...
...
tests/ut/python/dataset/test_concatenate_op.py
浏览文件 @
05b2a57d
...
...
@@ -108,7 +108,7 @@ def test_concatenate_op_type_mismatch():
with
pytest
.
raises
(
RuntimeError
)
as
error_info
:
for
_
in
data
:
pass
assert
"Tensor types do not match"
in
rep
r
(
error_info
.
value
)
assert
"Tensor types do not match"
in
st
r
(
error_info
.
value
)
def
test_concatenate_op_type_mismatch2
():
...
...
@@ -123,7 +123,7 @@ def test_concatenate_op_type_mismatch2():
with
pytest
.
raises
(
RuntimeError
)
as
error_info
:
for
_
in
data
:
pass
assert
"Tensor types do not match"
in
rep
r
(
error_info
.
value
)
assert
"Tensor types do not match"
in
st
r
(
error_info
.
value
)
def
test_concatenate_op_incorrect_dim
():
...
...
@@ -138,13 +138,13 @@ def test_concatenate_op_incorrect_dim():
with
pytest
.
raises
(
RuntimeError
)
as
error_info
:
for
_
in
data
:
pass
assert
"Only 1D tensors supported"
in
rep
r
(
error_info
.
value
)
assert
"Only 1D tensors supported"
in
st
r
(
error_info
.
value
)
def
test_concatenate_op_wrong_axis
():
with
pytest
.
raises
(
ValueError
)
as
error_info
:
data_trans
.
Concatenate
(
2
)
assert
"only 1D concatenation supported."
in
rep
r
(
error_info
.
value
)
assert
"only 1D concatenation supported."
in
st
r
(
error_info
.
value
)
def
test_concatenate_op_negative_axis
():
...
...
@@ -167,7 +167,7 @@ def test_concatenate_op_incorrect_input_dim():
with
pytest
.
raises
(
ValueError
)
as
error_info
:
data_trans
.
Concatenate
(
0
,
prepend_tensor
)
assert
"can only prepend 1D arrays."
in
rep
r
(
error_info
.
value
)
assert
"can only prepend 1D arrays."
in
st
r
(
error_info
.
value
)
if
__name__
==
"__main__"
:
...
...
tests/ut/python/dataset/test_dataset_numpy_slices.py
浏览文件 @
05b2a57d
...
...
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
sys
import
pytest
import
numpy
as
np
import
pandas
as
pd
import
mindspore.dataset
as
de
from
mindspore
import
log
as
logger
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
pandas
as
pd
def
test_numpy_slices_list_1
():
...
...
@@ -173,6 +174,25 @@ def test_numpy_slices_distributed_sampler():
assert
sum
([
1
for
_
in
ds
])
==
2
def
test_numpy_slices_distributed_shard_limit
():
logger
.
info
(
"Test Slicing a 1D list."
)
np_data
=
[
1
,
2
,
3
]
num
=
sys
.
maxsize
with
pytest
.
raises
(
ValueError
)
as
err
:
de
.
NumpySlicesDataset
(
np_data
,
num_shards
=
num
,
shard_id
=
0
,
shuffle
=
False
)
assert
"Input num_shards is not within the required interval of (1 to 2147483647)."
in
str
(
err
.
value
)
def
test_numpy_slices_distributed_zero_shard
():
logger
.
info
(
"Test Slicing a 1D list."
)
np_data
=
[
1
,
2
,
3
]
with
pytest
.
raises
(
ValueError
)
as
err
:
de
.
NumpySlicesDataset
(
np_data
,
num_shards
=
0
,
shard_id
=
0
,
shuffle
=
False
)
assert
"Input num_shards is not within the required interval of (1 to 2147483647)."
in
str
(
err
.
value
)
def
test_numpy_slices_sequential_sampler
():
logger
.
info
(
"Test numpy_slices_dataset with SequentialSampler and repeat."
)
...
...
@@ -210,6 +230,15 @@ def test_numpy_slices_invalid_empty_column_names():
assert
"column_names should not be empty"
in
str
(
err
.
value
)
def
test_numpy_slices_invalid_empty_data_column
():
logger
.
info
(
"Test incorrect column_names input"
)
np_data
=
[]
with
pytest
.
raises
(
ValueError
)
as
err
:
de
.
NumpySlicesDataset
(
np_data
,
shuffle
=
False
)
assert
"Argument data cannot be empty"
in
str
(
err
.
value
)
if
__name__
==
"__main__"
:
test_numpy_slices_list_1
()
test_numpy_slices_list_2
()
...
...
@@ -223,7 +252,10 @@ if __name__ == "__main__":
test_numpy_slices_csv_dict
()
test_numpy_slices_num_samplers
()
test_numpy_slices_distributed_sampler
()
test_numpy_slices_distributed_shard_limit
()
test_numpy_slices_distributed_zero_shard
()
test_numpy_slices_sequential_sampler
()
test_numpy_slices_invalid_column_names_type
()
test_numpy_slices_invalid_column_names_string
()
test_numpy_slices_invalid_empty_column_names
()
test_numpy_slices_invalid_empty_data_column
()
tests/ut/python/dataset/test_fill_op.py
浏览文件 @
05b2a57d
...
...
@@ -82,9 +82,9 @@ def test_fillop_error_handling():
data
=
data
.
map
(
input_columns
=
[
"col"
],
operations
=
fill_op
)
with
pytest
.
raises
(
RuntimeError
)
as
error_info
:
for
data_row
in
data
:
p
rint
(
data_row
)
assert
"Types do not match"
in
rep
r
(
error_info
.
value
)
for
_
in
data
:
p
ass
assert
"Types do not match"
in
st
r
(
error_info
.
value
)
if
__name__
==
"__main__"
:
...
...
tests/ut/python/dataset/test_minddataset_exception.py
浏览文件 @
05b2a57d
...
...
@@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards():
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
assert
'Input shard_id is not within the required interval of (0 to 0).'
in
rep
r
(
error_info
)
assert
'Input shard_id is not within the required interval of (0 to 0).'
in
st
r
(
error_info
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
...
...
@@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id():
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
assert
'Input shard_id is not within the required interval of (0 to 0).'
in
rep
r
(
error_info
)
assert
'Input shard_id is not within the required interval of (0 to 0).'
in
st
r
(
error_info
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
...
...
@@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard():
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
assert
'Input shard_id is not within the required interval of (0 to 1).'
in
rep
r
(
error_info
)
assert
'Input shard_id is not within the required interval of (0 to 1).'
in
st
r
(
error_info
)
with
pytest
.
raises
(
Exception
)
as
error_info
:
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
True
,
2
,
5
)
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
assert
'Input shard_id is not within the required interval of (0 to 1).'
in
rep
r
(
error_info
)
assert
'Input shard_id is not within the required interval of (0 to 1).'
in
st
r
(
error_info
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
tests/ut/python/dataset/test_nlp.py
浏览文件 @
05b2a57d
...
...
@@ -39,8 +39,27 @@ def test_on_tokenized_line():
res
=
np
.
array
([[
10
,
1
,
11
,
1
,
12
,
1
,
15
,
1
,
13
,
1
,
14
],
[
11
,
1
,
12
,
1
,
10
,
1
,
14
,
1
,
13
,
1
,
15
]],
dtype
=
np
.
int32
)
for
i
,
d
in
enumerate
(
data
.
create_dict_iterator
()):
_
=
(
np
.
testing
.
assert_array_equal
(
d
[
"text"
],
res
[
i
]),
i
)
np
.
testing
.
assert_array_equal
(
d
[
"text"
],
res
[
i
])
def
test_on_tokenized_line_with_no_special_tokens
():
data
=
ds
.
TextFileDataset
(
"../data/dataset/testVocab/lines.txt"
,
shuffle
=
False
)
jieba_op
=
text
.
JiebaTokenizer
(
HMM_FILE
,
MP_FILE
,
mode
=
text
.
JiebaMode
.
MP
)
with
open
(
VOCAB_FILE
,
'r'
)
as
f
:
for
line
in
f
:
word
=
line
.
split
(
','
)[
0
]
jieba_op
.
add_word
(
word
)
data
=
data
.
map
(
input_columns
=
[
"text"
],
operations
=
jieba_op
)
vocab
=
text
.
Vocab
.
from_file
(
VOCAB_FILE
,
","
)
lookup
=
text
.
Lookup
(
vocab
,
"not"
)
data
=
data
.
map
(
input_columns
=
[
"text"
],
operations
=
lookup
)
res
=
np
.
array
([[
8
,
0
,
9
,
0
,
10
,
0
,
13
,
0
,
11
,
0
,
12
],
[
9
,
0
,
10
,
0
,
8
,
0
,
12
,
0
,
11
,
0
,
13
]],
dtype
=
np
.
int32
)
for
i
,
d
in
enumerate
(
data
.
create_dict_iterator
()):
np
.
testing
.
assert_array_equal
(
d
[
"text"
],
res
[
i
])
if
__name__
==
'__main__'
:
test_on_tokenized_line
()
test_on_tokenized_line_with_no_special_tokens
()
tests/ut/python/dataset/test_sync_wait.py
浏览文件 @
05b2a57d
...
...
@@ -14,7 +14,7 @@
# ==============================================================================
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
...
...
@@ -163,7 +163,6 @@ def test_sync_exception_01():
"""
logger
.
info
(
"test_sync_exception_01"
)
shuffle_size
=
4
batch_size
=
10
dataset
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"input"
])
...
...
@@ -171,11 +170,9 @@ def test_sync_exception_01():
dataset
=
dataset
.
sync_wait
(
condition_name
=
"policy"
,
callback
=
aug
.
update
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
try
:
dataset
=
dataset
.
shuffle
(
shuffle_size
)
except
Exception
as
e
:
assert
"shuffle"
in
str
(
e
)
dataset
=
dataset
.
batch
(
batch_size
)
with
pytest
.
raises
(
RuntimeError
)
as
e
:
dataset
.
shuffle
(
shuffle_size
)
assert
"No shuffle after sync operators"
in
str
(
e
.
value
)
def
test_sync_exception_02
():
...
...
@@ -183,7 +180,6 @@ def test_sync_exception_02():
Test sync: with duplicated condition name
"""
logger
.
info
(
"test_sync_exception_02"
)
batch_size
=
6
dataset
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"input"
])
...
...
@@ -192,11 +188,9 @@ def test_sync_exception_02():
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
try
:
dataset
=
dataset
.
sync_wait
(
num_batch
=
2
,
condition_name
=
"every batch"
)
except
Exception
as
e
:
assert
"name"
in
str
(
e
)
dataset
=
dataset
.
batch
(
batch_size
)
with
pytest
.
raises
(
RuntimeError
)
as
e
:
dataset
.
sync_wait
(
num_batch
=
2
,
condition_name
=
"every batch"
)
assert
"Condition name is already in use"
in
str
(
e
.
value
)
def
test_sync_exception_03
():
...
...
@@ -209,12 +203,9 @@ def test_sync_exception_03():
aug
=
Augment
(
0
)
# try to create dataset with batch_size < 0
try
:
dataset
=
dataset
.
sync_wait
(
condition_name
=
"every batch"
,
num_batch
=-
1
,
callback
=
aug
.
update
)
except
Exception
as
e
:
assert
"num_batch"
in
str
(
e
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
with
pytest
.
raises
(
ValueError
)
as
e
:
dataset
.
sync_wait
(
condition_name
=
"every batch"
,
num_batch
=-
1
,
callback
=
aug
.
update
)
assert
"num_batch need to be greater than 0."
in
str
(
e
.
value
)
def
test_sync_exception_04
():
...
...
@@ -230,14 +221,13 @@ def test_sync_exception_04():
dataset
=
dataset
.
sync_wait
(
condition_name
=
"every batch"
,
callback
=
aug
.
update
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
count
=
0
try
:
with
pytest
.
raises
(
RuntimeError
)
as
e
:
for
_
in
dataset
.
create_dict_iterator
():
count
+=
1
data
=
{
"loss"
:
count
}
# dataset.disable_sync()
dataset
.
sync_update
(
condition_name
=
"every batch"
,
num_batch
=-
1
,
data
=
data
)
except
Exception
as
e
:
assert
"batch"
in
str
(
e
)
assert
"Sync_update batch size can only be positive"
in
str
(
e
.
value
)
def
test_sync_exception_05
():
"""
...
...
@@ -251,15 +241,15 @@ def test_sync_exception_05():
# try to create dataset with batch_size < 0
dataset
=
dataset
.
sync_wait
(
condition_name
=
"every batch"
,
callback
=
aug
.
update
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
try
:
with
pytest
.
raises
(
RuntimeError
)
as
e
:
for
_
in
dataset
.
create_dict_iterator
():
dataset
.
disable_sync
()
count
+=
1
data
=
{
"loss"
:
count
}
dataset
.
disable_sync
()
dataset
.
sync_update
(
condition_name
=
"every"
,
data
=
data
)
except
Exception
as
e
:
assert
"name"
in
str
(
e
)
assert
"Condition name not found"
in
str
(
e
.
value
)
if
__name__
==
"__main__"
:
test_simple_sync_wait
()
...
...
tests/ut/python/dataset/test_uniform_augment.py
浏览文件 @
05b2a57d
...
...
@@ -16,6 +16,7 @@
Testing UniformAugment in DE
"""
import
numpy
as
np
import
pytest
import
mindspore.dataset.engine
as
de
import
mindspore.dataset.transforms.vision.c_transforms
as
C
...
...
@@ -164,14 +165,13 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
C
.
RandomRotation
(
degrees
=
45
),
F
.
Invert
()]
try
:
with
pytest
.
raises
(
TypeError
)
as
e
:
_
=
C
.
UniformAugment
(
operations
=
transforms_ua
,
num_ops
=
num_ops
)
except
Exception
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Argument tensor_op_5 with value"
\
" <mindspore.dataset.transforms.vision.py_transforms.Invert"
in
str
(
e
)
assert
"is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)"
in
str
(
e
)
" <mindspore.dataset.transforms.vision.py_transforms.Invert"
in
str
(
e
.
valu
e
)
assert
"is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)"
in
str
(
e
.
valu
e
)
def
test_cpp_uniform_augment_exception_large_numops
(
num_ops
=
6
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录