Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
34bfa2f7
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
34bfa2f7
编写于
4月 29, 2020
作者:
J
jiangzhiwen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix skip
上级
9399dffe
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
118 addition
and
71 deletion
+118
-71
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
+47
-53
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
+4
-15
tests/ut/cpp/dataset/skip_op_test.cc
tests/ut/cpp/dataset/skip_op_test.cc
+1
-1
tests/ut/python/dataset/test_skip.py
tests/ut/python/dataset/test_skip.py
+66
-2
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
浏览文件 @
34bfa2f7
...
...
@@ -16,6 +16,7 @@
#include <iostream>
#include <utility>
#include "dataset/core/config_manager.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/db_connector.h"
...
...
@@ -26,7 +27,10 @@
namespace
mindspore
{
namespace
dataset
{
// Builder constructor. Creates the builder object.
SkipOp
::
Builder
::
Builder
(
int32_t
count
)
:
build_max_skips_
(
count
)
{}
SkipOp
::
Builder
::
Builder
(
int32_t
count
)
:
build_max_skips_
(
count
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_op_connector_size_
=
cfg
->
op_connector_size
();
}
Status
SkipOp
::
Builder
::
SanityCheck
()
const
{
if
(
build_max_skips_
<
0
)
{
...
...
@@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object.
Status
SkipOp
::
Builder
::
Build
(
std
::
shared_ptr
<
SkipOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
SkipOp
>
(
build_max_skips_
);
*
ptr
=
std
::
make_shared
<
SkipOp
>
(
build_max_skips_
,
builder_op_connector_size_
);
return
Status
::
OK
();
}
// Constructor of the SkipOp.
SkipOp
::
SkipOp
(
int32_t
count
)
:
PipelineOp
(
0
),
max_skips_
(
count
),
skip_count_
(
0
)
{}
SkipOp
::
SkipOp
(
int32_t
count
,
int32_t
op_connector_size
)
:
PipelineOp
(
op_connector_size
),
max_skips_
(
count
),
skip_count_
(
0
)
{}
// Destructor
SkipOp
::~
SkipOp
()
{}
...
...
@@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const {
<<
"
\n
Current skip count: "
<<
skip_count_
<<
"
\n
Max skip count: "
<<
max_skips_
;
}
// Since the buffer may contain multi rows, this function will drop the rows
// that need to skip in it, and then return the buffer.
Status
SkipOp
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
{
if
(
child_
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"SkipOp can't be the leaf node."
);
}
std
::
unique_ptr
<
DataBuffer
>
buf
;
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
// Drop first max_skips_ rows
while
(
skip_count_
<
max_skips_
)
{
if
(
buf
->
eoe
()
||
buf
->
eof
())
{
break
;
}
// Consider the rows of buffer more than 1
TensorRow
drop_row
;
int
row_num
=
buf
->
NumRows
();
int
drop_num
=
row_num
+
skip_count_
<
max_skips_
?
row_num
:
max_skips_
-
skip_count_
;
skip_count_
+=
drop_num
;
for
(
int
i
=
0
;
i
<
drop_num
;
i
++
)
{
RETURN_IF_NOT_OK
(
buf
->
PopRow
(
&
drop_row
));
}
if
(
buf
->
NumRows
()
==
0
)
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
}
}
// Handling eoe
if
(
buf
->
eoe
())
{
RETURN_IF_NOT_OK
(
EoeReceived
(
worker_id
));
}
// Handling eof
if
(
buf
->
eof
())
{
RETURN_IF_NOT_OK
(
EofReceived
(
worker_id
));
}
*
p_buffer
=
std
::
move
(
buf
);
return
Status
::
OK
();
}
// Base-class override for handling cases when an eoe is received.
Status
SkipOp
::
EoeReceived
(
int32_t
worker_id
)
{
skip_count_
=
0
;
...
...
@@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) {
return
Status
::
OK
();
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the SkipOp is defined as a inlined operator, so it is invalid to
// launch the functor since this op runs inlined inside another operator. The
// function is overloaded to ensure that it is not called by mistake (it will
// generate an error).
Status
SkipOp
::
operator
()()
{
RETURN_STATUS_UNEXPECTED
(
"Logic error. SkipOp is an inlined operator."
);
}
// main entry point for skip
Status
SkipOp
::
operator
()()
{
TaskManager
::
FindMe
()
->
Post
();
std
::
unique_ptr
<
DataBuffer
>
curr_buffer
;
RETURN_IF_NOT_OK
(
GetNextInput
(
&
curr_buffer
));
while
(
curr_buffer
->
eof
()
==
false
)
{
// Reset count
skip_count_
=
0
;
while
(
curr_buffer
->
eoe
()
==
false
)
{
// Drop first count rows
while
(
skip_count_
<
max_skips_
)
{
if
(
curr_buffer
->
eoe
()
||
curr_buffer
->
eof
())
{
break
;
}
// Consider the rows of buffer more than one
TensorRow
drop_row
;
int
row_num
=
curr_buffer
->
NumRows
();
int
drop_num
=
row_num
+
skip_count_
<
max_skips_
?
row_num
:
max_skips_
-
skip_count_
;
skip_count_
+=
drop_num
;
for
(
int
i
=
0
;
i
<
drop_num
;
i
++
)
{
RETURN_IF_NOT_OK
(
curr_buffer
->
PopRow
(
&
drop_row
));
}
if
(
curr_buffer
->
NumRows
()
==
0
)
{
RETURN_IF_NOT_OK
(
GetNextInput
(
&
curr_buffer
));
}
}
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
curr_buffer
)));
RETURN_IF_NOT_OK
(
GetNextInput
(
&
curr_buffer
));
}
// we got eoe, now try again until we got eof
MS_LOG
(
DEBUG
)
<<
"Skip operator EOE Received."
;
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
))));
RETURN_IF_NOT_OK
(
GetNextInput
(
&
curr_buffer
));
}
MS_LOG
(
DEBUG
)
<<
"Skip operator EOF Received."
;
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
))));
return
Status
::
OK
();
}
// Base-class override for handling cases when an eof is received.
Status
SkipOp
::
EofReceived
(
int32_t
worker_id
)
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
浏览文件 @
34bfa2f7
...
...
@@ -42,6 +42,7 @@ class SkipOp : public PipelineOp {
private:
int32_t
build_max_skips_
;
int32_t
builder_op_connector_size_
;
Status
SanityCheck
()
const
;
};
...
...
@@ -49,7 +50,7 @@ class SkipOp : public PipelineOp {
// Constructor of the SkipOp.
// @note The builder class should be used to call it
// @param count - The number of skips to do
explicit
SkipOp
(
int32_t
count
);
explicit
SkipOp
(
int32_t
count
,
int32_t
op_connector_size
);
// Destructor
~
SkipOp
();
...
...
@@ -60,23 +61,11 @@ class SkipOp : public PipelineOp {
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the SkipOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
Status
operator
()()
override
;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get
// a buffer from our child.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
override
;
// Base-class override for handling cases when an eoe is received.
// @param worker_id - The worker id
Status
EoeReceived
(
int32_t
worker_id
)
override
;
...
...
tests/ut/cpp/dataset/skip_op_test.cc
浏览文件 @
34bfa2f7
...
...
@@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) {
ASSERT_TRUE
(
rc
.
IsOk
());
// SkipOp
std
::
shared_ptr
<
SkipOp
>
skip_op
=
std
::
make_shared
<
SkipOp
>
(
5
);
std
::
shared_ptr
<
SkipOp
>
skip_op
=
std
::
make_shared
<
SkipOp
>
(
5
,
2
);
rc
=
my_tree
->
AssociateNode
(
skip_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
...
...
tests/ut/python/dataset/test_skip.py
浏览文件 @
34bfa2f7
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
...
...
@@ -51,7 +50,7 @@ def generator_md():
def
test_generator_skip
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
]
,
num_parallel_workers
=
4
)
# Here ds1 should be [3, 4]
ds1
=
ds1
.
skip
(
3
)
...
...
@@ -60,6 +59,7 @@ def test_generator_skip():
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
2
assert
buf
==
[
3
,
4
]
def
test_skip_1
():
...
...
@@ -72,6 +72,7 @@ def test_skip_1():
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
0
assert
buf
==
[]
def
test_skip_2
():
...
...
@@ -84,6 +85,7 @@ def test_skip_2():
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
5
assert
buf
==
[
0
,
1
,
2
,
3
,
4
]
def
test_skip_repeat_1
():
...
...
@@ -99,6 +101,7 @@ def test_skip_repeat_1():
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
7
assert
buf
==
[
3
,
4
,
0
,
1
,
2
,
3
,
4
]
def
test_skip_repeat_2
():
...
...
@@ -114,6 +117,7 @@ def test_skip_repeat_2():
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
4
assert
buf
==
[
3
,
4
,
3
,
4
]
def
test_skip_repeat_3
():
...
...
@@ -132,6 +136,62 @@ def test_skip_repeat_3():
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
6
assert
buf
==
[
3
,
4
,
3
,
4
,
3
,
4
]
def
test_skip_take_1
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
# Here ds1 should be [0, 1, 2, 3]
ds1
=
ds1
.
take
(
4
)
# Here ds1 should be [2, 3]
ds1
=
ds1
.
skip
(
2
)
buf
=
[]
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
2
assert
buf
==
[
2
,
3
]
def
test_skip_take_2
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
# Here ds1 should be [2, 3, 4]
ds1
=
ds1
.
skip
(
2
)
# Here ds1 should be [2, 3]
ds1
=
ds1
.
take
(
2
)
buf
=
[]
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
2
assert
buf
==
[
2
,
3
]
def
generator_1d
():
for
i
in
range
(
64
):
yield
(
np
.
array
([
i
]),
)
def
test_skip_filter_1
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
'data'
])
dataset
=
dataset
.
skip
(
5
)
dataset
=
dataset
.
filter
(
predicate
=
lambda
data
:
data
<
11
,
num_parallel_workers
=
4
)
buf
=
[]
for
item
in
dataset
:
buf
.
append
(
item
[
0
][
0
])
assert
buf
==
[
5
,
6
,
7
,
8
,
9
,
10
]
def
test_skip_filter_2
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
'data'
])
dataset
=
dataset
.
filter
(
predicate
=
lambda
data
:
data
<
11
,
num_parallel_workers
=
4
)
dataset
=
dataset
.
skip
(
5
)
buf
=
[]
for
item
in
dataset
:
buf
.
append
(
item
[
0
][
0
])
assert
buf
==
[
5
,
6
,
7
,
8
,
9
,
10
]
if
__name__
==
"__main__"
:
...
...
@@ -142,3 +202,7 @@ if __name__ == "__main__":
test_skip_repeat_1
()
test_skip_repeat_2
()
test_skip_repeat_3
()
test_skip_take_1
()
test_skip_take_2
()
test_skip_filter_1
()
test_skip_filter_2
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录