提交 81005a30 编写于 作者: C Cathy Wong

C++ API: Reorder code contents alphabetically

上级 e07f7436
...@@ -17,12 +17,14 @@ ...@@ -17,12 +17,14 @@
#include <fstream> #include <fstream>
#include "minddata/dataset/include/datasets.h" #include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/engine/dataset_iterator.h" #include "minddata/dataset/engine/dataset_iterator.h"
// Source dataset headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h" // Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/batch_op.h" #include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/map_op.h" #include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h"
...@@ -31,6 +33,7 @@ ...@@ -31,6 +33,7 @@
#include "minddata/dataset/engine/datasetops/project_op.h" #include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h" #include "minddata/dataset/engine/datasetops/rename_op.h"
// Sampler headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
...@@ -79,6 +82,18 @@ Dataset::Dataset() { ...@@ -79,6 +82,18 @@ Dataset::Dataset() {
connector_que_size_ = cfg->op_connector_size(); connector_que_size_ = cfg->op_connector_size();
} }
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
// (In alphabetical order)
// Function to create a Cifar10Dataset.
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
std::shared_ptr<SamplerObj> sampler) {
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, num_samples, sampler);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a ImageFolderDataset. // Function to create a ImageFolderDataset.
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode, std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode,
std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions, std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions,
...@@ -101,14 +116,8 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam ...@@ -101,14 +116,8 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
} }
// Function to create a Cifar10Dataset. // FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples, // (In alphabetical order)
std::shared_ptr<SamplerObj> sampler) {
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, num_samples, sampler);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a Batch dataset // Function to create a Batch dataset
std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) { std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
...@@ -127,14 +136,12 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai ...@@ -127,14 +136,12 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
return ds; return ds;
} }
// Function to create Repeat dataset. // Function to create a Map dataset.
std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) { std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations,
// Workaround for repeat == 1, do not inject repeat. std::vector<std::string> input_columns,
if (count == 1) { std::vector<std::string> output_columns,
return shared_from_this(); const std::vector<std::string> &project_columns) {
} auto ds = std::make_shared<MapDataset>(operations, input_columns, output_columns, project_columns);
auto ds = std::make_shared<RepeatDataset>(count);
if (!ds->ValidateParams()) { if (!ds->ValidateParams()) {
return nullptr; return nullptr;
...@@ -145,13 +152,10 @@ std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) { ...@@ -145,13 +152,10 @@ std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
return ds; return ds;
} }
// Function to create a Map dataset. // Function to create a ProjectDataset.
std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations, std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
std::vector<std::string> input_columns, auto ds = std::make_shared<ProjectDataset>(columns);
std::vector<std::string> output_columns, // Call derived class validation method.
const std::vector<std::string> &project_columns) {
auto ds = std::make_shared<MapDataset>(operations, input_columns, output_columns, project_columns);
if (!ds->ValidateParams()) { if (!ds->ValidateParams()) {
return nullptr; return nullptr;
} }
...@@ -161,11 +165,11 @@ std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOpera ...@@ -161,11 +165,11 @@ std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOpera
return ds; return ds;
} }
// Function to create a ShuffleOp // Function to create a RenameDataset.
std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) { std::shared_ptr<RenameDataset> Dataset::Rename(const std::vector<std::string> &input_columns,
// Pass in reshuffle_each_epoch with true const std::vector<std::string> &output_columns) {
auto ds = std::make_shared<ShuffleDataset>(shuffle_size, true); auto ds = std::make_shared<RenameDataset>(input_columns, output_columns);
// Call derived class validation method.
if (!ds->ValidateParams()) { if (!ds->ValidateParams()) {
return nullptr; return nullptr;
} }
...@@ -175,11 +179,15 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) { ...@@ -175,11 +179,15 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
return ds; return ds;
} }
// Function to create a SkipDataset. // Function to create Repeat dataset.
std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) { std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
auto ds = std::make_shared<SkipDataset>(count); // Workaround for repeat == 1, do not inject repeat.
if (count == 1) {
return shared_from_this();
}
auto ds = std::make_shared<RepeatDataset>(count);
// Call derived class validation method.
if (!ds->ValidateParams()) { if (!ds->ValidateParams()) {
return nullptr; return nullptr;
} }
...@@ -189,10 +197,11 @@ std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) { ...@@ -189,10 +197,11 @@ std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
return ds; return ds;
} }
// Function to create a ProjectDataset. // Function to create a ShuffleOp
std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) { std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
auto ds = std::make_shared<ProjectDataset>(columns); // Pass in reshuffle_each_epoch with true
// Call derived class validation method. auto ds = std::make_shared<ShuffleDataset>(shuffle_size, true);
if (!ds->ValidateParams()) { if (!ds->ValidateParams()) {
return nullptr; return nullptr;
} }
...@@ -202,10 +211,10 @@ std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> ...@@ -202,10 +211,10 @@ std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string>
return ds; return ds;
} }
// Function to create a RenameDataset. // Function to create a SkipDataset.
std::shared_ptr<RenameDataset> Dataset::Rename(const std::vector<std::string> &input_columns, std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
const std::vector<std::string> &output_columns) { auto ds = std::make_shared<SkipDataset>(count);
auto ds = std::make_shared<RenameDataset>(input_columns, output_columns);
// Call derived class validation method. // Call derived class validation method.
if (!ds->ValidateParams()) { if (!ds->ValidateParams()) {
return nullptr; return nullptr;
...@@ -231,6 +240,9 @@ std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Datas ...@@ -231,6 +240,9 @@ std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Datas
return ds; return ds;
} }
// OTHER FUNCTIONS
// (In alphabetical order)
// Helper function to create default RandomSampler. // Helper function to create default RandomSampler.
std::shared_ptr<SamplerObj> CreateDefaultSampler() { std::shared_ptr<SamplerObj> CreateDefaultSampler() {
const int32_t num_samples = 0; // 0 means to sample all ids. const int32_t num_samples = 0; // 0 means to sample all ids.
...@@ -240,6 +252,48 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() { ...@@ -240,6 +252,48 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() {
/* ####################################### Derived Dataset classes ################################# */ /* ####################################### Derived Dataset classes ################################# */
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
// (In alphabetical order)
// Constructor for Cifar10Dataset
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
bool Cifar10Dataset::ValidateParams() {
if (dataset_dir_.empty()) {
MS_LOG(ERROR) << "No dataset path is specified.";
return false;
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "Number of samples cannot be negative";
return false;
}
return true;
}
// Function to build CifarOp
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
if (sampler_ == nullptr) {
sampler_ = CreateDefaultSampler();
}
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build())));
return node_ops;
}
ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
bool recursive, std::set<std::string> extensions, bool recursive, std::set<std::string> extensions,
std::map<std::string, int32_t> class_indexing) std::map<std::string, int32_t> class_indexing)
...@@ -315,6 +369,9 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() { ...@@ -315,6 +369,9 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
return node_ops; return node_ops;
} }
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
// (In alphabetical order)
BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map, BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map,
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map) std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
: batch_size_(batch_size), : batch_size_(batch_size),
...@@ -347,24 +404,6 @@ bool BatchDataset::ValidateParams() { ...@@ -347,24 +404,6 @@ bool BatchDataset::ValidateParams() {
return true; return true;
} }
RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {}
std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
return node_ops;
}
bool RepeatDataset::ValidateParams() {
if (repeat_count_ <= 0) {
MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative";
return false;
}
return true;
}
MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns, MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns,
std::vector<std::string> output_columns, const std::vector<std::string> &project_columns) std::vector<std::string> output_columns, const std::vector<std::string> &project_columns)
: operations_(operations), : operations_(operations),
...@@ -409,6 +448,69 @@ bool MapDataset::ValidateParams() { ...@@ -409,6 +448,69 @@ bool MapDataset::ValidateParams() {
return true; return true;
} }
// Function to build ProjectOp
ProjectDataset::ProjectDataset(const std::vector<std::string> &columns) : columns_(columns) {}
bool ProjectDataset::ValidateParams() {
if (columns_.empty()) {
MS_LOG(ERROR) << "No columns are specified.";
return false;
}
return true;
}
std::vector<std::shared_ptr<DatasetOp>> ProjectDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ProjectOp>(columns_));
return node_ops;
}
// Function to build RenameOp
RenameDataset::RenameDataset(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {}
bool RenameDataset::ValidateParams() {
if (input_columns_.empty() || output_columns_.empty()) {
MS_LOG(ERROR) << "input and output columns must be specified";
return false;
}
if (input_columns_.size() != output_columns_.size()) {
MS_LOG(ERROR) << "input and output columns must be the same size";
return false;
}
return true;
}
std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops;
}
RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {}
std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
return node_ops;
}
bool RepeatDataset::ValidateParams() {
if (repeat_count_ <= 0) {
MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative";
return false;
}
return true;
}
// Constructor for ShuffleDataset // Constructor for ShuffleDataset
ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch) ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {} : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {}
...@@ -455,64 +557,6 @@ bool SkipDataset::ValidateParams() { ...@@ -455,64 +557,6 @@ bool SkipDataset::ValidateParams() {
return true; return true;
} }
// Constructor for Cifar10Dataset
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
bool Cifar10Dataset::ValidateParams() {
if (dataset_dir_.empty()) {
MS_LOG(ERROR) << "No dataset path is specified.";
return false;
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "Number of samples cannot be negative";
return false;
}
return true;
}
// Function to build CifarOp
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
if (sampler_ == nullptr) {
sampler_ = CreateDefaultSampler();
}
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build())));
return node_ops;
}
// Function to build ProjectOp
ProjectDataset::ProjectDataset(const std::vector<std::string> &columns) : columns_(columns) {}
bool ProjectDataset::ValidateParams() {
if (columns_.empty()) {
MS_LOG(ERROR) << "No columns are specified.";
return false;
}
return true;
}
std::vector<std::shared_ptr<DatasetOp>> ProjectDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ProjectOp>(columns_));
return node_ops;
}
// Function to build ZipOp // Function to build ZipOp
ZipDataset::ZipDataset() {} ZipDataset::ZipDataset() {}
...@@ -526,31 +570,6 @@ std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() { ...@@ -526,31 +570,6 @@ std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() {
return node_ops; return node_ops;
} }
// Function to build RenameOp
RenameDataset::RenameDataset(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {}
bool RenameDataset::ValidateParams() {
if (input_columns_.empty() || output_columns_.empty()) {
MS_LOG(ERROR) << "input and output columns must be specified";
return false;
}
if (input_columns_.size() != output_columns_.size()) {
MS_LOG(ERROR) << "input and output columns must be the same size";
return false;
}
return true;
}
std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops;
}
} // namespace api } // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
...@@ -40,17 +40,29 @@ namespace api { ...@@ -40,17 +40,29 @@ namespace api {
class TensorOperation; class TensorOperation;
class SamplerObj; class SamplerObj;
// Datasets classes (in alphabetical order)
class Cifar10Dataset;
class ImageFolderDataset; class ImageFolderDataset;
class MnistDataset; class MnistDataset;
// Dataset Op classes (in alphabetical order)
class BatchDataset; class BatchDataset;
class RepeatDataset;
class MapDataset; class MapDataset;
class ProjectDataset;
class RenameDataset;
class RepeatDataset;
class ShuffleDataset; class ShuffleDataset;
class SkipDataset; class SkipDataset;
class Cifar10Dataset;
class ProjectDataset;
class ZipDataset; class ZipDataset;
class RenameDataset;
/// \brief Function to create a Cifar10 Dataset
/// \notes The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] num_samples The number of images to be included in the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
std::shared_ptr<SamplerObj> sampler);
/// \brief Function to create an ImageFolderDataset /// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories /// \notes A source dataset that reads images from a tree of directories
...@@ -76,16 +88,6 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de ...@@ -76,16 +88,6 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
/// \return Shared pointer to the current MnistDataset /// \return Shared pointer to the current MnistDataset
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr); std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
/// \brief Function to create a Cifar10 Dataset
/// \notes The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] num_samples The number of images to be included in the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
std::shared_ptr<SamplerObj> sampler);
/// \class Dataset datasets.h /// \class Dataset datasets.h
/// \brief A base class to represent a dataset in the data pipeline. /// \brief A base class to represent a dataset in the data pipeline.
class Dataset : public std::enable_shared_from_this<Dataset> { class Dataset : public std::enable_shared_from_this<Dataset> {
...@@ -128,14 +130,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> { ...@@ -128,14 +130,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current BatchDataset /// \return Shared pointer to the current BatchDataset
std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false); std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
/// \brief Function to create a RepeatDataset
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
/// \param[in] count Number of times the dataset should be repeated
/// \return Shared pointer to the current Dataset
/// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
/// due to a limitation in the current implementation
std::shared_ptr<Dataset> Repeat(int32_t count = -1);
/// \brief Function to create a MapDataset /// \brief Function to create a MapDataset
/// \notes Applies each operation in operations to this dataset /// \notes Applies each operation in operations to this dataset
/// \param[in] operations Vector of operations to be applied on the dataset. Operations are /// \param[in] operations Vector of operations to be applied on the dataset. Operations are
...@@ -156,6 +150,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> { ...@@ -156,6 +150,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
std::vector<std::string> output_columns = {}, std::vector<std::string> output_columns = {},
const std::vector<std::string> &project_columns = {}); const std::vector<std::string> &project_columns = {});
/// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset
/// \param[in] columns The name of columns to project
/// \return Shared pointer to the current Dataset
std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns);
/// \brief Function to create a Rename Dataset
/// \notes Renames the columns in the input dataset
/// \param[in] input_columns List of the input columns to rename
/// \param[in] output_columns List of the output columns
/// \return Shared pointer to the current Dataset
std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns);
/// \brief Function to create a RepeatDataset
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
/// \param[in] count Number of times the dataset should be repeated
/// \return Shared pointer to the current Dataset
/// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
/// due to a limitation in the current implementation
std::shared_ptr<Dataset> Repeat(int32_t count = -1);
/// \brief Function to create a Shuffle Dataset /// \brief Function to create a Shuffle Dataset
/// \notes Randomly shuffles the rows of this dataset /// \notes Randomly shuffles the rows of this dataset
/// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
...@@ -168,26 +184,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> { ...@@ -168,26 +184,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current SkipDataset /// \return Shared pointer to the current SkipDataset
std::shared_ptr<SkipDataset> Skip(int32_t count); std::shared_ptr<SkipDataset> Skip(int32_t count);
/// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset
/// \param[in] columns The name of columns to project
/// \return Shared pointer to the current Dataset
std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns);
/// \brief Function to create a Zip Dataset /// \brief Function to create a Zip Dataset
/// \notes Applies zip to the dataset /// \notes Applies zip to the dataset
/// \param[in] datasets A list of shared pointer to the datasets that we want to zip /// \param[in] datasets A list of shared pointer to the datasets that we want to zip
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets); std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
/// \brief Function to create a Rename Dataset
/// \notes Renames the columns in the input dataset
/// \param[in] input_columns List of the input columns to rename
/// \param[in] output_columns List of the output columns
/// \return Shared pointer to the current Dataset
std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns);
protected: protected:
std::vector<std::shared_ptr<Dataset>> children; std::vector<std::shared_ptr<Dataset>> children;
std::shared_ptr<Dataset> parent; std::shared_ptr<Dataset> parent;
...@@ -199,6 +201,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> { ...@@ -199,6 +201,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/* ####################################### Derived Dataset classes ################################# */ /* ####################################### Derived Dataset classes ################################# */
class Cifar10Dataset : public Dataset {
public:
/// \brief Constructor
Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler);
/// \brief Destructor
~Cifar10Dataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::string dataset_dir_;
int32_t num_samples_;
std::shared_ptr<SamplerObj> sampler_;
};
/// \class ImageFolderDataset /// \class ImageFolderDataset
/// \brief A Dataset derived class to represent ImageFolder dataset /// \brief A Dataset derived class to represent ImageFolder dataset
class ImageFolderDataset : public Dataset { class ImageFolderDataset : public Dataset {
...@@ -273,13 +297,14 @@ class BatchDataset : public Dataset { ...@@ -273,13 +297,14 @@ class BatchDataset : public Dataset {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_; std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
}; };
class RepeatDataset : public Dataset { class MapDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
explicit RepeatDataset(uint32_t count); MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns = {},
std::vector<std::string> output_columns = {}, const std::vector<std::string> &columns = {});
/// \brief Destructor /// \brief Destructor
~RepeatDataset() = default; ~MapDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
...@@ -290,32 +315,19 @@ class RepeatDataset : public Dataset { ...@@ -290,32 +315,19 @@ class RepeatDataset : public Dataset {
bool ValidateParams() override; bool ValidateParams() override;
private: private:
uint32_t repeat_count_; std::vector<std::shared_ptr<TensorOperation>> operations_;
}; std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
class ShuffleDataset : public Dataset { std::vector<std::string> project_columns_;
public:
ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch);
~ShuffleDataset() = default;
std::vector<std::shared_ptr<DatasetOp>> Build() override;
bool ValidateParams() override;
private:
int32_t shuffle_size_;
uint32_t shuffle_seed_;
bool reset_every_epoch_;
}; };
class SkipDataset : public Dataset { class ProjectDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
explicit SkipDataset(int32_t count); explicit ProjectDataset(const std::vector<std::string> &columns);
/// \brief Destructor /// \brief Destructor
~SkipDataset() = default; ~ProjectDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
...@@ -326,17 +338,16 @@ class SkipDataset : public Dataset { ...@@ -326,17 +338,16 @@ class SkipDataset : public Dataset {
bool ValidateParams() override; bool ValidateParams() override;
private: private:
int32_t skip_count_; std::vector<std::string> columns_;
}; };
class MapDataset : public Dataset { class RenameDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns = {}, explicit RenameDataset(const std::vector<std::string> &input_columns, const std::vector<std::string> &output_columns);
std::vector<std::string> output_columns = {}, const std::vector<std::string> &columns = {});
/// \brief Destructor /// \brief Destructor
~MapDataset() = default; ~RenameDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
...@@ -347,19 +358,17 @@ class MapDataset : public Dataset { ...@@ -347,19 +358,17 @@ class MapDataset : public Dataset {
bool ValidateParams() override; bool ValidateParams() override;
private: private:
std::vector<std::shared_ptr<TensorOperation>> operations_;
std::vector<std::string> input_columns_; std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_; std::vector<std::string> output_columns_;
std::vector<std::string> project_columns_;
}; };
class Cifar10Dataset : public Dataset { class RepeatDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler); explicit RepeatDataset(uint32_t count);
/// \brief Destructor /// \brief Destructor
~Cifar10Dataset() = default; ~RepeatDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
...@@ -370,38 +379,32 @@ class Cifar10Dataset : public Dataset { ...@@ -370,38 +379,32 @@ class Cifar10Dataset : public Dataset {
bool ValidateParams() override; bool ValidateParams() override;
private: private:
std::string dataset_dir_; uint32_t repeat_count_;
int32_t num_samples_;
std::shared_ptr<SamplerObj> sampler_;
}; };
class ProjectDataset : public Dataset { class ShuffleDataset : public Dataset {
public: public:
/// \brief Constructor ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch);
explicit ProjectDataset(const std::vector<std::string> &columns);
/// \brief Destructor ~ShuffleDataset() = default;
~ProjectDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override; bool ValidateParams() override;
private: private:
std::vector<std::string> columns_; int32_t shuffle_size_;
uint32_t shuffle_seed_;
bool reset_every_epoch_;
}; };
class ZipDataset : public Dataset { class SkipDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
ZipDataset(); explicit SkipDataset(int32_t count);
/// \brief Destructor /// \brief Destructor
~ZipDataset() = default; ~SkipDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
...@@ -410,15 +413,18 @@ class ZipDataset : public Dataset { ...@@ -410,15 +413,18 @@ class ZipDataset : public Dataset {
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
bool ValidateParams() override; bool ValidateParams() override;
private:
int32_t skip_count_;
}; };
class RenameDataset : public Dataset { class ZipDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
explicit RenameDataset(const std::vector<std::string> &input_columns, const std::vector<std::string> &output_columns); ZipDataset();
/// \brief Destructor /// \brief Destructor
~RenameDataset() = default; ~ZipDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
...@@ -427,10 +433,6 @@ class RenameDataset : public Dataset { ...@@ -427,10 +433,6 @@ class RenameDataset : public Dataset {
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
bool ValidateParams() override; bool ValidateParams() override;
private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
}; };
} // namespace api } // namespace api
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册