Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c56fe3aa
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c56fe3aa
编写于
4月 29, 2020
作者:
M
ms_yan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify take op with an operator
上级
37e35827
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
85 addition
and
70 deletion
+85
-70
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
+35
-55
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
+4
-15
tests/ut/python/dataset/test_take.py
tests/ut/python/dataset/test_take.py
+46
-0
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
浏览文件 @
c56fe3aa
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <utility>
#include <utility>
#include "common/utils.h"
#include "common/utils.h"
#include "dataset/core/config_manager.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/db_connector.h"
...
@@ -25,7 +26,10 @@
...
@@ -25,7 +26,10 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
dataset
{
namespace
dataset
{
// Builder constructor. Creates the builder object.
// Builder constructor. Creates the builder object.
TakeOp
::
Builder
::
Builder
(
int32_t
count
)
:
build_max_takes_
(
count
)
{}
TakeOp
::
Builder
::
Builder
(
int32_t
count
)
:
build_max_takes_
(
count
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_op_connector_size_
=
cfg
->
op_connector_size
();
}
Status
TakeOp
::
Builder
::
SanityCheck
()
const
{
Status
TakeOp
::
Builder
::
SanityCheck
()
const
{
if
(
build_max_takes_
<=
0
)
{
if
(
build_max_takes_
<=
0
)
{
...
@@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const {
...
@@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object.
// The builder "build" method creates the final object.
Status
TakeOp
::
Builder
::
Build
(
std
::
shared_ptr
<
TakeOp
>
*
ptr
)
{
Status
TakeOp
::
Builder
::
Build
(
std
::
shared_ptr
<
TakeOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
TakeOp
>
(
build_max_takes_
);
*
ptr
=
std
::
make_shared
<
TakeOp
>
(
build_max_takes_
,
builder_op_connector_size_
);
return
Status
::
OK
();
return
Status
::
OK
();
}
}
// Constructor of the TakeOp.
// Constructor of the TakeOp.
TakeOp
::
TakeOp
(
int32_t
count
)
:
PipelineOp
(
0
),
max_takes_
(
count
),
take_count_
(
0
)
{}
TakeOp
::
TakeOp
(
int32_t
count
,
int32_t
op_connector_size
)
:
PipelineOp
(
op_connector_size
),
max_takes_
(
count
),
take_count_
(
0
)
{}
// A print method typically used for debugging
// A print method typically used for debugging
void
TakeOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
void
TakeOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
...
@@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const {
...
@@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const {
}
}
}
}
// This function will be call muti times to returns the buffer, when meet required max take count or meet
// Main entry point for Take
// EOF buffer then this will stop.
Status
TakeOp
::
operator
()()
{
Status
TakeOp
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
{
TaskManager
::
FindMe
()
->
Post
();
if
(
child_
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"TakeOp can't be the leaf node."
);
}
std
::
unique_ptr
<
DataBuffer
>
buf
;
std
::
unique_ptr
<
DataBuffer
>
buf
;
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
));
bool
last_repeat
=
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
);
while
(
buf
->
eof
()
==
false
)
{
if
(
take_count_
==
max_takes_
)
{
if
(
take_count_
==
max_takes_
)
{
if
(
state_
==
OpState
::
kDeOpRunning
)
{
// Do drain Operation
MS_LOG
(
DEBUG
)
<<
"Meet max count and push-back eoe buffer."
;
while
(
!
buf
->
eoe
()
&&
!
buf
->
eof
())
{
auto
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
));
*
p_buffer
=
std
::
move
(
eoe_buffer
);
state_
=
OpState
::
kDeOpIdle
;
// Reset the count and drain
if
(
!
last_repeat
)
{
take_count_
=
0
;
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
while
(
!
buf
->
eoe
()
&&
!
buf
->
eof
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
}
}
}
}
else
if
(
state_
==
OpState
::
kDeOpIdle
)
{
}
MS_LOG
(
DEBUG
)
<<
"Meet max count and push-back eof buffer."
;
auto
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
// Loop until non EOE is received
*
p_buffer
=
std
::
move
(
eof_buffer
);
if
(
buf
->
eoe
())
{
take_count_
=
0
;
take_count_
=
0
;
}
else
{
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
buf
)));
MS_LOG
(
WARNING
)
<<
"Invalid OpState: "
<<
state_
;
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
));
continue
;
}
}
return
Status
::
OK
();
}
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
// Loop until non EOE is received
if
(
buf
->
eoe
())
{
take_count_
=
0
;
*
p_buffer
=
std
::
move
(
buf
);
return
Status
::
OK
();
}
// Check if the last buf is next eof
// Get buffer and push back when take_count is still small
if
(
buf
->
eof
())
{
if
(
take_count_
<
max_takes_
)
{
*
p_buffer
=
std
::
move
(
buf
);
std
::
unique_ptr
<
DataBuffer
>
p_buffer
;
return
Status
::
OK
();
RETURN_IF_NOT_OK
(
FillBuffer
(
&
buf
,
&
p_buffer
));
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
p_buffer
)));
}
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
));
}
}
// Get buffer and push back when take_count is still small
take_count_
=
0
;
if
(
take_count_
<
max_takes_
)
{
MS_LOG
(
DEBUG
)
<<
"Meet the end and push-back eof buffer."
;
RETURN_IF_NOT_OK
(
FillBuffer
(
&
buf
,
p_buffer
)
);
auto
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
}
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eof_buffer
)));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
...
@@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
return
Status
::
OK
();
return
Status
::
OK
();
}
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
Status
TakeOp
::
operator
()()
{
RETURN_STATUS_UNEXPECTED
(
"Logic error. TakeOp is an inlined operator."
);
}
Status
TakeOp
::
PrepareNodePostAction
()
{
Status
TakeOp
::
PrepareNodePostAction
()
{
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
tree_
->
AddToRepeatStack
(
shared_from_this
());
tree_
->
AddToRepeatStack
(
shared_from_this
());
...
...
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
浏览文件 @
c56fe3aa
...
@@ -45,6 +45,7 @@ class TakeOp : public PipelineOp {
...
@@ -45,6 +45,7 @@ class TakeOp : public PipelineOp {
private:
private:
int32_t
build_max_takes_
;
int32_t
build_max_takes_
;
int32_t
builder_op_connector_size_
;
Status
SanityCheck
()
const
;
Status
SanityCheck
()
const
;
};
};
...
@@ -52,7 +53,7 @@ class TakeOp : public PipelineOp {
...
@@ -52,7 +53,7 @@ class TakeOp : public PipelineOp {
// Constructor of the TakeOp.
// Constructor of the TakeOp.
// @note The builder class should be used to call it
// @note The builder class should be used to call it
// @param count - The number of takes to do
// @param count - The number of takes to do
explicit
TakeOp
(
int32_t
count
);
explicit
TakeOp
(
int32_t
count
,
int32_t
op_connector_size
);
// Destructor
// Destructor
~
TakeOp
()
=
default
;
~
TakeOp
()
=
default
;
...
@@ -72,23 +73,11 @@ class TakeOp : public PipelineOp {
...
@@ -72,23 +73,11 @@ class TakeOp : public PipelineOp {
return
out
;
return
out
;
}
}
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// Most dataset ops operate by launching a thread (see ExecutionTree).
// provide the master loop that drives the logic for performing the work
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return
// @return Status - The error code return
Status
operator
()()
override
;
Status
operator
()()
override
;
// Gets a buffer from the child node. The caller is typically our parent node.
// @note This function sets the `retryIfEoe` flag when popping from the child connector. This way,
// this function will retry to pop the connector again and will get the non-EOE buffer if any.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
override
;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// @notes Derived versions of this function should always call it's superclass version first
...
...
tests/ut/python/dataset/test_take.py
浏览文件 @
c56fe3aa
...
@@ -30,6 +30,12 @@ def generator_10():
...
@@ -30,6 +30,12 @@ def generator_10():
yield
np
.
array
([
i
]),
yield
np
.
array
([
i
]),
def
filter_func_ge
(
data
):
if
data
>
3
:
return
False
return
True
def
test_take_01
():
def
test_take_01
():
"""
"""
Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
...
@@ -297,6 +303,44 @@ def test_take_16():
...
@@ -297,6 +303,44 @@ def test_take_16():
assert
sum
([
1
for
_
in
data1
])
==
5
assert
sum
([
1
for
_
in
data1
])
==
5
def
test_take_17
():
"""
Test take: take first, then do fiter operation
"""
logger
.
info
(
"test_take_17"
)
data1
=
ds
.
GeneratorDataset
(
generator_10
,
[
"data"
])
data1
=
data1
.
take
(
8
)
data1
=
data1
.
filter
(
predicate
=
filter_func_ge
,
num_parallel_workers
=
4
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
i
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
4
def
test_take_18
():
"""
Test take: take first, then do fiter, skip, batch and repeat operation
"""
logger
.
info
(
"test_take_18"
)
data1
=
ds
.
GeneratorDataset
(
generator_10
,
[
"data"
])
data1
=
data1
.
take
(
8
)
data1
=
data1
.
filter
(
predicate
=
filter_func_ge
,
num_parallel_workers
=
4
)
data1
=
data1
.
skip
(
2
)
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
2
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
2
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_take_01
()
test_take_01
()
test_take_02
()
test_take_02
()
...
@@ -314,4 +358,6 @@ if __name__ == '__main__':
...
@@ -314,4 +358,6 @@ if __name__ == '__main__':
test_take_14
()
test_take_14
()
test_take_15
()
test_take_15
()
test_take_16
()
test_take_16
()
test_take_17
()
test_take_18
()
logger
.
info
(
'== test take operation finished =='
)
logger
.
info
(
'== test take operation finished =='
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录