Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
ae1ed327
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ae1ed327
编写于
7月 13, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleanup dataset UT: Remove unneeded data files and tests
上级
b23fc4e4
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
148 addition
and
329 deletion
+148
-329
tests/ut/cpp/dataset/rename_op_test.cc
tests/ut/cpp/dataset/rename_op_test.cc
+1
-1
tests/ut/cpp/dataset/zip_op_test.cc
tests/ut/cpp/dataset/zip_op_test.cc
+2
-2
tests/ut/data/dataset/golden/repeat_list_result.npz
tests/ut/data/dataset/golden/repeat_list_result.npz
+0
-0
tests/ut/data/dataset/golden/repeat_result.npz
tests/ut/data/dataset/golden/repeat_result.npz
+0
-0
tests/ut/data/dataset/golden/tf_file_no_schema.npz
tests/ut/data/dataset/golden/tf_file_no_schema.npz
+0
-0
tests/ut/data/dataset/golden/tf_file_padBytes10.npz
tests/ut/data/dataset/golden/tf_file_padBytes10.npz
+0
-0
tests/ut/data/dataset/golden/tfreader_result.npz
tests/ut/data/dataset/golden/tfreader_result.npz
+0
-0
tests/ut/data/dataset/golden/tfrecord_files_basic.npz
tests/ut/data/dataset/golden/tfrecord_files_basic.npz
+0
-0
tests/ut/data/dataset/golden/tfrecord_no_schema.npz
tests/ut/data/dataset/golden/tfrecord_no_schema.npz
+0
-0
tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz
tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz
+0
-0
tests/ut/data/dataset/test_tf_file_3_images_1/datasetSchema.json
...t/data/dataset/test_tf_file_3_images_1/datasetSchema.json
+0
-11
tests/ut/data/dataset/test_tf_file_3_images_1/train-0000-of-0001.data
...a/dataset/test_tf_file_3_images_1/train-0000-of-0001.data
+0
-0
tests/ut/data/dataset/test_tf_file_3_images_2/datasetSchema.json
...t/data/dataset/test_tf_file_3_images_2/datasetSchema.json
+0
-11
tests/ut/data/dataset/test_tf_file_3_images_2/train-0000-of-0001.data
...a/dataset/test_tf_file_3_images_2/train-0000-of-0001.data
+0
-0
tests/ut/python/dataset/test_datasets_imagenet.py
tests/ut/python/dataset/test_datasets_imagenet.py
+0
-204
tests/ut/python/dataset/test_datasets_imagenet_distribution.py
.../ut/python/dataset/test_datasets_imagenet_distribution.py
+0
-40
tests/ut/python/dataset/test_onehot_op.py
tests/ut/python/dataset/test_onehot_op.py
+51
-4
tests/ut/python/dataset/test_repeat.py
tests/ut/python/dataset/test_repeat.py
+19
-11
tests/ut/python/dataset/test_tfreader_op.py
tests/ut/python/dataset/test_tfreader_op.py
+75
-45
未找到文件。
tests/ut/cpp/dataset/rename_op_test.cc
浏览文件 @
ae1ed327
...
...
@@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
// Creating TFReaderOp
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images
_1
/train-0000-of-0001.data"
;
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images/train-0000-of-0001.data"
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
rc
=
TFReaderOp
::
Builder
()
.
SetDatasetFilesList
({
dataset_path
})
...
...
tests/ut/cpp/dataset/zip_op_test.cc
浏览文件 @
ae1ed327
...
...
@@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
// Creating TFReaderOp
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images
_1
/train-0000-of-0001.data"
;
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images/train-0000-of-0001.data"
;
std
::
string
dataset_path2
=
datasets_root_path_
+
"/testBatchDataset/test.data"
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
rc
=
TFReaderOp
::
Builder
()
...
...
@@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
MS_LOG
(
INFO
)
<<
"UT test TestZipRepeat."
;
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images
_1
/train-0000-of-0001.data"
;
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images/train-0000-of-0001.data"
;
std
::
string
dataset_path2
=
datasets_root_path_
+
"/testBatchDataset/test.data"
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
rc
=
TFReaderOp
::
Builder
()
...
...
tests/ut/data/dataset/golden/repeat_list_result.npz
浏览文件 @
ae1ed327
无法预览此类型文件
tests/ut/data/dataset/golden/repeat_result.npz
浏览文件 @
ae1ed327
无法预览此类型文件
tests/ut/data/dataset/golden/tf_file_no_schema.npz
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/data/dataset/golden/tf_file_padBytes10.npz
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/data/dataset/golden/tfreader_result.npz
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/data/dataset/golden/tfrecord_files_basic.npz
0 → 100644
浏览文件 @
ae1ed327
文件已添加
tests/ut/data/dataset/golden/tfrecord_no_schema.npz
0 → 100644
浏览文件 @
ae1ed327
文件已添加
tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz
0 → 100644
浏览文件 @
ae1ed327
文件已添加
tests/ut/data/dataset/test_tf_file_3_images_1/datasetSchema.json
已删除
100644 → 0
浏览文件 @
b23fc4e4
{
"datasetType"
:
"TF"
,
"numRows"
:
3
,
"columns"
:
{
"label"
:
{
"type"
:
"int64"
,
"rank"
:
1
,
"t_impl"
:
"flex"
}
}
}
tests/ut/data/dataset/test_tf_file_3_images_1/train-0000-of-0001.data
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/data/dataset/test_tf_file_3_images_2/datasetSchema.json
已删除
100644 → 0
浏览文件 @
b23fc4e4
{
"datasetType"
:
"TF"
,
"numRows"
:
3
,
"columns"
:
{
"image"
:
{
"type"
:
"uint8"
,
"rank"
:
1
,
"t_impl"
:
"cvmat"
}
}
}
tests/ut/data/dataset/test_tf_file_3_images_2/train-0000-of-0001.data
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/python/dataset/test_datasets_imagenet.py
已删除
100644 → 0
浏览文件 @
b23fc4e4
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
def
test_case_repeat
():
"""
a simple repeat operation.
"""
logger
.
info
(
"Test Simple Repeat"
)
# define parameters
repeat_count
=
2
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
def
test_case_shuffle
():
"""
a simple shuffle operation.
"""
logger
.
info
(
"Test Simple Shuffle"
)
# define parameters
buffer_size
=
8
seed
=
10
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
for
item
in
data1
.
create_dict_iterator
():
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
def
test_case_0
():
"""
Test Repeat then Shuffle
"""
logger
.
info
(
"Test Repeat then Shuffle"
)
# define parameters
repeat_count
=
2
buffer_size
=
7
seed
=
9
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
def
test_case_0_reverse
():
"""
Test Shuffle then Repeat
"""
logger
.
info
(
"Test Shuffle then Repeat"
)
# define parameters
repeat_count
=
2
buffer_size
=
10
seed
=
9
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
data1
=
data1
.
repeat
(
repeat_count
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
def
test_case_3
():
"""
Test Map
"""
logger
.
info
(
"Test Map Rescale and Resize, then Shuffle"
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
# define data augmentation parameters
rescale
=
1.0
/
255.0
shift
=
0.0
resize_height
,
resize_width
=
224
,
224
# define map operations
decode_op
=
vision
.
Decode
()
rescale_op
=
vision
.
Rescale
(
rescale
,
shift
)
# resize_op = vision.Resize(resize_height, resize_width,
# InterpolationMode.DE_INTER_LINEAR) # Bilinear mode
resize_op
=
vision
.
Resize
((
resize_height
,
resize_width
))
# apply map operations on images
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
rescale_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
resize_op
)
# # apply ont-hot encoding on labels
num_classes
=
4
one_hot_encode
=
data_trans
.
OneHot
(
num_classes
)
# num_classes is input argument
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_encode
)
#
# # apply Datasets
buffer_size
=
100
seed
=
10
batch_size
=
2
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
# 10000 as in imageNet train script
data1
=
data1
.
batch
(
batch_size
,
drop_remainder
=
True
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
if
__name__
==
'__main__'
:
logger
.
info
(
'===========now test Repeat============'
)
# logger.info('Simple Repeat')
test_case_repeat
()
logger
.
info
(
'
\n
'
)
logger
.
info
(
'===========now test Shuffle==========='
)
# logger.info('Simple Shuffle')
test_case_shuffle
()
logger
.
info
(
'
\n
'
)
# Note: cannot work with different shapes, hence not for image
# logger.info('===========now test Batch=============')
# # logger.info('Simple Batch')
# test_case_batch()
# logger.info('\n')
logger
.
info
(
'===========now test case 0============'
)
# logger.info('Repeat then Shuffle')
test_case_0
()
logger
.
info
(
'
\n
'
)
logger
.
info
(
'===========now test case 0 reverse============'
)
# # logger.info('Shuffle then Repeat')
test_case_0_reverse
()
logger
.
info
(
'
\n
'
)
# logger.info('===========now test case 1============')
# # logger.info('Repeat with Batch')
# test_case_1()
# logger.info('\n')
# logger.info('===========now test case 2============')
# # logger.info('Batch with Shuffle')
# test_case_2()
# logger.info('\n')
# for image augmentation only
logger
.
info
(
'===========now test case 3============'
)
logger
.
info
(
'Map then Shuffle'
)
test_case_3
()
logger
.
info
(
'
\n
'
)
tests/ut/python/dataset/test_datasets_imagenet_distribution.py
已删除
100644 → 0
浏览文件 @
b23fc4e4
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images2/datasetSchema.json"
def
test_tf_file_normal
():
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
1
)
num_iter
=
0
for
_
in
data1
.
create_dict_iterator
():
# each data is a dictionary
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
num_iter
==
12
if
__name__
==
'__main__'
:
logger
.
info
(
'=======test normal======='
)
test_tf_file_normal
()
tests/ut/python/dataset/test_onehot_op.py
浏览文件 @
ae1ed327
...
...
@@ -13,12 +13,13 @@
# limitations under the License.
# ==============================================================================
"""
Testing the
one_hot op in DE
Testing the
OneHot Op
"""
import
numpy
as
np
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
from
mindspore
import
log
as
logger
from
util
import
diff_mse
...
...
@@ -37,15 +38,15 @@ def one_hot(index, depth):
def
test_one_hot
():
"""
Test
one_hot
Test
OneHot Tensor Operator
"""
logger
.
info
(
"
Test
one_hot"
)
logger
.
info
(
"
test_
one_hot"
)
depth
=
10
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
depth
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
depth
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
,
columns_order
=
[
"label"
])
# Second dataset
...
...
@@ -58,8 +59,54 @@ def test_one_hot():
label2
=
one_hot
(
item2
[
"label"
][
0
],
depth
)
mse
=
diff_mse
(
label1
,
label2
)
logger
.
info
(
"DE one_hot: {}, Numpy one_hot: {}, diff: {}"
.
format
(
label1
,
label2
,
mse
))
assert
mse
==
0
num_iter
+=
1
assert
num_iter
==
3
def
test_one_hot_post_aug
():
"""
Test One Hot Encoding after Multiple Data Augmentation Operators
"""
logger
.
info
(
"test_one_hot_post_aug"
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
# Define data augmentation parameters
rescale
=
1.0
/
255.0
shift
=
0.0
resize_height
,
resize_width
=
224
,
224
# Define map operations
decode_op
=
c_vision
.
Decode
()
rescale_op
=
c_vision
.
Rescale
(
rescale
,
shift
)
resize_op
=
c_vision
.
Resize
((
resize_height
,
resize_width
))
# Apply map operations on images
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
rescale_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
resize_op
)
# Apply one-hot encoding on labels
depth
=
4
one_hot_encode
=
data_trans
.
OneHot
(
depth
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_encode
)
# Apply datasets ops
buffer_size
=
100
seed
=
10
batch_size
=
2
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
data1
=
data1
.
batch
(
batch_size
,
drop_remainder
=
True
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
==
1
if
__name__
==
"__main__"
:
test_one_hot
()
test_one_hot_post_aug
()
tests/ut/python/dataset/test_repeat.py
浏览文件 @
ae1ed327
...
...
@@ -12,25 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Repeat Op
"""
import
numpy
as
np
from
util
import
save_and_check
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
from
util
import
save_and_check_dict
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
SCHEMA_DIR_TF
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS_TF
=
[
"col_1d"
,
"col_2d"
,
"col_3d"
,
"col_binary"
,
"col_float"
,
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
GENERATE_GOLDEN
=
False
IMG_DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
IMG_SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
DATA_DIR_TF2
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR_TF2
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN
=
False
def
test_tf_repeat_01
():
"""
...
...
@@ -39,14 +38,13 @@ def test_tf_repeat_01():
logger
.
info
(
"Test Simple Repeat"
)
# define parameters
repeat_count
=
2
parameters
=
{
"params"
:
{
'repeat_count'
:
repeat_count
}}
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF
,
SCHEMA_DIR_TF
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"repeat_result.npz"
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check
_dict
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_tf_repeat_02
():
...
...
@@ -99,14 +97,13 @@ def test_tf_repeat_04():
logger
.
info
(
"Test Simple Repeat Column List"
)
# define parameters
repeat_count
=
2
parameters
=
{
"params"
:
{
'repeat_count'
:
repeat_count
}}
columns_list
=
[
"col_sint64"
,
"col_sint32"
]
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF
,
SCHEMA_DIR_TF
,
columns_list
=
columns_list
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"repeat_list_result.npz"
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check
_dict
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
generator
():
...
...
@@ -115,6 +112,7 @@ def generator():
def
test_nested_repeat1
():
logger
.
info
(
"test_nested_repeat1"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
...
...
@@ -126,6 +124,7 @@ def test_nested_repeat1():
def
test_nested_repeat2
():
logger
.
info
(
"test_nested_repeat2"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
1
)
...
...
@@ -137,6 +136,7 @@ def test_nested_repeat2():
def
test_nested_repeat3
():
logger
.
info
(
"test_nested_repeat3"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
2
)
...
...
@@ -148,6 +148,7 @@ def test_nested_repeat3():
def
test_nested_repeat4
():
logger
.
info
(
"test_nested_repeat4"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
1
)
...
...
@@ -159,6 +160,7 @@ def test_nested_repeat4():
def
test_nested_repeat5
():
logger
.
info
(
"test_nested_repeat5"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
batch
(
3
)
data
=
data
.
repeat
(
2
)
...
...
@@ -171,6 +173,7 @@ def test_nested_repeat5():
def
test_nested_repeat6
():
logger
.
info
(
"test_nested_repeat6"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
batch
(
3
)
...
...
@@ -183,6 +186,7 @@ def test_nested_repeat6():
def
test_nested_repeat7
():
logger
.
info
(
"test_nested_repeat7"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
...
...
@@ -195,6 +199,7 @@ def test_nested_repeat7():
def
test_nested_repeat8
():
logger
.
info
(
"test_nested_repeat8"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
batch
(
2
,
drop_remainder
=
False
)
data
=
data
.
repeat
(
2
)
...
...
@@ -210,6 +215,7 @@ def test_nested_repeat8():
def
test_nested_repeat9
():
logger
.
info
(
"test_nested_repeat9"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
()
data
=
data
.
repeat
(
3
)
...
...
@@ -221,6 +227,7 @@ def test_nested_repeat9():
def
test_nested_repeat10
():
logger
.
info
(
"test_nested_repeat10"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
()
...
...
@@ -232,6 +239,7 @@ def test_nested_repeat10():
def
test_nested_repeat11
():
logger
.
info
(
"test_nested_repeat11"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
...
...
tests/ut/python/dataset/test_tfreader_op.py
浏览文件 @
ae1ed327
...
...
@@ -12,21 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test TFRecordDataset Ops
"""
import
numpy
as
np
import
pytest
from
util
import
save_and_check
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
util
import
save_and_check_dict
FILES
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
DATASET_ROOT
=
"../data/dataset/testTFTestAllTypes/"
SCHEMA_FILE
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
DATA_FILES2
=
[
"../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"
]
SCHEMA_FILE2
=
"../data/dataset/test_tf_file_3_images2/datasetSchema.json"
GENERATE_GOLDEN
=
False
def
test_case_tf_shape
():
def
test_tfrecord_shape
():
logger
.
info
(
"test_tfrecord_shape"
)
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
ds1
=
ds1
.
batch
(
2
)
...
...
@@ -36,7 +45,8 @@ def test_case_tf_shape():
assert
len
(
output_shape
[
-
1
])
==
1
def
test_case_tf_read_all_dataset
():
def
test_tfrecord_read_all_dataset
():
logger
.
info
(
"test_tfrecord_read_all_dataset"
)
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
assert
ds1
.
get_dataset_size
()
==
12
...
...
@@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset():
assert
count
==
12
def
test_case_num_samples
():
def
test_tfrecord_num_samples
():
logger
.
info
(
"test_tfrecord_num_samples"
)
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
,
num_samples
=
8
)
assert
ds1
.
get_dataset_size
()
==
8
...
...
@@ -56,7 +67,8 @@ def test_case_num_samples():
assert
count
==
8
def
test_case_num_samples2
():
def
test_tfrecord_num_samples2
():
logger
.
info
(
"test_tfrecord_num_samples2"
)
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
assert
ds1
.
get_dataset_size
()
==
7
...
...
@@ -66,42 +78,41 @@ def test_case_num_samples2():
assert
count
==
7
def
test_case_tf_shape_2
():
def
test_tfrecord_shape2
():
logger
.
info
(
"test_tfrecord_shape2"
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
)
ds1
=
ds1
.
batch
(
2
)
output_shape
=
ds1
.
output_shapes
()
assert
len
(
output_shape
[
-
1
])
==
2
def
test_case_tf_file
():
logger
.
info
(
"reading data from: {}"
.
format
(
FILES
[
0
]))
parameters
=
{
"params"
:
{}}
def
test_tfrecord_files_basic
():
logger
.
info
(
"test_tfrecord_files_basic"
)
data
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
filename
=
"tfre
ader_result
.npz"
save_and_check
(
data
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
filename
=
"tfre
cord_files_basic
.npz"
save_and_check
_dict
(
data
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_tf_file_no_schema
():
logger
.
info
(
"reading data from: {}"
.
format
(
FILES
[
0
]))
parameters
=
{
"params"
:
{}}
def
test_tfrecord_no_schema
():
logger
.
info
(
"test_tfrecord_no_schema"
)
data
=
ds
.
TFRecordDataset
(
FILES
,
shuffle
=
ds
.
Shuffle
.
FILES
)
filename
=
"tf
_file
_no_schema.npz"
save_and_check
(
data
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
filename
=
"tf
record
_no_schema.npz"
save_and_check
_dict
(
data
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_tf_file_pad
():
logger
.
info
(
"reading data from: {}"
.
format
(
FILES
[
0
]))
parameters
=
{
"params"
:
{}}
def
test_tfrecord_pad
():
logger
.
info
(
"test_tfrecord_pad"
)
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
data
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
,
shuffle
=
ds
.
Shuffle
.
FILES
)
filename
=
"tf
_file_padB
ytes10.npz"
save_and_check
(
data
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
filename
=
"tf
record_pad_b
ytes10.npz"
save_and_check
_dict
(
data
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_tf_files
():
def
test_tfrecord_read_files
():
logger
.
info
(
"test_tfrecord_read_files"
)
pattern
=
DATASET_ROOT
+
"/test.data"
data
=
ds
.
TFRecordDataset
(
pattern
,
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
assert
sum
([
1
for
_
in
data
])
==
12
...
...
@@ -123,7 +134,19 @@ def test_tf_files():
assert
sum
([
1
for
_
in
data
])
==
24
def
test_tf_record_schema
():
def
test_tfrecord_multi_files
():
logger
.
info
(
"test_tfrecord_multi_files"
)
data1
=
ds
.
TFRecordDataset
(
DATA_FILES2
,
SCHEMA_FILE2
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
1
)
num_iter
=
0
for
_
in
data1
.
create_dict_iterator
():
num_iter
+=
1
assert
num_iter
==
12
def
test_tfrecord_schema
():
logger
.
info
(
"test_tfrecord_schema"
)
schema
=
ds
.
Schema
()
schema
.
add_column
(
'col_1d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
])
schema
.
add_column
(
'col_2d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
,
2
])
...
...
@@ -142,7 +165,8 @@ def test_tf_record_schema():
assert
np
.
array_equal
(
t1
,
t2
)
def
test_tf_record_shuffle
():
def
test_tfrecord_shuffle
():
logger
.
info
(
"test_tfrecord_shuffle"
)
ds
.
config
.
set_seed
(
1
)
data1
=
ds
.
TFRecordDataset
(
FILES
,
schema
=
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
GLOBAL
)
data2
=
ds
.
TFRecordDataset
(
FILES
,
schema
=
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
...
...
@@ -153,7 +177,8 @@ def test_tf_record_shuffle():
assert
np
.
array_equal
(
t1
,
t2
)
def
test_tf_record_shard
():
def
test_tfrecord_shard
():
logger
.
info
(
"test_tfrecord_shard"
)
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
...
...
@@ -181,7 +206,8 @@ def test_tf_record_shard():
assert
set
(
worker2_res
)
==
set
(
worker1_res
)
def
test_tf_shard_equal_rows
():
def
test_tfrecord_shard_equal_rows
():
logger
.
info
(
"test_tfrecord_shard_equal_rows"
)
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
...
...
@@ -209,7 +235,8 @@ def test_tf_shard_equal_rows():
assert
len
(
worker4_res
)
==
40
def
test_case_tf_file_no_schema_columns_list
():
def
test_tfrecord_no_schema_columns_list
():
logger
.
info
(
"test_tfrecord_no_schema_columns_list"
)
data
=
ds
.
TFRecordDataset
(
FILES
,
shuffle
=
False
,
columns_list
=
[
"col_sint16"
])
row
=
data
.
create_dict_iterator
().
get_next
()
assert
row
[
"col_sint16"
]
==
[
-
32768
]
...
...
@@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list():
assert
"col_sint32"
in
str
(
info
.
value
)
def
test_tf_record_schema_columns_list
():
def
test_tfrecord_schema_columns_list
():
logger
.
info
(
"test_tfrecord_schema_columns_list"
)
schema
=
ds
.
Schema
()
schema
.
add_column
(
'col_1d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
])
schema
.
add_column
(
'col_2d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
,
2
])
...
...
@@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list():
assert
"col_sint32"
in
str
(
info
.
value
)
def
test_case_invalid_files
():
def
test_tfrecord_invalid_files
():
logger
.
info
(
"test_tfrecord_invalid_files"
)
valid_file
=
"../data/dataset/testTFTestAllTypes/test.data"
invalid_file
=
"../data/dataset/testTFTestAllTypes/invalidFile.txt"
files
=
[
invalid_file
,
valid_file
,
SCHEMA_FILE
]
...
...
@@ -266,19 +295,20 @@ def test_case_invalid_files():
if
__name__
==
'__main__'
:
test_case_tf_shape
()
test_case_tf_read_all_dataset
()
test_case_num_samples
()
test_case_num_samples2
()
test_case_tf_shape_2
()
test_case_tf_file
()
test_case_tf_file_no_schema
()
test_case_tf_file_pad
()
test_tf_files
()
test_tf_record_schema
()
test_tf_record_shuffle
()
test_tf_record_shard
()
test_tf_shard_equal_rows
()
test_case_tf_file_no_schema_columns_list
()
test_tf_record_schema_columns_list
()
test_case_invalid_files
()
test_tfrecord_shape
()
test_tfrecord_read_all_dataset
()
test_tfrecord_num_samples
()
test_tfrecord_num_samples2
()
test_tfrecord_shape2
()
test_tfrecord_files_basic
()
test_tfrecord_no_schema
()
test_tfrecord_pad
()
test_tfrecord_read_files
()
test_tfrecord_multi_files
()
test_tfrecord_schema
()
test_tfrecord_shuffle
()
test_tfrecord_shard
()
test_tfrecord_shard_equal_rows
()
test_tfrecord_no_schema_columns_list
()
test_tfrecord_schema_columns_list
()
test_tfrecord_invalid_files
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录