Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
30de261c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
30de261c
编写于
4月 14, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
!243 Support nested repeat
Merge pull request !243 from h.farahat/nested_repeat
上级
9a781025
0fc23eee
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
233 addition
and
62 deletion
+233
-62
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
+10
-7
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
+8
-2
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h
+13
-3
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h
+12
-3
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
+18
-18
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
+5
-5
mindspore/ccsrc/dataset/engine/execution_tree.cc
mindspore/ccsrc/dataset/engine/execution_tree.cc
+10
-15
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+2
-0
tests/ut/cpp/dataset/repeat_op_test.cc
tests/ut/cpp/dataset/repeat_op_test.cc
+19
-8
tests/ut/python/dataset/test_repeat.py
tests/ut/python/dataset/test_repeat.py
+136
-1
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
浏览文件 @
30de261c
...
...
@@ -161,15 +161,18 @@ Status DatasetOp::EofReceived(int32_t worker_id) {
return
(
out_connector_
->
Add
(
static_cast
<
int
>
(
worker_id
),
std
::
move
(
eof_buffer
)));
}
// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific
pre-
operations to perform depending on
// their role.
Status
DatasetOp
::
PrepareNodeAction
()
{
Status
DatasetOp
::
PrepareNodePreAction
()
{
if
(
BitTest
(
tree_
->
PrepareFlags
(),
ExecutionTree
::
kDePrepRepeat
))
set_control_flag
(
kDeOpRepeated
);
return
Status
::
OK
();
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status
DatasetOp
::
PrepareNodePostAction
()
{
// If this op does not have any children and it is in a repeat path of the tree...
if
(
child_
.
size
()
==
0
&&
BitTest
(
tree_
->
PrepareFlags
(),
ExecutionTree
::
kDePrepRepeat
))
{
// Then, flag this operator as a leaf node in a repeat path of tree execution.
BitSet
(
&
op_ctrl_flags_
,
kDeOpRepeated
);
// Secondly, push ourselves onto the tree repeat stack. Later, the repeat operator
if
(
child_
.
empty
()
&&
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
))
{
// push ourselves onto the tree repeat stack. Later, the repeat operator
// above us will consume them.
tree_
->
AddToRepeatStack
(
shared_from_this
());
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
浏览文件 @
30de261c
...
...
@@ -150,11 +150,17 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
return
Status
::
OK
();
}
// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific
pre-
operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
virtual
Status
PrepareNodeAction
();
virtual
Status
PrepareNodePreAction
();
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
virtual
Status
PrepareNodePostAction
();
// Getter function
// @return The operator id
...
...
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h
浏览文件 @
30de261c
...
...
@@ -64,14 +64,24 @@ class ParallelOp : public DatasetOp {
return
out
;
}
// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific
pre-
operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
// @return Status - The error return code
Status
PrepareNodeAction
()
override
{
Status
PrepareNode
Pre
Action
()
override
{
// Run common code from super class before adding ParallelOp specific logic
return
(
DatasetOp
::
PrepareNodeAction
());
return
(
DatasetOp
::
PrepareNodePreAction
());
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
// @return Status - The error return code
Status
PrepareNodePostAction
()
override
{
// Run common code from super class before adding ParallelOp specific logic
return
(
DatasetOp
::
PrepareNodePostAction
());
}
// Override base class reset to provide reset actions specific to the ParallelOp class.
...
...
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h
浏览文件 @
30de261c
...
...
@@ -64,13 +64,22 @@ class PipelineOp : public DatasetOp {
// @return The number of threads that push data to the output connector
int32_t
num_producers
()
const
override
{
return
1
;
}
// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific
pre-
operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status
PrepareNodeAction
()
override
{
Status
PrepareNode
Pre
Action
()
override
{
// Run common code from super class before adding PipelineOp specific logic
return
(
DatasetOp
::
PrepareNodeAction
());
return
(
DatasetOp
::
PrepareNodePreAction
());
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status
PrepareNodePostAction
()
override
{
// Run common code from super class before adding PipelineOp specific logic
return
(
DatasetOp
::
PrepareNodePostAction
());
}
protected:
...
...
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
浏览文件 @
30de261c
...
...
@@ -58,10 +58,10 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
out
<<
"RepeatOp:"
<<
"
\n
Current repeat count: "
<<
repeat_count_
<<
"
\n
Max repeat count: "
<<
max_repeats_
<<
"
\n
Leaf Nodes in my execution path:"
;
if
(
!
leaf
_ops_
.
empty
())
{
if
(
!
eoe
_ops_
.
empty
())
{
out
<<
"
\n
"
;
for
(
size_t
i
=
0
;
i
<
leaf
_ops_
.
size
();
i
++
)
{
out
<<
" Operator: "
<<
leaf
_ops_
[
i
]
->
id
()
<<
"
\n
"
;
for
(
size_t
i
=
0
;
i
<
eoe
_ops_
.
size
();
i
++
)
{
out
<<
" Operator: "
<<
eoe
_ops_
[
i
]
->
id
()
<<
"
\n
"
;
}
}
else
{
out
<<
" kNone."
;
...
...
@@ -71,21 +71,17 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase when it is visiting this operator.
Status
RepeatOp
::
PrepareNodeAction
()
{
Status
RepeatOp
::
PrepareNode
Post
Action
()
{
// Run any common code from super class first before adding our own specific logic
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodeAction
());
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNode
Post
Action
());
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
tree_
->
PopFromRepeatStack
();
while
(
leaf_op
!=
nullptr
)
{
// Track the leaf operators that are under this repeat op.
leaf_ops_
.
push_back
(
leaf_op
);
// Special case. If the repeat count is 1, then pre-flag the leaf nodes
// to tell them they are already at their last op:
if
(
max_repeats_
==
1
)
{
leaf_op
->
set_control_flag
(
kDeOpLastRepeat
);
}
eoe_ops_
.
push_back
(
leaf_op
);
leaf_op
=
tree_
->
PopFromRepeatStack
();
}
// Push ourselves to the stack in case one of our ascendants is repeat too.
tree_
->
AddToRepeatStack
(
shared_from_this
());
return
Status
::
OK
();
}
...
...
@@ -127,16 +123,20 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
Status
RepeatOp
::
EoeReceived
(
int32_t
worker_id
)
{
repeat_count_
++
;
MS_LOG
(
INFO
)
<<
"Repeat operator end of epoch message received. Repeat count is now: "
<<
repeat_count_
<<
"."
;
// If we've reached the requested repeat count, then flag the leaf nodes
bool
repeated
=
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
);
bool
last_repeat
=
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
);
// If we've reached the requested repeat count, then flag the eoe nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again.
if
(
max_repeats_
!=
kInfiniteRepeat
&&
repeat_count_
==
(
max_repeats_
-
1
))
{
for
(
size_t
i
=
0
;
i
<
leaf_ops_
.
size
();
i
++
)
{
leaf_ops_
[
i
]
->
set_control_flag
(
kDeOpLastRepeat
);
// of the last epoch, they quit rather than loop again. This happens in two cases:
// 1- We are also repeated (by another repeat op) and we are at the last repetition. Or,
// 2- We are not repeated
if
(
max_repeats_
!=
kInfiniteRepeat
&&
repeat_count_
==
(
max_repeats_
-
1
)
&&
(
!
repeated
||
last_repeat
))
{
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
eoe_op
->
set_control_flag
(
kDeOpLastRepeat
);
}
}
if
(
repeat_count_
==
max_repeats_
)
{
repeat_count_
=
0
;
state_
=
OpState
::
kDeOpIdle
;
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
浏览文件 @
30de261c
...
...
@@ -87,8 +87,8 @@ class RepeatOp : public PipelineOp {
uint32_t
PrepareFlags
()
const
override
;
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase when it is visiting this operator.
Status
PrepareNodeAction
()
override
;
// during the execution tree p
ost-p
repare phase when it is visiting this operator.
Status
PrepareNode
Post
Action
()
override
;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
...
...
@@ -119,9 +119,9 @@ class RepeatOp : public PipelineOp {
int32_t
num_producers
()
const
override
;
private:
int32_t
max_repeats_
;
// The number of repeats that the user requested
int32_t
repeat_count_
;
// A counter for the current number of executed repeats
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
leaf_ops_
;
// List of leaf operators
underneath this repeat.
int32_t
max_repeats_
;
// The number of repeats that the user requested
int32_t
repeat_count_
;
// A counter for the current number of executed repeats
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
eoe_ops_
;
// List of operators that can generate EOE
underneath this repeat.
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/engine/execution_tree.cc
浏览文件 @
30de261c
...
...
@@ -162,30 +162,25 @@ Status ExecutionTree::Prepare() {
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk.
Status
ExecutionTree
::
PrepareNode
(
const
std
::
shared_ptr
<
DatasetOp
>
&
dataset_op
)
{
int32_t
num_children
=
dataset_op
->
child_
.
size
();
// execute PreAction
RETURN_IF_NOT_OK
(
dataset_op
->
PrepareNodePreAction
());
// Before going down into children, make any prepare flags updates based on this
// operator.
// Before going down into children, make any prepare flags updates based on this operator.
uint32_t
op_prep_flags
=
dataset_op
->
PrepareFlags
();
// Sanity check. In future we can support nested repeats. for now it's not allowed.
// If somebody above us already set the repeat flag, and now we are another repeat...
if
(
BitTest
(
op_prep_flags
,
kDePrepRepeat
)
&&
BitTest
(
prepare_flags_
,
kDePrepRepeat
))
{
std
::
string
err_msg
(
"Nested RepeatOp detected! This is not supported yet."
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
BitSet
(
&
prepare_flags_
,
op_prep_flags
);
// Now, descend to children
for
(
int32_t
i
=
0
;
i
<
num_children
;
++
i
)
{
RETURN_IF_NOT_OK
(
this
->
PrepareNode
(
dataset_op
->
child_
[
i
]
));
for
(
const
auto
&
i
:
dataset_op
->
child_
)
{
RETURN_IF_NOT_OK
(
this
->
PrepareNode
(
i
));
}
// No more children, now we execute any prepare actions before going back up the
// the tree on recursive function exit
RETURN_IF_NOT_OK
(
dataset_op
->
PrepareNodeAction
());
// Then clear the flags from this op now that we have prepared it.
BitClear
(
&
prepare_flags_
,
op_prep_flags
);
// No more children, now we execute any prepare actions before going back up the
// the tree on recursive function
RETURN_IF_NOT_OK
(
dataset_op
->
PrepareNodePostAction
());
return
Status
::
OK
();
}
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
30de261c
...
...
@@ -419,6 +419,8 @@ class Dataset:
>>> repeat_and_shuffle = data.repeat(50)
>>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
"""
if
count
==
1
:
return
self
return
RepeatDataset
(
self
,
count
)
@
check_zip_dataset
...
...
tests/ut/cpp/dataset/repeat_op_test.cc
浏览文件 @
30de261c
...
...
@@ -33,18 +33,29 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
shared_ptr
<
DatasetOp
>
parent_op
=
std
::
make_shared
<
RepeatOp
>
(
32
);
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
std
::
make_shared
<
RepeatOp
>
(
16
);
std
::
string
dataset_path
;
dataset_path
=
datasets_root_path_
+
"/testTFTestAllTypes/test.data"
;
// TFReaderOp
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
TFReaderOp
::
Builder
builder
;
builder
.
SetDatasetFilesList
({
dataset_path
})
.
SetRowsPerBuffer
(
16
)
.
SetWorkerConnectorSize
(
16
)
.
SetNumWorkers
(
16
);
Status
rc
=
builder
.
Build
(
&
my_tfreader_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssociateNode
(
my_tfreader_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
my_tree
->
AssociateNode
(
parent_op
);
my_tree
->
AssociateNode
(
leaf_op
);
ASSERT_NE
(
parent_op
,
nullptr
);
ASSERT_NE
(
leaf_op
,
nullptr
);
parent_op
->
AddChild
(
std
::
move
(
leaf_op
));
parent_op
->
Print
(
std
::
cout
,
false
);
parent_op
->
PrepareNodeAction
();
ASSERT_NE
(
my_tfreader_op
,
nullptr
);
parent_op
->
AddChild
(
std
::
move
(
my_tfreader_op
));
MS_LOG
(
INFO
)
<<
parent_op
;
my_tree
->
Prepare
();
RepeatOp
RepeatOpOp
();
std
::
shared_ptr
<
RepeatOp
>
repeat_op
;
Status
rc
=
RepeatOp
::
Builder
(
3
).
Build
(
&
repeat_op
);
rc
=
RepeatOp
::
Builder
(
3
).
Build
(
&
repeat_op
);
ASSERT_NE
(
repeat_op
,
nullptr
);
}
tests/ut/python/dataset/test_repeat.py
浏览文件 @
30de261c
...
...
@@ -16,6 +16,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
from
util
import
save_and_check
import
mindspore.dataset
as
ds
import
numpy
as
np
from
mindspore
import
log
as
logger
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
...
...
@@ -95,6 +96,141 @@ def test_tf_repeat_03():
assert
num_iter
==
2
def
generator
():
for
i
in
range
(
3
):
yield
np
.
array
([
i
]),
def
test_nested_repeat1
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
for
i
,
d
in
enumerate
(
data
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data
])
==
2
*
3
*
3
def
test_nested_repeat2
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
1
)
for
i
,
d
in
enumerate
(
data
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data
])
==
3
def
test_nested_repeat3
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
1
)
data
=
data
.
repeat
(
2
)
for
i
,
d
in
enumerate
(
data
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data
])
==
2
*
3
def
test_nested_repeat4
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
1
)
for
i
,
d
in
enumerate
(
data
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data
])
==
2
*
3
def
test_nested_repeat5
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
batch
(
3
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
for
i
,
d
in
enumerate
(
data
):
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
def
test_nested_repeat6
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
batch
(
3
)
data
=
data
.
repeat
(
3
)
for
i
,
d
in
enumerate
(
data
):
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
def
test_nested_repeat7
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
data
=
data
.
batch
(
3
)
for
i
,
d
in
enumerate
(
data
):
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
def
test_nested_repeat8
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
batch
(
2
,
drop_remainder
=
False
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
for
i
,
d
in
enumerate
(
data
):
if
i
%
2
==
0
:
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
]]))
else
:
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
*
2
def
test_nested_repeat9
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
()
data
=
data
.
repeat
(
3
)
for
i
,
d
in
enumerate
(
data
):
assert
i
%
3
==
d
[
0
][
0
]
if
i
==
10
:
break
def
test_nested_repeat10
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
()
for
i
,
d
in
enumerate
(
data
):
assert
i
%
3
==
d
[
0
][
0
]
if
i
==
10
:
break
def
test_nested_repeat11
():
data
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
(
4
)
data
=
data
.
repeat
(
5
)
for
i
,
d
in
enumerate
(
data
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data
])
==
2
*
3
*
4
*
5
*
3
if
__name__
==
"__main__"
:
logger
.
info
(
"--------test tf repeat 01---------"
)
# test_repeat_01()
...
...
@@ -104,4 +240,3 @@ if __name__ == "__main__":
logger
.
info
(
"--------test tf repeat 03---------"
)
test_tf_repeat_03
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录