提交 ae1ed327 编写于 作者: C Cathy Wong

Cleanup dataset UT: Remove unneeded data files and tests

上级 b23fc4e4
......@@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
auto my_tree = std::make_shared<ExecutionTree>();
// Creating TFReaderOp
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})
......
......@@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
auto my_tree = std::make_shared<ExecutionTree>();
// Creating TFReaderOp
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
......@@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
MS_LOG(INFO) << "UT test TestZipRepeat.";
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
......
{
"datasetType": "TF",
"numRows": 3,
"columns": {
"label": {
"type": "int64",
"rank": 1,
"t_impl": "flex"
}
}
}
{
"datasetType": "TF",
"numRows": 3,
"columns": {
"image": {
"type": "uint8",
"rank": 1,
"t_impl": "cvmat"
}
}
}
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_case_repeat():
"""
a simple repeat operation.
"""
logger.info("Test Simple Repeat")
# define parameters
repeat_count = 2
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
def test_case_shuffle():
"""
a simple shuffle operation.
"""
logger.info("Test Simple Shuffle")
# define parameters
buffer_size = 8
seed = 10
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
for item in data1.create_dict_iterator():
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
def test_case_0():
"""
Test Repeat then Shuffle
"""
logger.info("Test Repeat then Shuffle")
# define parameters
repeat_count = 2
buffer_size = 7
seed = 9
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
def test_case_0_reverse():
"""
Test Shuffle then Repeat
"""
logger.info("Test Shuffle then Repeat")
# define parameters
repeat_count = 2
buffer_size = 10
seed = 9
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
data1 = data1.repeat(repeat_count)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
def test_case_3():
"""
Test Map
"""
logger.info("Test Map Rescale and Resize, then Shuffle")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
# define data augmentation parameters
rescale = 1.0 / 255.0
shift = 0.0
resize_height, resize_width = 224, 224
# define map operations
decode_op = vision.Decode()
rescale_op = vision.Rescale(rescale, shift)
# resize_op = vision.Resize(resize_height, resize_width,
# InterpolationMode.DE_INTER_LINEAR) # Bilinear mode
resize_op = vision.Resize((resize_height, resize_width))
# apply map operations on images
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=rescale_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
# # apply ont-hot encoding on labels
num_classes = 4
one_hot_encode = data_trans.OneHot(num_classes) # num_classes is input argument
data1 = data1.map(input_columns=["label"], operations=one_hot_encode)
#
# # apply Datasets
buffer_size = 100
seed = 10
batch_size = 2
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size) # 10000 as in imageNet train script
data1 = data1.batch(batch_size, drop_remainder=True)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
if __name__ == '__main__':
logger.info('===========now test Repeat============')
# logger.info('Simple Repeat')
test_case_repeat()
logger.info('\n')
logger.info('===========now test Shuffle===========')
# logger.info('Simple Shuffle')
test_case_shuffle()
logger.info('\n')
# Note: cannot work with different shapes, hence not for image
# logger.info('===========now test Batch=============')
# # logger.info('Simple Batch')
# test_case_batch()
# logger.info('\n')
logger.info('===========now test case 0============')
# logger.info('Repeat then Shuffle')
test_case_0()
logger.info('\n')
logger.info('===========now test case 0 reverse============')
# # logger.info('Shuffle then Repeat')
test_case_0_reverse()
logger.info('\n')
# logger.info('===========now test case 1============')
# # logger.info('Repeat with Batch')
# test_case_1()
# logger.info('\n')
# logger.info('===========now test case 2============')
# # logger.info('Batch with Shuffle')
# test_case_2()
# logger.info('\n')
# for image augmentation only
logger.info('===========now test case 3============')
logger.info('Map then Shuffle')
test_case_3()
logger.info('\n')
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
def test_tf_file_normal():
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(1)
num_iter = 0
for _ in data1.create_dict_iterator(): # each data is a dictionary
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 12
if __name__ == '__main__':
logger.info('=======test normal=======')
test_tf_file_normal()
......@@ -13,12 +13,13 @@
# limitations under the License.
# ==============================================================================
"""
Testing the one_hot op in DE
Testing the OneHot Op
"""
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
from util import diff_mse
......@@ -37,15 +38,15 @@ def one_hot(index, depth):
def test_one_hot():
"""
Test one_hot
Test OneHot Tensor Operator
"""
logger.info("Test one_hot")
logger.info("test_one_hot")
depth = 10
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
one_hot_op = data_trans.OneHot(depth)
one_hot_op = data_trans.OneHot(num_classes=depth)
data1 = data1.map(input_columns=["label"], operations=one_hot_op, columns_order=["label"])
# Second dataset
......@@ -58,8 +59,54 @@ def test_one_hot():
label2 = one_hot(item2["label"][0], depth)
mse = diff_mse(label1, label2)
logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse))
assert mse == 0
num_iter += 1
assert num_iter == 3
def test_one_hot_post_aug():
"""
Test One Hot Encoding after Multiple Data Augmentation Operators
"""
logger.info("test_one_hot_post_aug")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
# Define data augmentation parameters
rescale = 1.0 / 255.0
shift = 0.0
resize_height, resize_width = 224, 224
# Define map operations
decode_op = c_vision.Decode()
rescale_op = c_vision.Rescale(rescale, shift)
resize_op = c_vision.Resize((resize_height, resize_width))
# Apply map operations on images
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=rescale_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
# Apply one-hot encoding on labels
depth = 4
one_hot_encode = data_trans.OneHot(depth)
data1 = data1.map(input_columns=["label"], operations=one_hot_encode)
# Apply datasets ops
buffer_size = 100
seed = 10
batch_size = 2
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
data1 = data1.batch(batch_size, drop_remainder=True)
num_iter = 0
for item in data1.create_dict_iterator():
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1
assert num_iter == 1
if __name__ == "__main__":
test_one_hot()
test_one_hot_post_aug()
......@@ -12,25 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Repeat Op
"""
import numpy as np
from util import save_and_check
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
from util import save_and_check_dict
DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"]
GENERATE_GOLDEN = False
IMG_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
IMG_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN = False
def test_tf_repeat_01():
"""
......@@ -39,14 +38,13 @@ def test_tf_repeat_01():
logger.info("Test Simple Repeat")
# define parameters
repeat_count = 2
parameters = {"params": {'repeat_count': repeat_count}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
data1 = data1.repeat(repeat_count)
filename = "repeat_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_tf_repeat_02():
......@@ -99,14 +97,13 @@ def test_tf_repeat_04():
logger.info("Test Simple Repeat Column List")
# define parameters
repeat_count = 2
parameters = {"params": {'repeat_count': repeat_count}}
columns_list = ["col_sint64", "col_sint32"]
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False)
data1 = data1.repeat(repeat_count)
filename = "repeat_list_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def generator():
......@@ -115,6 +112,7 @@ def generator():
def test_nested_repeat1():
logger.info("test_nested_repeat1")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
......@@ -126,6 +124,7 @@ def test_nested_repeat1():
def test_nested_repeat2():
logger.info("test_nested_repeat2")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1)
data = data.repeat(1)
......@@ -137,6 +136,7 @@ def test_nested_repeat2():
def test_nested_repeat3():
logger.info("test_nested_repeat3")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1)
data = data.repeat(2)
......@@ -148,6 +148,7 @@ def test_nested_repeat3():
def test_nested_repeat4():
logger.info("test_nested_repeat4")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(1)
......@@ -159,6 +160,7 @@ def test_nested_repeat4():
def test_nested_repeat5():
logger.info("test_nested_repeat5")
data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(3)
data = data.repeat(2)
......@@ -171,6 +173,7 @@ def test_nested_repeat5():
def test_nested_repeat6():
logger.info("test_nested_repeat6")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.batch(3)
......@@ -183,6 +186,7 @@ def test_nested_repeat6():
def test_nested_repeat7():
logger.info("test_nested_repeat7")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
......@@ -195,6 +199,7 @@ def test_nested_repeat7():
def test_nested_repeat8():
logger.info("test_nested_repeat8")
data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(2, drop_remainder=False)
data = data.repeat(2)
......@@ -210,6 +215,7 @@ def test_nested_repeat8():
def test_nested_repeat9():
logger.info("test_nested_repeat9")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat()
data = data.repeat(3)
......@@ -221,6 +227,7 @@ def test_nested_repeat9():
def test_nested_repeat10():
logger.info("test_nested_repeat10")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(3)
data = data.repeat()
......@@ -232,6 +239,7 @@ def test_nested_repeat10():
def test_nested_repeat11():
logger.info("test_nested_repeat11")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
......
......@@ -12,21 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test TFRecordDataset Ops
"""
import numpy as np
import pytest
from util import save_and_check
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore import log as logger
from util import save_and_check_dict
FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
GENERATE_GOLDEN = False
def test_case_tf_shape():
def test_tfrecord_shape():
logger.info("test_tfrecord_shape")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
ds1 = ds1.batch(2)
......@@ -36,7 +45,8 @@ def test_case_tf_shape():
assert len(output_shape[-1]) == 1
def test_case_tf_read_all_dataset():
def test_tfrecord_read_all_dataset():
logger.info("test_tfrecord_read_all_dataset")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 12
......@@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset():
assert count == 12
def test_case_num_samples():
def test_tfrecord_num_samples():
logger.info("test_tfrecord_num_samples")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
assert ds1.get_dataset_size() == 8
......@@ -56,7 +67,8 @@ def test_case_num_samples():
assert count == 8
def test_case_num_samples2():
def test_tfrecord_num_samples2():
logger.info("test_tfrecord_num_samples2")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 7
......@@ -66,42 +78,41 @@ def test_case_num_samples2():
assert count == 7
def test_case_tf_shape_2():
def test_tfrecord_shape2():
logger.info("test_tfrecord_shape2")
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
ds1 = ds1.batch(2)
output_shape = ds1.output_shapes()
assert len(output_shape[-1]) == 2
def test_case_tf_file():
logger.info("reading data from: {}".format(FILES[0]))
parameters = {"params": {}}
def test_tfrecord_files_basic():
logger.info("test_tfrecord_files_basic")
data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
filename = "tfreader_result.npz"
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
filename = "tfrecord_files_basic.npz"
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
def test_case_tf_file_no_schema():
logger.info("reading data from: {}".format(FILES[0]))
parameters = {"params": {}}
def test_tfrecord_no_schema():
logger.info("test_tfrecord_no_schema")
data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES)
filename = "tf_file_no_schema.npz"
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
filename = "tfrecord_no_schema.npz"
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
def test_case_tf_file_pad():
logger.info("reading data from: {}".format(FILES[0]))
parameters = {"params": {}}
def test_tfrecord_pad():
logger.info("test_tfrecord_pad")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES)
filename = "tf_file_padBytes10.npz"
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
filename = "tfrecord_pad_bytes10.npz"
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
def test_tf_files():
def test_tfrecord_read_files():
logger.info("test_tfrecord_read_files")
pattern = DATASET_ROOT + "/test.data"
data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
assert sum([1 for _ in data]) == 12
......@@ -123,7 +134,19 @@ def test_tf_files():
assert sum([1 for _ in data]) == 24
def test_tf_record_schema():
def test_tfrecord_multi_files():
logger.info("test_tfrecord_multi_files")
data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False)
data1 = data1.repeat(1)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter += 1
assert num_iter == 12
def test_tfrecord_schema():
logger.info("test_tfrecord_schema")
schema = ds.Schema()
schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
......@@ -142,7 +165,8 @@ def test_tf_record_schema():
assert np.array_equal(t1, t2)
def test_tf_record_shuffle():
def test_tfrecord_shuffle():
logger.info("test_tfrecord_shuffle")
ds.config.set_seed(1)
data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
......@@ -153,7 +177,8 @@ def test_tf_record_shuffle():
assert np.array_equal(t1, t2)
def test_tf_record_shard():
def test_tfrecord_shard():
logger.info("test_tfrecord_shard")
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
......@@ -181,7 +206,8 @@ def test_tf_record_shard():
assert set(worker2_res) == set(worker1_res)
def test_tf_shard_equal_rows():
def test_tfrecord_shard_equal_rows():
logger.info("test_tfrecord_shard_equal_rows")
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
......@@ -209,7 +235,8 @@ def test_tf_shard_equal_rows():
assert len(worker4_res) == 40
def test_case_tf_file_no_schema_columns_list():
def test_tfrecord_no_schema_columns_list():
logger.info("test_tfrecord_no_schema_columns_list")
data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
row = data.create_dict_iterator().get_next()
assert row["col_sint16"] == [-32768]
......@@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list():
assert "col_sint32" in str(info.value)
def test_tf_record_schema_columns_list():
def test_tfrecord_schema_columns_list():
logger.info("test_tfrecord_schema_columns_list")
schema = ds.Schema()
schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
......@@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list():
assert "col_sint32" in str(info.value)
def test_case_invalid_files():
def test_tfrecord_invalid_files():
logger.info("test_tfrecord_invalid_files")
valid_file = "../data/dataset/testTFTestAllTypes/test.data"
invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
files = [invalid_file, valid_file, SCHEMA_FILE]
......@@ -266,19 +295,20 @@ def test_case_invalid_files():
if __name__ == '__main__':
test_case_tf_shape()
test_case_tf_read_all_dataset()
test_case_num_samples()
test_case_num_samples2()
test_case_tf_shape_2()
test_case_tf_file()
test_case_tf_file_no_schema()
test_case_tf_file_pad()
test_tf_files()
test_tf_record_schema()
test_tf_record_shuffle()
test_tf_record_shard()
test_tf_shard_equal_rows()
test_case_tf_file_no_schema_columns_list()
test_tf_record_schema_columns_list()
test_case_invalid_files()
test_tfrecord_shape()
test_tfrecord_read_all_dataset()
test_tfrecord_num_samples()
test_tfrecord_num_samples2()
test_tfrecord_shape2()
test_tfrecord_files_basic()
test_tfrecord_no_schema()
test_tfrecord_pad()
test_tfrecord_read_files()
test_tfrecord_multi_files()
test_tfrecord_schema()
test_tfrecord_shuffle()
test_tfrecord_shard()
test_tfrecord_shard_equal_rows()
test_tfrecord_no_schema_columns_list()
test_tfrecord_schema_columns_list()
test_tfrecord_invalid_files()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册