Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0868720e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0868720e
编写于
8月 12, 2020
作者:
T
tinazhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix parameter type for repeat op in c++ api and added c++/python ut.
上级
f37a2fa4
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
126 addition
and
8 deletion
+126
-8
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+3
-3
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+2
-2
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+1
-1
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+2
-1
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
+95
-0
tests/ut/python/dataset/test_repeat.py
tests/ut/python/dataset/test_repeat.py
+23
-1
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
0868720e
...
...
@@ -1165,7 +1165,7 @@ std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
return
node_ops
;
}
RepeatDataset
::
RepeatDataset
(
u
int32_t
count
)
:
repeat_count_
(
count
)
{}
RepeatDataset
::
RepeatDataset
(
int32_t
count
)
:
repeat_count_
(
count
)
{}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RepeatDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
...
...
@@ -1176,8 +1176,8 @@ std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
}
bool
RepeatDataset
::
ValidateParams
()
{
if
(
repeat_count_
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Repeat: Repeat count cannot be
negative"
;
if
(
repeat_count_
!=
-
1
&&
repeat_count_
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Repeat: Repeat count cannot be
"
<<
repeat_count_
;
return
false
;
}
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
0868720e
...
...
@@ -692,7 +692,7 @@ class RenameDataset : public Dataset {
class
RepeatDataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
RepeatDataset
(
u
int32_t
count
);
explicit
RepeatDataset
(
int32_t
count
);
/// \brief Destructor
~
RepeatDataset
()
=
default
;
...
...
@@ -706,7 +706,7 @@ class RepeatDataset : public Dataset {
bool
ValidateParams
()
override
;
private:
u
int32_t
repeat_count_
;
int32_t
repeat_count_
;
};
class
ShuffleDataset
:
public
Dataset
{
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
0868720e
...
...
@@ -2104,7 +2104,7 @@ class RepeatDataset(DatasetOp):
Args:
input_dataset (Dataset): Input Dataset to be repeated.
count (int): Number of times the dataset should be repeated.
count (int): Number of times the dataset should be repeated
(default=-1, repeat indefinitely)
.
"""
def
__init__
(
self
,
input_dataset
,
count
):
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
0868720e
...
...
@@ -597,7 +597,8 @@ def check_repeat(method):
type_check
(
count
,
(
int
,
type
(
None
)),
"repeat"
)
if
isinstance
(
count
,
int
):
check_value
(
count
,
(
-
1
,
INT32_MAX
),
"count"
)
if
(
count
<=
0
and
count
!=
-
1
)
or
count
>
INT32_MAX
:
raise
ValueError
(
"count should be either -1 or positive integer."
)
return
method
(
self
,
*
args
,
**
kwargs
)
return
new_method
...
...
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
浏览文件 @
0868720e
...
...
@@ -431,6 +431,101 @@ TEST_F(MindDataTestPipeline, TestRenameSuccess) {
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestRepeatDefault
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRepeatDefault."
;
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
// Default value of repeat count is -1, expected to repeat infinitely
ds
=
ds
->
Repeat
();
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
1
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
// manually stop
if
(
i
==
100
){
break
;}
i
++
;
auto
image
=
row
[
"image"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
iter
->
GetNextRow
(
&
row
);
}
EXPECT_EQ
(
i
,
100
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestRepeatOne
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRepeatOne."
;
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
int32_t
repeat_num
=
1
;
ds
=
ds
->
Repeat
(
repeat_num
);
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
1
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
i
++
;
auto
image
=
row
[
"image"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
iter
->
GetNextRow
(
&
row
);
}
EXPECT_EQ
(
i
,
10
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestRepeatFail
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRepeatFail."
;
// This case is expected to fail because the repeat count is invalid (<-1 && !=0).
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
int32_t
repeat_num
=
-
2
;
ds
=
ds
->
Repeat
(
repeat_num
);
EXPECT_EQ
(
ds
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestShuffleDataset
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestShuffleDataset."
;
...
...
tests/ut/python/dataset/test_repeat.py
浏览文件 @
0868720e
...
...
@@ -16,7 +16,7 @@
Test Repeat Op
"""
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
...
...
@@ -295,6 +295,26 @@ def test_repeat_count2():
assert
data1_size
==
3
assert
dataset_size
==
num1_iter
==
8
def
test_repeat_count0
():
"""
Test Repeat with invalid count 0.
"""
logger
.
info
(
"Test Repeat with invalid count 0"
)
with
pytest
.
raises
(
ValueError
)
as
info
:
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF2
,
SCHEMA_DIR_TF2
,
shuffle
=
False
)
data1
.
repeat
(
0
)
assert
"count"
in
str
(
info
)
def
test_repeat_countneg2
():
"""
Test Repeat with invalid count -2.
"""
logger
.
info
(
"Test Repeat with invalid count -2"
)
with
pytest
.
raises
(
ValueError
)
as
info
:
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF2
,
SCHEMA_DIR_TF2
,
shuffle
=
False
)
data1
.
repeat
(
-
2
)
assert
"count"
in
str
(
info
)
if
__name__
==
"__main__"
:
test_tf_repeat_01
()
test_tf_repeat_02
()
...
...
@@ -313,3 +333,5 @@ if __name__ == "__main__":
test_nested_repeat11
()
test_repeat_count1
()
test_repeat_count2
()
test_repeat_count0
()
test_repeat_countneg2
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录