提交 ae1ed327 编写于 作者: C Cathy Wong

Cleanup dataset UT: Remove unneeded data files and tests

上级 b23fc4e4
...@@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) { ...@@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
auto my_tree = std::make_shared<ExecutionTree>(); auto my_tree = std::make_shared<ExecutionTree>();
// Creating TFReaderOp // Creating TFReaderOp
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data"; std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::shared_ptr<TFReaderOp> my_tfreader_op; std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder() rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path}) .SetDatasetFilesList({dataset_path})
......
...@@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) { ...@@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
auto my_tree = std::make_shared<ExecutionTree>(); auto my_tree = std::make_shared<ExecutionTree>();
// Creating TFReaderOp // Creating TFReaderOp
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data"; std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data"; std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op; std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder() rc = TFReaderOp::Builder()
...@@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) { ...@@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
MS_LOG(INFO) << "UT test TestZipRepeat."; MS_LOG(INFO) << "UT test TestZipRepeat.";
auto my_tree = std::make_shared<ExecutionTree>(); auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data"; std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data"; std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op; std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder() rc = TFReaderOp::Builder()
......
{
"datasetType": "TF",
"numRows": 3,
"columns": {
"label": {
"type": "int64",
"rank": 1,
"t_impl": "flex"
}
}
}
{
"datasetType": "TF",
"numRows": 3,
"columns": {
"image": {
"type": "uint8",
"rank": 1,
"t_impl": "cvmat"
}
}
}
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_case_repeat():
"""
a simple repeat operation.
"""
logger.info("Test Simple Repeat")
# define parameters
repeat_count = 2
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
def test_case_shuffle():
"""
a simple shuffle operation.
"""
logger.info("Test Simple Shuffle")
# define parameters
buffer_size = 8
seed = 10
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
for item in data1.create_dict_iterator():
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
def test_case_0():
"""
Test Repeat then Shuffle
"""
logger.info("Test Repeat then Shuffle")
# define parameters
repeat_count = 2
buffer_size = 7
seed = 9
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
def test_case_0_reverse():
"""
Test Shuffle then Repeat
"""
logger.info("Test Shuffle then Repeat")
# define parameters
repeat_count = 2
buffer_size = 10
seed = 9
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
data1 = data1.repeat(repeat_count)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
def test_case_3():
"""
Test Map
"""
logger.info("Test Map Rescale and Resize, then Shuffle")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
# define data augmentation parameters
rescale = 1.0 / 255.0
shift = 0.0
resize_height, resize_width = 224, 224
# define map operations
decode_op = vision.Decode()
rescale_op = vision.Rescale(rescale, shift)
# resize_op = vision.Resize(resize_height, resize_width,
# InterpolationMode.DE_INTER_LINEAR) # Bilinear mode
resize_op = vision.Resize((resize_height, resize_width))
# apply map operations on images
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=rescale_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
# # apply ont-hot encoding on labels
num_classes = 4
one_hot_encode = data_trans.OneHot(num_classes) # num_classes is input argument
data1 = data1.map(input_columns=["label"], operations=one_hot_encode)
#
# # apply Datasets
buffer_size = 100
seed = 10
batch_size = 2
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size) # 10000 as in imageNet train script
data1 = data1.batch(batch_size, drop_remainder=True)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
if __name__ == '__main__':
logger.info('===========now test Repeat============')
# logger.info('Simple Repeat')
test_case_repeat()
logger.info('\n')
logger.info('===========now test Shuffle===========')
# logger.info('Simple Shuffle')
test_case_shuffle()
logger.info('\n')
# Note: cannot work with different shapes, hence not for image
# logger.info('===========now test Batch=============')
# # logger.info('Simple Batch')
# test_case_batch()
# logger.info('\n')
logger.info('===========now test case 0============')
# logger.info('Repeat then Shuffle')
test_case_0()
logger.info('\n')
logger.info('===========now test case 0 reverse============')
# # logger.info('Shuffle then Repeat')
test_case_0_reverse()
logger.info('\n')
# logger.info('===========now test case 1============')
# # logger.info('Repeat with Batch')
# test_case_1()
# logger.info('\n')
# logger.info('===========now test case 2============')
# # logger.info('Batch with Shuffle')
# test_case_2()
# logger.info('\n')
# for image augmentation only
logger.info('===========now test case 3============')
logger.info('Map then Shuffle')
test_case_3()
logger.info('\n')
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
def test_tf_file_normal():
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(1)
num_iter = 0
for _ in data1.create_dict_iterator(): # each data is a dictionary
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 12
if __name__ == '__main__':
logger.info('=======test normal=======')
test_tf_file_normal()
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" """
Testing the one_hot op in DE Testing the OneHot Op
""" """
import numpy as np import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger from mindspore import log as logger
from util import diff_mse from util import diff_mse
...@@ -37,15 +38,15 @@ def one_hot(index, depth): ...@@ -37,15 +38,15 @@ def one_hot(index, depth):
def test_one_hot(): def test_one_hot():
""" """
Test one_hot Test OneHot Tensor Operator
""" """
logger.info("Test one_hot") logger.info("test_one_hot")
depth = 10 depth = 10
# First dataset # First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
one_hot_op = data_trans.OneHot(depth) one_hot_op = data_trans.OneHot(num_classes=depth)
data1 = data1.map(input_columns=["label"], operations=one_hot_op, columns_order=["label"]) data1 = data1.map(input_columns=["label"], operations=one_hot_op, columns_order=["label"])
# Second dataset # Second dataset
...@@ -58,8 +59,54 @@ def test_one_hot(): ...@@ -58,8 +59,54 @@ def test_one_hot():
label2 = one_hot(item2["label"][0], depth) label2 = one_hot(item2["label"][0], depth)
mse = diff_mse(label1, label2) mse = diff_mse(label1, label2)
logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse)) logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse))
assert mse == 0
num_iter += 1 num_iter += 1
assert num_iter == 3
def test_one_hot_post_aug():
"""
Test One Hot Encoding after Multiple Data Augmentation Operators
"""
logger.info("test_one_hot_post_aug")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
# Define data augmentation parameters
rescale = 1.0 / 255.0
shift = 0.0
resize_height, resize_width = 224, 224
# Define map operations
decode_op = c_vision.Decode()
rescale_op = c_vision.Rescale(rescale, shift)
resize_op = c_vision.Resize((resize_height, resize_width))
# Apply map operations on images
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=rescale_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
# Apply one-hot encoding on labels
depth = 4
one_hot_encode = data_trans.OneHot(depth)
data1 = data1.map(input_columns=["label"], operations=one_hot_encode)
# Apply datasets ops
buffer_size = 100
seed = 10
batch_size = 2
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
data1 = data1.batch(batch_size, drop_remainder=True)
num_iter = 0
for item in data1.create_dict_iterator():
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
assert num_iter == 1
if __name__ == "__main__": if __name__ == "__main__":
test_one_hot() test_one_hot()
test_one_hot_post_aug()
...@@ -12,25 +12,24 @@ ...@@ -12,25 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""
Test Repeat Op
"""
import numpy as np import numpy as np
from util import save_and_check
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger
from util import save_and_check_dict
DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json" SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"]
GENERATE_GOLDEN = False
IMG_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
IMG_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN = False
def test_tf_repeat_01(): def test_tf_repeat_01():
""" """
...@@ -39,14 +38,13 @@ def test_tf_repeat_01(): ...@@ -39,14 +38,13 @@ def test_tf_repeat_01():
logger.info("Test Simple Repeat") logger.info("Test Simple Repeat")
# define parameters # define parameters
repeat_count = 2 repeat_count = 2
parameters = {"params": {'repeat_count': repeat_count}}
# apply dataset operations # apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
data1 = data1.repeat(repeat_count) data1 = data1.repeat(repeat_count)
filename = "repeat_result.npz" filename = "repeat_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_tf_repeat_02(): def test_tf_repeat_02():
...@@ -99,14 +97,13 @@ def test_tf_repeat_04(): ...@@ -99,14 +97,13 @@ def test_tf_repeat_04():
logger.info("Test Simple Repeat Column List") logger.info("Test Simple Repeat Column List")
# define parameters # define parameters
repeat_count = 2 repeat_count = 2
parameters = {"params": {'repeat_count': repeat_count}}
columns_list = ["col_sint64", "col_sint32"] columns_list = ["col_sint64", "col_sint32"]
# apply dataset operations # apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False)
data1 = data1.repeat(repeat_count) data1 = data1.repeat(repeat_count)
filename = "repeat_list_result.npz" filename = "repeat_list_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def generator(): def generator():
...@@ -115,6 +112,7 @@ def generator(): ...@@ -115,6 +112,7 @@ def generator():
def test_nested_repeat1(): def test_nested_repeat1():
logger.info("test_nested_repeat1")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2) data = data.repeat(2)
data = data.repeat(3) data = data.repeat(3)
...@@ -126,6 +124,7 @@ def test_nested_repeat1(): ...@@ -126,6 +124,7 @@ def test_nested_repeat1():
def test_nested_repeat2(): def test_nested_repeat2():
logger.info("test_nested_repeat2")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1) data = data.repeat(1)
data = data.repeat(1) data = data.repeat(1)
...@@ -137,6 +136,7 @@ def test_nested_repeat2(): ...@@ -137,6 +136,7 @@ def test_nested_repeat2():
def test_nested_repeat3(): def test_nested_repeat3():
logger.info("test_nested_repeat3")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1) data = data.repeat(1)
data = data.repeat(2) data = data.repeat(2)
...@@ -148,6 +148,7 @@ def test_nested_repeat3(): ...@@ -148,6 +148,7 @@ def test_nested_repeat3():
def test_nested_repeat4(): def test_nested_repeat4():
logger.info("test_nested_repeat4")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2) data = data.repeat(2)
data = data.repeat(1) data = data.repeat(1)
...@@ -159,6 +160,7 @@ def test_nested_repeat4(): ...@@ -159,6 +160,7 @@ def test_nested_repeat4():
def test_nested_repeat5(): def test_nested_repeat5():
logger.info("test_nested_repeat5")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(3) data = data.batch(3)
data = data.repeat(2) data = data.repeat(2)
...@@ -171,6 +173,7 @@ def test_nested_repeat5(): ...@@ -171,6 +173,7 @@ def test_nested_repeat5():
def test_nested_repeat6(): def test_nested_repeat6():
logger.info("test_nested_repeat6")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2) data = data.repeat(2)
data = data.batch(3) data = data.batch(3)
...@@ -183,6 +186,7 @@ def test_nested_repeat6(): ...@@ -183,6 +186,7 @@ def test_nested_repeat6():
def test_nested_repeat7(): def test_nested_repeat7():
logger.info("test_nested_repeat7")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2) data = data.repeat(2)
data = data.repeat(3) data = data.repeat(3)
...@@ -195,6 +199,7 @@ def test_nested_repeat7(): ...@@ -195,6 +199,7 @@ def test_nested_repeat7():
def test_nested_repeat8(): def test_nested_repeat8():
logger.info("test_nested_repeat8")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(2, drop_remainder=False) data = data.batch(2, drop_remainder=False)
data = data.repeat(2) data = data.repeat(2)
...@@ -210,6 +215,7 @@ def test_nested_repeat8(): ...@@ -210,6 +215,7 @@ def test_nested_repeat8():
def test_nested_repeat9(): def test_nested_repeat9():
logger.info("test_nested_repeat9")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat() data = data.repeat()
data = data.repeat(3) data = data.repeat(3)
...@@ -221,6 +227,7 @@ def test_nested_repeat9(): ...@@ -221,6 +227,7 @@ def test_nested_repeat9():
def test_nested_repeat10(): def test_nested_repeat10():
logger.info("test_nested_repeat10")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(3) data = data.repeat(3)
data = data.repeat() data = data.repeat()
...@@ -232,6 +239,7 @@ def test_nested_repeat10(): ...@@ -232,6 +239,7 @@ def test_nested_repeat10():
def test_nested_repeat11(): def test_nested_repeat11():
logger.info("test_nested_repeat11")
data = ds.GeneratorDataset(generator, ["data"]) data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2) data = data.repeat(2)
data = data.repeat(3) data = data.repeat(3)
......
...@@ -12,21 +12,30 @@ ...@@ -12,21 +12,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""
Test TFRecordDataset Ops
"""
import numpy as np import numpy as np
import pytest import pytest
from util import save_and_check
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from util import save_and_check_dict
FILES = ["../data/dataset/testTFTestAllTypes/test.data"] FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
DATASET_ROOT = "../data/dataset/testTFTestAllTypes/" DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
GENERATE_GOLDEN = False GENERATE_GOLDEN = False
def test_case_tf_shape(): def test_tfrecord_shape():
logger.info("test_tfrecord_shape")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json" schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds.TFRecordDataset(FILES, schema_file)
ds1 = ds1.batch(2) ds1 = ds1.batch(2)
...@@ -36,7 +45,8 @@ def test_case_tf_shape(): ...@@ -36,7 +45,8 @@ def test_case_tf_shape():
assert len(output_shape[-1]) == 1 assert len(output_shape[-1]) == 1
def test_case_tf_read_all_dataset(): def test_tfrecord_read_all_dataset():
logger.info("test_tfrecord_read_all_dataset")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 12 assert ds1.get_dataset_size() == 12
...@@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset(): ...@@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset():
assert count == 12 assert count == 12
def test_case_num_samples(): def test_tfrecord_num_samples():
logger.info("test_tfrecord_num_samples")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
assert ds1.get_dataset_size() == 8 assert ds1.get_dataset_size() == 8
...@@ -56,7 +67,8 @@ def test_case_num_samples(): ...@@ -56,7 +67,8 @@ def test_case_num_samples():
assert count == 8 assert count == 8
def test_case_num_samples2(): def test_tfrecord_num_samples2():
logger.info("test_tfrecord_num_samples2")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 7 assert ds1.get_dataset_size() == 7
...@@ -66,42 +78,41 @@ def test_case_num_samples2(): ...@@ -66,42 +78,41 @@ def test_case_num_samples2():
assert count == 7 assert count == 7
def test_case_tf_shape_2(): def test_tfrecord_shape2():
logger.info("test_tfrecord_shape2")
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
ds1 = ds1.batch(2) ds1 = ds1.batch(2)
output_shape = ds1.output_shapes() output_shape = ds1.output_shapes()
assert len(output_shape[-1]) == 2 assert len(output_shape[-1]) == 2
def test_case_tf_file(): def test_tfrecord_files_basic():
logger.info("reading data from: {}".format(FILES[0])) logger.info("test_tfrecord_files_basic")
parameters = {"params": {}}
data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
filename = "tfreader_result.npz" filename = "tfrecord_files_basic.npz"
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
def test_case_tf_file_no_schema(): def test_tfrecord_no_schema():
logger.info("reading data from: {}".format(FILES[0])) logger.info("test_tfrecord_no_schema")
parameters = {"params": {}}
data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES) data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES)
filename = "tf_file_no_schema.npz" filename = "tfrecord_no_schema.npz"
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
def test_case_tf_file_pad(): def test_tfrecord_pad():
logger.info("reading data from: {}".format(FILES[0])) logger.info("test_tfrecord_pad")
parameters = {"params": {}}
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json" schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES) data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES)
filename = "tf_file_padBytes10.npz" filename = "tfrecord_pad_bytes10.npz"
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
def test_tf_files(): def test_tfrecord_read_files():
logger.info("test_tfrecord_read_files")
pattern = DATASET_ROOT + "/test.data" pattern = DATASET_ROOT + "/test.data"
data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
assert sum([1 for _ in data]) == 12 assert sum([1 for _ in data]) == 12
...@@ -123,7 +134,19 @@ def test_tf_files(): ...@@ -123,7 +134,19 @@ def test_tf_files():
assert sum([1 for _ in data]) == 24 assert sum([1 for _ in data]) == 24
def test_tf_record_schema(): def test_tfrecord_multi_files():
logger.info("test_tfrecord_multi_files")
data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False)
data1 = data1.repeat(1)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter += 1
assert num_iter == 12
def test_tfrecord_schema():
logger.info("test_tfrecord_schema")
schema = ds.Schema() schema = ds.Schema()
schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
...@@ -142,7 +165,8 @@ def test_tf_record_schema(): ...@@ -142,7 +165,8 @@ def test_tf_record_schema():
assert np.array_equal(t1, t2) assert np.array_equal(t1, t2)
def test_tf_record_shuffle(): def test_tfrecord_shuffle():
logger.info("test_tfrecord_shuffle")
ds.config.set_seed(1) ds.config.set_seed(1)
data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL) data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
...@@ -153,7 +177,8 @@ def test_tf_record_shuffle(): ...@@ -153,7 +177,8 @@ def test_tf_record_shuffle():
assert np.array_equal(t1, t2) assert np.array_equal(t1, t2)
def test_tf_record_shard(): def test_tfrecord_shard():
logger.info("test_tfrecord_shard")
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
...@@ -181,7 +206,8 @@ def test_tf_record_shard(): ...@@ -181,7 +206,8 @@ def test_tf_record_shard():
assert set(worker2_res) == set(worker1_res) assert set(worker2_res) == set(worker1_res)
def test_tf_shard_equal_rows(): def test_tfrecord_shard_equal_rows():
logger.info("test_tfrecord_shard_equal_rows")
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
...@@ -209,7 +235,8 @@ def test_tf_shard_equal_rows(): ...@@ -209,7 +235,8 @@ def test_tf_shard_equal_rows():
assert len(worker4_res) == 40 assert len(worker4_res) == 40
def test_case_tf_file_no_schema_columns_list(): def test_tfrecord_no_schema_columns_list():
logger.info("test_tfrecord_no_schema_columns_list")
data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"]) data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
row = data.create_dict_iterator().get_next() row = data.create_dict_iterator().get_next()
assert row["col_sint16"] == [-32768] assert row["col_sint16"] == [-32768]
...@@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list(): ...@@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list():
assert "col_sint32" in str(info.value) assert "col_sint32" in str(info.value)
def test_tf_record_schema_columns_list(): def test_tfrecord_schema_columns_list():
logger.info("test_tfrecord_schema_columns_list")
schema = ds.Schema() schema = ds.Schema()
schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
...@@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list(): ...@@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list():
assert "col_sint32" in str(info.value) assert "col_sint32" in str(info.value)
def test_case_invalid_files(): def test_tfrecord_invalid_files():
logger.info("test_tfrecord_invalid_files")
valid_file = "../data/dataset/testTFTestAllTypes/test.data" valid_file = "../data/dataset/testTFTestAllTypes/test.data"
invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt" invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
files = [invalid_file, valid_file, SCHEMA_FILE] files = [invalid_file, valid_file, SCHEMA_FILE]
...@@ -266,19 +295,20 @@ def test_case_invalid_files(): ...@@ -266,19 +295,20 @@ def test_case_invalid_files():
if __name__ == '__main__': if __name__ == '__main__':
test_case_tf_shape() test_tfrecord_shape()
test_case_tf_read_all_dataset() test_tfrecord_read_all_dataset()
test_case_num_samples() test_tfrecord_num_samples()
test_case_num_samples2() test_tfrecord_num_samples2()
test_case_tf_shape_2() test_tfrecord_shape2()
test_case_tf_file() test_tfrecord_files_basic()
test_case_tf_file_no_schema() test_tfrecord_no_schema()
test_case_tf_file_pad() test_tfrecord_pad()
test_tf_files() test_tfrecord_read_files()
test_tf_record_schema() test_tfrecord_multi_files()
test_tf_record_shuffle() test_tfrecord_schema()
test_tf_record_shard() test_tfrecord_shuffle()
test_tf_shard_equal_rows() test_tfrecord_shard()
test_case_tf_file_no_schema_columns_list() test_tfrecord_shard_equal_rows()
test_tf_record_schema_columns_list() test_tfrecord_no_schema_columns_list()
test_case_invalid_files() test_tfrecord_schema_columns_list()
test_tfrecord_invalid_files()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册