提交 8921a609 编写于 作者: C Cathy Wong

C++ API Support for Skip Dataset Op and UTs

上级 4bbbf2dc
......@@ -27,6 +27,7 @@
#include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
......@@ -173,6 +174,20 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
return ds;
}
// Function to create a SkipDataset.
std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
auto ds = std::make_shared<SkipDataset>(count);
// Call derived class validation method.
if (!ds->ValidateParams()) {
return nullptr;
}
ds->children.push_back(shared_from_this());
return ds;
}
// Function to create a ProjectDataset.
std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
auto ds = std::make_shared<ProjectDataset>(columns);
......@@ -400,6 +415,28 @@ bool ShuffleDataset::ValidateParams() {
return true;
}
// Constructor for SkipDataset
SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {}
// Function to build the SkipOp
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> SkipDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
}
// Function to validate the parameters for SkipDataset
bool SkipDataset::ValidateParams() {
if (skip_count_ <= -1) {
MS_LOG(ERROR) << "Skip: Invalid input, skip_count: " << skip_count_;
return false;
}
return true;
}
// Constructor for Cifar10Dataset
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
......
......@@ -46,6 +46,7 @@ class BatchDataset;
class RepeatDataset;
class MapDataset;
class ShuffleDataset;
class SkipDataset;
class Cifar10Dataset;
class ProjectDataset;
class ZipDataset;
......@@ -160,6 +161,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current ShuffleDataset
std::shared_ptr<ShuffleDataset> Shuffle(int32_t shuffle_size);
/// \brief Function to create a SkipDataset
/// \notes Skips count elements in this dataset.
/// \param[in] count Number of elements the dataset to be skipped.
/// \return Shared pointer to the current SkipDataset
std::shared_ptr<SkipDataset> Skip(int32_t count);
/// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset
/// \param[in] columns The name of columns to project
......@@ -293,6 +300,26 @@ class ShuffleDataset : public Dataset {
bool reset_every_epoch_;
};
class SkipDataset : public Dataset {
public:
/// \brief Constructor
explicit SkipDataset(int32_t count);
/// \brief Destructor
~SkipDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
int32_t skip_count_;
};
class MapDataset : public Dataset {
public:
/// \brief Constructor
......
......@@ -2094,8 +2094,8 @@ class SkipDataset(DatasetOp):
The result of applying Skip operator to the input Dataset.
Args:
input_dataset (tuple): A tuple of datasets to be skipped.
count (int): Number of rows the dataset should be skipped.
input_dataset (Dataset): Input dataset to have rows skipped.
count (int): Number of rows in the dataset to be skipped.
"""
def __init__(self, input_dataset, count):
......
......@@ -573,6 +573,59 @@ TEST_F(MindDataTestPipeline, TestShuffleDataset) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestSkipDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDataset.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_TRUE(ds != nullptr);
// Create a Skip operation on ds
int32_t count = 3;
ds = ds->Skip(count);
EXPECT_TRUE(ds != nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_TRUE(iter != nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
MS_LOG(INFO) << "Number of rows: " << i;
// Expect 10-3=7 rows
EXPECT_TRUE(i == 7);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestSkipDatasetError1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDatasetError1.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_TRUE(ds != nullptr);
// Create a Skip operation on ds with invalid count input
int32_t count = -1;
ds = ds->Skip(count);
// Expect nullptr for invalid input skip_count
EXPECT_TRUE(ds == nullptr);
}
TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
// Create a Cifar10 Dataset
......
......@@ -13,9 +13,12 @@
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
......@@ -196,6 +199,29 @@ def test_skip_filter_2():
assert buf == [5, 6, 7, 8, 9, 10]
def test_skip_exception_1():
data1 = ds.GeneratorDataset(generator_md, ["data"])
try:
data1 = data1.skip(count=-1)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter += 1
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Skip count must be positive integer or 0." in str(e)
def test_skip_exception_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
with pytest.raises(ValueError) as e:
ds1 = ds1.skip(-2)
assert "Input count is not within the required interval" in str(e.value)
if __name__ == "__main__":
test_tf_skip()
test_generator_skip()
......@@ -208,3 +234,5 @@ if __name__ == "__main__":
test_skip_take_2()
test_skip_filter_1()
test_skip_filter_2()
test_skip_exception_1()
test_skip_exception_2()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册