Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e19d3824
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e19d3824
编写于
8月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3346 Maintain epoch/repeat count for ops
Merge pull request !3346 from lixiachen/repeat_rework
上级
2b5b35ea
ac85b77b
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
330 addition
and
101 deletion
+330
-101
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
...ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
...csrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
+2
-2
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
...pore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
+4
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc
...ore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc
+1
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
...re/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
+12
-3
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
...ore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
+29
-14
mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc
...ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc
+6
-15
mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc
...ore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc
+1
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc
...ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc
+1
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc
...re/ccsrc/minddata/dataset/engine/datasetops/project_op.cc
+3
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc
...ore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc
+10
-19
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h
...pore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h
+14
-5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc
...rc/minddata/dataset/engine/datasetops/source/celeba_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc
...src/minddata/dataset/engine/datasetops/source/cifar_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc
...csrc/minddata/dataset/engine/datasetops/source/clue_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc
...csrc/minddata/dataset/engine/datasetops/source/coco_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
...ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc
...minddata/dataset/engine/datasetops/source/generator_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc
...ddata/dataset/engine/datasetops/source/image_folder_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc
.../minddata/dataset/engine/datasetops/source/manifest_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc
...inddata/dataset/engine/datasetops/source/mindrecord_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc
...src/minddata/dataset/engine/datasetops/source/mnist_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc
...nddata/dataset/engine/datasetops/source/random_data_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc
...minddata/dataset/engine/datasetops/source/text_file_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc
...minddata/dataset/engine/datasetops/source/tf_reader_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc
...ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc
...spore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc
+1
-0
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc
...ore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc
+98
-22
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h
...pore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h
+25
-6
tests/ut/python/dataset/test_epoch_ctrl.py
tests/ut/python/dataset/test_epoch_ctrl.py
+93
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
浏览文件 @
e19d3824
...
...
@@ -89,13 +89,14 @@ Status CacheBase::FetchSamplesToWorkers() {
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
// If repeat but the not last repeat, wait for reset.
if
(
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
&&
!
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
!
IsLastIteration
(
))
{
MS_LOG
(
DEBUG
)
<<
Name
()
<<
" Waiting for reset. Count "
<<
++
wait_cnt
<<
" Buffer sent "
<<
buf_cnt
;
RETURN_IF_NOT_OK
(
epoch_sync_
.
Wait
());
}
else
{
// We can break out from the loop.
break
;
}
UpdateRepeatAndEpochCounter
();
}
while
(
true
);
// Flow the eof before exit
RETURN_IF_NOT_OK
(
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
浏览文件 @
e19d3824
...
...
@@ -292,7 +292,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
Status
CacheMergeOp
::
EoeReceived
(
int32_t
worker_id
)
{
// If we are in a repeat path, send the eoe up.
// Otherwise ignore it.
if
(
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
)
{
if
(
op_total_repeats_
>
1
)
{
return
DatasetOp
::
EoeReceived
(
worker_id
);
}
return
Status
::
OK
();
...
...
@@ -304,7 +304,7 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) {
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
// provides that for us.
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
)
{
if
(
op_total_repeats_
==
1
)
{
MS_LOG
(
DEBUG
)
<<
"Cache merge sending eoe"
;
RETURN_IF_NOT_OK
(
DatasetOp
::
EoeReceived
(
worker_id
));
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
浏览文件 @
e19d3824
...
...
@@ -85,6 +85,10 @@ Status CacheOp::operator()() {
TaskManager
::
FindMe
()
->
Post
();
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK
(
WaitForCachingAllRows
());
// Current repeats and current epochs may have increased when caching all rows with DatasetOp::GetNextInput.
// But they shouldn't be increased because now cache op is starting to act as a leaf and its epoch hasn't started.
op_current_repeats_
=
0
;
op_current_epochs_
=
0
;
RETURN_IF_NOT_OK
(
FetchSamplesToWorkers
());
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc
浏览文件 @
e19d3824
...
...
@@ -85,6 +85,7 @@ Status ConcatOp::operator()() {
auto
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
}
UpdateRepeatAndEpochCounter
();
}
CHECK_FAIL_RETURN_UNEXPECTED
(
eof_count
==
children_num_
,
"Something went wrong, eof count does not match the number of children."
);
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
浏览文件 @
e19d3824
...
...
@@ -42,7 +42,10 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler
operator_id_
(
kInvalidOperatorId
),
tree_
(
nullptr
),
state_
(
OpState
::
kDeOpIdle
),
op_ctrl_flags_
(
kDeOpNone
),
op_total_repeats_
(
kInfiniteRepeat
),
op_num_repeats_per_epoch_
(
kInfiniteRepeat
),
op_current_repeats_
(
0
),
op_current_epochs_
(
0
),
out_connector_
(
nullptr
)
{
// The operator starts out with an invalid operator id. The only way to
// get it out of invalid state is to assign the operator to an execution tree.
...
...
@@ -237,8 +240,8 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const {
for
(
size_t
i
=
0
;
i
<
parent_
.
size
();
i
++
)
{
out
<<
"
\n
Parent["
<<
i
<<
"] id: "
<<
parent_
[
i
]
->
id
();
}
out
<<
"
\n
Connector queue size : "
<<
oc_queue_size_
<<
"
\n
Operator control flags : 0x"
<<
std
::
hex
<<
std
::
setw
(
8
)
<<
std
::
setfill
(
'0'
)
<<
op_ctrl_flags_
<<
std
::
dec
<<
std
::
setfill
(
' '
)
;
out
<<
"
\n
Connector queue size : "
<<
oc_queue_size_
<<
"
\n
Total repeats : "
<<
op_total_repeats_
<<
"
\n
Number repeats per epoch : "
<<
op_num_repeats_per_epoch_
;
if
(
sampler_
)
{
sampler_
->
Print
(
out
,
show_all
);
}
...
...
@@ -265,6 +268,7 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
RETURN_IF_NOT_OK
(
child
->
GetNextBuffer
(
&
buf
,
worker_id
));
// Loop until non EOE is received
while
(
buf
->
eoe
())
{
UpdateRepeatAndEpochCounter
();
RETURN_IF_NOT_OK
(
EoeReceived
(
worker_id
));
if
(
state_
==
OpState
::
kDeOpIdle
)
{
*
p_buffer
=
std
::
move
(
buf
);
...
...
@@ -408,5 +412,10 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
uint32_t
cache_crc
=
system
::
Crc32c
::
GetMaskCrc32cValue
(
ss_str
.
c_str
(),
ss_str
.
length
());
return
cache_crc
;
}
void
DatasetOp
::
UpdateRepeatAndEpochCounter
()
{
op_current_repeats_
++
;
if
(
op_current_repeats_
%
op_num_repeats_per_epoch_
==
0
)
op_current_epochs_
++
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
浏览文件 @
e19d3824
...
...
@@ -70,13 +70,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
public:
static
constexpr
int32_t
kInvalidOperatorId
=
-
1
;
// Operator control flags
enum
OpControlFlags
{
kDeOpNone
=
0
,
kDeOpRepeated
=
1
,
// Operator is a node in a repeat path
kDeOpLastRepeat
=
1
<<
1
// We are in the last repeat loop
};
static
constexpr
int32_t
kInfiniteRepeat
=
-
1
;
// Flags that control operator runtime behaviours
enum
OpState
{
kDeOpRunning
=
0
,
kDeOpIdle
=
1
,
kDeOpTerminated
};
...
...
@@ -238,13 +232,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return T/F if this is an inlined operator
bool
inlined
()
const
{
return
(
oc_queue_size_
==
0
);
}
/// \brief Setter function
/// \return Sets the control flags
void
set_control_flag
(
uint64_t
flag
)
{
BitSet
(
&
op_ctrl_flags_
,
flag
);
}
/// \brief Setter function, set the number of total repeats for the operator
void
set_total_repeats
(
int32_t
total_repeats
)
{
op_total_repeats_
=
total_repeats
;
}
/// \brief Setter function, set the number of repeats per epoch for the operator
void
set_num_repeats_per_epoch
(
int32_t
num_repeats_per_epoch
)
{
op_num_repeats_per_epoch_
=
num_repeats_per_epoch
;
}
/// \brief Setter function
/// \return Sets the control flags
void
ClearControlFlag
(
uint64_t
flag
)
{
BitClear
(
&
op_ctrl_flags_
,
flag
);
}
/// \brief Getter function
/// \return The number of required repeats for the operator
int32_t
op_total_repeats
()
{
return
op_total_repeats_
;
}
/// \brief Getter function
/// \return The number of required epochs for the operator
int32_t
op_total_epochs
()
{
return
op_total_repeats_
/
op_num_repeats_per_epoch_
;
}
/// \brief Getter function
/// \return The number of repeats per epoch for the operator
int32_t
op_num_repeats_per_epoch
()
{
return
op_num_repeats_per_epoch_
;
}
/// \brief Register the internal worker connectors. No op unless it is a parallel op
/// \return Status
...
...
@@ -350,6 +354,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return boolean returns true if it's a leaf
bool
IsLeaf
()
{
return
(
child_
.
empty
());
}
/// Checks if an operator has reached its last iteration
/// \return boolean returns true if it's last iteration
bool
IsLastIteration
()
{
return
op_total_repeats_
==
op_current_repeats_
+
1
;
}
protected:
/// \brief Removes a parent operator from this operator
/// \notes External callers do not have access to this function
...
...
@@ -368,6 +376,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return - Status
virtual
Status
ComputeColMap
();
/// Increase op_current_repeats_ by 1 when one repeat finished.
/// If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1.
void
UpdateRepeatAndEpochCounter
();
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
child_
;
// Child nodes
std
::
vector
<
DatasetOp
*>
parent_
;
// Parent nodes. No ownership
std
::
shared_ptr
<
Sampler
>
sampler_
;
// Some leaf ops might have a sampler
...
...
@@ -375,7 +387,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
int32_t
operator_id_
;
// Generated id for the node
ExecutionTree
*
tree_
;
// Back pointer to our tree.
OpState
state_
;
// The state of the operator, Running, Idle, Terminated
uint32_t
op_ctrl_flags_
;
// Flags for the operator
int32_t
op_total_repeats_
;
// Required number of repeats for the operator
int32_t
op_num_repeats_per_epoch_
;
// Total number of repeats per epoch for the operator
int32_t
op_current_repeats_
;
// Current number of repeats the operator has handled
int32_t
op_current_epochs_
;
// Current number of epochs the operator has handled
std
::
unique_ptr
<
DbConnector
>
out_connector_
;
// Output Connector
std
::
unordered_map
<
std
::
string
,
int32_t
>
column_name_id_map_
;
// Mapping between col index and col name
std
::
mutex
column_name_map_mutex_
;
// For protecting shared access to the column map
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc
浏览文件 @
e19d3824
...
...
@@ -30,7 +30,7 @@ namespace dataset {
// The builder "build" method creates the final object.
Status
EpochCtrlOp
::
Builder
::
Build
(
std
::
shared_ptr
<
EpochCtrlOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
EpochCtrlOp
>
(
build_
max
_repeats_
);
*
ptr
=
std
::
make_shared
<
EpochCtrlOp
>
(
build_
num
_repeats_
);
return
Status
::
OK
();
}
...
...
@@ -46,12 +46,12 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common 1-liner info
PipelineOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal 1-liner info for this op
out
<<
" [epochs: "
<<
max
_repeats_
<<
"]
\n
"
;
out
<<
" [epochs: "
<<
num
_repeats_
<<
"]
\n
"
;
}
else
{
// Call the super class for displaying any common detailed info
PipelineOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Current epoch count: "
<<
repeat_count_
<<
"
\n
Max epoch count: "
<<
max
_repeats_
out
<<
"
\n
Current epoch count: "
<<
repeat_count_
<<
"
\n
Max epoch count: "
<<
num
_repeats_
<<
"
\n
Leaf Nodes in execution path:"
;
if
(
!
eoe_ops_
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
eoe_ops_
.
size
();
i
++
)
{
...
...
@@ -86,24 +86,15 @@ Status EpochCtrlOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t
}
Status
EpochCtrlOp
::
EoeReceived
(
int32_t
worker_id
)
{
UpdateRepeatAndEpochCounter
();
repeat_count_
++
;
MS_LOG
(
DEBUG
)
<<
"Epoch Control operator received end of epoch. Epoch count is now: "
<<
repeat_count_
<<
". Repeated: "
<<
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
<<
". Max epochs: "
<<
max_repeats_
;
// If we've reached the requested epoch count, then flag the leaf nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again.
if
(
max_repeats_
!=
kInfiniteRepeat
&&
repeat_count_
==
(
max_repeats_
-
1
))
{
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
MS_LOG
(
DEBUG
)
<<
"EpochCtrl setting last repeat for eoe_op: "
<<
eoe_op
->
id
();
eoe_op
->
set_control_flag
(
kDeOpLastRepeat
);
}
}
<<
". Max epochs: "
<<
num_repeats_
;
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
state_
=
OpState
::
kDeOpIdle
;
if
(
repeat_count_
!=
max
_repeats_
)
{
if
(
repeat_count_
!=
num
_repeats_
)
{
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
MS_LOG
(
DEBUG
)
<<
"Epoch Control driving reset to op: "
<<
eoe_op
->
id
();
RETURN_IF_NOT_OK
(
eoe_op
->
Reset
());
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc
浏览文件 @
e19d3824
...
...
@@ -117,6 +117,7 @@ Status FilterOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
in_buffer
,
worker_id
));
if
(
in_buffer
->
eoe
())
{
filter_queues_
[
worker_id
]
->
EmplaceBack
(
std
::
make_pair
(
std
::
move
(
in_buffer
),
filterCtrl
::
kFilterEoe
));
UpdateRepeatAndEpochCounter
();
continue
;
}
else
if
(
in_buffer
->
eof
())
{
filter_queues_
[
worker_id
]
->
EmplaceBack
(
std
::
make_pair
(
std
::
move
(
in_buffer
),
filterCtrl
::
kFilterEof
));
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc
浏览文件 @
e19d3824
...
...
@@ -231,6 +231,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
// Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work
// with Performance Mode design.
if
(
in_buffer
->
eoe
())
{
UpdateRepeatAndEpochCounter
();
// Calling base class EoeReceived to forward eoe buffer.
RETURN_IF_NOT_OK
(
EoeReceived
(
worker_id
));
// Fetch next data buffer and map job list
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc
浏览文件 @
e19d3824
...
...
@@ -74,6 +74,9 @@ Status ProjectOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t w
if
(
!
((
*
p_buffer
)
->
eoe
())
&&
!
((
*
p_buffer
)
->
eof
()))
{
RETURN_IF_NOT_OK
(
Project
(
p_buffer
));
}
if
((
*
p_buffer
)
->
eoe
())
{
UpdateRepeatAndEpochCounter
();
}
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc
浏览文件 @
e19d3824
...
...
@@ -28,10 +28,10 @@
namespace
mindspore
{
namespace
dataset
{
// Builder constructor. Creates the builder object.
RepeatOp
::
Builder
::
Builder
(
int32_t
count
)
:
build_
max
_repeats_
(
count
)
{}
RepeatOp
::
Builder
::
Builder
(
int32_t
count
)
:
build_
num
_repeats_
(
count
)
{}
Status
RepeatOp
::
Builder
::
SanityCheck
()
const
{
if
(
build_
max_repeats_
<
kInfiniteRepeat
||
build_max
_repeats_
==
0
)
{
if
(
build_
num_repeats_
<
kInfiniteRepeat
||
build_num
_repeats_
==
0
)
{
std
::
string
err_msg
(
"Repeat count must be > 0 or -1."
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
...
...
@@ -41,12 +41,12 @@ Status RepeatOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object.
Status
RepeatOp
::
Builder
::
Build
(
std
::
shared_ptr
<
RepeatOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
RepeatOp
>
(
build_
max
_repeats_
);
*
ptr
=
std
::
make_shared
<
RepeatOp
>
(
build_
num
_repeats_
);
return
Status
::
OK
();
}
// Constructor of the RepeatOp.
RepeatOp
::
RepeatOp
(
int32_t
count
)
:
PipelineOp
(
0
),
max
_repeats_
(
count
),
repeat_count_
(
0
)
{}
RepeatOp
::
RepeatOp
(
int32_t
count
)
:
PipelineOp
(
0
),
num
_repeats_
(
count
),
repeat_count_
(
0
)
{}
// Destructor
RepeatOp
::~
RepeatOp
()
{}
...
...
@@ -57,12 +57,12 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common 1-liner info
PipelineOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal 1-liner info for this op
out
<<
" [repeats: "
<<
max
_repeats_
<<
"]
\n
"
;
out
<<
" [repeats: "
<<
num
_repeats_
<<
"]
\n
"
;
}
else
{
// Call the super class for displaying any common detailed info
PipelineOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Current repeat count: "
<<
repeat_count_
<<
"
\n
Max repeat count: "
<<
max
_repeats_
out
<<
"
\n
Current repeat count: "
<<
repeat_count_
<<
"
\n
Max repeat count: "
<<
num
_repeats_
<<
"
\n
Leaf Nodes in execution path:"
;
if
(
!
eoe_ops_
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
eoe_ops_
.
size
();
i
++
)
{
...
...
@@ -107,22 +107,13 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
// Base-class override for handling cases when an eoe is received.
Status
RepeatOp
::
EoeReceived
(
int32_t
worker_id
)
{
UpdateRepeatAndEpochCounter
();
repeat_count_
++
;
MS_LOG
(
DEBUG
)
<<
"Repeat operator ("
<<
operator_id_
<<
") end of epoch message received. Repeat count is now: "
<<
repeat_count_
<<
"."
;
bool
repeated
=
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
);
bool
last_repeat
=
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
);
// If we've reached the requested repeat count, then flag the eoe nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again. This happens in two cases:
// 1- We are also repeated (by another repeat op) and we are at the last repetition. Or,
// 2- We are not repeated
if
(
max_repeats_
!=
kInfiniteRepeat
&&
repeat_count_
==
(
max_repeats_
-
1
)
&&
(
!
repeated
||
last_repeat
))
{
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
eoe_op
->
set_control_flag
(
kDeOpLastRepeat
);
}
}
if
(
repeat_count_
==
max_repeats_
)
{
if
(
repeat_count_
==
num_repeats_
)
{
repeat_count_
=
0
;
state_
=
OpState
::
kDeOpIdle
;
return
Status
::
OK
();
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h
浏览文件 @
e19d3824
...
...
@@ -26,8 +26,6 @@ namespace mindspore {
namespace
dataset
{
class
RepeatOp
:
public
PipelineOp
{
public:
static
constexpr
int32_t
kInfiniteRepeat
=
-
1
;
// The nested builder class inside of the RepeatOp is used to help manage all of the arguments
// for constructing it. This repeat op is very simple though, so this builder is really just
// provided for a consistent look and feel for creators of Dataset operators overall.
...
...
@@ -47,7 +45,7 @@ class RepeatOp : public PipelineOp {
Status
Build
(
std
::
shared_ptr
<
RepeatOp
>
*
);
protected:
int32_t
build_
max
_repeats_
;
int32_t
build_
num
_repeats_
;
Status
SanityCheck
()
const
;
};
...
...
@@ -131,13 +129,24 @@ class RepeatOp : public PipelineOp {
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
kRepeatOp
;
}
/// \brief Getter function
/// \return The number of repeats that the user requested
int32_t
num_repeats
()
{
return
num_repeats_
;
}
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
// \param[in] eoe_op The input leaf/eoe operator to add to the list
void
AddToEoeList
(
std
::
shared_ptr
<
DatasetOp
>
eoe_op
)
{
eoe_ops_
.
push_back
(
std
::
move
(
eoe_op
));
}
protected:
int32_t
max_repeats_
;
// The number of repeats that the user requested
int32_t
repeat_count_
;
// A counter for the current number of executed repeats
// The number of repeats that the user requested.
// Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class.
// For example, for repeat1 op in pipeline tfreader -> repeat1(3) -> repeat2(2) -> epoch ctrl(4),
// num_repeats_ = 3, op_total_repeats_ = 24, op_num_repeats_per_epoch_ = 6.
int32_t
num_repeats_
;
// A counter for the current number of executed repeats.
// Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class
// because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats.
int32_t
repeat_count_
;
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
eoe_ops_
;
// List of operators that can generate EOE underneath this repeat.
};
}
// namespace dataset
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc
浏览文件 @
e19d3824
...
...
@@ -293,7 +293,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buff_count
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
))));
}
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buff_count
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
...
...
@@ -310,6 +310,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
data_buffer
));
}
UpdateRepeatAndEpochCounter
();
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc
浏览文件 @
e19d3824
...
...
@@ -120,7 +120,7 @@ Status CifarOp::operator()() {
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
))));
}
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
...
...
@@ -137,6 +137,7 @@ Status CifarOp::operator()() {
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
&
sampler_buffer
));
}
UpdateRepeatAndEpochCounter
();
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc
浏览文件 @
e19d3824
...
...
@@ -271,13 +271,14 @@ Status ClueOp::operator()() {
std
::
unique_ptr
<
DataBuffer
>
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
finished_reading_dataset_
=
true
;
NotifyToFillIOBlockQueue
();
}
else
{
jagged_buffer_connector_
->
DoReset
();
buffer_id
=
0
;
}
UpdateRepeatAndEpochCounter
();
}
std
::
unique_ptr
<
DataBuffer
>
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eof_buffer
)));
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc
浏览文件 @
e19d3824
...
...
@@ -167,7 +167,7 @@ Status CocoOp::operator()() {
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
))));
}
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
std
::
unique_ptr
<
IOBlock
>
eoe_block
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
);
std
::
unique_ptr
<
IOBlock
>
eof_block
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEof
);
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
move
(
eoe_block
)));
...
...
@@ -184,6 +184,7 @@ Status CocoOp::operator()() {
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
&
sampler_buffer
));
}
UpdateRepeatAndEpochCounter
();
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
浏览文件 @
e19d3824
...
...
@@ -472,13 +472,14 @@ Status CsvOp::operator()() {
std
::
unique_ptr
<
DataBuffer
>
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
finished_reading_dataset_
=
true
;
NotifyToFillIOBlockQueue
();
}
else
{
jagged_buffer_connector_
->
DoReset
();
buffer_id
=
0
;
}
UpdateRepeatAndEpochCounter
();
}
std
::
unique_ptr
<
DataBuffer
>
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eof_buffer
)));
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc
浏览文件 @
e19d3824
...
...
@@ -216,7 +216,7 @@ Status GeneratorOp::operator()() {
MS_LOG
(
DEBUG
)
<<
"Generator operator sends out EOE."
;
std
::
unique_ptr
<
DataBuffer
>
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
// If last repeat or not repeated, push out EOF and exit master loop
MS_LOG
(
DEBUG
)
<<
"Generator operator sends out EOF."
;
std
::
unique_ptr
<
DataBuffer
>
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
...
...
@@ -231,6 +231,7 @@ Status GeneratorOp::operator()() {
// Clear the status of the wait post
wp_
.
Clear
();
}
UpdateRepeatAndEpochCounter
();
}
}
return
Status
::
OK
();
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc
浏览文件 @
e19d3824
...
...
@@ -151,7 +151,7 @@ Status ImageFolderOp::operator()() {
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
keys
,
IOBlock
::
kDeIoBlockNone
)));
}
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
std
::
unique_ptr
<
IOBlock
>
eoe_block
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
);
std
::
unique_ptr
<
IOBlock
>
eof_block
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEof
);
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
move
(
eoe_block
)));
...
...
@@ -168,6 +168,7 @@ Status ImageFolderOp::operator()() {
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
&
sampler_buffer
));
}
UpdateRepeatAndEpochCounter
();
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc
浏览文件 @
e19d3824
...
...
@@ -112,7 +112,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
))));
}
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
...
...
@@ -129,6 +129,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
sampler_buffer
));
}
UpdateRepeatAndEpochCounter
();
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
e19d3824
...
...
@@ -378,7 +378,7 @@ Status MindRecordOp::operator()() {
RETURN_IF_NOT_OK
(
io_blk_queues_
[
buf_cnt_
++
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
))));
}
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
RETURN_IF_NOT_OK
(
io_blk_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
...
...
@@ -396,6 +396,7 @@ Status MindRecordOp::operator()() {
RETURN_IF_NOT_OK
(
shard_reader_wait_post_
.
Wait
());
shard_reader_wait_post_
.
Clear
();
}
UpdateRepeatAndEpochCounter
();
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc
浏览文件 @
e19d3824
...
...
@@ -111,7 +111,7 @@ Status MnistOp::operator()() {
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
))));
}
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
...
...
@@ -128,6 +128,7 @@ Status MnistOp::operator()() {
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
&
sampler_buffer
));
}
UpdateRepeatAndEpochCounter
();
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc
浏览文件 @
e19d3824
...
...
@@ -219,7 +219,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) {
all_out_
.
Wait
();
// If we are not in a repeat loop, or that was the last repeat already, then setup our exit
// condition from the master loop.
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
*
quitting
=
true
;
}
...
...
@@ -229,6 +229,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) {
if
(
last_guy_in
)
{
MS_LOG
(
INFO
)
<<
"RandomDataOp worker "
<<
worker_id
<<
" is the last one to sync. eoe sent as worker "
<<
eoe_worker_id_
;
UpdateRepeatAndEpochCounter
();
// Prepare for sync
all_out_
.
Clear
();
// Always flow eoe at the end
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc
浏览文件 @
e19d3824
...
...
@@ -419,13 +419,14 @@ Status TextFileOp::operator()() {
std
::
unique_ptr
<
DataBuffer
>
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
finished_reading_dataset_
=
true
;
NotifyToFillIOBlockQueue
();
}
else
{
jagged_buffer_connector_
->
DoReset
();
buffer_id
=
0
;
}
UpdateRepeatAndEpochCounter
();
}
std
::
unique_ptr
<
DataBuffer
>
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc
浏览文件 @
e19d3824
...
...
@@ -308,13 +308,14 @@ Status TFReaderOp::operator()() {
std
::
unique_ptr
<
DataBuffer
>
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
finished_reading_dataset_
=
true
;
NotifyToFillIOBlockQueue
();
}
else
{
jagged_buffer_connector_
->
DoReset
();
buffer_id
=
0
;
}
UpdateRepeatAndEpochCounter
();
}
std
::
unique_ptr
<
DataBuffer
>
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
e19d3824
...
...
@@ -145,7 +145,7 @@ Status VOCOp::operator()() {
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
))));
}
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
if
(
IsLastIteration
(
))
{
std
::
unique_ptr
<
IOBlock
>
eoe_block
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
);
std
::
unique_ptr
<
IOBlock
>
eof_block
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEof
);
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
move
(
eoe_block
)));
...
...
@@ -162,6 +162,7 @@ Status VOCOp::operator()() {
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
&
sampler_buffer
));
}
UpdateRepeatAndEpochCounter
();
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc
浏览文件 @
e19d3824
...
...
@@ -82,6 +82,7 @@ Status TakeOp::operator()() {
// Loop until non EOE is received
if
(
buf
->
eoe
())
{
UpdateRepeatAndEpochCounter
();
take_count_
=
0
;
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
buf
)));
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
));
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc
浏览文件 @
e19d3824
...
...
@@ -25,18 +25,44 @@
namespace
mindspore
{
namespace
dataset
{
RepeatPass
::
RepeatPass
()
:
is_repeated_
(
false
),
nested_repeats_
(
0
),
is_merge_
(
false
),
cache_lookup_
(
nullptr
)
{}
RepeatPass
::
RepeatPass
()
:
is_repeated_
(
false
),
nested_repeats_
(
0
),
num_repeats_
(
1
),
num_epochs_
(
1
),
is_merge_
(
false
),
is_cached_
(
false
),
cache_lookup_
(
nullptr
)
{}
// Identifies the subtree below this node as being in a repeated path of the tree.
Status
RepeatPass
::
PreRunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
{
// Create a new stack for eoe operators and push onto our stack of stacks.
std
::
unique_ptr
<
eoe_op_stack
>
new_stack
=
std
::
make_unique
<
eoe_
op_stack
>
();
std
::
unique_ptr
<
op_stack
>
new_stack
=
std
::
make_unique
<
op_stack
>
();
eoe_op_stacks_
.
push
(
std
::
move
(
new_stack
));
// If we are already repeated, then this is a nested repeat.
if
(
is_repeated_
)
{
nested_repeats_
++
;
}
is_repeated_
=
true
;
// If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
// Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
if
(
node
->
num_repeats
()
==
DatasetOp
::
kInfiniteRepeat
&&
num_repeats_
<
0
)
{
num_repeats_
=
-
num_repeats_
;
}
// This RepeatOp and its descendent nodes should be repeated for another num_repeats() times.
//
// Consider this example:
// tfreader --> map --> repeat(2) --> epoch ctrl(3)
// num_repeats_ is originally 3, after this repeat(2), num_repeats_ becomes 6 (2*3),
// meaning repeat op should be set to read 6 times (2*3), do does map op and tfreader op.
//
// Another example:
// tfreader --> repeat1(3) --> map --> repeat2(2) --> epoch ctrl(4)
// num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4),
// meaning repeat2 and map op should be set to read 8 times (2*4).
// Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times.
num_repeats_
*=
node
->
num_repeats
();
return
Status
::
OK
();
}
...
...
@@ -46,9 +72,16 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modifie
// that RepeatOp does. However, epoch control is actually simpler because it can
// only exist as the root node so it doesn't need all the nested code.
// Create a new stack for eoe operators and push onto our stack of stacks.
std
::
unique_ptr
<
eoe_op_stack
>
new_stack
=
std
::
make_unique
<
eoe_
op_stack
>
();
std
::
unique_ptr
<
op_stack
>
new_stack
=
std
::
make_unique
<
op_stack
>
();
eoe_op_stacks_
.
push
(
std
::
move
(
new_stack
));
is_repeated_
=
true
;
// Get the total number of epochs from the EpochCtrlOp parameter
num_epochs_
=
node
->
num_repeats
();
// Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
// For example: tfreader --> epoch ctrl(3)
// num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3),
// meaning epoch ctrl op should be set to read 3 times (1*3), so does tfreader op.
num_repeats_
*=
num_epochs_
;
return
Status
::
OK
();
}
...
...
@@ -59,6 +92,13 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modifi
return
Status
::
OK
();
}
// Identifies the subtree below this node as being cached
Status
RepeatPass
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
// Turn on the flag that we're under a merge op
is_cached_
=
true
;
return
Status
::
OK
();
}
// Hooks up any identified eoe nodes under this repeat.
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
{
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
...
...
@@ -71,7 +111,7 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
// at this time, so we can pop it to get rid of it.
eoe_
op_stack
*
current_stack
=
eoe_op_stacks_
.
top
().
get
();
op_stack
*
current_stack
=
eoe_op_stacks_
.
top
().
get
();
if
(
!
current_stack
->
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"The eoe op stack should be empty right now!"
);
}
...
...
@@ -82,14 +122,14 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// from the save area, because the merge op above us may also take action on it later for a different
// case when there is no repeat in the merge leg.
if
(
is_merge_
&&
cache_lookup_
)
{
cache_lookup_
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
cache_lookup_
->
set_total_repeats
(
num_repeats_
);
cache_lookup_
->
set_num_repeats_per_epoch
(
num_repeats_
/
num_epochs_
);
node
->
AddToEoeList
(
std
::
move
(
cache_lookup_
));
}
// If we are a nested repeat, then we add ourself to the repeat stack for the next one above us.
// A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree.
if
(
nested_repeats_
>
0
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
AddToEOEOpStack
(
node
);
nested_repeats_
--
;
}
else
{
...
...
@@ -99,7 +139,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
}
is_repeated_
=
false
;
}
if
(
is_cached_
)
{
AddToCachedOpStack
(
node
);
}
node
->
set_total_repeats
(
num_repeats_
);
node
->
set_num_repeats_per_epoch
(
num_repeats_
/
num_epochs_
);
// We finish the walk of this RepeatOp's descendent nodes.
// The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n.
// But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode,
// so we devide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp.
num_repeats_
/=
node
->
num_repeats
();
return
Status
::
OK
();
}
...
...
@@ -112,13 +161,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified)
leaf_op
=
PopFromEOEOpStack
();
}
is_repeated_
=
false
;
node
->
set_total_repeats
(
num_repeats_
);
node
->
set_num_repeats_per_epoch
(
num_repeats_
/
num_epochs_
);
// We finish the walk of this EpochCtrl's descendent nodes.
num_repeats_
/=
node
->
num_repeats
();
return
Status
::
OK
();
}
// CacheOp removes previous leaf ops and replaces them with itself
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
is_cached_
=
false
;
if
(
is_repeated_
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
// if we are a cache within a repeat path of the tree, then there will be
// eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the
// repeat or epoch ctrl operators can work with them for repeat activity during runtime.
...
...
@@ -130,13 +183,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// the repeating behaviours shall be invoked against the cache op.
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
PopFromEOEOpStack
();
while
(
leaf_op
!=
nullptr
)
{
leaf_op
->
ClearControlFlag
(
DatasetOp
::
kDeOpLastRepeat
);
leaf_op
->
ClearControlFlag
(
DatasetOp
::
kDeOpRepeated
);
leaf_op
=
PopFromEOEOpStack
();
}
AddToEOEOpStack
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
// adjust the total epochs and total repeats for ops under this cache op
std
::
shared_ptr
<
DatasetOp
>
cached_op
=
PopFromCachedOpStack
();
while
(
cached_op
!=
nullptr
)
{
int32_t
cached_op_total_repeats
=
cached_op
->
op_total_repeats
()
/
num_repeats_
;
cached_op
->
set_total_repeats
(
cached_op_total_repeats
);
// Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1
cached_op
->
set_num_repeats_per_epoch
(
cached_op_total_repeats
);
cached_op
=
PopFromCachedOpStack
();
}
}
node
->
set_total_repeats
(
num_repeats_
);
node
->
set_num_repeats_per_epoch
(
num_repeats_
/
num_epochs_
);
return
Status
::
OK
();
}
...
...
@@ -145,13 +208,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
)
{
// If we are in a repeat path, then set our repeated flag
if
(
is_repeated_
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
// if we are a leaf node then save ourself in a stack for the repeat operator above us
if
(
node
->
IsLeaf
())
{
AddToEOEOpStack
(
node
);
}
}
if
(
is_cached_
)
{
AddToCachedOpStack
(
node
);
}
// Set total repeats and total epochs for the node
node
->
set_total_repeats
(
num_repeats_
);
node
->
set_num_repeats_per_epoch
(
num_repeats_
/
num_epochs_
);
return
Status
::
OK
();
}
...
...
@@ -159,13 +226,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
{
// Setting the flag is needed since we didn't call the base class DatasetOp version
if
(
is_repeated_
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
// would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack
if
(
cache_lookup_
)
{
cache_lookup_
->
set_total_repeats
(
num_repeats_
);
node
->
set_num_repeats_per_epoch
(
num_repeats_
/
num_epochs_
);
AddToEOEOpStack
(
std
::
move
(
cache_lookup_
));
}
}
node
->
set_total_repeats
(
num_repeats_
);
node
->
set_num_repeats_per_epoch
(
num_repeats_
/
num_epochs_
);
cache_lookup_
.
reset
();
// If we are not repeated then the saved lookup is no longer needed or used
is_merge_
=
false
;
return
Status
::
OK
();
...
...
@@ -178,13 +248,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
RETURN_STATUS_UNEXPECTED
(
"CacheLookupOp must be a leaf node!"
);
}
// If we are in a repeat path already, then there must be a repeat above the merge op
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
if
(
is_repeated_
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
// Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that.
}
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
...
...
@@ -197,19 +260,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
// Adds an operator to the eoe operator stack save area
void
RepeatPass
::
AddToEOEOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
)
{
eoe_
op_stack
*
current_stack
=
eoe_op_stacks_
.
top
().
get
();
op_stack
*
current_stack
=
eoe_op_stacks_
.
top
().
get
();
current_stack
->
push
(
dataset_op
);
}
// Pops an operator from the eoe operator stack save area
std
::
shared_ptr
<
DatasetOp
>
RepeatPass
::
PopFromEOEOpStack
()
{
std
::
shared_ptr
<
DatasetOp
>
top_op
=
nullptr
;
eoe_
op_stack
*
current_stack
=
eoe_op_stacks_
.
top
().
get
();
op_stack
*
current_stack
=
eoe_op_stacks_
.
top
().
get
();
if
(
current_stack
!=
nullptr
&&
!
current_stack
->
empty
())
{
top_op
=
current_stack
->
top
();
current_stack
->
pop
();
}
return
top_op
;
}
// Adds an operator to the cached operator stack save area
void
RepeatPass
::
AddToCachedOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
)
{
cached_op_stacks_
.
push
(
dataset_op
);
}
// Pops an operator from the cached operator stack save area
std
::
shared_ptr
<
DatasetOp
>
RepeatPass
::
PopFromCachedOpStack
()
{
std
::
shared_ptr
<
DatasetOp
>
top_op
=
nullptr
;
if
(
!
cached_op_stacks_
.
empty
())
{
top_op
=
cached_op_stacks_
.
top
();
cached_op_stacks_
.
pop
();
}
return
top_op
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h
浏览文件 @
e19d3824
...
...
@@ -30,7 +30,7 @@ namespace dataset {
/// to the eoe-producing (typically leaf) nodes underneath it.
class
RepeatPass
:
public
NodePass
{
public:
using
eoe_
op_stack
=
std
::
stack
<
std
::
shared_ptr
<
DatasetOp
>>
;
using
op_stack
=
std
::
stack
<
std
::
shared_ptr
<
DatasetOp
>>
;
/// \brief Constructor
RepeatPass
();
...
...
@@ -56,6 +56,12 @@ class RepeatPass : public NodePass {
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
override
;
/// \brief Identifies the subtree below this node as being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
...
...
@@ -103,11 +109,24 @@ class RepeatPass : public NodePass {
/// \return shared_ptr to the popped operator
std
::
shared_ptr
<
DatasetOp
>
PopFromEOEOpStack
();
bool
is_repeated_
;
// T/F if we are processing under a repeat
bool
is_merge_
;
// T/F if we are processing under a cache merge op
int32_t
nested_repeats_
;
// A counter for nested repeats
std
::
stack
<
std
::
unique_ptr
<
eoe_op_stack
>>
eoe_op_stacks_
;
// A save area for leaf/eoe ops (with nesting)
std
::
shared_ptr
<
DatasetOp
>
cache_lookup_
;
// A save area for a cache lookup op
/// \brief Adds an operator to the cached operator stack save area
/// \param op - The dataset op to work add to cached stack
/// \return Status - The error code return
void
AddToCachedOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
);
/// \brief Pops an operator from the cached operator stack save area
/// \return shared_ptr to the popped operator
std
::
shared_ptr
<
DatasetOp
>
PopFromCachedOpStack
();
bool
is_repeated_
;
// T/F if we are processing under a repeat
bool
is_merge_
;
// T/F if we are processing under a cache merge op
bool
is_cached_
;
// T/F is we are processing under a cache op
int32_t
nested_repeats_
;
// A counter for nested repeats
int32_t
num_repeats_
;
// A multiplier to the total number of repeats
int32_t
num_epochs_
;
// To save the total number of epochs
std
::
stack
<
std
::
unique_ptr
<
op_stack
>>
eoe_op_stacks_
;
// A save area for leaf/eoe ops (with nesting)
op_stack
cached_op_stacks_
;
// A save area for ops under a cache op
std
::
shared_ptr
<
DatasetOp
>
cache_lookup_
;
// A save area for a cache lookup op
};
}
// namespace dataset
}
// namespace mindspore
...
...
tests/ut/python/dataset/test_epoch_ctrl.py
浏览文件 @
e19d3824
...
...
@@ -565,6 +565,99 @@ def test_generator_tuple_repeat_repeat_3():
# rely on garbage collector to destroy iter1
def
test_generator_tuple_infinite_repeat_repeat_1
():
"""
test generator tuple infinite repeat repeat 1
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
repeat
()
data1
=
data1
.
repeat
(
3
)
iter1
=
data1
.
create_tuple_iterator
(
num_epochs
=
11
)
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
np
.
testing
.
assert_array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
if
i
==
100
:
break
# rely on garbage collector to destroy iter1
def
test_generator_tuple_infinite_repeat_repeat_2
():
"""
test generator tuple infinite repeat repeat 2
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
repeat
(
3
)
data1
=
data1
.
repeat
()
iter1
=
data1
.
create_tuple_iterator
(
num_epochs
=
11
)
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
np
.
testing
.
assert_array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
if
i
==
100
:
break
# rely on garbage collector to destroy iter1
def
test_generator_tuple_infinite_repeat_repeat_3
():
"""
test generator tuple infinite repeat repeat 3
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
repeat
()
data1
=
data1
.
repeat
()
iter1
=
data1
.
create_tuple_iterator
(
num_epochs
=
11
)
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
np
.
testing
.
assert_array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
if
i
==
100
:
break
# rely on garbage collector to destroy iter1
def
test_generator_tuple_infinite_repeat_repeat_4
():
"""
test generator tuple infinite repeat repeat 4
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
repeat
()
data1
=
data1
.
repeat
()
iter1
=
data1
.
create_tuple_iterator
()
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
np
.
testing
.
assert_array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
if
i
==
100
:
break
# rely on garbage collector to destroy iter1
def
test_generator_reusedataset
():
"""
test generator reusedataset
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录