提交 e07f7436 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3299 Rename for C API

Merge pull request !3299 from MahdiRahmaniHanzaki/rename-c-api
......@@ -30,6 +30,7 @@
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
......@@ -40,12 +41,13 @@ namespace mindspore {
namespace dataset {
namespace api {
#define RETURN_NULL_IF_ERROR(_s) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
return nullptr; \
} \
#define RETURN_EMPTY_IF_ERROR(_s) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
MS_LOG(ERROR) << __rc; \
return {}; \
} \
} while (false)
// Function to create the iterator, which will build and launch the execution tree.
......@@ -55,8 +57,7 @@ std::shared_ptr<Iterator> Dataset::CreateIterator() {
iter = std::make_shared<Iterator>();
Status rc = iter->BuildAndLaunchTree(shared_from_this());
if (rc.IsError()) {
MS_LOG(ERROR) << rc;
MS_LOG(ERROR) << "CreateIterator failed.";
MS_LOG(ERROR) << "CreateIterator failed." << rc;
return nullptr;
}
......@@ -201,6 +202,20 @@ std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string>
return ds;
}
// Function to create a RenameDataset.
std::shared_ptr<RenameDataset> Dataset::Rename(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns) {
auto ds = std::make_shared<RenameDataset>(input_columns, output_columns);
// Call derived class validation method.
if (!ds->ValidateParams()) {
return nullptr;
}
ds->children.push_back(shared_from_this());
return ds;
}
// Function to create a Zip dataset
std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
// Default values
......@@ -244,7 +259,7 @@ bool ImageFolderDataset::ValidateParams() {
return true;
}
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ImageFolderDataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
......@@ -257,14 +272,14 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ImageFolderDataset::Bui
// This arg is exist in ImageFolderOp, but not externalized (in Python API).
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
RETURN_NULL_IF_ERROR(
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_NULL_IF_ERROR(
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema),
std::move(sampler_->Build())));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
......@@ -279,7 +294,7 @@ bool MnistDataset::ValidateParams() {
return true;
}
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MnistDataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
......@@ -290,14 +305,14 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MnistDataset::Build() {
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_NULL_IF_ERROR(
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<MnistOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
std::move(schema), std::move(sampler_->Build())));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map,
......@@ -308,7 +323,7 @@ BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, st
cols_to_map_(cols_to_map),
pad_map_(pad_map) {}
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> BatchDataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> BatchDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
......@@ -320,11 +335,12 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> BatchDataset::Build() {
node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
cols_to_map_, pad_map_));
#endif
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
bool BatchDataset::ValidateParams() {
if (batch_size_ <= 0) {
MS_LOG(ERROR) << "Batch: Batch size cannot be negative";
return false;
}
......@@ -333,16 +349,17 @@ bool BatchDataset::ValidateParams() {
RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {}
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> RepeatDataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
bool RepeatDataset::ValidateParams() {
if (repeat_count_ <= 0) {
MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative";
return false;
}
......@@ -355,7 +372,7 @@ MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations,
output_columns_(output_columns),
project_columns_(project_columns) {}
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MapDataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> MapDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
......@@ -380,11 +397,12 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MapDataset::Build() {
}
node_ops.push_back(map_op);
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
bool MapDataset::ValidateParams() {
if (operations_.empty()) {
MS_LOG(ERROR) << "Map: No operation is specified.";
return false;
}
......@@ -396,13 +414,13 @@ ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {}
// Function to build the ShuffleOp
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ShuffleDataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> ShuffleDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
rows_per_buffer_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
// Function to validate the parameters for ShuffleDataset
......@@ -419,12 +437,12 @@ bool ShuffleDataset::ValidateParams() {
SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {}
// Function to build the SkipOp
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> SkipDataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> SkipDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
// Function to validate the parameters for SkipDataset
......@@ -454,7 +472,7 @@ bool Cifar10Dataset::ValidateParams() {
}
// Function to build CifarOp
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Cifar10Dataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
......@@ -465,15 +483,15 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Cifar10Dataset::Build()
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_NULL_IF_ERROR(
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build())));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
// Function to build ProjectOp
......@@ -487,12 +505,12 @@ bool ProjectDataset::ValidateParams() {
return true;
}
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ProjectDataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> ProjectDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ProjectOp>(columns_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
// Function to build ZipOp
......@@ -500,12 +518,37 @@ ZipDataset::ZipDataset() {}
bool ZipDataset::ValidateParams() { return true; }
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ZipDataset::Build() {
std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
return node_ops;
}
// Function to build RenameOp
RenameDataset::RenameDataset(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {}
bool RenameDataset::ValidateParams() {
if (input_columns_.empty() || output_columns_.empty()) {
MS_LOG(ERROR) << "input and output columns must be specified";
return false;
}
if (input_columns_.size() != output_columns_.size()) {
MS_LOG(ERROR) << "input and output columns must be the same size";
return false;
}
return true;
}
std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops;
}
} // namespace api
......
......@@ -56,7 +56,13 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
RETURN_STATUS_UNEXPECTED("Input is null pointer");
} else {
// Convert the current root node.
auto root_op = ds->Build()->front();
auto root_ops = ds->Build();
if (root_ops.empty()) {
RETURN_STATUS_UNEXPECTED("Node operation returned nothing");
}
auto root_op = root_ops.front();
RETURN_UNEXPECTED_IF_NULL(root_op);
RETURN_IF_NOT_OK(tree_->AssociateNode(root_op));
......@@ -70,20 +76,22 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
// Iterate through all the direct children of the first element in our BFS queue
for (auto child : node_pair.first->children) {
auto child_ops = child->Build();
RETURN_UNEXPECTED_IF_NULL(child_ops);
if (child_ops.empty()) {
RETURN_STATUS_UNEXPECTED("Node operation returned nothing");
}
auto node_op = node_pair.second;
// Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them
// with the execution tree and add the child and parent relationship between the nodes
// Note that some Dataset objects might return more than one DatasetOps
// e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset
for (auto child_op : *child_ops) {
for (auto child_op : child_ops) {
RETURN_IF_NOT_OK(tree_->AssociateNode(child_op));
RETURN_IF_NOT_OK(node_op->AddChild(child_op));
node_op = child_op;
}
// Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current
// execution tree) to the BFS queue
q.push(std::make_pair(child, child_ops->back()));
q.push(std::make_pair(child, child_ops.back()));
}
}
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
......
......@@ -50,6 +50,7 @@ class SkipDataset;
class Cifar10Dataset;
class ProjectDataset;
class ZipDataset;
class RenameDataset;
/// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories
......@@ -98,8 +99,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
~Dataset() = default;
/// \brief Pure virtual function to convert a Dataset class into a runtime dataset object
/// \return shared pointer to the list of newly created DatasetOps
virtual std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() = 0;
/// \return The list of shared pointers to the newly created DatasetOps
virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
/// \brief Pure virtual function for derived class to implement parameters validation
/// \return bool True if all the params are valid
......@@ -179,6 +180,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current Dataset
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
/// \brief Function to create a Rename Dataset
/// \notes Renames the columns in the input dataset
/// \param[in] input_columns List of the input columns to rename
/// \param[in] output_columns List of the output columns
/// \return Shared pointer to the current Dataset
std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns);
protected:
std::vector<std::shared_ptr<Dataset>> children;
std::shared_ptr<Dataset> parent;
......@@ -202,8 +211,8 @@ class ImageFolderDataset : public Dataset {
~ImageFolderDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
......@@ -227,8 +236,8 @@ class MnistDataset : public Dataset {
~MnistDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
......@@ -249,8 +258,8 @@ class BatchDataset : public Dataset {
~BatchDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
......@@ -273,8 +282,8 @@ class RepeatDataset : public Dataset {
~RepeatDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
......@@ -290,7 +299,7 @@ class ShuffleDataset : public Dataset {
~ShuffleDataset() = default;
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
std::vector<std::shared_ptr<DatasetOp>> Build() override;
bool ValidateParams() override;
......@@ -309,8 +318,8 @@ class SkipDataset : public Dataset {
~SkipDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
......@@ -330,8 +339,8 @@ class MapDataset : public Dataset {
~MapDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
......@@ -353,8 +362,8 @@ class Cifar10Dataset : public Dataset {
~Cifar10Dataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
......@@ -375,8 +384,8 @@ class ProjectDataset : public Dataset {
~ProjectDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
......@@ -395,12 +404,33 @@ class ZipDataset : public Dataset {
~ZipDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
};
class RenameDataset : public Dataset {
public:
/// \brief Constructor
explicit RenameDataset(const std::vector<std::string> &input_columns, const std::vector<std::string> &output_columns);
/// \brief Destructor
~RenameDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
};
} // namespace api
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册