提交 e07f7436 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3299 Rename for C API

Merge pull request !3299 from MahdiRahmaniHanzaki/rename-c-api
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "minddata/dataset/engine/datasetops/skip_op.h" #include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h" #include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
...@@ -40,12 +41,13 @@ namespace mindspore { ...@@ -40,12 +41,13 @@ namespace mindspore {
namespace dataset { namespace dataset {
namespace api { namespace api {
#define RETURN_NULL_IF_ERROR(_s) \ #define RETURN_EMPTY_IF_ERROR(_s) \
do { \ do { \
Status __rc = (_s); \ Status __rc = (_s); \
if (__rc.IsError()) { \ if (__rc.IsError()) { \
return nullptr; \ MS_LOG(ERROR) << __rc; \
} \ return {}; \
} \
} while (false) } while (false)
// Function to create the iterator, which will build and launch the execution tree. // Function to create the iterator, which will build and launch the execution tree.
...@@ -55,8 +57,7 @@ std::shared_ptr<Iterator> Dataset::CreateIterator() { ...@@ -55,8 +57,7 @@ std::shared_ptr<Iterator> Dataset::CreateIterator() {
iter = std::make_shared<Iterator>(); iter = std::make_shared<Iterator>();
Status rc = iter->BuildAndLaunchTree(shared_from_this()); Status rc = iter->BuildAndLaunchTree(shared_from_this());
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << rc; MS_LOG(ERROR) << "CreateIterator failed." << rc;
MS_LOG(ERROR) << "CreateIterator failed.";
return nullptr; return nullptr;
} }
...@@ -201,6 +202,20 @@ std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> ...@@ -201,6 +202,20 @@ std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string>
return ds; return ds;
} }
// Function to create a RenameDataset.
std::shared_ptr<RenameDataset> Dataset::Rename(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns) {
auto ds = std::make_shared<RenameDataset>(input_columns, output_columns);
// Call derived class validation method.
if (!ds->ValidateParams()) {
return nullptr;
}
ds->children.push_back(shared_from_this());
return ds;
}
// Function to create a Zip dataset // Function to create a Zip dataset
std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) { std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
// Default values // Default values
...@@ -244,7 +259,7 @@ bool ImageFolderDataset::ValidateParams() { ...@@ -244,7 +259,7 @@ bool ImageFolderDataset::ValidateParams() {
return true; return true;
} }
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ImageFolderDataset::Build() { std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
...@@ -257,14 +272,14 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ImageFolderDataset::Bui ...@@ -257,14 +272,14 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ImageFolderDataset::Bui
// This arg is exist in ImageFolderOp, but not externalized (in Python API). // This arg is exist in ImageFolderOp, but not externalized (in Python API).
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar(); TensorShape scalar = TensorShape::CreateScalar();
RETURN_NULL_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_NULL_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema), recursive_, decode_, exts_, class_indexing_, std::move(schema),
std::move(sampler_->Build()))); std::move(sampler_->Build())));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
} }
MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler) MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
...@@ -279,7 +294,7 @@ bool MnistDataset::ValidateParams() { ...@@ -279,7 +294,7 @@ bool MnistDataset::ValidateParams() {
return true; return true;
} }
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MnistDataset::Build() { std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
...@@ -290,14 +305,14 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MnistDataset::Build() { ...@@ -290,14 +305,14 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MnistDataset::Build() {
// Do internal Schema generation. // Do internal Schema generation.
auto schema = std::make_unique<DataSchema>(); auto schema = std::make_unique<DataSchema>();
RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar(); TensorShape scalar = TensorShape::CreateScalar();
RETURN_NULL_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<MnistOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, node_ops.push_back(std::make_shared<MnistOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
std::move(schema), std::move(sampler_->Build()))); std::move(schema), std::move(sampler_->Build())));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
} }
BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map, BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map,
...@@ -308,7 +323,7 @@ BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, st ...@@ -308,7 +323,7 @@ BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, st
cols_to_map_(cols_to_map), cols_to_map_(cols_to_map),
pad_map_(pad_map) {} pad_map_(pad_map) {}
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> BatchDataset::Build() { std::vector<std::shared_ptr<DatasetOp>> BatchDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
...@@ -320,11 +335,12 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> BatchDataset::Build() { ...@@ -320,11 +335,12 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> BatchDataset::Build() {
node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
cols_to_map_, pad_map_)); cols_to_map_, pad_map_));
#endif #endif
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
} }
bool BatchDataset::ValidateParams() { bool BatchDataset::ValidateParams() {
if (batch_size_ <= 0) { if (batch_size_ <= 0) {
MS_LOG(ERROR) << "Batch: Batch size cannot be negative";
return false; return false;
} }
...@@ -333,16 +349,17 @@ bool BatchDataset::ValidateParams() { ...@@ -333,16 +349,17 @@ bool BatchDataset::ValidateParams() {
RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {} RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {}
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> RepeatDataset::Build() { std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_)); node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
} }
bool RepeatDataset::ValidateParams() { bool RepeatDataset::ValidateParams() {
if (repeat_count_ <= 0) { if (repeat_count_ <= 0) {
MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative";
return false; return false;
} }
...@@ -355,7 +372,7 @@ MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, ...@@ -355,7 +372,7 @@ MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations,
output_columns_(output_columns), output_columns_(output_columns),
project_columns_(project_columns) {} project_columns_(project_columns) {}
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MapDataset::Build() { std::vector<std::shared_ptr<DatasetOp>> MapDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
...@@ -380,11 +397,12 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MapDataset::Build() { ...@@ -380,11 +397,12 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MapDataset::Build() {
} }
node_ops.push_back(map_op); node_ops.push_back(map_op);
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
} }
bool MapDataset::ValidateParams() { bool MapDataset::ValidateParams() {
if (operations_.empty()) { if (operations_.empty()) {
MS_LOG(ERROR) << "Map: No operation is specified.";
return false; return false;
} }
...@@ -396,13 +414,13 @@ ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch) ...@@ -396,13 +414,13 @@ ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {} : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {}
// Function to build the ShuffleOp // Function to build the ShuffleOp
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ShuffleDataset::Build() { std::vector<std::shared_ptr<DatasetOp>> ShuffleDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
rows_per_buffer_)); rows_per_buffer_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
} }
// Function to validate the parameters for ShuffleDataset // Function to validate the parameters for ShuffleDataset
...@@ -419,12 +437,12 @@ bool ShuffleDataset::ValidateParams() { ...@@ -419,12 +437,12 @@ bool ShuffleDataset::ValidateParams() {
SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {} SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {}
// Function to build the SkipOp // Function to build the SkipOp
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> SkipDataset::Build() { std::vector<std::shared_ptr<DatasetOp>> SkipDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_)); node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
} }
// Function to validate the parameters for SkipDataset // Function to validate the parameters for SkipDataset
...@@ -454,7 +472,7 @@ bool Cifar10Dataset::ValidateParams() { ...@@ -454,7 +472,7 @@ bool Cifar10Dataset::ValidateParams() {
} }
// Function to build CifarOp // Function to build CifarOp
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Cifar10Dataset::Build() { std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
...@@ -465,15 +483,15 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Cifar10Dataset::Build() ...@@ -465,15 +483,15 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Cifar10Dataset::Build()
// Do internal Schema generation. // Do internal Schema generation.
auto schema = std::make_unique<DataSchema>(); auto schema = std::make_unique<DataSchema>();
RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar(); TensorShape scalar = TensorShape::CreateScalar();
RETURN_NULL_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_, node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema), dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build()))); std::move(sampler_->Build())));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
} }
// Function to build ProjectOp // Function to build ProjectOp
...@@ -487,12 +505,12 @@ bool ProjectDataset::ValidateParams() { ...@@ -487,12 +505,12 @@ bool ProjectDataset::ValidateParams() {
return true; return true;
} }
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ProjectDataset::Build() { std::vector<std::shared_ptr<DatasetOp>> ProjectDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ProjectOp>(columns_)); node_ops.push_back(std::make_shared<ProjectOp>(columns_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
} }
// Function to build ZipOp // Function to build ZipOp
...@@ -500,12 +518,37 @@ ZipDataset::ZipDataset() {} ...@@ -500,12 +518,37 @@ ZipDataset::ZipDataset() {}
bool ZipDataset::ValidateParams() { return true; } bool ZipDataset::ValidateParams() { return true; }
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ZipDataset::Build() { std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_)); node_ops.push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); return node_ops;
}
// Function to build RenameOp
RenameDataset::RenameDataset(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {}
bool RenameDataset::ValidateParams() {
if (input_columns_.empty() || output_columns_.empty()) {
MS_LOG(ERROR) << "input and output columns must be specified";
return false;
}
if (input_columns_.size() != output_columns_.size()) {
MS_LOG(ERROR) << "input and output columns must be the same size";
return false;
}
return true;
}
std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops;
} }
} // namespace api } // namespace api
......
...@@ -56,7 +56,13 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { ...@@ -56,7 +56,13 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
RETURN_STATUS_UNEXPECTED("Input is null pointer"); RETURN_STATUS_UNEXPECTED("Input is null pointer");
} else { } else {
// Convert the current root node. // Convert the current root node.
auto root_op = ds->Build()->front(); auto root_ops = ds->Build();
if (root_ops.empty()) {
RETURN_STATUS_UNEXPECTED("Node operation returned nothing");
}
auto root_op = root_ops.front();
RETURN_UNEXPECTED_IF_NULL(root_op); RETURN_UNEXPECTED_IF_NULL(root_op);
RETURN_IF_NOT_OK(tree_->AssociateNode(root_op)); RETURN_IF_NOT_OK(tree_->AssociateNode(root_op));
...@@ -70,20 +76,22 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { ...@@ -70,20 +76,22 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
// Iterate through all the direct children of the first element in our BFS queue // Iterate through all the direct children of the first element in our BFS queue
for (auto child : node_pair.first->children) { for (auto child : node_pair.first->children) {
auto child_ops = child->Build(); auto child_ops = child->Build();
RETURN_UNEXPECTED_IF_NULL(child_ops); if (child_ops.empty()) {
RETURN_STATUS_UNEXPECTED("Node operation returned nothing");
}
auto node_op = node_pair.second; auto node_op = node_pair.second;
// Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them // Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them
// with the execution tree and add the child and parent relationship between the nodes // with the execution tree and add the child and parent relationship between the nodes
// Note that some Dataset objects might return more than one DatasetOps // Note that some Dataset objects might return more than one DatasetOps
// e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset // e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset
for (auto child_op : *child_ops) { for (auto child_op : child_ops) {
RETURN_IF_NOT_OK(tree_->AssociateNode(child_op)); RETURN_IF_NOT_OK(tree_->AssociateNode(child_op));
RETURN_IF_NOT_OK(node_op->AddChild(child_op)); RETURN_IF_NOT_OK(node_op->AddChild(child_op));
node_op = child_op; node_op = child_op;
} }
// Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current // Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current
// execution tree) to the BFS queue // execution tree) to the BFS queue
q.push(std::make_pair(child, child_ops->back())); q.push(std::make_pair(child, child_ops.back()));
} }
} }
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
......
...@@ -50,6 +50,7 @@ class SkipDataset; ...@@ -50,6 +50,7 @@ class SkipDataset;
class Cifar10Dataset; class Cifar10Dataset;
class ProjectDataset; class ProjectDataset;
class ZipDataset; class ZipDataset;
class RenameDataset;
/// \brief Function to create an ImageFolderDataset /// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories /// \notes A source dataset that reads images from a tree of directories
...@@ -98,8 +99,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> { ...@@ -98,8 +99,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
~Dataset() = default; ~Dataset() = default;
/// \brief Pure virtual function to convert a Dataset class into a runtime dataset object /// \brief Pure virtual function to convert a Dataset class into a runtime dataset object
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
virtual std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() = 0; virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
/// \brief Pure virtual function for derived class to implement parameters validation /// \brief Pure virtual function for derived class to implement parameters validation
/// \return bool True if all the params are valid /// \return bool True if all the params are valid
...@@ -179,6 +180,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> { ...@@ -179,6 +180,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets); std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
/// \brief Function to create a Rename Dataset
/// \notes Renames the columns in the input dataset
/// \param[in] input_columns List of the input columns to rename
/// \param[in] output_columns List of the output columns
/// \return Shared pointer to the current Dataset
std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns);
protected: protected:
std::vector<std::shared_ptr<Dataset>> children; std::vector<std::shared_ptr<Dataset>> children;
std::shared_ptr<Dataset> parent; std::shared_ptr<Dataset> parent;
...@@ -202,8 +211,8 @@ class ImageFolderDataset : public Dataset { ...@@ -202,8 +211,8 @@ class ImageFolderDataset : public Dataset {
~ImageFolderDataset() = default; ~ImageFolderDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
...@@ -227,8 +236,8 @@ class MnistDataset : public Dataset { ...@@ -227,8 +236,8 @@ class MnistDataset : public Dataset {
~MnistDataset() = default; ~MnistDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
...@@ -249,8 +258,8 @@ class BatchDataset : public Dataset { ...@@ -249,8 +258,8 @@ class BatchDataset : public Dataset {
~BatchDataset() = default; ~BatchDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
...@@ -273,8 +282,8 @@ class RepeatDataset : public Dataset { ...@@ -273,8 +282,8 @@ class RepeatDataset : public Dataset {
~RepeatDataset() = default; ~RepeatDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
...@@ -290,7 +299,7 @@ class ShuffleDataset : public Dataset { ...@@ -290,7 +299,7 @@ class ShuffleDataset : public Dataset {
~ShuffleDataset() = default; ~ShuffleDataset() = default;
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
bool ValidateParams() override; bool ValidateParams() override;
...@@ -309,8 +318,8 @@ class SkipDataset : public Dataset { ...@@ -309,8 +318,8 @@ class SkipDataset : public Dataset {
~SkipDataset() = default; ~SkipDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
...@@ -330,8 +339,8 @@ class MapDataset : public Dataset { ...@@ -330,8 +339,8 @@ class MapDataset : public Dataset {
~MapDataset() = default; ~MapDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
...@@ -353,8 +362,8 @@ class Cifar10Dataset : public Dataset { ...@@ -353,8 +362,8 @@ class Cifar10Dataset : public Dataset {
~Cifar10Dataset() = default; ~Cifar10Dataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
...@@ -375,8 +384,8 @@ class ProjectDataset : public Dataset { ...@@ -375,8 +384,8 @@ class ProjectDataset : public Dataset {
~ProjectDataset() = default; ~ProjectDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
...@@ -395,12 +404,33 @@ class ZipDataset : public Dataset { ...@@ -395,12 +404,33 @@ class ZipDataset : public Dataset {
~ZipDataset() = default; ~ZipDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
};
class RenameDataset : public Dataset {
public:
/// \brief Constructor
explicit RenameDataset(const std::vector<std::string> &input_columns, const std::vector<std::string> &output_columns);
/// \brief Destructor
~RenameDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation /// \brief Parameters validation
/// \return bool true if all the params are valid /// \return bool true if all the params are valid
bool ValidateParams() override; bool ValidateParams() override;
private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
}; };
} // namespace api } // namespace api
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册