diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index cfb411e8c6fcb7c635fd2e29f9c62add778b0aea..6021b63365b90fb9300121bc9bb99eacd3fb95de 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -1165,7 +1165,7 @@ std::vector> RenameDataset::Build() { return node_ops; } -RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {} +RepeatDataset::RepeatDataset(int32_t count) : repeat_count_(count) {} std::vector> RepeatDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create @@ -1176,8 +1176,8 @@ std::vector> RepeatDataset::Build() { } bool RepeatDataset::ValidateParams() { - if (repeat_count_ <= 0) { - MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative"; + if (repeat_count_ != -1 && repeat_count_ <= 0) { + MS_LOG(ERROR) << "Repeat: Repeat count cannot be" << repeat_count_; return false; } diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index eb5ac1c398bcdba30162b13027b55f8c32bf604c..1904fa15d4e36619ef4ba70955e3793265ce1e88 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -692,7 +692,7 @@ class RenameDataset : public Dataset { class RepeatDataset : public Dataset { public: /// \brief Constructor - explicit RepeatDataset(uint32_t count); + explicit RepeatDataset(int32_t count); /// \brief Destructor ~RepeatDataset() = default; @@ -706,7 +706,7 @@ class RepeatDataset : public Dataset { bool ValidateParams() override; private: - uint32_t repeat_count_; + int32_t repeat_count_; }; class ShuffleDataset : public Dataset { diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 2a615e01f7791d454b66263845039a087da38252..dd813c04baa41b1e39be8752b3f0bd5f6a9ae132 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2104,7 +2104,7 @@ class RepeatDataset(DatasetOp): Args: input_dataset (Dataset): Input Dataset to be repeated. - count (int): Number of times the dataset should be repeated. + count (int): Number of times the dataset should be repeated (default=-1, repeat indefinitely). """ def __init__(self, input_dataset, count): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 2c9b97654f70853f432817864294b05ba0e45dcd..fcc5f267123207e68ed713b1c29751df8168dc61 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -597,7 +597,8 @@ def check_repeat(method): type_check(count, (int, type(None)), "repeat") if isinstance(count, int): - check_value(count, (-1, INT32_MAX), "count") + if (count <= 0 and count != -1) or count > INT32_MAX: + raise ValueError("count should be either -1 or positive integer.") return method(self, *args, **kwargs) return new_method diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index ce78ba2b62ee609a1f461c0b1666c215a7040d1c..1c900598d4ee299fa5558eb244ac83b8a1895cd3 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -431,6 +431,101 @@ TEST_F(MindDataTestPipeline, TestRenameSuccess) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestRepeatDefault) { + MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatDefault."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds,nullptr); + + // Create a Repeat operation on ds + // Default value of repeat count is -1, expected to repeat infinitely + ds = ds->Repeat(); + EXPECT_NE(ds,nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds,nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter,nullptr); + + // Iterate the dataset and get each row + std::unordered_map > row; + iter->GetNextRow(&row); + uint64_t i = 0; + while (row.size()!= 0) { + // manually stop + if(i==100){break;} + i++; + auto image = row["image"]; + MS_LOG(INFO)<< "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i,100); + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestRepeatOne) { + MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatOne."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds,nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 1; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds,nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds,nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter,nullptr); + + // Iterate the dataset and get each row + std::unordered_map > row; + iter->GetNextRow(&row); + uint64_t i = 0; + while (row.size()!= 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO)<< "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i,10); + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestRepeatFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatFail."; + // This case is expected to fail because the repeat count is invalid (<-1 && !=0). + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = -2; + ds = ds->Repeat(repeat_num); + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestShuffleDataset) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestShuffleDataset."; diff --git a/tests/ut/python/dataset/test_repeat.py b/tests/ut/python/dataset/test_repeat.py index a059fc3a9c0b11fba50aeb6d8de4fd9193ff5df2..f55cb76afe702bfa9c26e20045e57648b876f74d 100644 --- a/tests/ut/python/dataset/test_repeat.py +++ b/tests/ut/python/dataset/test_repeat.py @@ -16,7 +16,7 @@ Test Repeat Op """ import numpy as np - +import pytest import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as vision from mindspore import log as logger @@ -295,6 +295,26 @@ def test_repeat_count2(): assert data1_size == 3 assert dataset_size == num1_iter == 8 +def test_repeat_count0(): + """ + Test Repeat with invalid count 0. + """ + logger.info("Test Repeat with invalid count 0") + with pytest.raises(ValueError) as info: + data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) + data1.repeat(0) + assert "count" in str(info) + +def test_repeat_countneg2(): + """ + Test Repeat with invalid count -2. + """ + logger.info("Test Repeat with invalid count -2") + with pytest.raises(ValueError) as info: + data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) + data1.repeat(-2) + assert "count" in str(info) + if __name__ == "__main__": test_tf_repeat_01() test_tf_repeat_02() @@ -313,3 +333,5 @@ if __name__ == "__main__": test_nested_repeat11() test_repeat_count1() test_repeat_count2() + test_repeat_count0() + test_repeat_countneg2()