提交 2f3684d4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4872 Add C++ API support for TFRecordDataset

Merge pull request !4872 from TinaMengtingZhang/cpp-api-tfrecord-dataset
......@@ -32,6 +32,7 @@
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#endif
......@@ -1508,6 +1509,56 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
return node_ops;
}
// Validator for TFRecordDataset
bool TFRecordDataset::ValidateParams() { return true; }
// Function to build TFRecordDataset
std::vector<std::shared_ptr<DatasetOp>> TFRecordDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// Sort the datasets file in a lexicographical order
std::vector<std::string> sorted_dir_files = dataset_files_;
std::sort(sorted_dir_files.begin(), sorted_dir_files.end());
// Create Schema Object
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
if (!schema_path_.empty()) {
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaFile(schema_path_, columns_list_));
} else if (schema_obj_ != nullptr) {
std::string schema_json_string = schema_obj_->to_json();
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, columns_list_));
}
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Create and initalize TFReaderOp
std::shared_ptr<TFReaderOp> tf_reader_op = std::make_shared<TFReaderOp>(
num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema),
connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_, nullptr);
RETURN_EMPTY_IF_ERROR(tf_reader_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
node_ops.push_back(shuffle_op);
}
// Add TFReaderOp
node_ops.push_back(tf_reader_op);
return node_ops;
}
#ifndef ENABLE_ANDROID
// Constructor for VOCDataset
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
......
......@@ -32,6 +32,7 @@
#include "minddata/dataset/include/type_id.h"
#include "minddata/dataset/kernels/c_func_op.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/path.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/text/vocab.h"
#endif
......@@ -69,6 +70,7 @@ class ManifestDataset;
class MnistDataset;
class RandomDataset;
class TextFileDataset;
class TFRecordDataset;
#ifndef ENABLE_ANDROID
class VOCDataset;
#endif
......@@ -320,6 +322,80 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
/// \brief Function to create a TFRecordDataset
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] schema SchemaObj or string to schema path. (Default = nullptr, which means that the
/// meta data from the TFData file is considered the schema.)
/// \param[in] columns_list List of columns to be read. (Default = {}, read all columns)
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// If num_samples is 0 and numRows(parsed from schema) does not exist, read the full dataset;
/// If num_samples is 0 and numRows(parsed from schema) is greater than 0, read numRows rows;
/// If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
/// \param[in] shuffle The mode for shuffling data every epoch. (Default = ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be specified only
/// when num_shards is also specified. (Default = 0)
/// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of
/// each shard may be not equal)
/// \return Shared pointer to the current TFRecordDataset
template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr,
const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0, bool shard_equal_rows = false) {
if (dataset_files.empty()) {
MS_LOG(ERROR) << "TFRecordDataset: dataset_files is not specified.";
return nullptr;
}
for (auto f : dataset_files) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
MS_LOG(ERROR) << "TFRecordDataset: dataset file: [" << f << "] is invalid or does not exist.";
return nullptr;
}
}
if (num_samples < 0) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid number of samples: " << num_samples;
return nullptr;
}
if (num_shards <= 0) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid num_shards: " << num_shards;
return nullptr;
}
if (shard_id < 0 || shard_id >= num_shards) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
return nullptr;
}
std::shared_ptr<TFRecordDataset> ds = nullptr;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema;
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
} else {
std::string schema_path = schema;
if (!schema_path.empty()) {
Path schema_file(schema_path);
if (!schema_file.Exists()) {
MS_LOG(ERROR) << "TFRecordDataset: schema path [" << schema_path << "] is invalid or does not exist.";
return nullptr;
}
}
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
}
return ds;
}
#ifndef ENABLE_ANDROID
/// \brief Function to create a VOCDataset
/// \notes The generated dataset has multi-columns :
......@@ -952,6 +1028,61 @@ class TextFileDataset : public Dataset {
ShuffleMode shuffle_;
};
/// \class TFRecordDataset
/// \brief A Dataset derived class to represent TFRecord dataset
class TFRecordDataset : public Dataset {
public:
/// \brief Constructor
/// \note Parameter 'schema' is the path to the schema file
TFRecordDataset(const std::vector<std::string> &dataset_files, std::string schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
schema_path_(schema),
columns_list_(columns_list),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
shard_equal_rows_(shard_equal_rows) {}
/// \brief Constructor
/// \note Parameter 'schema' is shared pointer to Schema object
TFRecordDataset(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
schema_obj_(schema),
columns_list_(columns_list),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
shard_equal_rows_(shard_equal_rows) {}
/// \brief Destructor
~TFRecordDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::vector<std::string> dataset_files_;
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object.
std::vector<std::string> columns_list_;
int64_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
bool shard_equal_rows_;
};
#ifndef ENABLE_ANDROID
class VOCDataset : public Dataset {
public:
......
......@@ -3548,7 +3548,7 @@ class TFRecordDataset(SourceDataset):
If the schema is not provided, the meta data from the TFData file is considered the schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_samples (int, optional): number of samples(rows) to read (default=None).
If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset;
If num_samples is None and numRows(parsed from schema) does not exist, read the full dataset;
If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
num_parallel_workers (int, optional): number of workers to read the data
......@@ -3567,8 +3567,8 @@ class TFRecordDataset(SourceDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number
of rows of each shard may be not equal.
shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows
is false, number of rows of each shard may be not equal.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Examples:
......
......@@ -107,9 +107,10 @@ SET(DE_UT_SRCS
c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_randomdata_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc
c_api_dataset_iterator_test.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h"
using namespace mindspore::dataset;
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;
using mindspore::dataset::ShuffleMode;
using mindspore::dataset::TensorShape;
using mindspore::dataset::DataType;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetBasic.";
// Create a TFRecord Dataset
std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
std::string schema_path = datasets_root_path_ + "/test_tf_file_3_images2/datasetSchema.json";
std::shared_ptr<Dataset> ds = TFRecord({file_path}, schema_path, {"image"}, 0);
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> random_horizontal_flip_op = vision::RandomHorizontalFlip(0.5);
EXPECT_NE(random_horizontal_flip_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({random_horizontal_flip_op}, {}, {}, {"image"});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 1);
EXPECT_NE(row.find("image"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 6);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShuffle.";
// This case is to verify if the list of datafiles are sorted in lexicographical order.
// Set configuration
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a TFRecord Dataset
std::string file1 = datasets_root_path_ + "/tf_file_dataset/test1.data";
std::string file2 = datasets_root_path_ + "/tf_file_dataset/test2.data";
std::string file3 = datasets_root_path_ + "/tf_file_dataset/test3.data";
std::string file4 = datasets_root_path_ + "/tf_file_dataset/test4.data";
std::shared_ptr<Dataset> ds1 = TFRecord({file4, file3, file2, file1}, "", {"scalars"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Dataset> ds2 = TFRecord({file1}, "", {"scalars"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
EXPECT_NE(iter1, nullptr);
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
EXPECT_NE(iter2, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row1;
iter1->GetNextRow(&row1);
std::unordered_map<std::string, std::shared_ptr<Tensor>> row2;
iter2->GetNextRow(&row2);
uint64_t i = 0;
int64_t value1 = 0;
int64_t value2 = 0;
while (row1.size() != 0 && row2.size() != 0) {
row1["scalars"]->GetItemAt(&value1, {0});
row2["scalars"]->GetItemAt(&value2, {0});
EXPECT_EQ(value1, value2);
iter1->GetNextRow(&row1);
iter2->GetNextRow(&row2);
i++;
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter1->Stop();
iter2->Stop();
// Restore configuration
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShuffle2.";
// This case is to verify the content of the data is indeed shuffled.
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(155);
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a TFRecord Dataset
std::string file = datasets_root_path_ + "/tf_file_dataset/test1.data";
std::shared_ptr<Dataset> ds = TFRecord({file}, nullptr, {"scalars"}, 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
std::vector<int> expect = {9, 3, 4, 7, 2, 1, 6, 8, 10, 5};
std::vector<int> actual = {};
int64_t value = 0;
uint64_t i = 0;
while (row.size() != 0) {
row["scalars"]->GetItemAt(&value, {});
actual.push_back(value);
iter->GetNextRow(&row);
i++;
}
ASSERT_EQ(actual, expect);
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetSchemaPath) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetSchemaPath.";
// Create a TFRecord Dataset
std::string file_path1 = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::string file_path2 = datasets_root_path_ + "/testTFTestAllTypes/test2.data";
std::string schema_path = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
std::shared_ptr<Dataset> ds = TFRecord({file_path2, file_path1}, schema_path, {}, 9, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 8);
EXPECT_NE(row.find("col_sint16"), row.end());
EXPECT_NE(row.find("col_sint32"), row.end());
EXPECT_NE(row.find("col_sint64"), row.end());
EXPECT_NE(row.find("col_float"), row.end());
EXPECT_NE(row.find("col_1d"), row.end());
EXPECT_NE(row.find("col_2d"), row.end());
EXPECT_NE(row.find("col_3d"), row.end());
EXPECT_NE(row.find("col_binary"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 9);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetSchemaObj) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetSchemaObj.";
// Create a TFRecord Dataset
std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::shared_ptr<SchemaObj> schema = Schema();
schema->add_column("col_sint16", "int16", {1});
schema->add_column("col_float", "float32", {1});
schema->add_column("col_2d", "int64", {2, 2});
std::shared_ptr<Dataset> ds = TFRecord({file_path}, schema);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 3);
EXPECT_NE(row.find("col_sint16"), row.end());
EXPECT_NE(row.find("col_float"), row.end());
EXPECT_NE(row.find("col_2d"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto col_sint16 = row["col_sint16"];
auto col_float = row["col_float"];
auto col_2d = row["col_2d"];
EXPECT_EQ(col_sint16->shape(), TensorShape({1}));
EXPECT_EQ(col_float->shape(), TensorShape({1}));
EXPECT_EQ(col_2d->shape(), TensorShape({2, 2}));
EXPECT_EQ(col_sint16->Rank(), 1);
EXPECT_EQ(col_float->Rank(), 1);
EXPECT_EQ(col_2d->Rank(), 2);
EXPECT_EQ(col_sint16->type(), DataType::DE_INT16);
EXPECT_EQ(col_float->type(), DataType::DE_FLOAT32);
EXPECT_EQ(col_2d->type(), DataType::DE_INT64);
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 12);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetNoSchema) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetNoSchema.";
// Create a TFRecord Dataset
std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
std::shared_ptr<SchemaObj> schema = nullptr;
std::shared_ptr<Dataset> ds = TFRecord({file_path}, nullptr, {});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 2);
EXPECT_NE(row.find("image"), row.end());
EXPECT_NE(row.find("label"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto image = row["image"];
auto label = row["label"];
MS_LOG(INFO) << "Shape of column [image]:" << image->shape();
MS_LOG(INFO) << "Shape of column [label]:" << label->shape();
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetColName) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetColName.";
// Create a TFRecord Dataset
// The dataset has two columns("image", "label") and 3 rows
std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
std::shared_ptr<Dataset> ds = TFRecord({file_path}, "", {"image"}, 0);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check column
EXPECT_EQ(row.size(), 1);
EXPECT_NE(row.find("image"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShard.";
// Create a TFRecord Dataset
// Each file has two columns("image", "label") and 3 rows
std::vector<std::string> files = {
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"
};
std::shared_ptr<Dataset> ds1 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, true);
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Dataset> ds2 = TFRecord({files}, "", {}, 0, ShuffleMode::kFalse, 2, 1, false);
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
EXPECT_NE(iter1, nullptr);
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
EXPECT_NE(iter2, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row1;
iter1->GetNextRow(&row1);
std::unordered_map<std::string, std::shared_ptr<Tensor>> row2;
iter2->GetNextRow(&row2);
uint64_t i = 0;
uint64_t j = 0;
while (row1.size() != 0) {
i++;
iter1->GetNextRow(&row1);
}
while (row2.size() != 0) {
j++;
iter2->GetNextRow(&row2);
}
EXPECT_EQ(i, 5);
EXPECT_EQ(j, 3);
// Manually terminate the pipeline
iter1->Stop();
iter2->Stop();
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetExeception.";
// This case expected to fail because the list of dir_path cannot be empty.
std::shared_ptr<Dataset> ds1 = TFRecord({});
EXPECT_EQ(ds1, nullptr);
// This case expected to fail because the file in dir_path is not exist.
std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::shared_ptr<Dataset> ds2 = TFRecord({file_path, "noexist.data"});
EXPECT_EQ(ds2, nullptr);
// This case expected to fail because the file of schema is not exist.
std::shared_ptr<Dataset> ds4 = TFRecord({file_path, "notexist.json"});
EXPECT_EQ(ds4, nullptr);
// This case expected to fail because num_samples is negative.
std::shared_ptr<Dataset> ds5 = TFRecord({file_path}, "", {}, -1);
EXPECT_EQ(ds5, nullptr);
// This case expected to fail because num_shards is negative.
std::shared_ptr<Dataset> ds6 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 0);
EXPECT_EQ(ds6, nullptr);
// This case expected to fail because shard_id is out_of_bound.
std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3);
EXPECT_EQ(ds7, nullptr);
}
TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetExeception2.";
// This case expected to fail because the input column name does not exist.
std::string file_path1 = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::string schema_path = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
// Create a TFRecord Dataset
// Column "image" does not exist in the dataset
std::shared_ptr<Dataset> ds = TFRecord({file_path1}, schema_path, {"image"}, 10);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This attempts to create Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册