diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index d8191f462fcd7fee85adbfe1d695ab99ca080c7a..92222e541625381fea9473b7899f3c78f429ab35 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -86,9 +86,16 @@ Dataset::Dataset() { // (In alphabetical order) // Function to create a Cifar10Dataset. -std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, - std::shared_ptr sampler) { - auto ds = std::make_shared(dataset_dir, num_samples, sampler); +std::shared_ptr Cifar10(const std::string &dataset_dir, std::shared_ptr sampler) { + auto ds = std::make_shared(dataset_dir, sampler); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a Cifar100Dataset. +std::shared_ptr Cifar100(const std::string &dataset_dir, std::shared_ptr sampler) { + auto ds = std::make_shared(dataset_dir, sampler); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -250,28 +257,27 @@ std::shared_ptr CreateDefaultSampler() { return std::make_shared(replacement, num_samples); } +// Helper function to validate dataset params +bool ValidateCommonDatasetParams(std::string dataset_dir) { + if (dataset_dir.empty()) { + MS_LOG(ERROR) << "No dataset path is specified"; + return false; + } + return true; +} + /* ####################################### Derived Dataset classes ################################# */ // DERIVED DATASET CLASSES LEAF-NODE DATASETS // (In alphabetical order) // Constructor for Cifar10Dataset -Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler) - : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} +Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr sampler) + : dataset_dir_(dataset_dir), sampler_(sampler) {} -bool Cifar10Dataset::ValidateParams() { - if (dataset_dir_.empty()) { - MS_LOG(ERROR) << "No dataset path is specified."; - return false; - } - if (num_samples_ < 0) { - MS_LOG(ERROR) << "Number of samples cannot be negative"; - return false; - } - return true; -} +bool Cifar10Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } -// Function to build CifarOp +// Function to build CifarOp for Cifar10 std::vector> Cifar10Dataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -294,6 +300,37 @@ std::vector> Cifar10Dataset::Build() { return node_ops; } +// Constructor for Cifar100Dataset +Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr sampler) + : dataset_dir_(dataset_dir), sampler_(sampler) {} + +bool Cifar100Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } + +// Function to build CifarOp for Cifar100 +std::vector> Cifar100Dataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler based on the shuffle variable. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_EMPTY_IF_ERROR( + schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + RETURN_EMPTY_IF_ERROR( + schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + node_ops.push_back(std::make_shared(CifarOp::CifarType::kCifar100, num_workers_, rows_per_buffer_, + dataset_dir_, connector_que_size_, std::move(schema), + std::move(sampler_->Build()))); + return node_ops; +} + ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, std::set extensions, std::map class_indexing) @@ -304,14 +341,7 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std class_indexing_(class_indexing), exts_(extensions) {} -bool ImageFolderDataset::ValidateParams() { - if (dataset_dir_.empty()) { - MS_LOG(ERROR) << "No dataset path is specified."; - return false; - } - - return true; -} +bool ImageFolderDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } std::vector> ImageFolderDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create @@ -339,14 +369,7 @@ std::vector> ImageFolderDataset::Build() { MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr sampler) : dataset_dir_(dataset_dir), sampler_(sampler) {} -bool MnistDataset::ValidateParams() { - if (dataset_dir_.empty()) { - MS_LOG(ERROR) << "No dataset path is specified."; - return false; - } - - return true; -} +bool MnistDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } std::vector> MnistDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 595404dd038d84495aa90f66927417b0df8931a0..14f8233ef293a466be4d72ca108c691dea60c6c5 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -42,6 +42,7 @@ class TensorOperation; class SamplerObj; // Datasets classes (in alphabetical order) class Cifar10Dataset; +class Cifar100Dataset; class ImageFolderDataset; class MnistDataset; // Dataset Op classes (in alphabetical order) @@ -57,12 +58,19 @@ class ZipDataset; /// \brief Function to create a Cifar10 Dataset /// \notes The generated dataset has two columns ['image', 'label'] /// \param[in] dataset_dir Path to the root directory that contains the dataset -/// \param[in] num_samples The number of images to be included in the dataset /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset -std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, - std::shared_ptr sampler); +std::shared_ptr Cifar10(const std::string &dataset_dir, std::shared_ptr sampler = nullptr); + +/// \brief Function to create a Cifar100 Dataset +/// \notes The generated dataset has two columns ['image', 'coarse_label', 'fine_label'] +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` +/// will be used to randomly iterate the entire dataset +/// \return Shared pointer to the current Dataset +std::shared_ptr Cifar100(const std::string &dataset_dir, + std::shared_ptr sampler = nullptr); /// \brief Function to create an ImageFolderDataset /// \notes A source dataset that reads images from a tree of directories @@ -204,7 +212,7 @@ class Dataset : public std::enable_shared_from_this { class Cifar10Dataset : public Dataset { public: /// \brief Constructor - Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler); + Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr sampler); /// \brief Destructor ~Cifar10Dataset() = default; @@ -219,7 +227,27 @@ class Cifar10Dataset : public Dataset { private: std::string dataset_dir_; - int32_t num_samples_; + std::shared_ptr sampler_; +}; + +class Cifar100Dataset : public Dataset { + public: + /// \brief Constructor + Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr sampler); + + /// \brief Destructor + ~Cifar100Dataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; std::shared_ptr sampler_; }; diff --git a/tests/ut/cpp/dataset/c_api_test.cc b/tests/ut/cpp/dataset/c_api_test.cc index 5ff1862cfd68f227a2c4762bef3a77d9c14bce56..8544c5c569c6b405785f1d7cd16b23246ccd9829 100644 --- a/tests/ut/cpp/dataset/c_api_test.cc +++ b/tests/ut/cpp/dataset/c_api_test.cc @@ -84,6 +84,12 @@ TEST_F(MindDataTestPipeline, TestBatchAndRepeat) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestMnistFail1) { + // Create a Mnist Dataset + std::shared_ptr ds = Mnist("", RandomSampler(false, 10)); + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; @@ -274,6 +280,12 @@ TEST_F(MindDataTestPipeline, TestImageFolderBatchAndRepeat) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestImageFolderFail1) { + // Create an ImageFolder Dataset + std::shared_ptr ds = ImageFolder("", true, nullptr); + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { std::shared_ptr sampl = DistributedSampler(2, 1); EXPECT_NE(sampl, nullptr); @@ -630,17 +642,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, 0, RandomSampler(false, 10)); - EXPECT_NE(ds, nullptr); - - // Create a Repeat operation on ds - int32_t repeat_num = 2; - ds = ds->Repeat(repeat_num); - EXPECT_NE(ds, nullptr); - - // Create a Batch operation on ds - int32_t batch_size = 2; - ds = ds->Batch(batch_size); + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -652,6 +654,9 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) { std::unordered_map> row; iter->GetNextRow(&row); + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + uint64_t i = 0; while (row.size() != 0) { i++; @@ -666,6 +671,54 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCifar10DatasetFail1) { + + // Create a Cifar10 Dataset + std::shared_ptr ds = Cifar10("", RandomSampler(false, 10)); + EXPECT_EQ(ds, nullptr); +} + +TEST_F(MindDataTestPipeline, TestCifar100Dataset) { + + // Create a Cifar100 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; + std::shared_ptr ds = Cifar100(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("coarse_label"), row.end()); + EXPECT_NE(row.find("fine_label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) { + + // Create a Cifar100 Dataset + std::shared_ptr ds = Cifar100("", RandomSampler(false, 10)); + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; @@ -843,7 +896,7 @@ TEST_F(MindDataTestPipeline, TestZipSuccess) { EXPECT_NE(ds1, nullptr); folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds2 = Cifar10(folder_path, 0, RandomSampler(false, 10)); + std::shared_ptr ds2 = Cifar10(folder_path, RandomSampler(false, 10)); EXPECT_NE(ds2, nullptr); // Create a Project operation on ds