Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
abca62f4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
abca62f4
编写于
5月 22, 2020
作者:
Y
Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
10:00 26/5 clean pylint
上级
93fc82b8
变更
43
隐藏空白更改
内联
并排
Showing
43 changed file
with
217 addition
and
239 deletion
+217
-239
tests/perf_test/mindrecord/imagenet/perf_read_imagenet.py
tests/perf_test/mindrecord/imagenet/perf_read_imagenet.py
+12
-12
tests/ut/data/dataset/testPyfuncMap/pyfuncmap.py
tests/ut/data/dataset/testPyfuncMap/pyfuncmap.py
+7
-7
tests/ut/python/dataset/prep_data.py
tests/ut/python/dataset/prep_data.py
+6
-6
tests/ut/python/dataset/test_Tensor.py
tests/ut/python/dataset/test_Tensor.py
+2
-2
tests/ut/python/dataset/test_apply.py
tests/ut/python/dataset/test_apply.py
+1
-1
tests/ut/python/dataset/test_cifarop.py
tests/ut/python/dataset/test_cifarop.py
+4
-4
tests/ut/python/dataset/test_config.py
tests/ut/python/dataset/test_config.py
+2
-2
tests/ut/python/dataset/test_datasets_celeba.py
tests/ut/python/dataset/test_datasets_celeba.py
+6
-6
tests/ut/python/dataset/test_datasets_imagefolder.py
tests/ut/python/dataset/test_datasets_imagefolder.py
+21
-21
tests/ut/python/dataset/test_datasets_imagenet.py
tests/ut/python/dataset/test_datasets_imagenet.py
+0
-2
tests/ut/python/dataset/test_datasets_imagenet_distribution.py
.../ut/python/dataset/test_datasets_imagenet_distribution.py
+1
-3
tests/ut/python/dataset/test_datasets_manifestop.py
tests/ut/python/dataset/test_datasets_manifestop.py
+14
-15
tests/ut/python/dataset/test_datasets_sharding.py
tests/ut/python/dataset/test_datasets_sharding.py
+8
-8
tests/ut/python/dataset/test_datasets_textfileop.py
tests/ut/python/dataset/test_datasets_textfileop.py
+12
-12
tests/ut/python/dataset/test_decode.py
tests/ut/python/dataset/test_decode.py
+1
-2
tests/ut/python/dataset/test_filterop.py
tests/ut/python/dataset/test_filterop.py
+6
-8
tests/ut/python/dataset/test_generator.py
tests/ut/python/dataset/test_generator.py
+5
-6
tests/ut/python/dataset/test_iterator.py
tests/ut/python/dataset/test_iterator.py
+0
-1
tests/ut/python/dataset/test_minddataset.py
tests/ut/python/dataset/test_minddataset.py
+3
-3
tests/ut/python/dataset/test_minddataset_exception.py
tests/ut/python/dataset/test_minddataset_exception.py
+3
-3
tests/ut/python/dataset/test_minddataset_multi_images_and_ndarray.py
...thon/dataset/test_minddataset_multi_images_and_ndarray.py
+1
-1
tests/ut/python/dataset/test_minddataset_sampler.py
tests/ut/python/dataset/test_minddataset_sampler.py
+0
-6
tests/ut/python/dataset/test_mixup_label_smoothing.py
tests/ut/python/dataset/test_mixup_label_smoothing.py
+1
-1
tests/ut/python/dataset/test_normalizeOp.py
tests/ut/python/dataset/test_normalizeOp.py
+1
-1
tests/ut/python/dataset/test_onehot_op.py
tests/ut/python/dataset/test_onehot_op.py
+0
-1
tests/ut/python/dataset/test_pad.py
tests/ut/python/dataset/test_pad.py
+1
-2
tests/ut/python/dataset/test_pad_batch.py
tests/ut/python/dataset/test_pad_batch.py
+6
-7
tests/ut/python/dataset/test_random_crop_and_resize.py
tests/ut/python/dataset/test_random_crop_and_resize.py
+2
-1
tests/ut/python/dataset/test_random_crop_decode_resize.py
tests/ut/python/dataset/test_random_crop_decode_resize.py
+1
-1
tests/ut/python/dataset/test_random_dataset.py
tests/ut/python/dataset/test_random_dataset.py
+2
-4
tests/ut/python/dataset/test_random_rotation.py
tests/ut/python/dataset/test_random_rotation.py
+1
-1
tests/ut/python/dataset/test_rename.py
tests/ut/python/dataset/test_rename.py
+1
-1
tests/ut/python/dataset/test_rgb_hsv.py
tests/ut/python/dataset/test_rgb_hsv.py
+1
-1
tests/ut/python/dataset/test_serdes_dataset.py
tests/ut/python/dataset/test_serdes_dataset.py
+25
-25
tests/ut/python/dataset/test_shuffle.py
tests/ut/python/dataset/test_shuffle.py
+1
-1
tests/ut/python/dataset/test_skip.py
tests/ut/python/dataset/test_skip.py
+1
-2
tests/ut/python/dataset/test_sync_wait.py
tests/ut/python/dataset/test_sync_wait.py
+12
-13
tests/ut/python/dataset/test_take.py
tests/ut/python/dataset/test_take.py
+25
-25
tests/ut/python/dataset/test_tfreader_op.py
tests/ut/python/dataset/test_tfreader_op.py
+12
-12
tests/ut/python/dataset/test_var_batch_map.py
tests/ut/python/dataset/test_var_batch_map.py
+6
-6
tests/ut/python/mindrecord/test_mindrecord_base.py
tests/ut/python/mindrecord/test_mindrecord_base.py
+1
-1
tests/ut/python/mindrecord/test_mindrecord_multi_images.py
tests/ut/python/mindrecord/test_mindrecord_multi_images.py
+1
-1
tests/ut/python/mindrecord/test_mnist_to_mr.py
tests/ut/python/mindrecord/test_mnist_to_mr.py
+1
-1
未找到文件。
tests/perf_test/mindrecord/imagenet/perf_read_imagenet.py
浏览文件 @
abca62f4
...
@@ -34,7 +34,7 @@ def use_filereader(mindrecord):
...
@@ -34,7 +34,7 @@ def use_filereader(mindrecord):
num_consumer
=
4
,
num_consumer
=
4
,
columns
=
columns_list
)
columns
=
columns_list
)
num_iter
=
0
num_iter
=
0
for
index
,
item
in
enumerate
(
reader
.
get_next
()):
for
_
,
_
in
enumerate
(
reader
.
get_next
()):
num_iter
+=
1
num_iter
+=
1
print_log
(
num_iter
)
print_log
(
num_iter
)
end
=
time
.
time
()
end
=
time
.
time
()
...
@@ -48,7 +48,7 @@ def use_minddataset(mindrecord):
...
@@ -48,7 +48,7 @@ def use_minddataset(mindrecord):
columns_list
=
columns_list
,
columns_list
=
columns_list
,
num_parallel_workers
=
4
)
num_parallel_workers
=
4
)
num_iter
=
0
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
print_log
(
num_iter
)
print_log
(
num_iter
)
end
=
time
.
time
()
end
=
time
.
time
()
...
@@ -64,7 +64,7 @@ def use_tfrecorddataset(tfrecord):
...
@@ -64,7 +64,7 @@ def use_tfrecorddataset(tfrecord):
shuffle
=
ds
.
Shuffle
.
GLOBAL
)
shuffle
=
ds
.
Shuffle
.
GLOBAL
)
data_set
=
data_set
.
shuffle
(
10000
)
data_set
=
data_set
.
shuffle
(
10000
)
num_iter
=
0
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
print_log
(
num_iter
)
print_log
(
num_iter
)
end
=
time
.
time
()
end
=
time
.
time
()
...
@@ -87,7 +87,7 @@ def use_tensorflow_tfrecorddataset(tfrecord):
...
@@ -87,7 +87,7 @@ def use_tensorflow_tfrecorddataset(tfrecord):
num_parallel_reads
=
4
)
num_parallel_reads
=
4
)
data_set
=
data_set
.
map
(
_parse_record
,
num_parallel_calls
=
4
)
data_set
=
data_set
.
map
(
_parse_record
,
num_parallel_calls
=
4
)
num_iter
=
0
num_iter
=
0
for
item
in
data_set
.
__iter__
():
for
_
in
data_set
.
__iter__
():
num_iter
+=
1
num_iter
+=
1
print_log
(
num_iter
)
print_log
(
num_iter
)
end
=
time
.
time
()
end
=
time
.
time
()
...
@@ -96,18 +96,18 @@ def use_tensorflow_tfrecorddataset(tfrecord):
...
@@ -96,18 +96,18 @@ def use_tensorflow_tfrecorddataset(tfrecord):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# use MindDataset
# use MindDataset
mindrecord
=
'./imagenet.mindrecord00'
mindrecord
_test
=
'./imagenet.mindrecord00'
use_minddataset
(
mindrecord
)
use_minddataset
(
mindrecord
_test
)
# use TFRecordDataset
# use TFRecordDataset
tfrecord
=
[
'imagenet.tfrecord00'
,
'imagenet.tfrecord01'
,
'imagenet.tfrecord02'
,
'imagenet.tfrecord03'
,
tfrecord
_test
=
[
'imagenet.tfrecord00'
,
'imagenet.tfrecord01'
,
'imagenet.tfrecord02'
,
'imagenet.tfrecord03'
,
'imagenet.tfrecord04'
,
'imagenet.tfrecord05'
,
'imagenet.tfrecord06'
,
'imagenet.tfrecord07'
,
'imagenet.tfrecord04'
,
'imagenet.tfrecord05'
,
'imagenet.tfrecord06'
,
'imagenet.tfrecord07'
,
'imagenet.tfrecord08'
,
'imagenet.tfrecord09'
,
'imagenet.tfrecord10'
,
'imagenet.tfrecord11'
,
'imagenet.tfrecord08'
,
'imagenet.tfrecord09'
,
'imagenet.tfrecord10'
,
'imagenet.tfrecord11'
,
'imagenet.tfrecord12'
,
'imagenet.tfrecord13'
,
'imagenet.tfrecord14'
,
'imagenet.tfrecord15'
]
'imagenet.tfrecord12'
,
'imagenet.tfrecord13'
,
'imagenet.tfrecord14'
,
'imagenet.tfrecord15'
]
use_tfrecorddataset
(
tfrecord
)
use_tfrecorddataset
(
tfrecord
_test
)
# use TensorFlow TFRecordDataset
# use TensorFlow TFRecordDataset
use_tensorflow_tfrecorddataset
(
tfrecord
)
use_tensorflow_tfrecorddataset
(
tfrecord
_test
)
# use FileReader
# use FileReader
# use_filereader(mindrecord)
# use_filereader(mindrecord)
tests/ut/data/dataset/testPyfuncMap/pyfuncmap.py
浏览文件 @
abca62f4
...
@@ -29,7 +29,7 @@ def test_case_0():
...
@@ -29,7 +29,7 @@ def test_case_0():
# apply dataset operations
# apply dataset operations
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds1
.
map
(
input_column
_names
=
col
,
output_column_names
=
"out"
,
operation
=
(
lambda
x
:
x
+
x
))
ds1
=
ds1
.
map
(
input_column
s
=
col
,
output_columns
=
"out"
,
operations
=
(
lambda
x
:
x
+
x
))
print
(
"************** Output Tensor *****************"
)
print
(
"************** Output Tensor *****************"
)
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
...
@@ -49,7 +49,7 @@ def test_case_1():
...
@@ -49,7 +49,7 @@ def test_case_1():
# apply dataset operations
# apply dataset operations
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds1
.
map
(
input_column
_names
=
col
,
output_column_names
=
[
"out0"
,
"out1"
],
operation
=
(
lambda
x
:
(
x
,
x
+
x
)))
ds1
=
ds1
.
map
(
input_column
s
=
col
,
output_columns
=
[
"out0"
,
"out1"
],
operations
=
(
lambda
x
:
(
x
,
x
+
x
)))
print
(
"************** Output Tensor *****************"
)
print
(
"************** Output Tensor *****************"
)
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
...
@@ -72,7 +72,7 @@ def test_case_2():
...
@@ -72,7 +72,7 @@ def test_case_2():
# apply dataset operations
# apply dataset operations
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds1
.
map
(
input_column
_names
=
col
,
output_column_names
=
"out"
,
operation
=
(
lambda
x
,
y
:
x
+
y
))
ds1
=
ds1
.
map
(
input_column
s
=
col
,
output_columns
=
"out"
,
operations
=
(
lambda
x
,
y
:
x
+
y
))
print
(
"************** Output Tensor *****************"
)
print
(
"************** Output Tensor *****************"
)
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
...
@@ -93,8 +93,8 @@ def test_case_3():
...
@@ -93,8 +93,8 @@ def test_case_3():
# apply dataset operations
# apply dataset operations
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds1
.
map
(
input_column
_names
=
col
,
output_column_name
s
=
[
"out0"
,
"out1"
,
"out2"
],
ds1
=
ds1
.
map
(
input_column
s
=
col
,
output_column
s
=
[
"out0"
,
"out1"
,
"out2"
],
operation
=
(
lambda
x
,
y
:
(
x
,
x
+
y
,
x
+
x
+
y
)))
operation
s
=
(
lambda
x
,
y
:
(
x
,
x
+
y
,
x
+
x
+
y
)))
print
(
"************** Output Tensor *****************"
)
print
(
"************** Output Tensor *****************"
)
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
...
@@ -119,8 +119,8 @@ def test_case_4():
...
@@ -119,8 +119,8 @@ def test_case_4():
# apply dataset operations
# apply dataset operations
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds1
=
ds1
.
map
(
input_column
_names
=
col
,
output_column_name
s
=
[
"out0"
,
"out1"
,
"out2"
],
num_parallel_workers
=
4
,
ds1
=
ds1
.
map
(
input_column
s
=
col
,
output_column
s
=
[
"out0"
,
"out1"
,
"out2"
],
num_parallel_workers
=
4
,
operation
=
(
lambda
x
,
y
:
(
x
,
x
+
y
,
x
+
x
+
y
)))
operation
s
=
(
lambda
x
,
y
:
(
x
,
x
+
y
,
x
+
x
+
y
)))
print
(
"************** Output Tensor *****************"
)
print
(
"************** Output Tensor *****************"
)
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
...
...
tests/ut/python/dataset/prep_data.py
浏览文件 @
abca62f4
...
@@ -22,11 +22,11 @@ def create_data_cache_dir():
...
@@ -22,11 +22,11 @@ def create_data_cache_dir():
cwd
=
os
.
getcwd
()
cwd
=
os
.
getcwd
()
target_directory
=
os
.
path
.
join
(
cwd
,
"data_cache"
)
target_directory
=
os
.
path
.
join
(
cwd
,
"data_cache"
)
try
:
try
:
if
not
(
os
.
path
.
exists
(
target_directory
)
):
if
not
os
.
path
.
exists
(
target_directory
):
os
.
mkdir
(
target_directory
)
os
.
mkdir
(
target_directory
)
except
OSError
:
except
OSError
:
print
(
"Creation of the directory %s failed"
%
target_directory
)
print
(
"Creation of the directory %s failed"
%
target_directory
)
return
target_directory
;
return
target_directory
def
download_and_uncompress
(
files
,
source_url
,
target_directory
,
is_tar
=
False
):
def
download_and_uncompress
(
files
,
source_url
,
target_directory
,
is_tar
=
False
):
...
@@ -53,13 +53,13 @@ def download_and_uncompress(files, source_url, target_directory, is_tar=False):
...
@@ -53,13 +53,13 @@ def download_and_uncompress(files, source_url, target_directory, is_tar=False):
def
download_mnist
(
target_directory
=
None
):
def
download_mnist
(
target_directory
=
None
):
if
target_directory
==
None
:
if
target_directory
is
None
:
target_directory
=
create_data_cache_dir
()
target_directory
=
create_data_cache_dir
()
##create mnst directory
##create mnst directory
target_directory
=
os
.
path
.
join
(
target_directory
,
"mnist"
)
target_directory
=
os
.
path
.
join
(
target_directory
,
"mnist"
)
try
:
try
:
if
not
(
os
.
path
.
exists
(
target_directory
)
):
if
not
os
.
path
.
exists
(
target_directory
):
os
.
mkdir
(
target_directory
)
os
.
mkdir
(
target_directory
)
except
OSError
:
except
OSError
:
print
(
"Creation of the directory %s failed"
%
target_directory
)
print
(
"Creation of the directory %s failed"
%
target_directory
)
...
@@ -78,14 +78,14 @@ CIFAR_URL = "https://www.cs.toronto.edu/~kriz/"
...
@@ -78,14 +78,14 @@ CIFAR_URL = "https://www.cs.toronto.edu/~kriz/"
def
download_cifar
(
target_directory
,
files
,
directory_from_tar
):
def
download_cifar
(
target_directory
,
files
,
directory_from_tar
):
if
target_directory
==
None
:
if
target_directory
is
None
:
target_directory
=
create_data_cache_dir
()
target_directory
=
create_data_cache_dir
()
download_and_uncompress
([
files
],
CIFAR_URL
,
target_directory
,
is_tar
=
True
)
download_and_uncompress
([
files
],
CIFAR_URL
,
target_directory
,
is_tar
=
True
)
##if target dir was specify move data from directory created by tar
##if target dir was specify move data from directory created by tar
##and put data into target dir
##and put data into target dir
if
target_directory
!=
None
:
if
target_directory
is
not
None
:
tar_dir_full_path
=
os
.
path
.
join
(
target_directory
,
directory_from_tar
)
tar_dir_full_path
=
os
.
path
.
join
(
target_directory
,
directory_from_tar
)
all_files
=
os
.
path
.
join
(
tar_dir_full_path
,
"*"
)
all_files
=
os
.
path
.
join
(
tar_dir_full_path
,
"*"
)
cmd
=
"mv "
+
all_files
+
" "
+
target_directory
cmd
=
"mv "
+
all_files
+
" "
+
target_directory
...
...
tests/ut/python/dataset/test_Tensor.py
浏览文件 @
abca62f4
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
mindspore._c_dataengine
as
cde
import
numpy
as
np
import
numpy
as
np
import
mindspore._c_dataengine
as
cde
def
test_shape
():
def
test_shape
():
x
=
[
2
,
3
]
x
=
[
2
,
3
]
...
...
tests/ut/python/dataset/test_apply.py
浏览文件 @
abca62f4
...
@@ -221,7 +221,7 @@ def test_apply_exception_case():
...
@@ -221,7 +221,7 @@ def test_apply_exception_case():
try
:
try
:
data2
=
data1
.
apply
(
dataset_fn
)
data2
=
data1
.
apply
(
dataset_fn
)
data3
=
data1
.
apply
(
dataset_fn
)
data3
=
data1
.
apply
(
dataset_fn
)
for
item1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
for
_
,
_
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
pass
pass
assert
False
assert
False
except
ValueError
:
except
ValueError
:
...
...
tests/ut/python/dataset/test_cifarop.py
浏览文件 @
abca62f4
...
@@ -35,10 +35,10 @@ def test_case_dataset_cifar10():
...
@@ -35,10 +35,10 @@ def test_case_dataset_cifar10():
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR_10
,
100
)
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR_10
,
100
)
num_iter
=
0
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
for
_
in
data1
.
create_dict_iterator
():
# in this example, each dictionary has keys "image" and "label"
# in this example, each dictionary has keys "image" and "label"
num_iter
+=
1
num_iter
+=
1
assert
(
num_iter
==
100
)
assert
num_iter
==
100
def
test_case_dataset_cifar100
():
def
test_case_dataset_cifar100
():
...
@@ -50,10 +50,10 @@ def test_case_dataset_cifar100():
...
@@ -50,10 +50,10 @@ def test_case_dataset_cifar100():
data1
=
ds
.
Cifar100Dataset
(
DATA_DIR_100
,
100
)
data1
=
ds
.
Cifar100Dataset
(
DATA_DIR_100
,
100
)
num_iter
=
0
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
for
_
in
data1
.
create_dict_iterator
():
# in this example, each dictionary has keys "image" and "label"
# in this example, each dictionary has keys "image" and "label"
num_iter
+=
1
num_iter
+=
1
assert
(
num_iter
==
100
)
assert
num_iter
==
100
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tests/ut/python/dataset/test_config.py
浏览文件 @
abca62f4
...
@@ -15,10 +15,10 @@
...
@@ -15,10 +15,10 @@
"""
"""
Testing configuration manager
Testing configuration manager
"""
"""
import
os
import
filecmp
import
filecmp
import
glob
import
glob
import
numpy
as
np
import
numpy
as
np
import
os
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
...
@@ -89,7 +89,7 @@ def test_pipeline():
...
@@ -89,7 +89,7 @@ def test_pipeline():
ds
.
serialize
(
data2
,
"testpipeline2.json"
)
ds
.
serialize
(
data2
,
"testpipeline2.json"
)
# check that the generated output is different
# check that the generated output is different
assert
(
filecmp
.
cmp
(
'testpipeline.json'
,
'testpipeline2.json'
)
)
assert
filecmp
.
cmp
(
'testpipeline.json'
,
'testpipeline2.json'
)
# this test passes currently because our num_parallel_workers don't get updated.
# this test passes currently because our num_parallel_workers don't get updated.
...
...
tests/ut/python/dataset/test_datasets_celeba.py
浏览文件 @
abca62f4
...
@@ -33,9 +33,9 @@ def test_celeba_dataset_label():
...
@@ -33,9 +33,9 @@ def test_celeba_dataset_label():
logger
.
info
(
"----------attr--------"
)
logger
.
info
(
"----------attr--------"
)
logger
.
info
(
item
[
"attr"
])
logger
.
info
(
item
[
"attr"
])
for
index
in
range
(
len
(
expect_labels
[
count
])):
for
index
in
range
(
len
(
expect_labels
[
count
])):
assert
(
item
[
"attr"
][
index
]
==
expect_labels
[
count
][
index
])
assert
item
[
"attr"
][
index
]
==
expect_labels
[
count
][
index
]
count
=
count
+
1
count
=
count
+
1
assert
(
count
==
2
)
assert
count
==
2
def
test_celeba_dataset_op
():
def
test_celeba_dataset_op
():
...
@@ -54,7 +54,7 @@ def test_celeba_dataset_op():
...
@@ -54,7 +54,7 @@ def test_celeba_dataset_op():
logger
.
info
(
"----------image--------"
)
logger
.
info
(
"----------image--------"
)
logger
.
info
(
item
[
"image"
])
logger
.
info
(
item
[
"image"
])
count
=
count
+
1
count
=
count
+
1
assert
(
count
==
4
)
assert
count
==
4
def
test_celeba_dataset_ext
():
def
test_celeba_dataset_ext
():
...
@@ -69,9 +69,9 @@ def test_celeba_dataset_ext():
...
@@ -69,9 +69,9 @@ def test_celeba_dataset_ext():
logger
.
info
(
"----------attr--------"
)
logger
.
info
(
"----------attr--------"
)
logger
.
info
(
item
[
"attr"
])
logger
.
info
(
item
[
"attr"
])
for
index
in
range
(
len
(
expect_labels
[
count
])):
for
index
in
range
(
len
(
expect_labels
[
count
])):
assert
(
item
[
"attr"
][
index
]
==
expect_labels
[
count
][
index
])
assert
item
[
"attr"
][
index
]
==
expect_labels
[
count
][
index
]
count
=
count
+
1
count
=
count
+
1
assert
(
count
==
1
)
assert
count
==
1
def
test_celeba_dataset_distribute
():
def
test_celeba_dataset_distribute
():
...
@@ -83,7 +83,7 @@ def test_celeba_dataset_distribute():
...
@@ -83,7 +83,7 @@ def test_celeba_dataset_distribute():
logger
.
info
(
"----------attr--------"
)
logger
.
info
(
"----------attr--------"
)
logger
.
info
(
item
[
"attr"
])
logger
.
info
(
item
[
"attr"
])
count
=
count
+
1
count
=
count
+
1
assert
(
count
==
1
)
assert
count
==
1
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tests/ut/python/dataset/test_datasets_imagefolder.py
浏览文件 @
abca62f4
...
@@ -35,7 +35,7 @@ def test_imagefolder_basic():
...
@@ -35,7 +35,7 @@ def test_imagefolder_basic():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
44
)
assert
num_iter
==
44
def
test_imagefolder_numsamples
():
def
test_imagefolder_numsamples
():
...
@@ -55,7 +55,7 @@ def test_imagefolder_numsamples():
...
@@ -55,7 +55,7 @@ def test_imagefolder_numsamples():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
10
)
assert
num_iter
==
10
def
test_imagefolder_numshards
():
def
test_imagefolder_numshards
():
...
@@ -75,7 +75,7 @@ def test_imagefolder_numshards():
...
@@ -75,7 +75,7 @@ def test_imagefolder_numshards():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
11
)
assert
num_iter
==
11
def
test_imagefolder_shardid
():
def
test_imagefolder_shardid
():
...
@@ -95,7 +95,7 @@ def test_imagefolder_shardid():
...
@@ -95,7 +95,7 @@ def test_imagefolder_shardid():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
11
)
assert
num_iter
==
11
def
test_imagefolder_noshuffle
():
def
test_imagefolder_noshuffle
():
...
@@ -115,7 +115,7 @@ def test_imagefolder_noshuffle():
...
@@ -115,7 +115,7 @@ def test_imagefolder_noshuffle():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
44
)
assert
num_iter
==
44
def
test_imagefolder_extrashuffle
():
def
test_imagefolder_extrashuffle
():
...
@@ -136,7 +136,7 @@ def test_imagefolder_extrashuffle():
...
@@ -136,7 +136,7 @@ def test_imagefolder_extrashuffle():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
88
)
assert
num_iter
==
88
def
test_imagefolder_classindex
():
def
test_imagefolder_classindex
():
...
@@ -157,11 +157,11 @@ def test_imagefolder_classindex():
...
@@ -157,11 +157,11 @@ def test_imagefolder_classindex():
# in this example, each dictionary has keys "image" and "label"
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"image is {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is {}"
.
format
(
item
[
"label"
]))
logger
.
info
(
"label is {}"
.
format
(
item
[
"label"
]))
assert
(
item
[
"label"
]
==
golden
[
num_iter
])
assert
item
[
"label"
]
==
golden
[
num_iter
]
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
22
)
assert
num_iter
==
22
def
test_imagefolder_negative_classindex
():
def
test_imagefolder_negative_classindex
():
...
@@ -182,11 +182,11 @@ def test_imagefolder_negative_classindex():
...
@@ -182,11 +182,11 @@ def test_imagefolder_negative_classindex():
# in this example, each dictionary has keys "image" and "label"
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"image is {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is {}"
.
format
(
item
[
"label"
]))
logger
.
info
(
"label is {}"
.
format
(
item
[
"label"
]))
assert
(
item
[
"label"
]
==
golden
[
num_iter
])
assert
item
[
"label"
]
==
golden
[
num_iter
]
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
22
)
assert
num_iter
==
22
def
test_imagefolder_extensions
():
def
test_imagefolder_extensions
():
...
@@ -207,7 +207,7 @@ def test_imagefolder_extensions():
...
@@ -207,7 +207,7 @@ def test_imagefolder_extensions():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
44
)
assert
num_iter
==
44
def
test_imagefolder_decode
():
def
test_imagefolder_decode
():
...
@@ -228,7 +228,7 @@ def test_imagefolder_decode():
...
@@ -228,7 +228,7 @@ def test_imagefolder_decode():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
44
)
assert
num_iter
==
44
def
test_sequential_sampler
():
def
test_sequential_sampler
():
...
@@ -255,7 +255,7 @@ def test_sequential_sampler():
...
@@ -255,7 +255,7 @@ def test_sequential_sampler():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Result: {}"
.
format
(
result
))
logger
.
info
(
"Result: {}"
.
format
(
result
))
assert
(
result
==
golden
)
assert
result
==
golden
def
test_random_sampler
():
def
test_random_sampler
():
...
@@ -276,7 +276,7 @@ def test_random_sampler():
...
@@ -276,7 +276,7 @@ def test_random_sampler():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
44
)
assert
num_iter
==
44
def
test_distributed_sampler
():
def
test_distributed_sampler
():
...
@@ -297,7 +297,7 @@ def test_distributed_sampler():
...
@@ -297,7 +297,7 @@ def test_distributed_sampler():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
5
)
assert
num_iter
==
5
def
test_pk_sampler
():
def
test_pk_sampler
():
...
@@ -318,7 +318,7 @@ def test_pk_sampler():
...
@@ -318,7 +318,7 @@ def test_pk_sampler():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
12
)
assert
num_iter
==
12
def
test_subset_random_sampler
():
def
test_subset_random_sampler
():
...
@@ -340,7 +340,7 @@ def test_subset_random_sampler():
...
@@ -340,7 +340,7 @@ def test_subset_random_sampler():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
12
)
assert
num_iter
==
12
def
test_weighted_random_sampler
():
def
test_weighted_random_sampler
():
...
@@ -362,7 +362,7 @@ def test_weighted_random_sampler():
...
@@ -362,7 +362,7 @@ def test_weighted_random_sampler():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
11
)
assert
num_iter
==
11
def
test_imagefolder_rename
():
def
test_imagefolder_rename
():
...
@@ -382,7 +382,7 @@ def test_imagefolder_rename():
...
@@ -382,7 +382,7 @@ def test_imagefolder_rename():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
10
)
assert
num_iter
==
10
data1
=
data1
.
rename
(
input_columns
=
[
"image"
],
output_columns
=
"image2"
)
data1
=
data1
.
rename
(
input_columns
=
[
"image"
],
output_columns
=
"image2"
)
...
@@ -394,7 +394,7 @@ def test_imagefolder_rename():
...
@@ -394,7 +394,7 @@ def test_imagefolder_rename():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
10
)
assert
num_iter
==
10
def
test_imagefolder_zip
():
def
test_imagefolder_zip
():
...
@@ -419,7 +419,7 @@ def test_imagefolder_zip():
...
@@ -419,7 +419,7 @@ def test_imagefolder_zip():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
(
num_iter
==
10
)
assert
num_iter
==
10
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tests/ut/python/dataset/test_datasets_imagenet.py
浏览文件 @
abca62f4
...
@@ -12,8 +12,6 @@
...
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
pytest
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
...
...
tests/ut/python/dataset/test_datasets_imagenet_distribution.py
浏览文件 @
abca62f4
...
@@ -12,8 +12,6 @@
...
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
pytest
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
...
@@ -30,7 +28,7 @@ def test_tf_file_normal():
...
@@ -30,7 +28,7 @@ def test_tf_file_normal():
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
1
)
data1
=
data1
.
repeat
(
1
)
num_iter
=
0
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
for
_
in
data1
.
create_dict_iterator
():
# each data is a dictionary
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
...
...
tests/ut/python/dataset/test_datasets_manifestop.py
浏览文件 @
abca62f4
...
@@ -16,7 +16,6 @@ import numpy as np
...
@@ -16,7 +16,6 @@ import numpy as np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
DATA_FILE
=
"../data/dataset/testManifestData/test.manifest"
DATA_FILE
=
"../data/dataset/testManifestData/test.manifest"
...
@@ -34,9 +33,9 @@ def test_manifest_dataset_train():
...
@@ -34,9 +33,9 @@ def test_manifest_dataset_train():
cat_count
=
cat_count
+
1
cat_count
=
cat_count
+
1
elif
item
[
"label"
].
size
==
1
and
item
[
"label"
]
==
1
:
elif
item
[
"label"
].
size
==
1
and
item
[
"label"
]
==
1
:
dog_count
=
dog_count
+
1
dog_count
=
dog_count
+
1
assert
(
cat_count
==
2
)
assert
cat_count
==
2
assert
(
dog_count
==
1
)
assert
dog_count
==
1
assert
(
count
==
4
)
assert
count
==
4
def
test_manifest_dataset_eval
():
def
test_manifest_dataset_eval
():
...
@@ -46,36 +45,36 @@ def test_manifest_dataset_eval():
...
@@ -46,36 +45,36 @@ def test_manifest_dataset_eval():
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
count
=
count
+
1
count
=
count
+
1
if
item
[
"label"
]
!=
0
and
item
[
"label"
]
!=
1
:
if
item
[
"label"
]
!=
0
and
item
[
"label"
]
!=
1
:
assert
(
0
)
assert
0
assert
(
count
==
2
)
assert
count
==
2
def
test_manifest_dataset_class_index
():
def
test_manifest_dataset_class_index
():
class_indexing
=
{
"dog"
:
11
}
class_indexing
=
{
"dog"
:
11
}
data
=
ds
.
ManifestDataset
(
DATA_FILE
,
decode
=
True
,
class_indexing
=
class_indexing
)
data
=
ds
.
ManifestDataset
(
DATA_FILE
,
decode
=
True
,
class_indexing
=
class_indexing
)
out_class_indexing
=
data
.
get_class_indexing
()
out_class_indexing
=
data
.
get_class_indexing
()
assert
(
out_class_indexing
==
{
"dog"
:
11
})
assert
out_class_indexing
==
{
"dog"
:
11
}
count
=
0
count
=
0
for
item
in
data
.
create_dict_iterator
():
for
item
in
data
.
create_dict_iterator
():
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
count
=
count
+
1
count
=
count
+
1
if
item
[
"label"
]
!=
11
:
if
item
[
"label"
]
!=
11
:
assert
(
0
)
assert
0
assert
(
count
==
1
)
assert
count
==
1
def
test_manifest_dataset_get_class_index
():
def
test_manifest_dataset_get_class_index
():
data
=
ds
.
ManifestDataset
(
DATA_FILE
,
decode
=
True
)
data
=
ds
.
ManifestDataset
(
DATA_FILE
,
decode
=
True
)
class_indexing
=
data
.
get_class_indexing
()
class_indexing
=
data
.
get_class_indexing
()
assert
(
class_indexing
==
{
'cat'
:
0
,
'dog'
:
1
,
'flower'
:
2
})
assert
class_indexing
==
{
'cat'
:
0
,
'dog'
:
1
,
'flower'
:
2
}
data
=
data
.
shuffle
(
4
)
data
=
data
.
shuffle
(
4
)
class_indexing
=
data
.
get_class_indexing
()
class_indexing
=
data
.
get_class_indexing
()
assert
(
class_indexing
==
{
'cat'
:
0
,
'dog'
:
1
,
'flower'
:
2
})
assert
class_indexing
==
{
'cat'
:
0
,
'dog'
:
1
,
'flower'
:
2
}
count
=
0
count
=
0
for
item
in
data
.
create_dict_iterator
():
for
item
in
data
.
create_dict_iterator
():
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
count
=
count
+
1
count
=
count
+
1
assert
(
count
==
4
)
assert
count
==
4
def
test_manifest_dataset_multi_label
():
def
test_manifest_dataset_multi_label
():
...
@@ -83,10 +82,10 @@ def test_manifest_dataset_multi_label():
...
@@ -83,10 +82,10 @@ def test_manifest_dataset_multi_label():
count
=
0
count
=
0
expect_label
=
[
1
,
0
,
0
,
[
0
,
2
]]
expect_label
=
[
1
,
0
,
0
,
[
0
,
2
]]
for
item
in
data
.
create_dict_iterator
():
for
item
in
data
.
create_dict_iterator
():
assert
(
item
[
"label"
].
tolist
()
==
expect_label
[
count
])
assert
item
[
"label"
].
tolist
()
==
expect_label
[
count
]
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
count
=
count
+
1
count
=
count
+
1
assert
(
count
==
4
)
assert
count
==
4
def
multi_label_hot
(
x
):
def
multi_label_hot
(
x
):
...
@@ -109,7 +108,7 @@ def test_manifest_dataset_multi_label_onehot():
...
@@ -109,7 +108,7 @@ def test_manifest_dataset_multi_label_onehot():
data
=
data
.
batch
(
2
)
data
=
data
.
batch
(
2
)
count
=
0
count
=
0
for
item
in
data
.
create_dict_iterator
():
for
item
in
data
.
create_dict_iterator
():
assert
(
item
[
"label"
].
tolist
()
==
expect_label
[
count
])
assert
item
[
"label"
].
tolist
()
==
expect_label
[
count
]
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"item[image] is {}"
.
format
(
item
[
"image"
]))
count
=
count
+
1
count
=
count
+
1
...
...
tests/ut/python/dataset/test_datasets_sharding.py
浏览文件 @
abca62f4
...
@@ -27,7 +27,7 @@ def test_imagefolder_shardings(print_res=False):
...
@@ -27,7 +27,7 @@ def test_imagefolder_shardings(print_res=False):
res
=
[]
res
=
[]
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
res
.
append
(
item
[
"label"
].
item
())
res
.
append
(
item
[
"label"
].
item
())
if
(
print_res
)
:
if
print_res
:
logger
.
info
(
"labels of dataset: {}"
.
format
(
res
))
logger
.
info
(
"labels of dataset: {}"
.
format
(
res
))
return
res
return
res
...
@@ -39,12 +39,12 @@ def test_imagefolder_shardings(print_res=False):
...
@@ -39,12 +39,12 @@ def test_imagefolder_shardings(print_res=False):
assert
(
sharding_config
(
2
,
0
,
55
,
False
,
dict
())
==
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
])
# 22 rows
assert
(
sharding_config
(
2
,
0
,
55
,
False
,
dict
())
==
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
])
# 22 rows
assert
(
sharding_config
(
2
,
1
,
55
,
False
,
dict
())
==
[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
])
# 22 rows
assert
(
sharding_config
(
2
,
1
,
55
,
False
,
dict
())
==
[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
])
# 22 rows
# total 22 in dataset rows because of class indexing which takes only 2 folders
# total 22 in dataset rows because of class indexing which takes only 2 folders
assert
(
len
(
sharding_config
(
4
,
0
,
None
,
True
,
{
"class1"
:
111
,
"class2"
:
999
}))
==
6
)
assert
len
(
sharding_config
(
4
,
0
,
None
,
True
,
{
"class1"
:
111
,
"class2"
:
999
}))
==
6
assert
(
len
(
sharding_config
(
4
,
2
,
3
,
True
,
{
"class1"
:
111
,
"class2"
:
999
}))
==
3
)
assert
len
(
sharding_config
(
4
,
2
,
3
,
True
,
{
"class1"
:
111
,
"class2"
:
999
}))
==
3
# test with repeat
# test with repeat
assert
(
sharding_config
(
4
,
0
,
12
,
False
,
dict
(),
3
)
==
[
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
]
*
3
)
assert
(
sharding_config
(
4
,
0
,
12
,
False
,
dict
(),
3
)
==
[
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
]
*
3
)
assert
(
sharding_config
(
4
,
0
,
5
,
False
,
dict
(),
5
)
==
[
0
,
0
,
0
,
1
,
1
]
*
5
)
assert
(
sharding_config
(
4
,
0
,
5
,
False
,
dict
(),
5
)
==
[
0
,
0
,
0
,
1
,
1
]
*
5
)
assert
(
len
(
sharding_config
(
5
,
1
,
None
,
True
,
{
"class1"
:
111
,
"class2"
:
999
},
4
))
==
20
)
assert
len
(
sharding_config
(
5
,
1
,
None
,
True
,
{
"class1"
:
111
,
"class2"
:
999
},
4
))
==
20
def
test_tfrecord_shardings1
(
print_res
=
False
):
def
test_tfrecord_shardings1
(
print_res
=
False
):
...
@@ -176,8 +176,8 @@ def test_voc_shardings(print_res=False):
...
@@ -176,8 +176,8 @@ def test_voc_shardings(print_res=False):
# then takes the first 2 bc num_samples = 2
# then takes the first 2 bc num_samples = 2
assert
(
sharding_config
(
3
,
2
,
2
,
False
,
4
)
==
[
2268
,
607
]
*
4
)
assert
(
sharding_config
(
3
,
2
,
2
,
False
,
4
)
==
[
2268
,
607
]
*
4
)
# test that each epoch, each shard_worker returns a different sample
# test that each epoch, each shard_worker returns a different sample
assert
(
len
(
sharding_config
(
2
,
0
,
None
,
True
,
1
))
==
5
)
assert
len
(
sharding_config
(
2
,
0
,
None
,
True
,
1
))
==
5
assert
(
len
(
set
(
sharding_config
(
11
,
0
,
None
,
True
,
10
)))
>
1
)
assert
len
(
set
(
sharding_config
(
11
,
0
,
None
,
True
,
10
)))
>
1
def
test_cifar10_shardings
(
print_res
=
False
):
def
test_cifar10_shardings
(
print_res
=
False
):
...
@@ -196,8 +196,8 @@ def test_cifar10_shardings(print_res=False):
...
@@ -196,8 +196,8 @@ def test_cifar10_shardings(print_res=False):
# 60000 rows in total. CIFAR reads everything in memory which would make each test case very slow
# 60000 rows in total. CIFAR reads everything in memory which would make each test case very slow
# therefore, only 2 test cases for now.
# therefore, only 2 test cases for now.
assert
(
sharding_config
(
10000
,
9999
,
7
,
False
,
1
)
==
[
9
])
assert
sharding_config
(
10000
,
9999
,
7
,
False
,
1
)
==
[
9
]
assert
(
sharding_config
(
10000
,
0
,
4
,
False
,
3
)
==
[
0
,
0
,
0
])
assert
sharding_config
(
10000
,
0
,
4
,
False
,
3
)
==
[
0
,
0
,
0
]
def
test_cifar100_shardings
(
print_res
=
False
):
def
test_cifar100_shardings
(
print_res
=
False
):
...
...
tests/ut/python/dataset/test_datasets_textfileop.py
浏览文件 @
abca62f4
...
@@ -27,7 +27,7 @@ def test_textline_dataset_one_file():
...
@@ -27,7 +27,7 @@ def test_textline_dataset_one_file():
for
i
in
data
.
create_dict_iterator
():
for
i
in
data
.
create_dict_iterator
():
logger
.
info
(
"{}"
.
format
(
i
[
"text"
]))
logger
.
info
(
"{}"
.
format
(
i
[
"text"
]))
count
+=
1
count
+=
1
assert
(
count
==
3
)
assert
count
==
3
def
test_textline_dataset_all_file
():
def
test_textline_dataset_all_file
():
...
@@ -36,7 +36,7 @@ def test_textline_dataset_all_file():
...
@@ -36,7 +36,7 @@ def test_textline_dataset_all_file():
for
i
in
data
.
create_dict_iterator
():
for
i
in
data
.
create_dict_iterator
():
logger
.
info
(
"{}"
.
format
(
i
[
"text"
]))
logger
.
info
(
"{}"
.
format
(
i
[
"text"
]))
count
+=
1
count
+=
1
assert
(
count
==
5
)
assert
count
==
5
def
test_textline_dataset_totext
():
def
test_textline_dataset_totext
():
...
@@ -46,8 +46,8 @@ def test_textline_dataset_totext():
...
@@ -46,8 +46,8 @@ def test_textline_dataset_totext():
line
=
[
"This is a text file."
,
"Another file."
,
line
=
[
"This is a text file."
,
"Another file."
,
"Be happy every day."
,
"End of file."
,
"Good luck to everyone."
]
"Be happy every day."
,
"End of file."
,
"Good luck to everyone."
]
for
i
in
data
.
create_dict_iterator
():
for
i
in
data
.
create_dict_iterator
():
str
=
i
[
"text"
].
item
().
decode
(
"utf8"
)
str
s
=
i
[
"text"
].
item
().
decode
(
"utf8"
)
assert
(
str
==
line
[
count
])
assert
strs
==
line
[
count
]
count
+=
1
count
+=
1
assert
(
count
==
5
)
assert
(
count
==
5
)
# Restore configuration num_parallel_workers
# Restore configuration num_parallel_workers
...
@@ -57,17 +57,17 @@ def test_textline_dataset_totext():
...
@@ -57,17 +57,17 @@ def test_textline_dataset_totext():
def
test_textline_dataset_num_samples
():
def
test_textline_dataset_num_samples
():
data
=
ds
.
TextFileDataset
(
DATA_FILE
,
num_samples
=
2
)
data
=
ds
.
TextFileDataset
(
DATA_FILE
,
num_samples
=
2
)
count
=
0
count
=
0
for
i
in
data
.
create_dict_iterator
():
for
_
in
data
.
create_dict_iterator
():
count
+=
1
count
+=
1
assert
(
count
==
2
)
assert
count
==
2
def
test_textline_dataset_distribution
():
def
test_textline_dataset_distribution
():
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
num_shards
=
2
,
shard_id
=
1
)
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
num_shards
=
2
,
shard_id
=
1
)
count
=
0
count
=
0
for
i
in
data
.
create_dict_iterator
():
for
_
in
data
.
create_dict_iterator
():
count
+=
1
count
+=
1
assert
(
count
==
3
)
assert
count
==
3
def
test_textline_dataset_repeat
():
def
test_textline_dataset_repeat
():
...
@@ -78,16 +78,16 @@ def test_textline_dataset_repeat():
...
@@ -78,16 +78,16 @@ def test_textline_dataset_repeat():
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
,
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
,
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
]
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
]
for
i
in
data
.
create_dict_iterator
():
for
i
in
data
.
create_dict_iterator
():
str
=
i
[
"text"
].
item
().
decode
(
"utf8"
)
str
s
=
i
[
"text"
].
item
().
decode
(
"utf8"
)
assert
(
str
==
line
[
count
])
assert
strs
==
line
[
count
]
count
+=
1
count
+=
1
assert
(
count
==
9
)
assert
count
==
9
def
test_textline_dataset_get_datasetsize
():
def
test_textline_dataset_get_datasetsize
():
data
=
ds
.
TextFileDataset
(
DATA_FILE
)
data
=
ds
.
TextFileDataset
(
DATA_FILE
)
size
=
data
.
get_dataset_size
()
size
=
data
.
get_dataset_size
()
assert
(
size
==
3
)
assert
size
==
3
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/ut/python/dataset/test_decode.py
浏览文件 @
abca62f4
...
@@ -15,9 +15,8 @@
...
@@ -15,9 +15,8 @@
"""
"""
Testing Decode op in DE
Testing Decode op in DE
"""
"""
import
cv2
import
numpy
as
np
import
numpy
as
np
from
util
import
diff_mse
import
cv2
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
...
...
tests/ut/python/dataset/test_filterop.py
浏览文件 @
abca62f4
...
@@ -88,7 +88,7 @@ def test_filter_by_generator_with_repeat():
...
@@ -88,7 +88,7 @@ def test_filter_by_generator_with_repeat():
ret_data
.
append
(
item
[
"data"
])
ret_data
.
append
(
item
[
"data"
])
assert
num_iter
==
44
assert
num_iter
==
44
for
i
in
range
(
4
):
for
i
in
range
(
4
):
for
ii
in
range
(
len
(
expected_rs
)
):
for
ii
,
_
in
enumerate
(
expected_rs
):
index
=
i
*
len
(
expected_rs
)
+
ii
index
=
i
*
len
(
expected_rs
)
+
ii
assert
ret_data
[
index
]
==
expected_rs
[
ii
]
assert
ret_data
[
index
]
==
expected_rs
[
ii
]
...
@@ -106,7 +106,7 @@ def test_filter_by_generator_with_repeat_after():
...
@@ -106,7 +106,7 @@ def test_filter_by_generator_with_repeat_after():
ret_data
.
append
(
item
[
"data"
])
ret_data
.
append
(
item
[
"data"
])
assert
num_iter
==
44
assert
num_iter
==
44
for
i
in
range
(
4
):
for
i
in
range
(
4
):
for
ii
in
range
(
len
(
expected_rs
)
):
for
ii
,
_
in
enumerate
(
expected_rs
):
index
=
i
*
len
(
expected_rs
)
+
ii
index
=
i
*
len
(
expected_rs
)
+
ii
assert
ret_data
[
index
]
==
expected_rs
[
ii
]
assert
ret_data
[
index
]
==
expected_rs
[
ii
]
...
@@ -167,7 +167,7 @@ def test_filter_by_generator_with_shuffle():
...
@@ -167,7 +167,7 @@ def test_filter_by_generator_with_shuffle():
dataset_s
=
dataset
.
shuffle
(
4
)
dataset_s
=
dataset
.
shuffle
(
4
)
dataset_f
=
dataset_s
.
filter
(
predicate
=
filter_func_shuffle
,
num_parallel_workers
=
4
)
dataset_f
=
dataset_s
.
filter
(
predicate
=
filter_func_shuffle
,
num_parallel_workers
=
4
)
num_iter
=
0
num_iter
=
0
for
item
in
dataset_f
.
create_dict_iterator
():
for
_
in
dataset_f
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
21
assert
num_iter
==
21
...
@@ -184,7 +184,7 @@ def test_filter_by_generator_with_shuffle_after():
...
@@ -184,7 +184,7 @@ def test_filter_by_generator_with_shuffle_after():
dataset_f
=
dataset
.
filter
(
predicate
=
filter_func_shuffle_after
,
num_parallel_workers
=
4
)
dataset_f
=
dataset
.
filter
(
predicate
=
filter_func_shuffle_after
,
num_parallel_workers
=
4
)
dataset_s
=
dataset_f
.
shuffle
(
4
)
dataset_s
=
dataset_f
.
shuffle
(
4
)
num_iter
=
0
num_iter
=
0
for
item
in
dataset_s
.
create_dict_iterator
():
for
_
in
dataset_s
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
21
assert
num_iter
==
21
...
@@ -258,8 +258,7 @@ def filter_func_map(col1, col2):
...
@@ -258,8 +258,7 @@ def filter_func_map(col1, col2):
def
filter_func_map_part
(
col1
):
def
filter_func_map_part
(
col1
):
if
col1
<
3
:
if
col1
<
3
:
return
True
return
True
else
:
return
False
return
False
def
filter_func_map_all
(
col1
,
col2
):
def
filter_func_map_all
(
col1
,
col2
):
...
@@ -276,7 +275,7 @@ def func_map(data_col1, data_col2):
...
@@ -276,7 +275,7 @@ def func_map(data_col1, data_col2):
def
func_map_part
(
data_col1
):
def
func_map_part
(
data_col1
):
return
(
data_col1
)
return
data_col1
# test with map
# test with map
...
@@ -473,7 +472,6 @@ def test_filte_case_dataset_cifar10():
...
@@ -473,7 +472,6 @@ def test_filte_case_dataset_cifar10():
ds
.
config
.
load
(
'../data/dataset/declient_filter.cfg'
)
ds
.
config
.
load
(
'../data/dataset/declient_filter.cfg'
)
dataset_c
=
ds
.
Cifar10Dataset
(
dataset_dir
=
DATA_DIR_10
,
num_samples
=
100000
,
shuffle
=
False
)
dataset_c
=
ds
.
Cifar10Dataset
(
dataset_dir
=
DATA_DIR_10
,
num_samples
=
100000
,
shuffle
=
False
)
dataset_f1
=
dataset_c
.
filter
(
input_columns
=
[
"image"
,
"label"
],
predicate
=
filter_func_cifar
,
num_parallel_workers
=
1
)
dataset_f1
=
dataset_c
.
filter
(
input_columns
=
[
"image"
,
"label"
],
predicate
=
filter_func_cifar
,
num_parallel_workers
=
1
)
num_iter
=
0
for
item
in
dataset_f1
.
create_dict_iterator
():
for
item
in
dataset_f1
.
create_dict_iterator
():
# in this example, each dictionary has keys "image" and "label"
# in this example, each dictionary has keys "image" and "label"
assert
item
[
"label"
]
%
3
==
0
assert
item
[
"label"
]
%
3
==
0
...
...
tests/ut/python/dataset/test_generator.py
浏览文件 @
abca62f4
...
@@ -184,7 +184,7 @@ def test_case_6():
...
@@ -184,7 +184,7 @@ def test_case_6():
de_types
=
[
mstype
.
int8
,
mstype
.
int16
,
mstype
.
int32
,
mstype
.
int64
,
mstype
.
uint8
,
mstype
.
uint16
,
mstype
.
uint32
,
de_types
=
[
mstype
.
int8
,
mstype
.
int16
,
mstype
.
int32
,
mstype
.
int64
,
mstype
.
uint8
,
mstype
.
uint16
,
mstype
.
uint32
,
mstype
.
uint64
,
mstype
.
float32
,
mstype
.
float64
]
mstype
.
uint64
,
mstype
.
float32
,
mstype
.
float64
]
for
i
in
range
(
len
(
np_types
)
):
for
i
,
_
in
enumerate
(
np_types
):
type_tester_with_type_check
(
np_types
[
i
],
de_types
[
i
])
type_tester_with_type_check
(
np_types
[
i
],
de_types
[
i
])
...
@@ -219,7 +219,7 @@ def test_case_7():
...
@@ -219,7 +219,7 @@ def test_case_7():
de_types
=
[
mstype
.
int8
,
mstype
.
int16
,
mstype
.
int32
,
mstype
.
int64
,
mstype
.
uint8
,
mstype
.
uint16
,
mstype
.
uint32
,
de_types
=
[
mstype
.
int8
,
mstype
.
int16
,
mstype
.
int32
,
mstype
.
int64
,
mstype
.
uint8
,
mstype
.
uint16
,
mstype
.
uint32
,
mstype
.
uint64
,
mstype
.
float32
,
mstype
.
float64
]
mstype
.
uint64
,
mstype
.
float32
,
mstype
.
float64
]
for
i
in
range
(
len
(
np_types
)
):
for
i
,
_
in
enumerate
(
np_types
):
type_tester_with_type_check_2c
(
np_types
[
i
],
[
None
,
de_types
[
i
]])
type_tester_with_type_check_2c
(
np_types
[
i
],
[
None
,
de_types
[
i
]])
...
@@ -526,7 +526,7 @@ def test_sequential_sampler():
...
@@ -526,7 +526,7 @@ def test_sequential_sampler():
def
test_random_sampler
():
def
test_random_sampler
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
shuffle
=
True
)
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
shuffle
=
True
)
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
for
_
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
pass
pass
...
@@ -611,7 +611,7 @@ def test_schema():
...
@@ -611,7 +611,7 @@ def test_schema():
de_types
=
[
mstype
.
int8
,
mstype
.
int16
,
mstype
.
int32
,
mstype
.
int64
,
mstype
.
uint8
,
mstype
.
uint16
,
mstype
.
uint32
,
de_types
=
[
mstype
.
int8
,
mstype
.
int16
,
mstype
.
int32
,
mstype
.
int64
,
mstype
.
uint8
,
mstype
.
uint16
,
mstype
.
uint32
,
mstype
.
uint64
,
mstype
.
float32
,
mstype
.
float64
]
mstype
.
uint64
,
mstype
.
float32
,
mstype
.
float64
]
for
i
in
range
(
len
(
np_types
)
):
for
i
,
_
in
enumerate
(
np_types
):
type_tester_with_type_check_2c_schema
(
np_types
[
i
],
[
de_types
[
i
],
de_types
[
i
]])
type_tester_with_type_check_2c_schema
(
np_types
[
i
],
[
de_types
[
i
],
de_types
[
i
]])
...
@@ -630,8 +630,7 @@ def manual_test_keyborad_interrupt():
...
@@ -630,8 +630,7 @@ def manual_test_keyborad_interrupt():
return
1024
return
1024
ds1
=
ds
.
GeneratorDataset
(
MyDS
(),
[
"data"
],
num_parallel_workers
=
4
).
repeat
(
2
)
ds1
=
ds
.
GeneratorDataset
(
MyDS
(),
[
"data"
],
num_parallel_workers
=
4
).
repeat
(
2
)
i
=
0
for
_
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
pass
pass
...
...
tests/ut/python/dataset/test_iterator.py
浏览文件 @
abca62f4
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
copy
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
...
tests/ut/python/dataset/test_minddataset.py
浏览文件 @
abca62f4
...
@@ -320,7 +320,7 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file):
...
@@ -320,7 +320,7 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file):
data
=
data
.
shuffle
(
2
)
data
=
data
.
shuffle
(
2
)
data
=
data
.
repeat
(
9
)
data
=
data
.
repeat
(
9
)
num_iter
=
0
num_iter
=
0
for
item
in
data
.
create_dict_iterator
():
for
_
in
data
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
18
assert
num_iter
==
18
...
@@ -572,7 +572,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
...
@@ -572,7 +572,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
num_readers
=
4
num_readers
=
4
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
)
assert
data_set
.
get_dataset_size
()
==
10
assert
data_set
.
get_dataset_size
()
==
10
for
epoch
in
range
(
5
):
for
_
in
range
(
5
):
num_iter
=
0
num_iter
=
0
for
data
in
data_set
:
for
data
in
data_set
:
logger
.
info
(
"data is {}"
.
format
(
data
))
logger
.
info
(
"data is {}"
.
format
(
data
))
...
@@ -603,7 +603,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_
...
@@ -603,7 +603,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_
data_set
=
data_set
.
batch
(
2
)
data_set
=
data_set
.
batch
(
2
)
assert
data_set
.
get_dataset_size
()
==
5
assert
data_set
.
get_dataset_size
()
==
5
for
epoch
in
range
(
5
):
for
_
in
range
(
5
):
num_iter
=
0
num_iter
=
0
for
data
in
data_set
:
for
data
in
data_set
:
logger
.
info
(
"data is {}"
.
format
(
data
))
logger
.
info
(
"data is {}"
.
format
(
data
))
...
...
tests/ut/python/dataset/test_minddataset_exception.py
浏览文件 @
abca62f4
...
@@ -91,7 +91,7 @@ def test_invalid_mindrecord():
...
@@ -91,7 +91,7 @@ def test_invalid_mindrecord():
with
pytest
.
raises
(
Exception
,
match
=
"MindRecordOp init failed"
):
with
pytest
.
raises
(
Exception
,
match
=
"MindRecordOp init failed"
):
data_set
=
ds
.
MindDataset
(
'dummy.mindrecord'
,
columns_list
,
num_readers
)
data_set
=
ds
.
MindDataset
(
'dummy.mindrecord'
,
columns_list
,
num_readers
)
num_iter
=
0
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
0
assert
num_iter
==
0
os
.
remove
(
'dummy.mindrecord'
)
os
.
remove
(
'dummy.mindrecord'
)
...
@@ -105,7 +105,7 @@ def test_minddataset_lack_db():
...
@@ -105,7 +105,7 @@ def test_minddataset_lack_db():
with
pytest
.
raises
(
Exception
,
match
=
"MindRecordOp init failed"
):
with
pytest
.
raises
(
Exception
,
match
=
"MindRecordOp init failed"
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
)
num_iter
=
0
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
0
assert
num_iter
==
0
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
CV_FILE_NAME
)
...
@@ -119,7 +119,7 @@ def test_cv_minddataset_pk_sample_error_class_column():
...
@@ -119,7 +119,7 @@ def test_cv_minddataset_pk_sample_error_class_column():
with
pytest
.
raises
(
Exception
,
match
=
"MindRecordOp launch failed"
):
with
pytest
.
raises
(
Exception
,
match
=
"MindRecordOp launch failed"
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
sampler
=
sampler
)
num_iter
=
0
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
...
...
tests/ut/python/dataset/test_minddataset_multi_images_and_ndarray.py
浏览文件 @
abca62f4
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
"""
"""
This is the test module for mindrecord
This is the test module for mindrecord
"""
"""
import
numpy
as
np
import
os
import
os
import
numpy
as
np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
...
...
tests/ut/python/dataset/test_minddataset_sampler.py
浏览文件 @
abca62f4
...
@@ -15,16 +15,10 @@
...
@@ -15,16 +15,10 @@
"""
"""
This is the test module for mindrecord
This is the test module for mindrecord
"""
"""
import
collections
import
json
import
numpy
as
np
import
os
import
os
import
pytest
import
pytest
import
re
import
string
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
mindspore.dataset.transforms.vision
import
Inter
from
mindspore.dataset.transforms.vision
import
Inter
from
mindspore.dataset.text
import
to_str
from
mindspore.dataset.text
import
to_str
...
...
tests/ut/python/dataset/test_mixup_label_smoothing.py
浏览文件 @
abca62f4
...
@@ -49,7 +49,7 @@ def test_one_hot_op():
...
@@ -49,7 +49,7 @@ def test_one_hot_op():
label
=
data
[
"label"
]
label
=
data
[
"label"
]
logger
.
info
(
"label is {}"
.
format
(
label
))
logger
.
info
(
"label is {}"
.
format
(
label
))
logger
.
info
(
"golden_label is {}"
.
format
(
golden_label
))
logger
.
info
(
"golden_label is {}"
.
format
(
golden_label
))
assert
(
label
.
all
()
==
golden_label
.
all
()
)
assert
label
.
all
()
==
golden_label
.
all
(
)
logger
.
info
(
"====test one hot op ok===="
)
logger
.
info
(
"====test one hot op ok===="
)
...
...
tests/ut/python/dataset/test_normalizeOp.py
浏览文件 @
abca62f4
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
...
@@ -50,6 +49,7 @@ def get_normalized(image_id):
...
@@ -50,6 +49,7 @@ def get_normalized(image_id):
if
num_iter
==
image_id
:
if
num_iter
==
image_id
:
return
normalize_np
(
image
)
return
normalize_np
(
image
)
num_iter
+=
1
num_iter
+=
1
return
None
def
test_normalize_op
():
def
test_normalize_op
():
...
...
tests/ut/python/dataset/test_onehot_op.py
浏览文件 @
abca62f4
...
@@ -19,7 +19,6 @@ import numpy as np
...
@@ -19,7 +19,6 @@ import numpy as np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
...
...
tests/ut/python/dataset/test_pad.py
浏览文件 @
abca62f4
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
"""
"""
Testing Pad op in DE
Testing Pad op in DE
"""
"""
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
from
util
import
diff_mse
from
util
import
diff_mse
...
@@ -118,7 +117,7 @@ def test_pad_grayscale():
...
@@ -118,7 +117,7 @@ def test_pad_grayscale():
for
shape1
,
shape2
in
zip
(
dataset_shape_1
,
dataset_shape_2
):
for
shape1
,
shape2
in
zip
(
dataset_shape_1
,
dataset_shape_2
):
# validate that the first two dimensions are the same
# validate that the first two dimensions are the same
# we have a little inconsistency here because the third dimension is 1 after py_vision.Grayscale
# we have a little inconsistency here because the third dimension is 1 after py_vision.Grayscale
assert
(
shape1
[
0
:
1
]
==
shape2
[
0
:
1
])
assert
shape1
[
0
:
1
]
==
shape2
[
0
:
1
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/ut/python/dataset/test_pad_batch.py
浏览文件 @
abca62f4
...
@@ -13,8 +13,8 @@
...
@@ -13,8 +13,8 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
numpy
as
np
import
time
import
time
import
numpy
as
np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
...
@@ -117,8 +117,7 @@ def batch_padding_performance_3d():
...
@@ -117,8 +117,7 @@ def batch_padding_performance_3d():
data1
=
data1
.
batch
(
batch_size
=
24
,
drop_remainder
=
True
,
pad_info
=
pad_info
)
data1
=
data1
.
batch
(
batch_size
=
24
,
drop_remainder
=
True
,
pad_info
=
pad_info
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
num_batches
=
0
num_batches
=
0
ret
=
[]
for
_
in
data1
.
create_dict_iterator
():
for
data
in
data1
.
create_dict_iterator
():
num_batches
+=
1
num_batches
+=
1
res
=
"total number of batch:"
+
str
(
num_batches
)
+
" time elapsed:"
+
str
(
time
.
time
()
-
start_time
)
res
=
"total number of batch:"
+
str
(
num_batches
)
+
" time elapsed:"
+
str
(
time
.
time
()
-
start_time
)
# print(res)
# print(res)
...
@@ -134,7 +133,7 @@ def batch_padding_performance_1d():
...
@@ -134,7 +133,7 @@ def batch_padding_performance_1d():
data1
=
data1
.
batch
(
batch_size
=
24
,
drop_remainder
=
True
,
pad_info
=
pad_info
)
data1
=
data1
.
batch
(
batch_size
=
24
,
drop_remainder
=
True
,
pad_info
=
pad_info
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
num_batches
=
0
num_batches
=
0
for
data
in
data1
.
create_dict_iterator
():
for
_
in
data1
.
create_dict_iterator
():
num_batches
+=
1
num_batches
+=
1
res
=
"total number of batch:"
+
str
(
num_batches
)
+
" time elapsed:"
+
str
(
time
.
time
()
-
start_time
)
res
=
"total number of batch:"
+
str
(
num_batches
)
+
" time elapsed:"
+
str
(
time
.
time
()
-
start_time
)
# print(res)
# print(res)
...
@@ -150,7 +149,7 @@ def batch_pyfunc_padding_3d():
...
@@ -150,7 +149,7 @@ def batch_pyfunc_padding_3d():
data1
=
data1
.
batch
(
batch_size
=
24
,
drop_remainder
=
True
)
data1
=
data1
.
batch
(
batch_size
=
24
,
drop_remainder
=
True
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
num_batches
=
0
num_batches
=
0
for
data
in
data1
.
create_dict_iterator
():
for
_
in
data1
.
create_dict_iterator
():
num_batches
+=
1
num_batches
+=
1
res
=
"total number of batch:"
+
str
(
num_batches
)
+
" time elapsed:"
+
str
(
time
.
time
()
-
start_time
)
res
=
"total number of batch:"
+
str
(
num_batches
)
+
" time elapsed:"
+
str
(
time
.
time
()
-
start_time
)
# print(res)
# print(res)
...
@@ -165,7 +164,7 @@ def batch_pyfunc_padding_1d():
...
@@ -165,7 +164,7 @@ def batch_pyfunc_padding_1d():
data1
=
data1
.
batch
(
batch_size
=
24
,
drop_remainder
=
True
)
data1
=
data1
.
batch
(
batch_size
=
24
,
drop_remainder
=
True
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
num_batches
=
0
num_batches
=
0
for
data
in
data1
.
create_dict_iterator
():
for
_
in
data1
.
create_dict_iterator
():
num_batches
+=
1
num_batches
+=
1
res
=
"total number of batch:"
+
str
(
num_batches
)
+
" time elapsed:"
+
str
(
time
.
time
()
-
start_time
)
res
=
"total number of batch:"
+
str
(
num_batches
)
+
" time elapsed:"
+
str
(
time
.
time
()
-
start_time
)
# print(res)
# print(res)
...
@@ -197,7 +196,7 @@ def test_pad_via_map():
...
@@ -197,7 +196,7 @@ def test_pad_via_map():
res_from_map
=
pad_map_config
()
res_from_map
=
pad_map_config
()
res_from_batch
=
pad_batch_config
()
res_from_batch
=
pad_batch_config
()
assert
len
(
res_from_batch
)
==
len
(
res_from_batch
)
assert
len
(
res_from_batch
)
==
len
(
res_from_batch
)
for
i
in
range
(
len
(
res_from_map
)
):
for
i
,
_
in
enumerate
(
res_from_map
):
assert
np
.
array_equal
(
res_from_map
[
i
],
res_from_batch
[
i
])
assert
np
.
array_equal
(
res_from_map
[
i
],
res_from_batch
[
i
])
...
...
tests/ut/python/dataset/test_random_crop_and_resize.py
浏览文件 @
abca62f4
...
@@ -15,8 +15,9 @@
...
@@ -15,8 +15,9 @@
"""
"""
Testing RandomCropAndResize op in DE
Testing RandomCropAndResize op in DE
"""
"""
import
cv2
import
numpy
as
np
import
numpy
as
np
import
cv2
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.py_transforms
as
py_vision
import
mindspore.dataset.transforms.vision.py_transforms
as
py_vision
import
mindspore.dataset.transforms.vision.utils
as
mode
import
mindspore.dataset.transforms.vision.utils
as
mode
...
...
tests/ut/python/dataset/test_random_crop_decode_resize.py
浏览文件 @
abca62f4
...
@@ -15,9 +15,9 @@
...
@@ -15,9 +15,9 @@
"""
"""
Testing RandomCropDecodeResize op in DE
Testing RandomCropDecodeResize op in DE
"""
"""
import
cv2
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
cv2
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
...
...
tests/ut/python/dataset/test_random_dataset.py
浏览文件 @
abca62f4
...
@@ -12,8 +12,6 @@
...
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
from
pathlib
import
Path
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
...
@@ -39,7 +37,7 @@ def test_randomdataset_basic1():
...
@@ -39,7 +37,7 @@ def test_randomdataset_basic1():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in ds1: "
,
num_iter
)
logger
.
info
(
"Number of data in ds1: "
,
num_iter
)
assert
(
num_iter
==
200
)
assert
num_iter
==
200
# Another simple test
# Another simple test
...
@@ -65,7 +63,7 @@ def test_randomdataset_basic2():
...
@@ -65,7 +63,7 @@ def test_randomdataset_basic2():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of data in ds1: "
,
num_iter
)
logger
.
info
(
"Number of data in ds1: "
,
num_iter
)
assert
(
num_iter
==
40
)
assert
num_iter
==
40
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tests/ut/python/dataset/test_random_rotation.py
浏览文件 @
abca62f4
...
@@ -15,9 +15,9 @@
...
@@ -15,9 +15,9 @@
"""
"""
Testing RandomRotation op in DE
Testing RandomRotation op in DE
"""
"""
import
cv2
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
cv2
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
...
...
tests/ut/python/dataset/test_rename.py
浏览文件 @
abca62f4
...
@@ -34,7 +34,7 @@ def test_rename():
...
@@ -34,7 +34,7 @@ def test_rename():
num_iter
=
0
num_iter
=
0
for
i
,
item
in
enumerate
(
data
.
create_dict_iterator
()):
for
_
,
item
in
enumerate
(
data
.
create_dict_iterator
()):
logger
.
info
(
"item[mask] is {}"
.
format
(
item
[
"masks"
]))
logger
.
info
(
"item[mask] is {}"
.
format
(
item
[
"masks"
]))
np
.
testing
.
assert_equal
(
item
[
"masks"
],
item
[
"input_ids"
])
np
.
testing
.
assert_equal
(
item
[
"masks"
],
item
[
"input_ids"
])
logger
.
info
(
"item[seg_ids] is {}"
.
format
(
item
[
"seg_ids"
]))
logger
.
info
(
"item[seg_ids] is {}"
.
format
(
item
[
"seg_ids"
]))
...
...
tests/ut/python/dataset/test_rgb_hsv.py
浏览文件 @
abca62f4
...
@@ -159,7 +159,7 @@ def test_rgb_hsv_pipeline():
...
@@ -159,7 +159,7 @@ def test_rgb_hsv_pipeline():
ori_img
=
data1
[
"image"
]
ori_img
=
data1
[
"image"
]
cvt_img
=
data2
[
"image"
]
cvt_img
=
data2
[
"image"
]
assert_allclose
(
ori_img
.
flatten
(),
cvt_img
.
flatten
(),
rtol
=
1e-5
,
atol
=
0
)
assert_allclose
(
ori_img
.
flatten
(),
cvt_img
.
flatten
(),
rtol
=
1e-5
,
atol
=
0
)
assert
(
ori_img
.
shape
==
cvt_img
.
shape
)
assert
ori_img
.
shape
==
cvt_img
.
shape
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/ut/python/dataset/test_serdes_dataset.py
浏览文件 @
abca62f4
...
@@ -57,7 +57,7 @@ def test_imagefolder(remove_json_files=True):
...
@@ -57,7 +57,7 @@ def test_imagefolder(remove_json_files=True):
# data1 should still work after saving.
# data1 should still work after saving.
ds
.
serialize
(
data1
,
"imagenet_dataset_pipeline.json"
)
ds
.
serialize
(
data1
,
"imagenet_dataset_pipeline.json"
)
ds1_dict
=
ds
.
serialize
(
data1
)
ds1_dict
=
ds
.
serialize
(
data1
)
assert
(
validate_jsonfile
(
"imagenet_dataset_pipeline.json"
)
is
True
)
assert
validate_jsonfile
(
"imagenet_dataset_pipeline.json"
)
is
True
# Print the serialized pipeline to stdout
# Print the serialized pipeline to stdout
ds
.
show
(
data1
)
ds
.
show
(
data1
)
...
@@ -68,8 +68,8 @@ def test_imagefolder(remove_json_files=True):
...
@@ -68,8 +68,8 @@ def test_imagefolder(remove_json_files=True):
# Serialize the pipeline we just deserialized.
# Serialize the pipeline we just deserialized.
# The content of the json file should be the same to the previous serialize.
# The content of the json file should be the same to the previous serialize.
ds
.
serialize
(
data2
,
"imagenet_dataset_pipeline_1.json"
)
ds
.
serialize
(
data2
,
"imagenet_dataset_pipeline_1.json"
)
assert
(
validate_jsonfile
(
"imagenet_dataset_pipeline_1.json"
)
is
True
)
assert
validate_jsonfile
(
"imagenet_dataset_pipeline_1.json"
)
is
True
assert
(
filecmp
.
cmp
(
'imagenet_dataset_pipeline.json'
,
'imagenet_dataset_pipeline_1.json'
)
)
assert
filecmp
.
cmp
(
'imagenet_dataset_pipeline.json'
,
'imagenet_dataset_pipeline_1.json'
)
# Deserialize the latest json file again
# Deserialize the latest json file again
data3
=
ds
.
deserialize
(
json_filepath
=
"imagenet_dataset_pipeline_1.json"
)
data3
=
ds
.
deserialize
(
json_filepath
=
"imagenet_dataset_pipeline_1.json"
)
...
@@ -78,16 +78,16 @@ def test_imagefolder(remove_json_files=True):
...
@@ -78,16 +78,16 @@ def test_imagefolder(remove_json_files=True):
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
for
item1
,
item2
,
item3
,
item4
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
(),
for
item1
,
item2
,
item3
,
item4
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
(),
data3
.
create_dict_iterator
(),
data4
.
create_dict_iterator
()):
data3
.
create_dict_iterator
(),
data4
.
create_dict_iterator
()):
assert
(
np
.
array_equal
(
item1
[
'image'
],
item2
[
'image'
])
)
assert
np
.
array_equal
(
item1
[
'image'
],
item2
[
'image'
]
)
assert
(
np
.
array_equal
(
item1
[
'image'
],
item3
[
'image'
])
)
assert
np
.
array_equal
(
item1
[
'image'
],
item3
[
'image'
]
)
assert
(
np
.
array_equal
(
item1
[
'label'
],
item2
[
'label'
])
)
assert
np
.
array_equal
(
item1
[
'label'
],
item2
[
'label'
]
)
assert
(
np
.
array_equal
(
item1
[
'label'
],
item3
[
'label'
])
)
assert
np
.
array_equal
(
item1
[
'label'
],
item3
[
'label'
]
)
assert
(
np
.
array_equal
(
item3
[
'image'
],
item4
[
'image'
])
)
assert
np
.
array_equal
(
item3
[
'image'
],
item4
[
'image'
]
)
assert
(
np
.
array_equal
(
item3
[
'label'
],
item4
[
'label'
])
)
assert
np
.
array_equal
(
item3
[
'label'
],
item4
[
'label'
]
)
num_samples
+=
1
num_samples
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_samples
))
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_samples
))
assert
(
num_samples
==
6
)
assert
num_samples
==
6
# Remove the generated json file
# Remove the generated json file
if
remove_json_files
:
if
remove_json_files
:
...
@@ -106,26 +106,26 @@ def test_mnist_dataset(remove_json_files=True):
...
@@ -106,26 +106,26 @@ def test_mnist_dataset(remove_json_files=True):
data1
=
data1
.
batch
(
batch_size
=
10
,
drop_remainder
=
True
)
data1
=
data1
.
batch
(
batch_size
=
10
,
drop_remainder
=
True
)
ds
.
serialize
(
data1
,
"mnist_dataset_pipeline.json"
)
ds
.
serialize
(
data1
,
"mnist_dataset_pipeline.json"
)
assert
(
validate_jsonfile
(
"mnist_dataset_pipeline.json"
)
is
True
)
assert
validate_jsonfile
(
"mnist_dataset_pipeline.json"
)
is
True
data2
=
ds
.
deserialize
(
json_filepath
=
"mnist_dataset_pipeline.json"
)
data2
=
ds
.
deserialize
(
json_filepath
=
"mnist_dataset_pipeline.json"
)
ds
.
serialize
(
data2
,
"mnist_dataset_pipeline_1.json"
)
ds
.
serialize
(
data2
,
"mnist_dataset_pipeline_1.json"
)
assert
(
validate_jsonfile
(
"mnist_dataset_pipeline_1.json"
)
is
True
)
assert
validate_jsonfile
(
"mnist_dataset_pipeline_1.json"
)
is
True
assert
(
filecmp
.
cmp
(
'mnist_dataset_pipeline.json'
,
'mnist_dataset_pipeline_1.json'
)
)
assert
filecmp
.
cmp
(
'mnist_dataset_pipeline.json'
,
'mnist_dataset_pipeline_1.json'
)
data3
=
ds
.
deserialize
(
json_filepath
=
"mnist_dataset_pipeline_1.json"
)
data3
=
ds
.
deserialize
(
json_filepath
=
"mnist_dataset_pipeline_1.json"
)
num
=
0
num
=
0
for
data1
,
data2
,
data3
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
(),
for
data1
,
data2
,
data3
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
(),
data3
.
create_dict_iterator
()):
data3
.
create_dict_iterator
()):
assert
(
np
.
array_equal
(
data1
[
'image'
],
data2
[
'image'
])
)
assert
np
.
array_equal
(
data1
[
'image'
],
data2
[
'image'
]
)
assert
(
np
.
array_equal
(
data1
[
'image'
],
data3
[
'image'
])
)
assert
np
.
array_equal
(
data1
[
'image'
],
data3
[
'image'
]
)
assert
(
np
.
array_equal
(
data1
[
'label'
],
data2
[
'label'
])
)
assert
np
.
array_equal
(
data1
[
'label'
],
data2
[
'label'
]
)
assert
(
np
.
array_equal
(
data1
[
'label'
],
data3
[
'label'
])
)
assert
np
.
array_equal
(
data1
[
'label'
],
data3
[
'label'
]
)
num
+=
1
num
+=
1
logger
.
info
(
"mnist total num samples is {}"
.
format
(
str
(
num
)))
logger
.
info
(
"mnist total num samples is {}"
.
format
(
str
(
num
)))
assert
(
num
==
10
)
assert
num
==
10
if
remove_json_files
:
if
remove_json_files
:
delete_json_files
()
delete_json_files
()
...
@@ -146,13 +146,13 @@ def test_zip_dataset(remove_json_files=True):
...
@@ -146,13 +146,13 @@ def test_zip_dataset(remove_json_files=True):
"column_1d"
,
"column_2d"
,
"column_3d"
,
"column_binary"
])
"column_1d"
,
"column_2d"
,
"column_3d"
,
"column_binary"
])
data3
=
ds
.
zip
((
data1
,
data2
))
data3
=
ds
.
zip
((
data1
,
data2
))
ds
.
serialize
(
data3
,
"zip_dataset_pipeline.json"
)
ds
.
serialize
(
data3
,
"zip_dataset_pipeline.json"
)
assert
(
validate_jsonfile
(
"zip_dataset_pipeline.json"
)
is
True
)
assert
validate_jsonfile
(
"zip_dataset_pipeline.json"
)
is
True
assert
(
validate_jsonfile
(
"zip_dataset_pipeline_typo.json"
)
is
False
)
assert
validate_jsonfile
(
"zip_dataset_pipeline_typo.json"
)
is
False
data4
=
ds
.
deserialize
(
json_filepath
=
"zip_dataset_pipeline.json"
)
data4
=
ds
.
deserialize
(
json_filepath
=
"zip_dataset_pipeline.json"
)
ds
.
serialize
(
data4
,
"zip_dataset_pipeline_1.json"
)
ds
.
serialize
(
data4
,
"zip_dataset_pipeline_1.json"
)
assert
(
validate_jsonfile
(
"zip_dataset_pipeline_1.json"
)
is
True
)
assert
validate_jsonfile
(
"zip_dataset_pipeline_1.json"
)
is
True
assert
(
filecmp
.
cmp
(
'zip_dataset_pipeline.json'
,
'zip_dataset_pipeline_1.json'
)
)
assert
filecmp
.
cmp
(
'zip_dataset_pipeline.json'
,
'zip_dataset_pipeline_1.json'
)
rows
=
0
rows
=
0
for
d0
,
d3
,
d4
in
zip
(
ds0
,
data3
,
data4
):
for
d0
,
d3
,
d4
in
zip
(
ds0
,
data3
,
data4
):
...
@@ -165,7 +165,7 @@ def test_zip_dataset(remove_json_files=True):
...
@@ -165,7 +165,7 @@ def test_zip_dataset(remove_json_files=True):
assert
np
.
array_equal
(
t1
,
d4
[
offset
+
num_cols
])
assert
np
.
array_equal
(
t1
,
d4
[
offset
+
num_cols
])
offset
+=
1
offset
+=
1
rows
+=
1
rows
+=
1
assert
(
rows
==
12
)
assert
rows
==
12
if
remove_json_files
:
if
remove_json_files
:
delete_json_files
()
delete_json_files
()
...
@@ -197,7 +197,7 @@ def test_random_crop():
...
@@ -197,7 +197,7 @@ def test_random_crop():
for
item1
,
item1_1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data1_1
.
create_dict_iterator
(),
for
item1
,
item1_1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data1_1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
data2
.
create_dict_iterator
()):
assert
(
np
.
array_equal
(
item1
[
'image'
],
item1_1
[
'image'
])
)
assert
np
.
array_equal
(
item1
[
'image'
],
item1_1
[
'image'
]
)
image2
=
item2
[
"image"
]
image2
=
item2
[
"image"
]
...
@@ -250,7 +250,7 @@ def test_minddataset(add_and_remove_cv_file):
...
@@ -250,7 +250,7 @@ def test_minddataset(add_and_remove_cv_file):
data
=
get_data
(
CV_DIR_NAME
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
5
assert
data_set
.
get_dataset_size
()
==
5
num_iter
=
0
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
5
assert
num_iter
==
5
...
...
tests/ut/python/dataset/test_shuffle.py
浏览文件 @
abca62f4
...
@@ -120,7 +120,7 @@ def test_shuffle_05():
...
@@ -120,7 +120,7 @@ def test_shuffle_05():
def
test_shuffle_06
():
def
test_shuffle_06
():
"""
"""
Test shuffle: with set seed, both datasets
Test shuffle: with set seed, both datasets
"""
"""
logger
.
info
(
"test_shuffle_06"
)
logger
.
info
(
"test_shuffle_06"
)
# define parameters
# define parameters
...
...
tests/ut/python/dataset/test_skip.py
浏览文件 @
abca62f4
...
@@ -16,7 +16,6 @@ import numpy as np
...
@@ -16,7 +16,6 @@ import numpy as np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
DATA_DIR_TF2
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
DATA_DIR_TF2
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR_TF2
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
SCHEMA_DIR_TF2
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
...
@@ -36,7 +35,7 @@ def test_tf_skip():
...
@@ -36,7 +35,7 @@ def test_tf_skip():
data1
=
data1
.
skip
(
2
)
data1
=
data1
.
skip
(
2
)
num_iter
=
0
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
for
_
in
data1
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
1
assert
num_iter
==
1
...
...
tests/ut/python/dataset/test_sync_wait.py
浏览文件 @
abca62f4
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# ==============================================================================
# ==============================================================================
import
numpy
as
np
import
numpy
as
np
import
time
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
...
@@ -22,7 +21,7 @@ from mindspore import log as logger
...
@@ -22,7 +21,7 @@ from mindspore import log as logger
def
gen
():
def
gen
():
for
i
in
range
(
100
):
for
i
in
range
(
100
):
yield
np
.
array
(
i
),
yield
(
np
.
array
(
i
),)
class
Augment
:
class
Augment
:
...
@@ -38,7 +37,7 @@ class Augment:
...
@@ -38,7 +37,7 @@ class Augment:
def
test_simple_sync_wait
():
def
test_simple_sync_wait
():
"""
"""
Test simple sync wait: test sync in dataset pipeline
Test simple sync wait: test sync in dataset pipeline
"""
"""
logger
.
info
(
"test_simple_sync_wait"
)
logger
.
info
(
"test_simple_sync_wait"
)
batch_size
=
4
batch_size
=
4
...
@@ -51,7 +50,7 @@ def test_simple_sync_wait():
...
@@ -51,7 +50,7 @@ def test_simple_sync_wait():
count
=
0
count
=
0
for
data
in
dataset
.
create_dict_iterator
():
for
data
in
dataset
.
create_dict_iterator
():
assert
(
data
[
"input"
][
0
]
==
count
)
assert
data
[
"input"
][
0
]
==
count
count
+=
batch_size
count
+=
batch_size
data
=
{
"loss"
:
count
}
data
=
{
"loss"
:
count
}
dataset
.
sync_update
(
condition_name
=
"policy"
,
data
=
data
)
dataset
.
sync_update
(
condition_name
=
"policy"
,
data
=
data
)
...
@@ -59,7 +58,7 @@ def test_simple_sync_wait():
...
@@ -59,7 +58,7 @@ def test_simple_sync_wait():
def
test_simple_shuffle_sync
():
def
test_simple_shuffle_sync
():
"""
"""
Test simple shuffle sync: test shuffle before sync
Test simple shuffle sync: test shuffle before sync
"""
"""
logger
.
info
(
"test_simple_shuffle_sync"
)
logger
.
info
(
"test_simple_shuffle_sync"
)
shuffle_size
=
4
shuffle_size
=
4
...
@@ -83,7 +82,7 @@ def test_simple_shuffle_sync():
...
@@ -83,7 +82,7 @@ def test_simple_shuffle_sync():
def
test_two_sync
():
def
test_two_sync
():
"""
"""
Test two sync: dataset pipeline with with two sync_operators
Test two sync: dataset pipeline with with two sync_operators
"""
"""
logger
.
info
(
"test_two_sync"
)
logger
.
info
(
"test_two_sync"
)
batch_size
=
6
batch_size
=
6
...
@@ -111,7 +110,7 @@ def test_two_sync():
...
@@ -111,7 +110,7 @@ def test_two_sync():
def
test_sync_epoch
():
def
test_sync_epoch
():
"""
"""
Test sync wait with epochs: test sync with epochs in dataset pipeline
Test sync wait with epochs: test sync with epochs in dataset pipeline
"""
"""
logger
.
info
(
"test_sync_epoch"
)
logger
.
info
(
"test_sync_epoch"
)
batch_size
=
30
batch_size
=
30
...
@@ -122,11 +121,11 @@ def test_sync_epoch():
...
@@ -122,11 +121,11 @@ def test_sync_epoch():
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
for
epochs
in
range
(
3
):
for
_
in
range
(
3
):
aug
.
update
({
"loss"
:
0
})
aug
.
update
({
"loss"
:
0
})
count
=
0
count
=
0
for
data
in
dataset
.
create_dict_iterator
():
for
data
in
dataset
.
create_dict_iterator
():
assert
(
data
[
"input"
][
0
]
==
count
)
assert
data
[
"input"
][
0
]
==
count
count
+=
batch_size
count
+=
batch_size
data
=
{
"loss"
:
count
}
data
=
{
"loss"
:
count
}
dataset
.
sync_update
(
condition_name
=
"policy"
,
data
=
data
)
dataset
.
sync_update
(
condition_name
=
"policy"
,
data
=
data
)
...
@@ -134,7 +133,7 @@ def test_sync_epoch():
...
@@ -134,7 +133,7 @@ def test_sync_epoch():
def
test_multiple_iterators
():
def
test_multiple_iterators
():
"""
"""
Test sync wait with multiple iterators: will start multiple
Test sync wait with multiple iterators: will start multiple
"""
"""
logger
.
info
(
"test_sync_epoch"
)
logger
.
info
(
"test_sync_epoch"
)
batch_size
=
30
batch_size
=
30
...
@@ -153,7 +152,7 @@ def test_multiple_iterators():
...
@@ -153,7 +152,7 @@ def test_multiple_iterators():
dataset2
=
dataset2
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset2
=
dataset2
.
batch
(
batch_size
,
drop_remainder
=
True
)
for
item1
,
item2
in
zip
(
dataset
.
create_dict_iterator
(),
dataset2
.
create_dict_iterator
()):
for
item1
,
item2
in
zip
(
dataset
.
create_dict_iterator
(),
dataset2
.
create_dict_iterator
()):
assert
(
item1
[
"input"
][
0
]
==
item2
[
"input"
][
0
])
assert
item1
[
"input"
][
0
]
==
item2
[
"input"
][
0
]
data1
=
{
"loss"
:
item1
[
"input"
][
0
]}
data1
=
{
"loss"
:
item1
[
"input"
][
0
]}
data2
=
{
"loss"
:
item2
[
"input"
][
0
]}
data2
=
{
"loss"
:
item2
[
"input"
][
0
]}
dataset
.
sync_update
(
condition_name
=
"policy"
,
data
=
data1
)
dataset
.
sync_update
(
condition_name
=
"policy"
,
data
=
data1
)
...
@@ -162,7 +161,7 @@ def test_multiple_iterators():
...
@@ -162,7 +161,7 @@ def test_multiple_iterators():
def
test_sync_exception_01
():
def
test_sync_exception_01
():
"""
"""
Test sync: with shuffle in sync mode
Test sync: with shuffle in sync mode
"""
"""
logger
.
info
(
"test_sync_exception_01"
)
logger
.
info
(
"test_sync_exception_01"
)
shuffle_size
=
4
shuffle_size
=
4
...
@@ -183,7 +182,7 @@ def test_sync_exception_01():
...
@@ -183,7 +182,7 @@ def test_sync_exception_01():
def
test_sync_exception_02
():
def
test_sync_exception_02
():
"""
"""
Test sync: with duplicated condition name
Test sync: with duplicated condition name
"""
"""
logger
.
info
(
"test_sync_exception_02"
)
logger
.
info
(
"test_sync_exception_02"
)
batch_size
=
6
batch_size
=
6
...
...
tests/ut/python/dataset/test_take.py
浏览文件 @
abca62f4
...
@@ -21,13 +21,13 @@ from mindspore import log as logger
...
@@ -21,13 +21,13 @@ from mindspore import log as logger
# In generator dataset: Number of rows is 3, its value is 0, 1, 2
# In generator dataset: Number of rows is 3, its value is 0, 1, 2
def
generator
():
def
generator
():
for
i
in
range
(
3
):
for
i
in
range
(
3
):
yield
np
.
array
([
i
]),
yield
(
np
.
array
([
i
]),)
# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
def
generator_10
():
def
generator_10
():
for
i
in
range
(
10
):
for
i
in
range
(
10
):
yield
np
.
array
([
i
]),
yield
(
np
.
array
([
i
]),)
def
filter_func_ge
(
data
):
def
filter_func_ge
(
data
):
...
@@ -47,8 +47,8 @@ def test_take_01():
...
@@ -47,8 +47,8 @@ def test_take_01():
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
_
,
d
in
enumerate
(
data1
):
assert
0
==
d
[
0
][
0
]
assert
d
[
0
][
0
]
==
0
assert
sum
([
1
for
_
in
data1
])
==
2
assert
sum
([
1
for
_
in
data1
])
==
2
...
@@ -97,7 +97,7 @@ def test_take_04():
...
@@ -97,7 +97,7 @@ def test_take_04():
data1
=
data1
.
take
(
4
)
data1
=
data1
.
take
(
4
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
3
==
d
[
0
][
0
]
assert
i
%
3
==
d
[
0
][
0
]
...
@@ -113,7 +113,7 @@ def test_take_05():
...
@@ -113,7 +113,7 @@ def test_take_05():
data1
=
data1
.
take
(
2
)
data1
=
data1
.
take
(
2
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
i
,
d
in
enumerate
(
data1
):
assert
i
==
d
[
0
][
0
]
assert
i
==
d
[
0
][
0
]
...
@@ -130,7 +130,7 @@ def test_take_06():
...
@@ -130,7 +130,7 @@ def test_take_06():
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
take
(
4
)
data1
=
data1
.
take
(
4
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
3
==
d
[
0
][
0
]
assert
i
%
3
==
d
[
0
][
0
]
...
@@ -171,7 +171,7 @@ def test_take_09():
...
@@ -171,7 +171,7 @@ def test_take_09():
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
take
(
-
1
)
data1
=
data1
.
take
(
-
1
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
3
==
d
[
0
][
0
]
assert
i
%
3
==
d
[
0
][
0
]
...
@@ -188,7 +188,7 @@ def test_take_10():
...
@@ -188,7 +188,7 @@ def test_take_10():
data1
=
data1
.
take
(
-
1
)
data1
=
data1
.
take
(
-
1
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
3
==
d
[
0
][
0
]
assert
i
%
3
==
d
[
0
][
0
]
...
@@ -206,7 +206,7 @@ def test_take_11():
...
@@ -206,7 +206,7 @@ def test_take_11():
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
take
(
-
1
)
data1
=
data1
.
take
(
-
1
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
i
,
d
in
enumerate
(
data1
):
assert
2
*
(
i
%
2
)
==
d
[
0
][
0
]
assert
2
*
(
i
%
2
)
==
d
[
0
][
0
]
...
@@ -224,9 +224,9 @@ def test_take_12():
...
@@ -224,9 +224,9 @@ def test_take_12():
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
_
,
d
in
enumerate
(
data1
):
assert
0
==
d
[
0
][
0
]
assert
d
[
0
][
0
]
==
0
assert
sum
([
1
for
_
in
data1
])
==
2
assert
sum
([
1
for
_
in
data1
])
==
2
...
@@ -243,9 +243,9 @@ def test_take_13():
...
@@ -243,9 +243,9 @@ def test_take_13():
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
_
,
d
in
enumerate
(
data1
):
assert
2
==
d
[
0
][
0
]
assert
d
[
0
][
0
]
==
2
assert
sum
([
1
for
_
in
data1
])
==
2
assert
sum
([
1
for
_
in
data1
])
==
2
...
@@ -262,9 +262,9 @@ def test_take_14():
...
@@ -262,9 +262,9 @@ def test_take_14():
data1
=
data1
.
skip
(
1
)
data1
=
data1
.
skip
(
1
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
_
,
d
in
enumerate
(
data1
):
assert
2
==
d
[
0
][
0
]
assert
d
[
0
][
0
]
==
2
assert
sum
([
1
for
_
in
data1
])
==
2
assert
sum
([
1
for
_
in
data1
])
==
2
...
@@ -279,7 +279,7 @@ def test_take_15():
...
@@ -279,7 +279,7 @@ def test_take_15():
data1
=
data1
.
take
(
6
)
data1
=
data1
.
take
(
6
)
data1
=
data1
.
skip
(
2
)
data1
=
data1
.
skip
(
2
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
i
,
d
in
enumerate
(
data1
):
assert
(
i
+
2
)
==
d
[
0
][
0
]
assert
(
i
+
2
)
==
d
[
0
][
0
]
...
@@ -296,7 +296,7 @@ def test_take_16():
...
@@ -296,7 +296,7 @@ def test_take_16():
data1
=
data1
.
skip
(
3
)
data1
=
data1
.
skip
(
3
)
data1
=
data1
.
take
(
5
)
data1
=
data1
.
take
(
5
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
i
,
d
in
enumerate
(
data1
):
assert
(
i
+
3
)
==
d
[
0
][
0
]
assert
(
i
+
3
)
==
d
[
0
][
0
]
...
@@ -313,7 +313,7 @@ def test_take_17():
...
@@ -313,7 +313,7 @@ def test_take_17():
data1
=
data1
.
take
(
8
)
data1
=
data1
.
take
(
8
)
data1
=
data1
.
filter
(
predicate
=
filter_func_ge
,
num_parallel_workers
=
4
)
data1
=
data1
.
filter
(
predicate
=
filter_func_ge
,
num_parallel_workers
=
4
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
i
,
d
in
enumerate
(
data1
):
assert
i
==
d
[
0
][
0
]
assert
i
==
d
[
0
][
0
]
...
@@ -334,9 +334,9 @@ def test_take_18():
...
@@ -334,9 +334,9 @@ def test_take_18():
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
for
_
,
d
in
enumerate
(
data1
):
assert
2
==
d
[
0
][
0
]
assert
d
[
0
][
0
]
==
2
assert
sum
([
1
for
_
in
data1
])
==
2
assert
sum
([
1
for
_
in
data1
])
==
2
...
...
tests/ut/python/dataset/test_tfreader_op.py
浏览文件 @
abca62f4
...
@@ -33,7 +33,7 @@ def test_case_tf_shape():
...
@@ -33,7 +33,7 @@ def test_case_tf_shape():
for
data
in
ds1
.
create_dict_iterator
():
for
data
in
ds1
.
create_dict_iterator
():
logger
.
info
(
data
)
logger
.
info
(
data
)
output_shape
=
ds1
.
output_shapes
()
output_shape
=
ds1
.
output_shapes
()
assert
(
len
(
output_shape
[
-
1
])
==
1
)
assert
len
(
output_shape
[
-
1
])
==
1
def
test_case_tf_read_all_dataset
():
def
test_case_tf_read_all_dataset
():
...
@@ -41,7 +41,7 @@ def test_case_tf_read_all_dataset():
...
@@ -41,7 +41,7 @@ def test_case_tf_read_all_dataset():
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
assert
ds1
.
get_dataset_size
()
==
12
assert
ds1
.
get_dataset_size
()
==
12
count
=
0
count
=
0
for
data
in
ds1
.
create_tuple_iterator
():
for
_
in
ds1
.
create_tuple_iterator
():
count
+=
1
count
+=
1
assert
count
==
12
assert
count
==
12
...
@@ -51,7 +51,7 @@ def test_case_num_samples():
...
@@ -51,7 +51,7 @@ def test_case_num_samples():
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
,
num_samples
=
8
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
,
num_samples
=
8
)
assert
ds1
.
get_dataset_size
()
==
8
assert
ds1
.
get_dataset_size
()
==
8
count
=
0
count
=
0
for
data
in
ds1
.
create_dict_iterator
():
for
_
in
ds1
.
create_dict_iterator
():
count
+=
1
count
+=
1
assert
count
==
8
assert
count
==
8
...
@@ -61,7 +61,7 @@ def test_case_num_samples2():
...
@@ -61,7 +61,7 @@ def test_case_num_samples2():
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
assert
ds1
.
get_dataset_size
()
==
7
assert
ds1
.
get_dataset_size
()
==
7
count
=
0
count
=
0
for
data
in
ds1
.
create_dict_iterator
():
for
_
in
ds1
.
create_dict_iterator
():
count
+=
1
count
+=
1
assert
count
==
7
assert
count
==
7
...
@@ -70,7 +70,7 @@ def test_case_tf_shape_2():
...
@@ -70,7 +70,7 @@ def test_case_tf_shape_2():
ds1
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
)
ds1
=
ds1
.
batch
(
2
)
ds1
=
ds1
.
batch
(
2
)
output_shape
=
ds1
.
output_shapes
()
output_shape
=
ds1
.
output_shapes
()
assert
(
len
(
output_shape
[
-
1
])
==
2
)
assert
len
(
output_shape
[
-
1
])
==
2
def
test_case_tf_file
():
def
test_case_tf_file
():
...
@@ -175,10 +175,10 @@ def test_tf_record_shard():
...
@@ -175,10 +175,10 @@ def test_tf_record_shard():
assert
len
(
worker1_res
)
==
48
assert
len
(
worker1_res
)
==
48
assert
len
(
worker1_res
)
==
len
(
worker2_res
)
assert
len
(
worker1_res
)
==
len
(
worker2_res
)
# check criteria 1
# check criteria 1
for
i
in
range
(
len
(
worker1_res
)
):
for
i
,
_
in
enumerate
(
worker1_res
):
assert
(
worker1_res
[
i
]
!=
worker2_res
[
i
])
assert
worker1_res
[
i
]
!=
worker2_res
[
i
]
# check criteria 2
# check criteria 2
assert
(
set
(
worker2_res
)
==
set
(
worker1_res
)
)
assert
set
(
worker2_res
)
==
set
(
worker1_res
)
def
test_tf_shard_equal_rows
():
def
test_tf_shard_equal_rows
():
...
@@ -197,16 +197,16 @@ def test_tf_shard_equal_rows():
...
@@ -197,16 +197,16 @@ def test_tf_shard_equal_rows():
worker2_res
=
get_res
(
3
,
1
,
2
)
worker2_res
=
get_res
(
3
,
1
,
2
)
worker3_res
=
get_res
(
3
,
2
,
2
)
worker3_res
=
get_res
(
3
,
2
,
2
)
# check criteria 1
# check criteria 1
for
i
in
range
(
len
(
worker1_res
)
):
for
i
,
_
in
enumerate
(
worker1_res
):
assert
(
worker1_res
[
i
]
!=
worker2_res
[
i
])
assert
worker1_res
[
i
]
!=
worker2_res
[
i
]
assert
(
worker2_res
[
i
]
!=
worker3_res
[
i
])
assert
worker2_res
[
i
]
!=
worker3_res
[
i
]
# Confirm each worker gets same number of rows
# Confirm each worker gets same number of rows
assert
len
(
worker1_res
)
==
28
assert
len
(
worker1_res
)
==
28
assert
len
(
worker1_res
)
==
len
(
worker2_res
)
assert
len
(
worker1_res
)
==
len
(
worker2_res
)
assert
len
(
worker2_res
)
==
len
(
worker3_res
)
assert
len
(
worker2_res
)
==
len
(
worker3_res
)
worker4_res
=
get_res
(
1
,
0
,
1
)
worker4_res
=
get_res
(
1
,
0
,
1
)
assert
(
len
(
worker4_res
)
==
40
)
assert
len
(
worker4_res
)
==
40
def
test_case_tf_file_no_schema_columns_list
():
def
test_case_tf_file_no_schema_columns_list
():
...
...
tests/ut/python/dataset/test_var_batch_map.py
浏览文件 @
abca62f4
...
@@ -59,7 +59,7 @@ def test_batch_corner_cases():
...
@@ -59,7 +59,7 @@ def test_batch_corner_cases():
# to a pyfunc which makes a deep copy of the row
# to a pyfunc which makes a deep copy of the row
def
test_variable_size_batch
():
def
test_variable_size_batch
():
def
check_res
(
arr1
,
arr2
):
def
check_res
(
arr1
,
arr2
):
for
ind
in
range
(
len
(
arr1
)
):
for
ind
,
_
in
enumerate
(
arr1
):
if
not
np
.
array_equal
(
arr1
[
ind
],
np
.
array
(
arr2
[
ind
])):
if
not
np
.
array_equal
(
arr1
[
ind
],
np
.
array
(
arr2
[
ind
])):
return
False
return
False
return
len
(
arr1
)
==
len
(
arr2
)
return
len
(
arr1
)
==
len
(
arr2
)
...
@@ -143,7 +143,7 @@ def test_variable_size_batch():
...
@@ -143,7 +143,7 @@ def test_variable_size_batch():
def
test_basic_batch_map
():
def
test_basic_batch_map
():
def
check_res
(
arr1
,
arr2
):
def
check_res
(
arr1
,
arr2
):
for
ind
in
range
(
len
(
arr1
)
):
for
ind
,
_
in
enumerate
(
arr1
):
if
not
np
.
array_equal
(
arr1
[
ind
],
np
.
array
(
arr2
[
ind
])):
if
not
np
.
array_equal
(
arr1
[
ind
],
np
.
array
(
arr2
[
ind
])):
return
False
return
False
return
len
(
arr1
)
==
len
(
arr2
)
return
len
(
arr1
)
==
len
(
arr2
)
...
@@ -176,7 +176,7 @@ def test_basic_batch_map():
...
@@ -176,7 +176,7 @@ def test_basic_batch_map():
def
test_batch_multi_col_map
():
def
test_batch_multi_col_map
():
def
check_res
(
arr1
,
arr2
):
def
check_res
(
arr1
,
arr2
):
for
ind
in
range
(
len
(
arr1
)
):
for
ind
,
_
in
enumerate
(
arr1
):
if
not
np
.
array_equal
(
arr1
[
ind
],
np
.
array
(
arr2
[
ind
])):
if
not
np
.
array_equal
(
arr1
[
ind
],
np
.
array
(
arr2
[
ind
])):
return
False
return
False
return
len
(
arr1
)
==
len
(
arr2
)
return
len
(
arr1
)
==
len
(
arr2
)
...
@@ -224,7 +224,7 @@ def test_batch_multi_col_map():
...
@@ -224,7 +224,7 @@ def test_batch_multi_col_map():
def
test_var_batch_multi_col_map
():
def
test_var_batch_multi_col_map
():
def
check_res
(
arr1
,
arr2
):
def
check_res
(
arr1
,
arr2
):
for
ind
in
range
(
len
(
arr1
)
):
for
ind
,
_
in
enumerate
(
arr1
):
if
not
np
.
array_equal
(
arr1
[
ind
],
np
.
array
(
arr2
[
ind
])):
if
not
np
.
array_equal
(
arr1
[
ind
],
np
.
array
(
arr2
[
ind
])):
return
False
return
False
return
len
(
arr1
)
==
len
(
arr2
)
return
len
(
arr1
)
==
len
(
arr2
)
...
@@ -269,7 +269,7 @@ def test_var_batch_var_resize():
...
@@ -269,7 +269,7 @@ def test_var_batch_var_resize():
return
([
np
.
copy
(
c
[
0
:
s
,
0
:
s
,
:])
for
c
in
col
],)
return
([
np
.
copy
(
c
[
0
:
s
,
0
:
s
,
:])
for
c
in
col
],)
def
add_one
(
batchInfo
):
def
add_one
(
batchInfo
):
return
(
batchInfo
.
get_batch_num
()
+
1
)
return
batchInfo
.
get_batch_num
()
+
1
data1
=
ds
.
ImageFolderDatasetV2
(
"../data/dataset/testPK/data/"
,
num_parallel_workers
=
4
,
decode
=
True
)
data1
=
ds
.
ImageFolderDatasetV2
(
"../data/dataset/testPK/data/"
,
num_parallel_workers
=
4
,
decode
=
True
)
data1
=
data1
.
batch
(
batch_size
=
add_one
,
drop_remainder
=
True
,
input_columns
=
[
"image"
],
per_batch_map
=
np_psedo_resize
)
data1
=
data1
.
batch
(
batch_size
=
add_one
,
drop_remainder
=
True
,
input_columns
=
[
"image"
],
per_batch_map
=
np_psedo_resize
)
...
@@ -303,7 +303,7 @@ def test_exception():
...
@@ -303,7 +303,7 @@ def test_exception():
data2
=
ds
.
GeneratorDataset
((
lambda
:
gen
(
100
)),
[
"num"
]).
batch
(
4
,
input_columns
=
[
"num"
],
per_batch_map
=
bad_map_func
)
data2
=
ds
.
GeneratorDataset
((
lambda
:
gen
(
100
)),
[
"num"
]).
batch
(
4
,
input_columns
=
[
"num"
],
per_batch_map
=
bad_map_func
)
try
:
try
:
for
item
in
data2
.
create_dict_iterator
():
for
_
in
data2
.
create_dict_iterator
():
pass
pass
assert
False
assert
False
except
RuntimeError
:
except
RuntimeError
:
...
...
tests/ut/python/mindrecord/test_mindrecord_base.py
浏览文件 @
abca62f4
...
@@ -13,9 +13,9 @@
...
@@ -13,9 +13,9 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""test mindrecord base"""
"""test mindrecord base"""
import
numpy
as
np
import
os
import
os
import
uuid
import
uuid
import
numpy
as
np
from
utils
import
get_data
,
get_nlp_data
from
utils
import
get_data
,
get_nlp_data
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
...
...
tests/ut/python/mindrecord/test_mindrecord_multi_images.py
浏览文件 @
abca62f4
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""test write multiple images"""
"""test write multiple images"""
import
numpy
as
np
import
os
import
os
import
numpy
as
np
from
utils
import
get_two_bytes_data
,
get_multi_bytes_data
from
utils
import
get_two_bytes_data
,
get_multi_bytes_data
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
...
...
tests/ut/python/mindrecord/test_mnist_to_mr.py
浏览文件 @
abca62f4
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
"""test mnist to mindrecord tool"""
"""test mnist to mindrecord tool"""
import
gzip
import
gzip
import
os
import
os
import
numpy
as
np
import
cv2
import
cv2
import
numpy
as
np
import
pytest
import
pytest
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录