提交 0868720e 编写于 作者: T tinazhang

fix parameter type for repeat op in c++ api and added c++/python ut.

上级 f37a2fa4
......@@ -1165,7 +1165,7 @@ std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
return node_ops;
}
RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {}
RepeatDataset::RepeatDataset(int32_t count) : repeat_count_(count) {}
std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
......@@ -1176,8 +1176,8 @@ std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
}
bool RepeatDataset::ValidateParams() {
if (repeat_count_ <= 0) {
MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative";
if (repeat_count_ != -1 && repeat_count_ <= 0) {
MS_LOG(ERROR) << "Repeat: Repeat count cannot be" << repeat_count_;
return false;
}
......
......@@ -692,7 +692,7 @@ class RenameDataset : public Dataset {
class RepeatDataset : public Dataset {
public:
/// \brief Constructor
explicit RepeatDataset(uint32_t count);
explicit RepeatDataset(int32_t count);
/// \brief Destructor
~RepeatDataset() = default;
......@@ -706,7 +706,7 @@ class RepeatDataset : public Dataset {
bool ValidateParams() override;
private:
uint32_t repeat_count_;
int32_t repeat_count_;
};
class ShuffleDataset : public Dataset {
......
......@@ -2104,7 +2104,7 @@ class RepeatDataset(DatasetOp):
Args:
input_dataset (Dataset): Input Dataset to be repeated.
count (int): Number of times the dataset should be repeated.
count (int): Number of times the dataset should be repeated (default=-1, repeat indefinitely).
"""
def __init__(self, input_dataset, count):
......
......@@ -597,7 +597,8 @@ def check_repeat(method):
type_check(count, (int, type(None)), "repeat")
if isinstance(count, int):
check_value(count, (-1, INT32_MAX), "count")
if (count <= 0 and count != -1) or count > INT32_MAX:
raise ValueError("count should be either -1 or positive integer.")
return method(self, *args, **kwargs)
return new_method
......
......@@ -431,6 +431,101 @@ TEST_F(MindDataTestPipeline, TestRenameSuccess) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRepeatDefault) {
MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatDefault.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds,nullptr);
// Create a Repeat operation on ds
// Default value of repeat count is -1, expected to repeat infinitely
ds = ds->Repeat();
EXPECT_NE(ds,nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds,nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr <Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter,nullptr);
// Iterate the dataset and get each row
std::unordered_map <std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size()!= 0) {
// manually stop
if(i==100){break;}
i++;
auto image = row["image"];
MS_LOG(INFO)<< "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i,100);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRepeatOne) {
MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatOne.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds,nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 1;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds,nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds,nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr <Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter,nullptr);
// Iterate the dataset and get each row
std::unordered_map <std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size()!= 0) {
i++;
auto image = row["image"];
MS_LOG(INFO)<< "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i,10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRepeatFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatFail.";
// This case is expected to fail because the repeat count is invalid (<-1 && !=0).
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = -2;
ds = ds->Repeat(repeat_num);
EXPECT_EQ(ds, nullptr);
}
TEST_F(MindDataTestPipeline, TestShuffleDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestShuffleDataset.";
......
......@@ -16,7 +16,7 @@
Test Repeat Op
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
......@@ -295,6 +295,26 @@ def test_repeat_count2():
assert data1_size == 3
assert dataset_size == num1_iter == 8
def test_repeat_count0():
"""
Test Repeat with invalid count 0.
"""
logger.info("Test Repeat with invalid count 0")
with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(0)
assert "count" in str(info)
def test_repeat_countneg2():
"""
Test Repeat with invalid count -2.
"""
logger.info("Test Repeat with invalid count -2")
with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(-2)
assert "count" in str(info)
if __name__ == "__main__":
test_tf_repeat_01()
test_tf_repeat_02()
......@@ -313,3 +333,5 @@ if __name__ == "__main__":
test_nested_repeat11()
test_repeat_count1()
test_repeat_count2()
test_repeat_count0()
test_repeat_countneg2()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册