diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 2a31f300158513245100b270de955d92d4998beb..a6f8237f2f7c1f0272a15938469179e500358099 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -32,6 +32,7 @@ #include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h" #include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" #ifndef ENABLE_ANDROID #include "minddata/dataset/engine/datasetops/source/voc_op.h" #endif @@ -1503,6 +1504,56 @@ std::vector> TextFileDataset::Build() { return node_ops; } +// Validator for TFRecordDataset +bool TFRecordDataset::ValidateParams() { return true; } + +// Function to build TFRecordDataset +std::vector> TFRecordDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // Sort the datasets file in a lexicographical order + std::vector sorted_dir_files = dataset_files_; + std::sort(sorted_dir_files.begin(), sorted_dir_files.end()); + + // Create Schema Object + std::unique_ptr data_schema = std::make_unique(); + if (!schema_path_.empty()) { + RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaFile(schema_path_, columns_list_)); + } else if (schema_obj_ != nullptr) { + std::string schema_json_string = schema_obj_->to_json(); + RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, columns_list_)); + } + + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + + // Create and initalize TFReaderOp + std::shared_ptr tf_reader_op = std::make_shared( + num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema), + connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_, nullptr); + + RETURN_EMPTY_IF_ERROR(tf_reader_op->Init()); + + if (shuffle_ == ShuffleMode::kGlobal) { + // Inject ShuffleOp + + std::shared_ptr shuffle_op = nullptr; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset + RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files)); + + // Add the shuffle op after this op + RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, + rows_per_buffer_, &shuffle_op)); + node_ops.push_back(shuffle_op); + } + + // Add TFReaderOp + node_ops.push_back(tf_reader_op); + return node_ops; +} + #ifndef ENABLE_ANDROID // Constructor for VOCDataset VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index de507ac8ba43ed67e192151e2888a70e0c149a77..23c93d4aecfbcad09e7c0df63c610900e60affbb 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -32,6 +32,7 @@ #include "minddata/dataset/include/type_id.h" #include "minddata/dataset/kernels/c_func_op.h" #include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/path.h" #ifndef ENABLE_ANDROID #include "minddata/dataset/text/vocab.h" #endif @@ -69,6 +70,7 @@ class ManifestDataset; class MnistDataset; class RandomDataset; class TextFileDataset; +class TFRecordDataset; #ifndef ENABLE_ANDROID class VOCDataset; #endif @@ -320,6 +322,80 @@ std::shared_ptr TextFile(const std::vector &datase ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0); +/// \brief Function to create a TFRecordDataset +/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list +/// will be sorted in a lexicographical order. +/// \param[in] schema SchemaObj or string to schema path. (Default = nullptr, which means that the +/// meta data from the TFData file is considered the schema.) +/// \param[in] columns_list List of columns to be read. (Default = {}, read all columns) +/// \param[in] num_samples The number of samples to be included in the dataset. +/// (Default = 0 means all samples.) +/// If num_samples is 0 and numRows(parsed from schema) does not exist, read the full dataset; +/// If num_samples is 0 and numRows(parsed from schema) is greater than 0, read numRows rows; +/// If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. +/// \param[in] shuffle The mode for shuffling data every epoch. (Default = ShuffleMode::kGlobal) +/// Can be any of: +/// ShuffleMode::kFalse - No shuffling is performed. +/// ShuffleMode::kFiles - Shuffle files only. +/// ShuffleMode::kGlobal - Shuffle both the files and samples. +/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) +/// \param[in] shard_id The shard ID within num_shards. This argument should be specified only +/// when num_shards is also specified. (Default = 0) +/// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of +/// each shard may be not equal) +/// \return Shared pointer to the current TFRecordDataset +template > +std::shared_ptr TFRecord(const std::vector &dataset_files, const T &schema = nullptr, + const std::vector &columns_list = {}, int64_t num_samples = 0, + ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, + int32_t shard_id = 0, bool shard_equal_rows = false) { + if (dataset_files.empty()) { + MS_LOG(ERROR) << "TFRecordDataset: dataset_files is not specified."; + return nullptr; + } + + for (auto f : dataset_files) { + Path dataset_file(f); + if (!dataset_file.Exists()) { + MS_LOG(ERROR) << "TFRecordDataset: dataset file: [" << f << "] is invalid or does not exist."; + return nullptr; + } + } + + if (num_samples < 0) { + MS_LOG(ERROR) << "TFRecordDataset: Invalid number of samples: " << num_samples; + return nullptr; + } + + if (num_shards <= 0) { + MS_LOG(ERROR) << "TFRecordDataset: Invalid num_shards: " << num_shards; + return nullptr; + } + + if (shard_id < 0 || shard_id >= num_shards) { + MS_LOG(ERROR) << "TFRecordDataset: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards; + return nullptr; + } + std::shared_ptr ds = nullptr; + if constexpr (std::is_same::value || std::is_same>::value) { + std::shared_ptr schema_obj = schema; + ds = std::make_shared(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards, + shard_id, shard_equal_rows); + } else { + std::string schema_path = schema; + if (!schema_path.empty()) { + Path schema_file(schema_path); + if (!schema_file.Exists()) { + MS_LOG(ERROR) << "TFRecordDataset: schema path [" << schema_path << "] is invalid or does not exist."; + return nullptr; + } + } + ds = std::make_shared(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards, + shard_id, shard_equal_rows); + } + return ds; +} + #ifndef ENABLE_ANDROID /// \brief Function to create a VOCDataset /// \notes The generated dataset has multi-columns : @@ -952,6 +1028,61 @@ class TextFileDataset : public Dataset { ShuffleMode shuffle_; }; +/// \class TFRecordDataset +/// \brief A Dataset derived class to represent TFRecord dataset +class TFRecordDataset : public Dataset { + public: + /// \brief Constructor + /// \note Parameter 'schema' is the path to the schema file + TFRecordDataset(const std::vector &dataset_files, std::string schema, + const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, bool shard_equal_rows) + : dataset_files_(dataset_files), + schema_path_(schema), + columns_list_(columns_list), + num_samples_(num_samples), + shuffle_(shuffle), + num_shards_(num_shards), + shard_id_(shard_id), + shard_equal_rows_(shard_equal_rows) {} + + /// \brief Constructor + /// \note Parameter 'schema' is shared pointer to Schema object + TFRecordDataset(const std::vector &dataset_files, std::shared_ptr schema, + const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, bool shard_equal_rows) + : dataset_files_(dataset_files), + schema_obj_(schema), + columns_list_(columns_list), + num_samples_(num_samples), + shuffle_(shuffle), + num_shards_(num_shards), + shard_id_(shard_id), + shard_equal_rows_(shard_equal_rows) {} + + /// \brief Destructor + ~TFRecordDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::vector dataset_files_; + std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string + std::shared_ptr schema_obj_; // schema_obj_ schema object. + std::vector columns_list_; + int64_t num_samples_; + ShuffleMode shuffle_; + int32_t num_shards_; + int32_t shard_id_; + bool shard_equal_rows_; +}; + #ifndef ENABLE_ANDROID class VOCDataset : public Dataset { public: diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index bf478e527269b10b049c63e384a1ab554c4827f7..992b54953576f6f9e7d31a882c8e3099b1b0cb58 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3541,7 +3541,7 @@ class TFRecordDataset(SourceDataset): If the schema is not provided, the meta data from the TFData file is considered the schema. columns_list (list[str], optional): List of columns to be read (default=None, read all columns) num_samples (int, optional): number of samples(rows) to read (default=None). - If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset; + If num_samples is None and numRows(parsed from schema) does not exist, read the full dataset; If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows; If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. num_parallel_workers (int, optional): number of workers to read the data @@ -3560,8 +3560,8 @@ class TFRecordDataset(SourceDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only when num_shards is also specified. - shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number - of rows of each shard may be not equal. + shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows + is false, number of rows of each shard may be not equal. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). The cache feature is under development and is not recommended. Examples: diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 401214804d178489405a904ea0ffc682064f0a95..bbdf93f62796f14cb490a71495c8e17d4acf0ac1 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -107,9 +107,10 @@ SET(DE_UT_SRCS c_api_dataset_clue_test.cc c_api_dataset_coco_test.cc c_api_dataset_csv_test.cc - c_api_dataset_textfile_test.cc c_api_dataset_manifest_test.cc c_api_dataset_randomdata_test.cc + c_api_dataset_textfile_test.cc + c_api_dataset_tfrecord_test.cc c_api_dataset_voc_test.cc c_api_datasets_test.cc c_api_dataset_iterator_test.cc diff --git a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..83d2ed2333266eb83cb83f979a2d97e4079442f4 --- /dev/null +++ b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc @@ -0,0 +1,444 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/include/transforms.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" + +using namespace mindspore::dataset; +using namespace mindspore::dataset::api; +using mindspore::dataset::Tensor; +using mindspore::dataset::ShuffleMode; +using mindspore::dataset::TensorShape; +using mindspore::dataset::DataType; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetBasic."; + + // Create a TFRecord Dataset + std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; + std::string schema_path = datasets_root_path_ + "/test_tf_file_3_images2/datasetSchema.json"; + std::shared_ptr ds = TFRecord({file_path}, schema_path, {"image"}, 0); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr random_horizontal_flip_op = vision::RandomHorizontalFlip(0.5); + EXPECT_NE(random_horizontal_flip_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({random_horizontal_flip_op}, {}, {}, {"image"}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check column + EXPECT_EQ(row.size(), 1); + EXPECT_NE(row.find("image"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShuffle."; + // This case is to verify if the list of datafiles are sorted in lexicographical order. + // Set configuration + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_num_parallel_workers(1); + + // Create a TFRecord Dataset + std::string file1 = datasets_root_path_ + "/tf_file_dataset/test1.data"; + std::string file2 = datasets_root_path_ + "/tf_file_dataset/test2.data"; + std::string file3 = datasets_root_path_ + "/tf_file_dataset/test3.data"; + std::string file4 = datasets_root_path_ + "/tf_file_dataset/test4.data"; + std::shared_ptr ds1 = TFRecord({file4, file3, file2, file1}, "", {"scalars"}, 0, ShuffleMode::kFalse); + EXPECT_NE(ds1, nullptr); + std::shared_ptr ds2 = TFRecord({file1}, "", {"scalars"}, 0, ShuffleMode::kFalse); + EXPECT_NE(ds2, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter1 = ds1->CreateIterator(); + EXPECT_NE(iter1, nullptr); + std::shared_ptr iter2 = ds2->CreateIterator(); + EXPECT_NE(iter2, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row1; + iter1->GetNextRow(&row1); + std::unordered_map> row2; + iter2->GetNextRow(&row2); + + uint64_t i = 0; + int64_t value1 = 0; + int64_t value2 = 0; + while (row1.size() != 0 && row2.size() != 0) { + row1["scalars"]->GetItemAt(&value1, {0}); + row2["scalars"]->GetItemAt(&value2, {0}); + EXPECT_EQ(value1, value2); + iter1->GetNextRow(&row1); + iter2->GetNextRow(&row2); + i++; + } + EXPECT_EQ(i, 10); + // Manually terminate the pipeline + iter1->Stop(); + iter2->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle2) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShuffle2."; + // This case is to verify the content of the data is indeed shuffled. + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(155); + GlobalContext::config_manager()->set_num_parallel_workers(1); + + // Create a TFRecord Dataset + std::string file = datasets_root_path_ + "/tf_file_dataset/test1.data"; + std::shared_ptr ds = TFRecord({file}, nullptr, {"scalars"}, 0, ShuffleMode::kGlobal); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + std::vector expect = {9, 3, 4, 7, 2, 1, 6, 8, 10, 5}; + std::vector actual = {}; + int64_t value = 0; + uint64_t i = 0; + while (row.size() != 0) { + row["scalars"]->GetItemAt(&value, {}); + actual.push_back(value); + iter->GetNextRow(&row); + i++; + } + ASSERT_EQ(actual, expect); + EXPECT_EQ(i, 10); + // Manually terminate the pipeline + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetSchemaPath) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetSchemaPath."; + + // Create a TFRecord Dataset + std::string file_path1 = datasets_root_path_ + "/testTFTestAllTypes/test.data"; + std::string file_path2 = datasets_root_path_ + "/testTFTestAllTypes/test2.data"; + std::string schema_path = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; + std::shared_ptr ds = TFRecord({file_path2, file_path1}, schema_path, {}, 9, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check column + EXPECT_EQ(row.size(), 8); + EXPECT_NE(row.find("col_sint16"), row.end()); + EXPECT_NE(row.find("col_sint32"), row.end()); + EXPECT_NE(row.find("col_sint64"), row.end()); + EXPECT_NE(row.find("col_float"), row.end()); + EXPECT_NE(row.find("col_1d"), row.end()); + EXPECT_NE(row.find("col_2d"), row.end()); + EXPECT_NE(row.find("col_3d"), row.end()); + EXPECT_NE(row.find("col_binary"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 9); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetSchemaObj) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetSchemaObj."; + + // Create a TFRecord Dataset + std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; + std::shared_ptr schema = Schema(); + schema->add_column("col_sint16", "int16", {1}); + schema->add_column("col_float", "float32", {1}); + schema->add_column("col_2d", "int64", {2, 2}); + std::shared_ptr ds = TFRecord({file_path}, schema); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check column + EXPECT_EQ(row.size(), 3); + EXPECT_NE(row.find("col_sint16"), row.end()); + EXPECT_NE(row.find("col_float"), row.end()); + EXPECT_NE(row.find("col_2d"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + auto col_sint16 = row["col_sint16"]; + auto col_float = row["col_float"]; + auto col_2d = row["col_2d"]; + + EXPECT_EQ(col_sint16->shape(), TensorShape({1})); + EXPECT_EQ(col_float->shape(), TensorShape({1})); + EXPECT_EQ(col_2d->shape(), TensorShape({2, 2})); + + EXPECT_EQ(col_sint16->Rank(), 1); + EXPECT_EQ(col_float->Rank(), 1); + EXPECT_EQ(col_2d->Rank(), 2); + + EXPECT_EQ(col_sint16->type(), DataType::DE_INT16); + EXPECT_EQ(col_float->type(), DataType::DE_FLOAT32); + EXPECT_EQ(col_2d->type(), DataType::DE_INT64); + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 12); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetNoSchema) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetNoSchema."; + + // Create a TFRecord Dataset + std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; + std::shared_ptr schema = nullptr; + std::shared_ptr ds = TFRecord({file_path}, nullptr, {}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check column + EXPECT_EQ(row.size(), 2); + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto label = row["label"]; + + MS_LOG(INFO) << "Shape of column [image]:" << image->shape(); + MS_LOG(INFO) << "Shape of column [label]:" << label->shape(); + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetColName) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetColName."; + + // Create a TFRecord Dataset + // The dataset has two columns("image", "label") and 3 rows + std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; + std::shared_ptr ds = TFRecord({file_path}, "", {"image"}, 0); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check column + EXPECT_EQ(row.size(), 1); + EXPECT_NE(row.find("image"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShard."; + + // Create a TFRecord Dataset + // Each file has two columns("image", "label") and 3 rows + std::vector files = { + datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data", + datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data", + datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data" + }; + std::shared_ptr ds1 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, true); + EXPECT_NE(ds1, nullptr); + std::shared_ptr ds2 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, false); + EXPECT_NE(ds2, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter1 = ds1->CreateIterator(); + EXPECT_NE(iter1, nullptr); + std::shared_ptr iter2 = ds2->CreateIterator(); + EXPECT_NE(iter2, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row1; + iter1->GetNextRow(&row1); + std::unordered_map> row2; + iter2->GetNextRow(&row2); + + uint64_t i = 0; + uint64_t j = 0; + while (row1.size() != 0) { + i++; + iter1->GetNextRow(&row1); + } + + while (row2.size() != 0) { + j++; + iter2->GetNextRow(&row2); + } + + EXPECT_EQ(i, 5); + EXPECT_EQ(j, 3); + // Manually terminate the pipeline + iter1->Stop(); + iter2->Stop(); +} + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetExeception."; + + // This case expected to fail because the list of dir_path cannot be empty. + std::shared_ptr ds1 = TFRecord({}); + EXPECT_EQ(ds1, nullptr); + + // This case expected to fail because the file in dir_path is not exist. + std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; + std::shared_ptr ds2 = TFRecord({file_path, "noexist.data"}); + EXPECT_EQ(ds2, nullptr); + + // This case expected to fail because the file of schema is not exist. + std::shared_ptr ds4 = TFRecord({file_path, "notexist.json"}); + EXPECT_EQ(ds4, nullptr); + + // This case expected to fail because num_samples is negative. + std::shared_ptr ds5 = TFRecord({file_path}, "", {}, -1); + EXPECT_EQ(ds5, nullptr); + + // This case expected to fail because num_shards is negative. + std::shared_ptr ds6 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 0); + EXPECT_EQ(ds6, nullptr); + + // This case expected to fail because shard_id is out_of_bound. + std::shared_ptr ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3); + EXPECT_EQ(ds7, nullptr); +} + +TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetExeception2."; + // This case expected to fail because the input column name does not exist. + + std::string file_path1 = datasets_root_path_ + "/testTFTestAllTypes/test.data"; + std::string schema_path = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; + // Create a TFRecord Dataset + // Column "image" does not exist in the dataset + std::shared_ptr ds = TFRecord({file_path1}, schema_path, {"image"}, 10); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This attempts to create Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +}