Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
ae1ed327
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ae1ed327
编写于
7月 13, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleanup dataset UT: Remove unneeded data files and tests
上级
b23fc4e4
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
148 addition
and
329 deletion
+148
-329
tests/ut/cpp/dataset/rename_op_test.cc
tests/ut/cpp/dataset/rename_op_test.cc
+1
-1
tests/ut/cpp/dataset/zip_op_test.cc
tests/ut/cpp/dataset/zip_op_test.cc
+2
-2
tests/ut/data/dataset/golden/repeat_list_result.npz
tests/ut/data/dataset/golden/repeat_list_result.npz
+0
-0
tests/ut/data/dataset/golden/repeat_result.npz
tests/ut/data/dataset/golden/repeat_result.npz
+0
-0
tests/ut/data/dataset/golden/tf_file_no_schema.npz
tests/ut/data/dataset/golden/tf_file_no_schema.npz
+0
-0
tests/ut/data/dataset/golden/tf_file_padBytes10.npz
tests/ut/data/dataset/golden/tf_file_padBytes10.npz
+0
-0
tests/ut/data/dataset/golden/tfreader_result.npz
tests/ut/data/dataset/golden/tfreader_result.npz
+0
-0
tests/ut/data/dataset/golden/tfrecord_files_basic.npz
tests/ut/data/dataset/golden/tfrecord_files_basic.npz
+0
-0
tests/ut/data/dataset/golden/tfrecord_no_schema.npz
tests/ut/data/dataset/golden/tfrecord_no_schema.npz
+0
-0
tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz
tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz
+0
-0
tests/ut/data/dataset/test_tf_file_3_images_1/datasetSchema.json
...t/data/dataset/test_tf_file_3_images_1/datasetSchema.json
+0
-11
tests/ut/data/dataset/test_tf_file_3_images_1/train-0000-of-0001.data
...a/dataset/test_tf_file_3_images_1/train-0000-of-0001.data
+0
-0
tests/ut/data/dataset/test_tf_file_3_images_2/datasetSchema.json
...t/data/dataset/test_tf_file_3_images_2/datasetSchema.json
+0
-11
tests/ut/data/dataset/test_tf_file_3_images_2/train-0000-of-0001.data
...a/dataset/test_tf_file_3_images_2/train-0000-of-0001.data
+0
-0
tests/ut/python/dataset/test_datasets_imagenet.py
tests/ut/python/dataset/test_datasets_imagenet.py
+0
-204
tests/ut/python/dataset/test_datasets_imagenet_distribution.py
.../ut/python/dataset/test_datasets_imagenet_distribution.py
+0
-40
tests/ut/python/dataset/test_onehot_op.py
tests/ut/python/dataset/test_onehot_op.py
+51
-4
tests/ut/python/dataset/test_repeat.py
tests/ut/python/dataset/test_repeat.py
+19
-11
tests/ut/python/dataset/test_tfreader_op.py
tests/ut/python/dataset/test_tfreader_op.py
+75
-45
未找到文件。
tests/ut/cpp/dataset/rename_op_test.cc
浏览文件 @
ae1ed327
...
@@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
...
@@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
// Creating TFReaderOp
// Creating TFReaderOp
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images
_1
/train-0000-of-0001.data"
;
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images/train-0000-of-0001.data"
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
rc
=
TFReaderOp
::
Builder
()
rc
=
TFReaderOp
::
Builder
()
.
SetDatasetFilesList
({
dataset_path
})
.
SetDatasetFilesList
({
dataset_path
})
...
...
tests/ut/cpp/dataset/zip_op_test.cc
浏览文件 @
ae1ed327
...
@@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
...
@@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
// Creating TFReaderOp
// Creating TFReaderOp
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images
_1
/train-0000-of-0001.data"
;
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images/train-0000-of-0001.data"
;
std
::
string
dataset_path2
=
datasets_root_path_
+
"/testBatchDataset/test.data"
;
std
::
string
dataset_path2
=
datasets_root_path_
+
"/testBatchDataset/test.data"
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
rc
=
TFReaderOp
::
Builder
()
rc
=
TFReaderOp
::
Builder
()
...
@@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
...
@@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
MS_LOG
(
INFO
)
<<
"UT test TestZipRepeat."
;
MS_LOG
(
INFO
)
<<
"UT test TestZipRepeat."
;
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images
_1
/train-0000-of-0001.data"
;
std
::
string
dataset_path
=
datasets_root_path_
+
"/test_tf_file_3_images/train-0000-of-0001.data"
;
std
::
string
dataset_path2
=
datasets_root_path_
+
"/testBatchDataset/test.data"
;
std
::
string
dataset_path2
=
datasets_root_path_
+
"/testBatchDataset/test.data"
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
rc
=
TFReaderOp
::
Builder
()
rc
=
TFReaderOp
::
Builder
()
...
...
tests/ut/data/dataset/golden/repeat_list_result.npz
浏览文件 @
ae1ed327
无法预览此类型文件
tests/ut/data/dataset/golden/repeat_result.npz
浏览文件 @
ae1ed327
无法预览此类型文件
tests/ut/data/dataset/golden/tf_file_no_schema.npz
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/data/dataset/golden/tf_file_padBytes10.npz
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/data/dataset/golden/tfreader_result.npz
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/data/dataset/golden/tfrecord_files_basic.npz
0 → 100644
浏览文件 @
ae1ed327
文件已添加
tests/ut/data/dataset/golden/tfrecord_no_schema.npz
0 → 100644
浏览文件 @
ae1ed327
文件已添加
tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz
0 → 100644
浏览文件 @
ae1ed327
文件已添加
tests/ut/data/dataset/test_tf_file_3_images_1/datasetSchema.json
已删除
100644 → 0
浏览文件 @
b23fc4e4
{
"datasetType"
:
"TF"
,
"numRows"
:
3
,
"columns"
:
{
"label"
:
{
"type"
:
"int64"
,
"rank"
:
1
,
"t_impl"
:
"flex"
}
}
}
tests/ut/data/dataset/test_tf_file_3_images_1/train-0000-of-0001.data
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/data/dataset/test_tf_file_3_images_2/datasetSchema.json
已删除
100644 → 0
浏览文件 @
b23fc4e4
{
"datasetType"
:
"TF"
,
"numRows"
:
3
,
"columns"
:
{
"image"
:
{
"type"
:
"uint8"
,
"rank"
:
1
,
"t_impl"
:
"cvmat"
}
}
}
tests/ut/data/dataset/test_tf_file_3_images_2/train-0000-of-0001.data
已删除
100644 → 0
浏览文件 @
b23fc4e4
文件已删除
tests/ut/python/dataset/test_datasets_imagenet.py
已删除
100644 → 0
浏览文件 @
b23fc4e4
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
def
test_case_repeat
():
"""
a simple repeat operation.
"""
logger
.
info
(
"Test Simple Repeat"
)
# define parameters
repeat_count
=
2
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
def
test_case_shuffle
():
"""
a simple shuffle operation.
"""
logger
.
info
(
"Test Simple Shuffle"
)
# define parameters
buffer_size
=
8
seed
=
10
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
for
item
in
data1
.
create_dict_iterator
():
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
def
test_case_0
():
"""
Test Repeat then Shuffle
"""
logger
.
info
(
"Test Repeat then Shuffle"
)
# define parameters
repeat_count
=
2
buffer_size
=
7
seed
=
9
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
def
test_case_0_reverse
():
"""
Test Shuffle then Repeat
"""
logger
.
info
(
"Test Shuffle then Repeat"
)
# define parameters
repeat_count
=
2
buffer_size
=
10
seed
=
9
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
data1
=
data1
.
repeat
(
repeat_count
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
def
test_case_3
():
"""
Test Map
"""
logger
.
info
(
"Test Map Rescale and Resize, then Shuffle"
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
# define data augmentation parameters
rescale
=
1.0
/
255.0
shift
=
0.0
resize_height
,
resize_width
=
224
,
224
# define map operations
decode_op
=
vision
.
Decode
()
rescale_op
=
vision
.
Rescale
(
rescale
,
shift
)
# resize_op = vision.Resize(resize_height, resize_width,
# InterpolationMode.DE_INTER_LINEAR) # Bilinear mode
resize_op
=
vision
.
Resize
((
resize_height
,
resize_width
))
# apply map operations on images
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
rescale_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
resize_op
)
# # apply ont-hot encoding on labels
num_classes
=
4
one_hot_encode
=
data_trans
.
OneHot
(
num_classes
)
# num_classes is input argument
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_encode
)
#
# # apply Datasets
buffer_size
=
100
seed
=
10
batch_size
=
2
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
# 10000 as in imageNet train script
data1
=
data1
.
batch
(
batch_size
,
drop_remainder
=
True
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
if
__name__
==
'__main__'
:
logger
.
info
(
'===========now test Repeat============'
)
# logger.info('Simple Repeat')
test_case_repeat
()
logger
.
info
(
'
\n
'
)
logger
.
info
(
'===========now test Shuffle==========='
)
# logger.info('Simple Shuffle')
test_case_shuffle
()
logger
.
info
(
'
\n
'
)
# Note: cannot work with different shapes, hence not for image
# logger.info('===========now test Batch=============')
# # logger.info('Simple Batch')
# test_case_batch()
# logger.info('\n')
logger
.
info
(
'===========now test case 0============'
)
# logger.info('Repeat then Shuffle')
test_case_0
()
logger
.
info
(
'
\n
'
)
logger
.
info
(
'===========now test case 0 reverse============'
)
# # logger.info('Shuffle then Repeat')
test_case_0_reverse
()
logger
.
info
(
'
\n
'
)
# logger.info('===========now test case 1============')
# # logger.info('Repeat with Batch')
# test_case_1()
# logger.info('\n')
# logger.info('===========now test case 2============')
# # logger.info('Batch with Shuffle')
# test_case_2()
# logger.info('\n')
# for image augmentation only
logger
.
info
(
'===========now test case 3============'
)
logger
.
info
(
'Map then Shuffle'
)
test_case_3
()
logger
.
info
(
'
\n
'
)
tests/ut/python/dataset/test_datasets_imagenet_distribution.py
已删除
100644 → 0
浏览文件 @
b23fc4e4
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images2/datasetSchema.json"
def
test_tf_file_normal
():
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
1
)
num_iter
=
0
for
_
in
data1
.
create_dict_iterator
():
# each data is a dictionary
num_iter
+=
1
logger
.
info
(
"Number of data in data1: {}"
.
format
(
num_iter
))
assert
num_iter
==
12
if
__name__
==
'__main__'
:
logger
.
info
(
'=======test normal======='
)
test_tf_file_normal
()
tests/ut/python/dataset/test_onehot_op.py
浏览文件 @
ae1ed327
...
@@ -13,12 +13,13 @@
...
@@ -13,12 +13,13 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
"""
Testing the
one_hot op in DE
Testing the
OneHot Op
"""
"""
import
numpy
as
np
import
numpy
as
np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
diff_mse
from
util
import
diff_mse
...
@@ -37,15 +38,15 @@ def one_hot(index, depth):
...
@@ -37,15 +38,15 @@ def one_hot(index, depth):
def
test_one_hot
():
def
test_one_hot
():
"""
"""
Test
one_hot
Test
OneHot Tensor Operator
"""
"""
logger
.
info
(
"
Test
one_hot"
)
logger
.
info
(
"
test_
one_hot"
)
depth
=
10
depth
=
10
# First dataset
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
depth
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
depth
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
,
columns_order
=
[
"label"
])
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
,
columns_order
=
[
"label"
])
# Second dataset
# Second dataset
...
@@ -58,8 +59,54 @@ def test_one_hot():
...
@@ -58,8 +59,54 @@ def test_one_hot():
label2
=
one_hot
(
item2
[
"label"
][
0
],
depth
)
label2
=
one_hot
(
item2
[
"label"
][
0
],
depth
)
mse
=
diff_mse
(
label1
,
label2
)
mse
=
diff_mse
(
label1
,
label2
)
logger
.
info
(
"DE one_hot: {}, Numpy one_hot: {}, diff: {}"
.
format
(
label1
,
label2
,
mse
))
logger
.
info
(
"DE one_hot: {}, Numpy one_hot: {}, diff: {}"
.
format
(
label1
,
label2
,
mse
))
assert
mse
==
0
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
3
def
test_one_hot_post_aug
():
"""
Test One Hot Encoding after Multiple Data Augmentation Operators
"""
logger
.
info
(
"test_one_hot_post_aug"
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
# Define data augmentation parameters
rescale
=
1.0
/
255.0
shift
=
0.0
resize_height
,
resize_width
=
224
,
224
# Define map operations
decode_op
=
c_vision
.
Decode
()
rescale_op
=
c_vision
.
Rescale
(
rescale
,
shift
)
resize_op
=
c_vision
.
Resize
((
resize_height
,
resize_width
))
# Apply map operations on images
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
rescale_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
resize_op
)
# Apply one-hot encoding on labels
depth
=
4
one_hot_encode
=
data_trans
.
OneHot
(
depth
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_encode
)
# Apply datasets ops
buffer_size
=
100
seed
=
10
batch_size
=
2
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
data1
=
data1
.
batch
(
batch_size
,
drop_remainder
=
True
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
logger
.
info
(
"image is: {}"
.
format
(
item
[
"image"
]))
logger
.
info
(
"label is: {}"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
==
1
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_one_hot
()
test_one_hot
()
test_one_hot_post_aug
()
tests/ut/python/dataset/test_repeat.py
浏览文件 @
ae1ed327
...
@@ -12,25 +12,24 @@
...
@@ -12,25 +12,24 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
Test Repeat Op
"""
import
numpy
as
np
import
numpy
as
np
from
util
import
save_and_check
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
save_and_check_dict
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
SCHEMA_DIR_TF
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
SCHEMA_DIR_TF
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS_TF
=
[
"col_1d"
,
"col_2d"
,
"col_3d"
,
"col_binary"
,
"col_float"
,
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
GENERATE_GOLDEN
=
False
IMG_DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
IMG_SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
DATA_DIR_TF2
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
DATA_DIR_TF2
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR_TF2
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
SCHEMA_DIR_TF2
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN
=
False
def
test_tf_repeat_01
():
def
test_tf_repeat_01
():
"""
"""
...
@@ -39,14 +38,13 @@ def test_tf_repeat_01():
...
@@ -39,14 +38,13 @@ def test_tf_repeat_01():
logger
.
info
(
"Test Simple Repeat"
)
logger
.
info
(
"Test Simple Repeat"
)
# define parameters
# define parameters
repeat_count
=
2
repeat_count
=
2
parameters
=
{
"params"
:
{
'repeat_count'
:
repeat_count
}}
# apply dataset operations
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF
,
SCHEMA_DIR_TF
,
shuffle
=
False
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF
,
SCHEMA_DIR_TF
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"repeat_result.npz"
filename
=
"repeat_result.npz"
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check
_dict
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_tf_repeat_02
():
def
test_tf_repeat_02
():
...
@@ -99,14 +97,13 @@ def test_tf_repeat_04():
...
@@ -99,14 +97,13 @@ def test_tf_repeat_04():
logger
.
info
(
"Test Simple Repeat Column List"
)
logger
.
info
(
"Test Simple Repeat Column List"
)
# define parameters
# define parameters
repeat_count
=
2
repeat_count
=
2
parameters
=
{
"params"
:
{
'repeat_count'
:
repeat_count
}}
columns_list
=
[
"col_sint64"
,
"col_sint32"
]
columns_list
=
[
"col_sint64"
,
"col_sint32"
]
# apply dataset operations
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF
,
SCHEMA_DIR_TF
,
columns_list
=
columns_list
,
shuffle
=
False
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF
,
SCHEMA_DIR_TF
,
columns_list
=
columns_list
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"repeat_list_result.npz"
filename
=
"repeat_list_result.npz"
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check
_dict
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
generator
():
def
generator
():
...
@@ -115,6 +112,7 @@ def generator():
...
@@ -115,6 +112,7 @@ def generator():
def
test_nested_repeat1
():
def
test_nested_repeat1
():
logger
.
info
(
"test_nested_repeat1"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
(
3
)
...
@@ -126,6 +124,7 @@ def test_nested_repeat1():
...
@@ -126,6 +124,7 @@ def test_nested_repeat1():
def
test_nested_repeat2
():
def
test_nested_repeat2
():
logger
.
info
(
"test_nested_repeat2"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
1
)
...
@@ -137,6 +136,7 @@ def test_nested_repeat2():
...
@@ -137,6 +136,7 @@ def test_nested_repeat2():
def
test_nested_repeat3
():
def
test_nested_repeat3
():
logger
.
info
(
"test_nested_repeat3"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
2
)
...
@@ -148,6 +148,7 @@ def test_nested_repeat3():
...
@@ -148,6 +148,7 @@ def test_nested_repeat3():
def
test_nested_repeat4
():
def
test_nested_repeat4
():
logger
.
info
(
"test_nested_repeat4"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
1
)
...
@@ -159,6 +160,7 @@ def test_nested_repeat4():
...
@@ -159,6 +160,7 @@ def test_nested_repeat4():
def
test_nested_repeat5
():
def
test_nested_repeat5
():
logger
.
info
(
"test_nested_repeat5"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
batch
(
3
)
data
=
data
.
batch
(
3
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
2
)
...
@@ -171,6 +173,7 @@ def test_nested_repeat5():
...
@@ -171,6 +173,7 @@ def test_nested_repeat5():
def
test_nested_repeat6
():
def
test_nested_repeat6
():
logger
.
info
(
"test_nested_repeat6"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
2
)
data
=
data
.
batch
(
3
)
data
=
data
.
batch
(
3
)
...
@@ -183,6 +186,7 @@ def test_nested_repeat6():
...
@@ -183,6 +186,7 @@ def test_nested_repeat6():
def
test_nested_repeat7
():
def
test_nested_repeat7
():
logger
.
info
(
"test_nested_repeat7"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
(
3
)
...
@@ -195,6 +199,7 @@ def test_nested_repeat7():
...
@@ -195,6 +199,7 @@ def test_nested_repeat7():
def
test_nested_repeat8
():
def
test_nested_repeat8
():
logger
.
info
(
"test_nested_repeat8"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
batch
(
2
,
drop_remainder
=
False
)
data
=
data
.
batch
(
2
,
drop_remainder
=
False
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
2
)
...
@@ -210,6 +215,7 @@ def test_nested_repeat8():
...
@@ -210,6 +215,7 @@ def test_nested_repeat8():
def
test_nested_repeat9
():
def
test_nested_repeat9
():
logger
.
info
(
"test_nested_repeat9"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
()
data
=
data
.
repeat
()
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
(
3
)
...
@@ -221,6 +227,7 @@ def test_nested_repeat9():
...
@@ -221,6 +227,7 @@ def test_nested_repeat9():
def
test_nested_repeat10
():
def
test_nested_repeat10
():
logger
.
info
(
"test_nested_repeat10"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
()
data
=
data
.
repeat
()
...
@@ -232,6 +239,7 @@ def test_nested_repeat10():
...
@@ -232,6 +239,7 @@ def test_nested_repeat10():
def
test_nested_repeat11
():
def
test_nested_repeat11
():
logger
.
info
(
"test_nested_repeat11"
)
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
(
3
)
...
...
tests/ut/python/dataset/test_tfreader_op.py
浏览文件 @
ae1ed327
...
@@ -12,21 +12,30 @@
...
@@ -12,21 +12,30 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
Test TFRecordDataset Ops
"""
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
util
import
save_and_check
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
save_and_check_dict
FILES
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
FILES
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
DATASET_ROOT
=
"../data/dataset/testTFTestAllTypes/"
DATASET_ROOT
=
"../data/dataset/testTFTestAllTypes/"
SCHEMA_FILE
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
SCHEMA_FILE
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
DATA_FILES2
=
[
"../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data"
,
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"
]
SCHEMA_FILE2
=
"../data/dataset/test_tf_file_3_images2/datasetSchema.json"
GENERATE_GOLDEN
=
False
GENERATE_GOLDEN
=
False
def
test_case_tf_shape
():
def
test_tfrecord_shape
():
logger
.
info
(
"test_tfrecord_shape"
)
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
ds1
=
ds1
.
batch
(
2
)
ds1
=
ds1
.
batch
(
2
)
...
@@ -36,7 +45,8 @@ def test_case_tf_shape():
...
@@ -36,7 +45,8 @@ def test_case_tf_shape():
assert
len
(
output_shape
[
-
1
])
==
1
assert
len
(
output_shape
[
-
1
])
==
1
def
test_case_tf_read_all_dataset
():
def
test_tfrecord_read_all_dataset
():
logger
.
info
(
"test_tfrecord_read_all_dataset"
)
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
assert
ds1
.
get_dataset_size
()
==
12
assert
ds1
.
get_dataset_size
()
==
12
...
@@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset():
...
@@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset():
assert
count
==
12
assert
count
==
12
def
test_case_num_samples
():
def
test_tfrecord_num_samples
():
logger
.
info
(
"test_tfrecord_num_samples"
)
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
,
num_samples
=
8
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
,
num_samples
=
8
)
assert
ds1
.
get_dataset_size
()
==
8
assert
ds1
.
get_dataset_size
()
==
8
...
@@ -56,7 +67,8 @@ def test_case_num_samples():
...
@@ -56,7 +67,8 @@ def test_case_num_samples():
assert
count
==
8
assert
count
==
8
def
test_case_num_samples2
():
def
test_tfrecord_num_samples2
():
logger
.
info
(
"test_tfrecord_num_samples2"
)
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
assert
ds1
.
get_dataset_size
()
==
7
assert
ds1
.
get_dataset_size
()
==
7
...
@@ -66,42 +78,41 @@ def test_case_num_samples2():
...
@@ -66,42 +78,41 @@ def test_case_num_samples2():
assert
count
==
7
assert
count
==
7
def
test_case_tf_shape_2
():
def
test_tfrecord_shape2
():
logger
.
info
(
"test_tfrecord_shape2"
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
)
ds1
=
ds1
.
batch
(
2
)
ds1
=
ds1
.
batch
(
2
)
output_shape
=
ds1
.
output_shapes
()
output_shape
=
ds1
.
output_shapes
()
assert
len
(
output_shape
[
-
1
])
==
2
assert
len
(
output_shape
[
-
1
])
==
2
def
test_case_tf_file
():
def
test_tfrecord_files_basic
():
logger
.
info
(
"reading data from: {}"
.
format
(
FILES
[
0
]))
logger
.
info
(
"test_tfrecord_files_basic"
)
parameters
=
{
"params"
:
{}}
data
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
data
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
filename
=
"tfre
ader_result
.npz"
filename
=
"tfre
cord_files_basic
.npz"
save_and_check
(
data
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check
_dict
(
data
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_tf_file_no_schema
():
def
test_tfrecord_no_schema
():
logger
.
info
(
"reading data from: {}"
.
format
(
FILES
[
0
]))
logger
.
info
(
"test_tfrecord_no_schema"
)
parameters
=
{
"params"
:
{}}
data
=
ds
.
TFRecordDataset
(
FILES
,
shuffle
=
ds
.
Shuffle
.
FILES
)
data
=
ds
.
TFRecordDataset
(
FILES
,
shuffle
=
ds
.
Shuffle
.
FILES
)
filename
=
"tf
_file
_no_schema.npz"
filename
=
"tf
record
_no_schema.npz"
save_and_check
(
data
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check
_dict
(
data
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_tf_file_pad
():
def
test_tfrecord_pad
():
logger
.
info
(
"reading data from: {}"
.
format
(
FILES
[
0
]))
logger
.
info
(
"test_tfrecord_pad"
)
parameters
=
{
"params"
:
{}}
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
data
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
,
shuffle
=
ds
.
Shuffle
.
FILES
)
data
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
,
shuffle
=
ds
.
Shuffle
.
FILES
)
filename
=
"tf
_file_padB
ytes10.npz"
filename
=
"tf
record_pad_b
ytes10.npz"
save_and_check
(
data
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check
_dict
(
data
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_tf_files
():
def
test_tfrecord_read_files
():
logger
.
info
(
"test_tfrecord_read_files"
)
pattern
=
DATASET_ROOT
+
"/test.data"
pattern
=
DATASET_ROOT
+
"/test.data"
data
=
ds
.
TFRecordDataset
(
pattern
,
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
data
=
ds
.
TFRecordDataset
(
pattern
,
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
assert
sum
([
1
for
_
in
data
])
==
12
assert
sum
([
1
for
_
in
data
])
==
12
...
@@ -123,7 +134,19 @@ def test_tf_files():
...
@@ -123,7 +134,19 @@ def test_tf_files():
assert
sum
([
1
for
_
in
data
])
==
24
assert
sum
([
1
for
_
in
data
])
==
24
def
test_tf_record_schema
():
def
test_tfrecord_multi_files
():
logger
.
info
(
"test_tfrecord_multi_files"
)
data1
=
ds
.
TFRecordDataset
(
DATA_FILES2
,
SCHEMA_FILE2
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
1
)
num_iter
=
0
for
_
in
data1
.
create_dict_iterator
():
num_iter
+=
1
assert
num_iter
==
12
def
test_tfrecord_schema
():
logger
.
info
(
"test_tfrecord_schema"
)
schema
=
ds
.
Schema
()
schema
=
ds
.
Schema
()
schema
.
add_column
(
'col_1d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
])
schema
.
add_column
(
'col_1d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
])
schema
.
add_column
(
'col_2d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
,
2
])
schema
.
add_column
(
'col_2d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
,
2
])
...
@@ -142,7 +165,8 @@ def test_tf_record_schema():
...
@@ -142,7 +165,8 @@ def test_tf_record_schema():
assert
np
.
array_equal
(
t1
,
t2
)
assert
np
.
array_equal
(
t1
,
t2
)
def
test_tf_record_shuffle
():
def
test_tfrecord_shuffle
():
logger
.
info
(
"test_tfrecord_shuffle"
)
ds
.
config
.
set_seed
(
1
)
ds
.
config
.
set_seed
(
1
)
data1
=
ds
.
TFRecordDataset
(
FILES
,
schema
=
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
GLOBAL
)
data1
=
ds
.
TFRecordDataset
(
FILES
,
schema
=
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
GLOBAL
)
data2
=
ds
.
TFRecordDataset
(
FILES
,
schema
=
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
data2
=
ds
.
TFRecordDataset
(
FILES
,
schema
=
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
...
@@ -153,7 +177,8 @@ def test_tf_record_shuffle():
...
@@ -153,7 +177,8 @@ def test_tf_record_shuffle():
assert
np
.
array_equal
(
t1
,
t2
)
assert
np
.
array_equal
(
t1
,
t2
)
def
test_tf_record_shard
():
def
test_tfrecord_shard
():
logger
.
info
(
"test_tfrecord_shard"
)
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
...
@@ -181,7 +206,8 @@ def test_tf_record_shard():
...
@@ -181,7 +206,8 @@ def test_tf_record_shard():
assert
set
(
worker2_res
)
==
set
(
worker1_res
)
assert
set
(
worker2_res
)
==
set
(
worker1_res
)
def
test_tf_shard_equal_rows
():
def
test_tfrecord_shard_equal_rows
():
logger
.
info
(
"test_tfrecord_shard_equal_rows"
)
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
...
@@ -209,7 +235,8 @@ def test_tf_shard_equal_rows():
...
@@ -209,7 +235,8 @@ def test_tf_shard_equal_rows():
assert
len
(
worker4_res
)
==
40
assert
len
(
worker4_res
)
==
40
def
test_case_tf_file_no_schema_columns_list
():
def
test_tfrecord_no_schema_columns_list
():
logger
.
info
(
"test_tfrecord_no_schema_columns_list"
)
data
=
ds
.
TFRecordDataset
(
FILES
,
shuffle
=
False
,
columns_list
=
[
"col_sint16"
])
data
=
ds
.
TFRecordDataset
(
FILES
,
shuffle
=
False
,
columns_list
=
[
"col_sint16"
])
row
=
data
.
create_dict_iterator
().
get_next
()
row
=
data
.
create_dict_iterator
().
get_next
()
assert
row
[
"col_sint16"
]
==
[
-
32768
]
assert
row
[
"col_sint16"
]
==
[
-
32768
]
...
@@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list():
...
@@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list():
assert
"col_sint32"
in
str
(
info
.
value
)
assert
"col_sint32"
in
str
(
info
.
value
)
def
test_tf_record_schema_columns_list
():
def
test_tfrecord_schema_columns_list
():
logger
.
info
(
"test_tfrecord_schema_columns_list"
)
schema
=
ds
.
Schema
()
schema
=
ds
.
Schema
()
schema
.
add_column
(
'col_1d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
])
schema
.
add_column
(
'col_1d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
])
schema
.
add_column
(
'col_2d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
,
2
])
schema
.
add_column
(
'col_2d'
,
de_type
=
mstype
.
int64
,
shape
=
[
2
,
2
])
...
@@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list():
...
@@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list():
assert
"col_sint32"
in
str
(
info
.
value
)
assert
"col_sint32"
in
str
(
info
.
value
)
def
test_case_invalid_files
():
def
test_tfrecord_invalid_files
():
logger
.
info
(
"test_tfrecord_invalid_files"
)
valid_file
=
"../data/dataset/testTFTestAllTypes/test.data"
valid_file
=
"../data/dataset/testTFTestAllTypes/test.data"
invalid_file
=
"../data/dataset/testTFTestAllTypes/invalidFile.txt"
invalid_file
=
"../data/dataset/testTFTestAllTypes/invalidFile.txt"
files
=
[
invalid_file
,
valid_file
,
SCHEMA_FILE
]
files
=
[
invalid_file
,
valid_file
,
SCHEMA_FILE
]
...
@@ -266,19 +295,20 @@ def test_case_invalid_files():
...
@@ -266,19 +295,20 @@ def test_case_invalid_files():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_case_tf_shape
()
test_tfrecord_shape
()
test_case_tf_read_all_dataset
()
test_tfrecord_read_all_dataset
()
test_case_num_samples
()
test_tfrecord_num_samples
()
test_case_num_samples2
()
test_tfrecord_num_samples2
()
test_case_tf_shape_2
()
test_tfrecord_shape2
()
test_case_tf_file
()
test_tfrecord_files_basic
()
test_case_tf_file_no_schema
()
test_tfrecord_no_schema
()
test_case_tf_file_pad
()
test_tfrecord_pad
()
test_tf_files
()
test_tfrecord_read_files
()
test_tf_record_schema
()
test_tfrecord_multi_files
()
test_tf_record_shuffle
()
test_tfrecord_schema
()
test_tf_record_shard
()
test_tfrecord_shuffle
()
test_tf_shard_equal_rows
()
test_tfrecord_shard
()
test_case_tf_file_no_schema_columns_list
()
test_tfrecord_shard_equal_rows
()
test_tf_record_schema_columns_list
()
test_tfrecord_no_schema_columns_list
()
test_case_invalid_files
()
test_tfrecord_schema_columns_list
()
test_tfrecord_invalid_files
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录