提交 e430b405 编写于 作者: T tinazhang

add tfrecord dataset to cpp api

fix to support schema=nullptr
上级 64ced295
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h" #include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/voc_op.h" #include "minddata/dataset/engine/datasetops/source/voc_op.h"
#endif #endif
...@@ -1503,6 +1504,56 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() { ...@@ -1503,6 +1504,56 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
return node_ops; return node_ops;
} }
// Validator for TFRecordDataset
bool TFRecordDataset::ValidateParams() { return true; }
// Function to build TFRecordDataset
std::vector<std::shared_ptr<DatasetOp>> TFRecordDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// Sort the datasets file in a lexicographical order
std::vector<std::string> sorted_dir_files = dataset_files_;
std::sort(sorted_dir_files.begin(), sorted_dir_files.end());
// Create Schema Object
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
if (!schema_path_.empty()) {
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaFile(schema_path_, columns_list_));
} else if (schema_obj_ != nullptr) {
std::string schema_json_string = schema_obj_->to_json();
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, columns_list_));
}
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Create and initalize TFReaderOp
std::shared_ptr<TFReaderOp> tf_reader_op = std::make_shared<TFReaderOp>(
num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema),
connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_, nullptr);
RETURN_EMPTY_IF_ERROR(tf_reader_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
node_ops.push_back(shuffle_op);
}
// Add TFReaderOp
node_ops.push_back(tf_reader_op);
return node_ops;
}
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// Constructor for VOCDataset // Constructor for VOCDataset
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "minddata/dataset/include/type_id.h" #include "minddata/dataset/include/type_id.h"
#include "minddata/dataset/kernels/c_func_op.h" #include "minddata/dataset/kernels/c_func_op.h"
#include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/path.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/text/vocab.h" #include "minddata/dataset/text/vocab.h"
#endif #endif
...@@ -69,6 +70,7 @@ class ManifestDataset; ...@@ -69,6 +70,7 @@ class ManifestDataset;
class MnistDataset; class MnistDataset;
class RandomDataset; class RandomDataset;
class TextFileDataset; class TextFileDataset;
class TFRecordDataset;
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
class VOCDataset; class VOCDataset;
#endif #endif
...@@ -320,6 +322,80 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase ...@@ -320,6 +322,80 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0); int32_t shard_id = 0);
/// \brief Function to create a TFRecordDataset
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] schema SchemaObj or string to schema path. (Default = nullptr, which means that the
/// meta data from the TFData file is considered the schema.)
/// \param[in] columns_list List of columns to be read. (Default = {}, read all columns)
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// If num_samples is 0 and numRows(parsed from schema) does not exist, read the full dataset;
/// If num_samples is 0 and numRows(parsed from schema) is greater than 0, read numRows rows;
/// If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
/// \param[in] shuffle The mode for shuffling data every epoch. (Default = ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be specified only
/// when num_shards is also specified. (Default = 0)
/// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of
/// each shard may be not equal)
/// \return Shared pointer to the current TFRecordDataset
template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr,
const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0, bool shard_equal_rows = false) {
if (dataset_files.empty()) {
MS_LOG(ERROR) << "TFRecordDataset: dataset_files is not specified.";
return nullptr;
}
for (auto f : dataset_files) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
MS_LOG(ERROR) << "TFRecordDataset: dataset file: [" << f << "] is invalid or does not exist.";
return nullptr;
}
}
if (num_samples < 0) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid number of samples: " << num_samples;
return nullptr;
}
if (num_shards <= 0) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid num_shards: " << num_shards;
return nullptr;
}
if (shard_id < 0 || shard_id >= num_shards) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
return nullptr;
}
std::shared_ptr<TFRecordDataset> ds = nullptr;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema;
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
} else {
std::string schema_path = schema;
if (!schema_path.empty()) {
Path schema_file(schema_path);
if (!schema_file.Exists()) {
MS_LOG(ERROR) << "TFRecordDataset: schema path [" << schema_path << "] is invalid or does not exist.";
return nullptr;
}
}
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
}
return ds;
}
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Function to create a VOCDataset /// \brief Function to create a VOCDataset
/// \notes The generated dataset has multi-columns : /// \notes The generated dataset has multi-columns :
...@@ -952,6 +1028,61 @@ class TextFileDataset : public Dataset { ...@@ -952,6 +1028,61 @@ class TextFileDataset : public Dataset {
ShuffleMode shuffle_; ShuffleMode shuffle_;
}; };
/// \class TFRecordDataset
/// \brief A Dataset derived class to represent TFRecord dataset
class TFRecordDataset : public Dataset {
public:
/// \brief Constructor
/// \note Parameter 'schema' is the path to the schema file
TFRecordDataset(const std::vector<std::string> &dataset_files, std::string schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
schema_path_(schema),
columns_list_(columns_list),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
shard_equal_rows_(shard_equal_rows) {}
/// \brief Constructor
/// \note Parameter 'schema' is shared pointer to Schema object
TFRecordDataset(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
schema_obj_(schema),
columns_list_(columns_list),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
shard_equal_rows_(shard_equal_rows) {}
/// \brief Destructor
~TFRecordDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::vector<std::string> dataset_files_;
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object.
std::vector<std::string> columns_list_;
int64_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
bool shard_equal_rows_;
};
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
class VOCDataset : public Dataset { class VOCDataset : public Dataset {
public: public:
......
...@@ -3541,7 +3541,7 @@ class TFRecordDataset(SourceDataset): ...@@ -3541,7 +3541,7 @@ class TFRecordDataset(SourceDataset):
If the schema is not provided, the meta data from the TFData file is considered the schema. If the schema is not provided, the meta data from the TFData file is considered the schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns) columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_samples (int, optional): number of samples(rows) to read (default=None). num_samples (int, optional): number of samples(rows) to read (default=None).
If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset; If num_samples is None and numRows(parsed from schema) does not exist, read the full dataset;
If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows; If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
num_parallel_workers (int, optional): number of workers to read the data num_parallel_workers (int, optional): number of workers to read the data
...@@ -3560,8 +3560,8 @@ class TFRecordDataset(SourceDataset): ...@@ -3560,8 +3560,8 @@ class TFRecordDataset(SourceDataset):
into (default=None). into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified. argument should be specified only when num_shards is also specified.
shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows
of rows of each shard may be not equal. is false, number of rows of each shard may be not equal.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended. The cache feature is under development and is not recommended.
Examples: Examples:
......
...@@ -107,9 +107,10 @@ SET(DE_UT_SRCS ...@@ -107,9 +107,10 @@ SET(DE_UT_SRCS
c_api_dataset_clue_test.cc c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc c_api_dataset_coco_test.cc
c_api_dataset_csv_test.cc c_api_dataset_csv_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_manifest_test.cc c_api_dataset_manifest_test.cc
c_api_dataset_randomdata_test.cc c_api_dataset_randomdata_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_voc_test.cc c_api_dataset_voc_test.cc
c_api_datasets_test.cc c_api_datasets_test.cc
c_api_dataset_iterator_test.cc c_api_dataset_iterator_test.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h"
using namespace mindspore::dataset;
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;
using mindspore::dataset::ShuffleMode;
using mindspore::dataset::TensorShape;
using mindspore::dataset::DataType;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetBasic.";
// Create a TFRecord Dataset
std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
std::string schema_path = datasets_root_path_ + "/test_tf_file_3_images2/datasetSchema.json";
std::shared_ptr<Dataset> ds = TFRecord({file_path}, schema_path, {"image"}, 0);
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> random_horizontal_flip_op = vision::RandomHorizontalFlip(0.5);
EXPECT_NE(random_horizontal_flip_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({random_horizontal_flip_op}, {}, {}, {"image"});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 1);
EXPECT_NE(row.find("image"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 6);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShuffle.";
// This case is to verify if the list of datafiles are sorted in lexicographical order.
// Set configuration
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a TFRecord Dataset
std::string file1 = datasets_root_path_ + "/tf_file_dataset/test1.data";
std::string file2 = datasets_root_path_ + "/tf_file_dataset/test2.data";
std::string file3 = datasets_root_path_ + "/tf_file_dataset/test3.data";
std::string file4 = datasets_root_path_ + "/tf_file_dataset/test4.data";
std::shared_ptr<Dataset> ds1 = TFRecord({file4, file3, file2, file1}, "", {"scalars"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Dataset> ds2 = TFRecord({file1}, "", {"scalars"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
EXPECT_NE(iter1, nullptr);
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
EXPECT_NE(iter2, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row1;
iter1->GetNextRow(&row1);
std::unordered_map<std::string, std::shared_ptr<Tensor>> row2;
iter2->GetNextRow(&row2);
uint64_t i = 0;
int64_t value1 = 0;
int64_t value2 = 0;
while (row1.size() != 0 && row2.size() != 0) {
row1["scalars"]->GetItemAt(&value1, {0});
row2["scalars"]->GetItemAt(&value2, {0});
EXPECT_EQ(value1, value2);
iter1->GetNextRow(&row1);
iter2->GetNextRow(&row2);
i++;
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter1->Stop();
iter2->Stop();
// Restore configuration
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShuffle2.";
// This case is to verify the content of the data is indeed shuffled.
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(155);
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a TFRecord Dataset
std::string file = datasets_root_path_ + "/tf_file_dataset/test1.data";
std::shared_ptr<Dataset> ds = TFRecord({file}, nullptr, {"scalars"}, 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
std::vector<int> expect = {9, 3, 4, 7, 2, 1, 6, 8, 10, 5};
std::vector<int> actual = {};
int64_t value = 0;
uint64_t i = 0;
while (row.size() != 0) {
row["scalars"]->GetItemAt(&value, {});
actual.push_back(value);
iter->GetNextRow(&row);
i++;
}
ASSERT_EQ(actual, expect);
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetSchemaPath) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetSchemaPath.";
// Create a TFRecord Dataset
std::string file_path1 = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::string file_path2 = datasets_root_path_ + "/testTFTestAllTypes/test2.data";
std::string schema_path = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
std::shared_ptr<Dataset> ds = TFRecord({file_path2, file_path1}, schema_path, {}, 9, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 8);
EXPECT_NE(row.find("col_sint16"), row.end());
EXPECT_NE(row.find("col_sint32"), row.end());
EXPECT_NE(row.find("col_sint64"), row.end());
EXPECT_NE(row.find("col_float"), row.end());
EXPECT_NE(row.find("col_1d"), row.end());
EXPECT_NE(row.find("col_2d"), row.end());
EXPECT_NE(row.find("col_3d"), row.end());
EXPECT_NE(row.find("col_binary"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetSchemaObj) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetSchemaObj.";
// Create a TFRecord Dataset
std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::shared_ptr<SchemaObj> schema = Schema();
schema->add_column("col_sint16", "int16", {1});
schema->add_column("col_float", "float32", {1});
schema->add_column("col_2d", "int64", {2, 2});
std::shared_ptr<Dataset> ds = TFRecord({file_path}, schema);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 3);
EXPECT_NE(row.find("col_sint16"), row.end());
EXPECT_NE(row.find("col_float"), row.end());
EXPECT_NE(row.find("col_2d"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto col_sint16 = row["col_sint16"];
auto col_float = row["col_float"];
auto col_2d = row["col_2d"];
EXPECT_EQ(col_sint16->shape(), TensorShape({1}));
EXPECT_EQ(col_float->shape(), TensorShape({1}));
EXPECT_EQ(col_2d->shape(), TensorShape({2, 2}));
EXPECT_EQ(col_sint16->Rank(), 1);
EXPECT_EQ(col_float->Rank(), 1);
EXPECT_EQ(col_2d->Rank(), 2);
EXPECT_EQ(col_sint16->type(), DataType::DE_INT16);
EXPECT_EQ(col_float->type(), DataType::DE_FLOAT32);
EXPECT_EQ(col_2d->type(), DataType::DE_INT64);
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 12);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetNoSchema) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetNoSchema.";
// Create a TFRecord Dataset
std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
std::shared_ptr<SchemaObj> schema = nullptr;
std::shared_ptr<Dataset> ds = TFRecord({file_path}, nullptr, {});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 2);
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto image = row["image"];
auto label = row["label"];
MS_LOG(INFO) << "Shape of column [image]:" << image->shape();
MS_LOG(INFO) << "Shape of column [label]:" << label->shape();
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetColName) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetColName.";
// Create a TFRecord Dataset
// The dataset has two columns("image", "label") and 3 rows
std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
std::shared_ptr<Dataset> ds = TFRecord({file_path}, "", {"image"}, 0);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 1);
EXPECT_NE(row.find("image"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShard.";
// Create a TFRecord Dataset
// Each file has two columns("image", "label") and 3 rows
std::vector<std::string> files = {
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"
};
std::shared_ptr<Dataset> ds1 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, true);
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Dataset> ds2 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, false);
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
EXPECT_NE(iter1, nullptr);
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
EXPECT_NE(iter2, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row1;
iter1->GetNextRow(&row1);
std::unordered_map<std::string, std::shared_ptr<Tensor>> row2;
iter2->GetNextRow(&row2);
uint64_t i = 0;
uint64_t j = 0;
while (row1.size() != 0) {
i++;
iter1->GetNextRow(&row1);
}
while (row2.size() != 0) {
j++;
iter2->GetNextRow(&row2);
}
EXPECT_EQ(i, 5);
EXPECT_EQ(j, 3);
// Manually terminate the pipeline
iter1->Stop();
iter2->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetExeception.";
// This case expected to fail because the list of dir_path cannot be empty.
std::shared_ptr<Dataset> ds1 = TFRecord({});
EXPECT_EQ(ds1, nullptr);
// This case expected to fail because the file in dir_path is not exist.
std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::shared_ptr<Dataset> ds2 = TFRecord({file_path, "noexist.data"});
EXPECT_EQ(ds2, nullptr);
// This case expected to fail because the file of schema is not exist.
std::shared_ptr<Dataset> ds4 = TFRecord({file_path, "notexist.json"});
EXPECT_EQ(ds4, nullptr);
// This case expected to fail because num_samples is negative.
std::shared_ptr<Dataset> ds5 = TFRecord({file_path}, "", {}, -1);
EXPECT_EQ(ds5, nullptr);
// This case expected to fail because num_shards is negative.
std::shared_ptr<Dataset> ds6 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 0);
EXPECT_EQ(ds6, nullptr);
// This case expected to fail because shard_id is out_of_bound.
std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3);
EXPECT_EQ(ds7, nullptr);
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetExeception2.";
// This case expected to fail because the input column name does not exist.
std::string file_path1 = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::string schema_path = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
// Create a TFRecord Dataset
// Column "image" does not exist in the dataset
std::shared_ptr<Dataset> ds = TFRecord({file_path1}, schema_path, {"image"}, 10);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This attempts to create Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册