Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
dc0491ca
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
dc0491ca
编写于
4月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!508 [Dataset] Adding sync_wait operator for dataset
Merge pull request !508 from EricZ/master
上级
b0f4b36f
cd945187
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
868 addition
and
10 deletion
+868
-10
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+25
-0
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+3
-0
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+1
-0
mindspore/ccsrc/dataset/core/client.h
mindspore/ccsrc/dataset/core/client.h
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc
mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc
+235
-0
mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h
mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h
+172
-0
mindspore/ccsrc/dataset/engine/datasetops/zip_op.h
mindspore/ccsrc/dataset/engine/datasetops/zip_op.h
+9
-9
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+183
-1
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+2
-0
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+16
-0
tests/ut/python/dataset/test_config.py
tests/ut/python/dataset/test_config.py
+38
-0
tests/ut/python/dataset/test_sync_wait.py
tests/ut/python/dataset/test_sync_wait.py
+182
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
dc0491ca
...
...
@@ -48,6 +48,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{
kMap
,
&
DEPipeline
::
ParseMapOp
},
{
kFilter
,
&
DEPipeline
::
ParseFilterOp
},
{
kBatch
,
&
DEPipeline
::
ParseBatchOp
},
{
kBarrier
,
&
DEPipeline
::
ParseBarrierOp
},
{
kRepeat
,
&
DEPipeline
::
ParseRepeatOp
},
{
kSkip
,
&
DEPipeline
::
ParseSkipOp
},
{
kZip
,
&
DEPipeline
::
ParseZipOp
},
...
...
@@ -627,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseBarrierOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
std
::
shared_ptr
<
BarrierOp
::
Builder
>
builder
=
std
::
make_shared
<
BarrierOp
::
Builder
>
();
// Right now barrier should only take num_rows_per_buffer = 1
// The reason for this is because having it otherwise can lead to blocking issues
// See barrier_op.h for more details
(
void
)
builder
->
SetRowsPerBuffer
(
1
);
for
(
auto
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"condition_name"
)
{
(
void
)
builder
->
SetConditionName
(
ToString
(
value
));
}
else
if
(
key
==
"condition_func"
)
{
(
void
)
builder
->
SetConditionFunc
(
value
.
cast
<
py
::
function
>
());
}
}
}
std
::
shared_ptr
<
BarrierOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
*
ptr
=
op
;
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseDeviceQueueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
int32_t
prefetch_size
=
0
;
if
(
args
.
contains
(
"prefetch_size"
))
{
...
...
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
dc0491ca
...
...
@@ -40,6 +40,7 @@ enum OpName {
kShuffle
,
kMindrecord
,
kBatch
,
kBarrier
,
kCache
,
kRepeat
,
kSkip
,
...
...
@@ -115,6 +116,8 @@ class DEPipeline {
Status
ParseBatchOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseBarrierOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseGeneratorOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseRenameOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
dc0491ca
...
...
@@ -481,6 +481,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.
value
(
"STORAGE"
,
OpName
::
kStorage
)
.
value
(
"SHUFFLE"
,
OpName
::
kShuffle
)
.
value
(
"BATCH"
,
OpName
::
kBatch
)
.
value
(
"BARRIER"
,
OpName
::
kBarrier
)
.
value
(
"MINDRECORD"
,
OpName
::
kMindrecord
)
.
value
(
"CACHE"
,
OpName
::
kCache
)
.
value
(
"REPEAT"
,
OpName
::
kRepeat
)
...
...
mindspore/ccsrc/dataset/core/client.h
浏览文件 @
dc0491ca
...
...
@@ -25,6 +25,7 @@
#include "dataset/core/tensor_shape.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/barrier_op.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
...
...
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
浏览文件 @
dc0491ca
...
...
@@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT
dataset_op.cc
parallel_op.cc
pipeline_op.cc
barrier_op.cc
batch_op.cc
device_queue_op.cc
map_op.cc
...
...
mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc
0 → 100644
浏览文件 @
dc0491ca
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/barrier_op.h"
#include <utility>
#include "dataset/core/constants.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
BarrierOp
::
Builder
::
Builder
()
{
// Some arguments to the BarrierOp constructor have a default argument that is taken
// from the client config.
// The user may choose to change these values for the construction of the BarrierOp by
// using the various builder set methods.
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_rows_per_buffer_
=
cfg
->
rows_per_buffer
();
builder_op_connector_size_
=
cfg
->
op_connector_size
();
}
Status
BarrierOp
::
Builder
::
SanityCheck
()
const
{
return
Status
::
OK
();
}
Status
BarrierOp
::
Builder
::
Build
(
std
::
shared_ptr
<
BarrierOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
BarrierOp
>
(
builder_rows_per_buffer_
,
builder_op_connector_size_
,
builder_condition_name_
,
builder_condition_func_
);
return
Status
::
OK
();
}
// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions
BarrierOp
::
BarrierOp
(
int32_t
rows_per_buffer
,
int32_t
op_connector_size
,
const
std
::
string
&
condition_name
,
py
::
function
condition_func
)
:
PipelineOp
(
op_connector_size
),
rows_per_buffer_
(
rows_per_buffer
),
buffer_id_
(
0
),
clean_up_
(
false
),
eof_
(
false
),
condition_name_
(
condition_name
),
condition_function_
(
condition_func
)
{}
// destructor
BarrierOp
::~
BarrierOp
()
{}
// Entry point for Barrier, called by launch()
Status
BarrierOp
::
operator
()()
{
// The children_num_ parameter needs to be put here
// Synchronize with TaskManager once the thread is created.
TaskManager
::
FindMe
()
->
Post
();
// create child iterator, right now this barrier is a pipeline operator
int32_t
worker_id
=
0
;
int32_t
child_idx
=
0
;
child_iterator_
=
std
::
make_unique
<
ChildIterator
>
(
this
,
worker_id
,
child_idx
);
// Loop until eof is true
while
(
!
eof_
)
{
// Create new table to put the new tensor rows
std
::
unique_ptr
<
TensorQTable
>
curr_table
=
std
::
make_unique
<
TensorQTable
>
();
RETURN_IF_NOT_OK
(
prepare
(
curr_table
.
get
()));
// If an eof got picked up during the above prepare, then we're done
if
(
eof_
)
{
break
;
}
// we have to output new buffer with possibly different buffer size, possibly one row
while
(
!
clean_up_
)
{
// 1. If a previous loop iteration sent the current table out, then create a new one.
if
(
curr_table
==
nullptr
)
{
curr_table
=
std
::
make_unique
<
TensorQTable
>
();
}
// 2 fill the table. Note: clean_up mode might get turned on if epoch is finished
RETURN_IF_NOT_OK
(
fillBuffer
(
curr_table
.
get
()));
// 3 create and update buffer and send it to the out connector
if
(
!
curr_table
->
empty
())
{
std
::
unique_ptr
<
DataBuffer
>
curr_buffer
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
,
DataBuffer
::
kDeBFlagNone
);
curr_buffer
->
set_tensor_table
(
std
::
move
(
curr_table
));
curr_buffer
->
set_column_name_map
(
col_name_id_map_
);
MS_LOG
(
DEBUG
)
<<
"Barrier operator finished one buffer, pushing, rows "
<<
curr_buffer
->
NumRows
()
<<
", cols "
<<
curr_buffer
->
NumCols
()
<<
", map "
<<
col_name_id_map_
.
size
()
<<
"."
;
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
curr_buffer
)));
buffer_id_
++
;
}
}
// 4 handle drain state.
if
(
clean_up_
)
{
MS_LOG
(
DEBUG
)
<<
"Barrier operator sending epoch ending signal."
;
// Send the eoe up.
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
))));
}
}
// 5 handle eof
// propagate eof here.
MS_LOG
(
INFO
)
<<
"Barrier operator got EOF, propagating."
;
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
))));
return
Status
::
OK
();
}
// Handles preprocessing of the main loop, used when starting new epoch
Status
BarrierOp
::
prepare
(
TensorQTable
*
const
table
)
{
MS_LOG
(
DEBUG
)
<<
"Barrier operator prepares for new epoch."
;
clean_up_
=
false
;
buffer_id_
=
0
;
if
(
table
==
nullptr
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"BarrierOp prepare phase requires a tensor table."
);
}
// fill initial row
TensorRow
new_row
=
{};
// use iterator to get next row and invoke pyfunc wait
RETURN_IF_NOT_OK
(
getNextTensorRow
(
&
new_row
));
// If the first row fetching resulted in eof, then we are done.
if
(
eof_
)
{
return
Status
::
OK
();
}
if
(
new_row
.
empty
())
{
// This epoch is empty
return
Status
::
OK
();
}
// Pack this first row into our tensor table
// first row we also have to check if we should block
RETURN_IF_NOT_OK
(
blockCond
());
table
->
push_back
(
std
::
move
(
new_row
));
// At this point we have 1 row produced, we take the old column map id and use it in the new table
// Initializing col_name_id_map_ from the first data buffer.
col_name_id_map_
=
child_iterator_
->
col_name_id_map
();
// the update code below shouldn't do anything bad if the column name already exists.
return
Status
::
OK
();
}
// fillBuffer always expects a new table to fill
Status
BarrierOp
::
fillBuffer
(
TensorQTable
*
const
table
)
{
if
(
table
==
nullptr
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"BarrierOp fillBuffer null table pointer."
);
}
TensorRow
new_row
=
{};
while
(
table
->
size
()
<
static_cast
<
size_t
>
(
rows_per_buffer_
))
{
RETURN_IF_NOT_OK
(
getNextTensorRow
(
&
new_row
));
// Early exit the loop if we got empty row from any of our child iterations
if
(
new_row
.
empty
())
{
return
Status
::
OK
();
}
// else we got a row so pack it into the tensor table.
RETURN_IF_NOT_OK
(
blockCond
());
table
->
push_back
(
std
::
move
(
new_row
));
}
return
Status
::
OK
();
}
// function executes a py_func and blocks until condition becomes true.
Status
BarrierOp
::
blockCond
()
{
{
py
::
gil_scoped_acquire
gil_acquire
;
if
(
Py_IsInitialized
()
==
0
)
{
return
Status
(
StatusCode
::
kPythonInterpreterFailure
,
"Python Interpreter is finalized"
);
}
// we have condition name, however the flexibility is in python today
try
{
// Invoke python function
py
::
object
ret_py_obj
=
condition_function_
();
// Process the return value
if
(
!
py
::
isinstance
<
py
::
bool_
>
(
ret_py_obj
))
{
return
Status
(
StatusCode
::
kPyFuncException
,
"Condition wait function should return true/false"
);
}
}
catch
(
const
py
::
error_already_set
&
e
)
{
return
Status
(
StatusCode
::
kPyFuncException
,
e
.
what
());
}
}
return
Status
::
OK
();
}
// fetches next Barrier buffer row
Status
BarrierOp
::
getNextTensorRow
(
TensorRow
*
new_row
)
{
// iterate over all iterators and generate a row
RETURN_IF_NOT_OK
((
child_iterator_
)
->
FetchNextTensorRow
(
new_row
));
// add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
if
(
new_row
->
empty
())
{
// If we did not get a row from any of the children, then it's the end of an epoch and we can move
// to drain state.
MS_LOG
(
INFO
)
<<
"Barrier operator child iterator produced empty row."
;
clean_up_
=
true
;
// If we picked up an eof here, then we are completely done.
if
((
child_iterator_
)
->
eof_handled
())
{
MS_LOG
(
INFO
)
<<
"Barrier operator iterator got EOF."
;
eof_
=
true
;
}
return
Status
::
OK
();
}
return
Status
::
OK
();
}
// A function that prints info about the Operator
void
BarrierOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Call base class printer first
PipelineOp
::
Print
(
out
,
show_all
);
out
<<
"
\n
BarrierOp:
\n
"
<<
"
\n
Condition "
<<
condition_name_
<<
"
\n\n
"
;
}
// overwrite function and handle eof
Status
BarrierOp
::
EofReceived
(
int32_t
)
{
MS_LOG
(
DEBUG
)
<<
"Barrier operator EOF received, do nothing now."
;
return
Status
::
OK
();
}
// overwrite function and handle eoe
Status
BarrierOp
::
EoeReceived
(
int32_t
)
{
state_
=
OpState
::
kDeOpIdle
;
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h
0 → 100644
浏览文件 @
dc0491ca
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/pipeline_op.h"
#include "dataset/kernels/tensor_op.h"
namespace
mindspore
{
namespace
dataset
{
// Forward declare
class
DataBuffer
;
class
ExecutionTree
;
// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has
// been received. This signal is given from python layer. The current barrier design respects the
// rows per buffer design and will only output a buffer with rows once it has received rows per buffer
// signals from python.
class
BarrierOp
:
public
PipelineOp
{
public:
// The nested builder class inside of the BarrierOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
class
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder
();
// Default destructor
~
Builder
()
=
default
;
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder
&
SetRowsPerBuffer
(
int32_t
rows_per_buffer
)
{
builder_rows_per_buffer_
=
rows_per_buffer
;
return
*
this
;
}
// Setter method.
// @param int32_t op_connector_size
// @return Builder setter method returns reference to the builder.
Builder
&
SetOpConnectorSize
(
int32_t
op_connector_size
)
{
builder_op_connector_size_
=
op_connector_size
;
return
*
this
;
}
// Setter method.
// @param const std::string & condition_name
// @return Builder setter method returns reference to the builder.
Builder
&
SetConditionName
(
const
std
::
string
&
condition_name
)
{
builder_condition_name_
=
condition_name
;
return
*
this
;
}
// Setter method.
// @param py::function condition_func - blocking condition function
// @return Builder setter method returns reference to the builder.
Builder
&
SetConditionFunc
(
py
::
function
condition_func
)
{
builder_condition_func_
=
condition_func
;
return
*
this
;
}
// The builder "build" method creates the BarrierOp dataset Operator.
// @return shared_ptr to the new BarrierOp object
Status
Build
(
std
::
shared_ptr
<
BarrierOp
>
*
);
private:
int32_t
builder_rows_per_buffer_
;
int32_t
builder_op_connector_size_
;
std
::
string
builder_condition_name_
;
py
::
function
builder_condition_func_
;
Status
SanityCheck
()
const
;
};
// Constructor for BarrierOp
// @param rows_per_buffer - number of rows in output buffer
// @param op_connector_size - connector size
// @param condition_name - the condition name associated with this operator
// @param condition_func - the blocking condition check per row
// @note - currently rows_per_buffer should = 1 for barrier.
// The reason for this is having other values would complicate how the pipeline behaves with other operators
// One example of such case is having batch after barrier. Batch would be waiting for data and having
// rows per buffer in this case can result in hanging
BarrierOp
(
int32_t
rows_per_buffer
,
int32_t
op_connector_size
,
const
std
::
string
&
condition_name
,
py
::
function
condition_func
);
// Destructor
~
BarrierOp
();
Status
EofReceived
(
int32_t
)
override
;
Status
EoeReceived
(
int32_t
)
override
;
// Print function for Barrier
// @param out - output stream to print to
// @param show_all - if it should print everything
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// Provide stream operator for displaying it
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
BarrierOp
&
bo
)
{
bo
.
Print
(
out
,
false
);
return
out
;
}
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
Status
operator
()()
override
;
// Handles preprocessing of the main loop, used when starting new epoch
// @param table - a table of tensors to be moved into a buffer
Status
prepare
(
TensorQTable
*
const
table
);
// This function calls takes a table repeatedly adds rows to it.
// @param table - a table of tensors to be moved into a buffer
Status
fillBuffer
(
TensorQTable
*
const
table
);
// Gets next tensor row and sets control signals
Status
getNextTensorRow
(
TensorRow
*
new_row
);
// This function runs the wait function on condition
Status
blockCond
();
private:
// clean up variable to return imcomplete buffer
bool
clean_up_
;
// end of file state, we stop reading data and shut down
bool
eof_
;
// rows per buffer
int32_t
rows_per_buffer_
;
// buffer_id
int32_t
buffer_id_
;
// local variable to keep track of the buffer information
std
::
unordered_map
<
std
::
string
,
int32_t
>
col_name_id_map_
;
// iterator to pull new rows, we only have one child
std
::
unique_ptr
<
ChildIterator
>
child_iterator_
;
// condition name, to support multiple barriers
std
::
string
condition_name_
;
// Function pointer of blocking function
py
::
function
condition_function_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
mindspore/ccsrc/dataset/engine/datasetops/zip_op.h
浏览文件 @
dc0491ca
...
...
@@ -34,7 +34,7 @@ class DataBuffer;
class
ZipOp
:
public
PipelineOp
{
public:
// The nested builder class inside of the
Batch
Op is used to help manage all of
// The nested builder class inside of the
Zip
Op is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
...
...
@@ -76,8 +76,8 @@ class ZipOp : public PipelineOp {
};
// Constructor for ZipOp
// @param rows_per_buffer number of rows in output buffer
// @param op_connector_size
connector
// @param rows_per_buffer
-
number of rows in output buffer
// @param op_connector_size
- connector size
ZipOp
(
int32_t
rows_per_buffer
,
int32_t
op_connector_size
);
// Destructor
...
...
@@ -88,8 +88,8 @@ class ZipOp : public PipelineOp {
Status
EoeReceived
(
int32_t
)
override
;
// Print function for Zip
// @param out output stream to print to
// @param show_all if it should print everything
// @param out
-
output stream to print to
// @param show_all
-
if it should print everything
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// Provide stream operator for displaying it
...
...
@@ -113,14 +113,14 @@ class ZipOp : public PipelineOp {
Status
fillBuffer
(
TensorQTable
*
const
table
);
// Special handle case where an empty row has been received from child iterator
// @note we need to drain eoe signals from all children connectors.
// @details when this function is called, then we encountered eoe at child iterator
// @note
-
we need to drain eoe signals from all children connectors.
// @details
-
when this function is called, then we encountered eoe at child iterator
// we have to drain rows from other child iterators until we hit eoe from all other child iterators
Status
drainPipeline
();
// Merges 1 row from each childIterator together
// @param new_zip_row
input and output, will return
a non-empty row if all rows from childConnectors are non-empty
// @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true
// @param new_zip_row
- input and output, will be
a non-empty row if all rows from childConnectors are non-empty
// @param updateColumnMapping
-
generates a new column name to index mapping (mColNameIdMap) if set to true
// @details merge rows from iterator together. This is the main functionality for ZipOp
// this function takes one row and fills it with tensors from rows fetched
// from childIterators.
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
dc0491ca
...
...
@@ -28,6 +28,7 @@ import multiprocessing
import
queue
from
enum
import
Enum
from
importlib
import
import_module
import
threading
import
numpy
as
np
from
mindspore._c_dataengine
import
DataType
,
TFReaderOp
,
ImageFolderOp
,
CifarOp
,
MnistOp
,
ManifestOp
,
\
...
...
@@ -40,7 +41,7 @@ from .iterators import DictIterator, TupleIterator
from
.validators
import
check
,
check_batch
,
check_shuffle
,
check_map
,
check_filter
,
check_repeat
,
check_skip
,
check_zip
,
check_rename
,
\
check_take
,
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_celebadataset
,
check_minddataset
,
check_generatordataset
,
\
check_zip_dataset
,
check_add_column
,
check_textfiledataset
check_
sync_wait
,
check_
zip_dataset
,
check_add_column
,
check_textfiledataset
from
..core.datatypes
import
mstype_to_detype
,
mstypelist_to_detypelist
try
:
...
...
@@ -141,6 +142,7 @@ class Dataset:
self
.
_batch_size
=
None
self
.
_num_classes
=
None
self
.
_repeat_count
=
None
self
.
_sync
=
False
def
get_args
(
self
):
"""
...
...
@@ -198,6 +200,30 @@ class Dataset:
"""
return
BatchDataset
(
self
,
batch_size
,
drop_remainder
,
num_parallel_workers
,
per_batch_map
,
input_columns
)
@
check_sync_wait
def
sync_wait
(
self
,
condition_name
,
num_batch
=
1
,
callback
=
None
):
'''
Add a blocking condition to the input Dataset
Args:
input_dataset (Dataset): Input dataset to apply flow control
num_batch (int): the number of batches without blocking at the start of each epoch
condition_name (str): The condition name that is used to toggle sending next row
callback (function): The callback funciton that will be invoked when sync_update is called
Raises:
RuntimeError: If condition name already exists.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>> data = data.sync_wait("callback1")
>>> data = data.batch(batch_size)
>>> for batch_data in data.create_dict_iterator():
>>> data = data.sync_update("callback1")
'''
return
SyncWaitDataset
(
self
,
condition_name
,
num_batch
,
callback
)
@
check_shuffle
def
shuffle
(
self
,
buffer_size
):
"""
...
...
@@ -220,6 +246,9 @@ class Dataset:
Returns:
ShuffleDataset, dataset shuffled.
Raises:
RuntimeError: If exist sync operators before shuffle.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object
...
...
@@ -821,6 +850,9 @@ class Dataset:
self
.
_input_indexs
=
value
def
_get_pipeline_info
(
self
):
"""
Gets pipeline information.
"""
device_iter
=
TupleIterator
(
self
)
self
.
_output_shapes
=
device_iter
.
get_output_shapes
()
self
.
_output_types
=
device_iter
.
get_output_types
()
...
...
@@ -875,6 +907,30 @@ class Dataset:
return
self
.
input
[
0
].
num_classes
()
return
None
def
get_sync_notifiers
(
self
):
if
self
.
input
:
return
self
.
input
[
0
].
get_sync_notifiers
()
return
{}
def
is_sync
(
self
):
if
self
.
input
:
return
self
.
input
[
0
].
is_sync
()
return
False
def
sync_update
(
self
,
condition_name
,
num_batch
=
None
,
data
=
None
):
"""
condition_name (str): The condition name that is used to toggle sending next row
step_size (int or None): The number of steps(rows) that are released
when pass_rows is None, will update the same number as sync_wait specified
data (dict or None): The data passed to the callback
"""
notifiers_dict
=
self
.
get_sync_notifiers
()
if
condition_name
not
in
notifiers_dict
:
raise
RuntimeError
(
"Condition name not found"
)
if
num_batch
is
not
None
:
num_batch
*=
self
.
get_batch_size
()
notifiers_dict
[
condition_name
](
num_batch
,
data
)
def
get_batch_size
(
self
):
"""
Get the size of a batch.
...
...
@@ -978,6 +1034,8 @@ class BatchDataset(DatasetOp):
if
BatchDataset
.
_is_ancestor_of_repeat
(
input_dataset
):
logger
.
warning
(
"Repeat is located before batch, data from two epochs can be batched together."
)
BatchDataset
.
_update_batch_size_for_syncwait
(
input_dataset
,
batch_size
)
self
.
batch_size
=
batch_size
self
.
drop_remainder
=
drop_remainder
self
.
per_batch_map
=
per_batch_map
...
...
@@ -1034,6 +1092,20 @@ class BatchDataset(DatasetOp):
flag
=
flag
|
BatchDataset
.
_is_ancestor_of_repeat
(
input_dataset
)
return
flag
@
staticmethod
def
_update_batch_size_for_syncwait
(
dataset
,
batch_size
):
"""
Utility function to notify batch size to sync_wait.
Args:
dataset (Dataset): dataset to be checked
batchsize (int): batch size to notify
"""
if
isinstance
(
dataset
,
SyncWaitDataset
):
dataset
.
update_sync_batch_size
(
batch_size
)
for
input_dataset
in
dataset
.
input
:
BatchDataset
.
_update_batch_size_for_syncwait
(
input_dataset
,
batch_size
)
class
BatchInfo
(
CBatchInfo
):
"""
...
...
@@ -1058,6 +1130,108 @@ class BatchInfo(CBatchInfo):
"""
return
class
BlockReleasePair
:
"""
The blocking condition class used by SyncWaitDataset
Args:
init_release_rows (int): Number of lines to allow through the pipeline
callback (function): The callback funciton that will be called when release is called
"""
def
__init__
(
self
,
init_release_rows
,
callback
=
None
):
self
.
row_count
=
-
init_release_rows
self
.
cv
=
threading
.
Condition
()
self
.
callback
=
callback
self
.
default_rows
=
init_release_rows
def
__deepcopy__
(
self
,
memodict
):
if
id
(
self
)
in
memodict
:
return
memodict
[
id
(
self
)]
memodict
[
id
(
self
)]
=
self
# condition variable and callback are the same, but reset the counter
self
.
reset
()
return
self
def
reset
(
self
):
with
self
.
cv
:
self
.
row_count
=
-
self
.
default_rows
self
.
cv
.
notify_all
()
def
update_batched_size
(
self
,
batch_size
):
# should only use before the pipeline creates
self
.
row_count
*=
batch_size
self
.
default_rows
*=
batch_size
def
block_func
(
self
):
with
self
.
cv
:
self
.
cv
.
wait_for
(
lambda
:
self
.
row_count
<
0
)
self
.
row_count
+=
1
return
True
def
release_func
(
self
,
pass_rows
=
None
,
data
=
None
):
with
self
.
cv
:
if
pass_rows
is
None
:
pass_rows
=
self
.
default_rows
self
.
row_count
-=
pass_rows
if
self
.
callback
is
not
None
:
self
.
callback
(
data
)
self
.
cv
.
notify_all
()
class
SyncWaitDataset
(
DatasetOp
):
"""
The result of adding a blocking condition to the input Dataset
Args:
input_dataset (Dataset): Input dataset to apply flow control
num_batch (int): the number of batches without blocking at the start of each epoch
condition_name (str): The condition name that is used to toggle sending next row
callback (function): The callback funciton that will be invoked when sync_update is called
Raises:
RuntimeError: If condition name already exists.
"""
def
__init__
(
self
,
input_dataset
,
condition_name
,
num_batch
,
callback
=
None
):
super
().
__init__
()
self
.
input
.
append
(
input_dataset
)
input_dataset
.
output
.
append
(
self
)
# set to the default value, waiting for the batch to update it
self
.
_condition_name
=
condition_name
self
.
_pair
=
BlockReleasePair
(
num_batch
,
callback
)
if
self
.
_condition_name
in
self
.
input
[
0
].
get_sync_notifiers
():
raise
RuntimeError
(
"Condition name is already in use"
)
def
get_sync_notifiers
(
self
):
return
{
**
self
.
input
[
0
].
get_sync_notifiers
(),
**
{
self
.
_condition_name
:
self
.
_pair
.
release_func
}}
def
is_sync
(
self
):
return
True
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"condition_name"
]
=
self
.
_condition_name
args
[
"condition_func"
]
=
self
.
_pair
.
block_func
return
args
def
update_sync_batch_size
(
self
,
batch_size
):
self
.
_pair
.
update_batched_size
(
batch_size
)
@
staticmethod
def
_is_ancestor_of_batch
(
dataset
):
"""
Utility function to find the case where sync_wait is used before batch.
Args:
dataset (Dataset): dataset to be checked
Return:
True or False
"""
if
isinstance
(
dataset
,
BatchDataset
):
return
True
flag
=
False
for
input_dataset
in
dataset
.
input
:
flag
=
flag
|
SyncWaitDataset
.
_is_ancestor_of_batch
(
input_dataset
)
return
flag
class
ShuffleDataset
(
DatasetOp
):
"""
...
...
@@ -1066,6 +1240,9 @@ class ShuffleDataset(DatasetOp):
Args:
input_dataset (Dataset): Input Dataset to be shuffled.
buffer_size (int): The size of the buffer.
Raises:
RuntimeError: If exist sync operators before shuffle.
"""
def
__init__
(
self
,
input_dataset
,
buffer_size
):
...
...
@@ -1074,6 +1251,8 @@ class ShuffleDataset(DatasetOp):
self
.
input
.
append
(
input_dataset
)
input_dataset
.
output
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
if
self
.
is_sync
():
raise
RuntimeError
(
"No shuffle after sync operators"
)
def
get_args
(
self
):
args
=
super
().
get_args
()
...
...
@@ -1427,6 +1606,9 @@ class ZipDataset(DatasetOp):
"""
return
None
def
is_sync
(
self
):
return
any
([
c
.
is_sync
()
for
c
in
self
.
input
])
def
get_args
(
self
):
args
=
super
().
get_args
()
return
args
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
dc0491ca
...
...
@@ -129,6 +129,8 @@ class Iterator:
op_type
=
OpName
.
MINDRECORD
elif
isinstance
(
dataset
,
de
.
BatchDataset
):
op_type
=
OpName
.
BATCH
elif
isinstance
(
dataset
,
de
.
SyncWaitDataset
):
op_type
=
OpName
.
BARRIER
elif
isinstance
(
dataset
,
de
.
ZipDataset
):
op_type
=
OpName
.
ZIP
elif
isinstance
(
dataset
,
de
.
MapDataset
):
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
dc0491ca
...
...
@@ -652,6 +652,22 @@ def check_batch(method):
return
new_method
def
check_sync_wait
(
method
):
"""check the input arguments of sync_wait."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
nreq_param_str
=
[
'condition_name'
]
nreq_param_int
=
[
'step_size'
]
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
check_param_type
(
nreq_param_str
,
param_dict
,
str
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_shuffle
(
method
):
"""check the input arguments of shuffle."""
...
...
tests/ut/python/dataset/test_config.py
浏览文件 @
dc0491ca
...
...
@@ -12,8 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing configuration manager
"""
import
filecmp
import
glob
import
os
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
def
test_basic
():
ds
.
config
.
load
(
'../data/dataset/declient.cfg'
)
...
...
@@ -36,6 +46,34 @@ def test_basic():
assert
ds
.
config
.
get_prefetch_size
()
==
4
assert
ds
.
config
.
get_seed
()
==
5
def
test_pipeline
():
"""
Test that our configuration pipeline works when we set parameters at dataset interval
"""
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds
.
config
.
set_num_parallel_workers
(
2
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
[
vision
.
Decode
(
True
)])
ds
.
serialize
(
data1
,
"testpipeline.json"
)
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds
.
config
.
set_num_parallel_workers
(
4
)
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
[
vision
.
Decode
(
True
)])
ds
.
serialize
(
data2
,
"testpipeline2.json"
)
# check that the generated output is different
assert
(
filecmp
.
cmp
(
'testpipeline.json'
,
'testpipeline2.json'
))
# this test passes currently because our num_parallel_workers don't get updated.
# remove generated jason files
file_list
=
glob
.
glob
(
'*.json'
)
for
f
in
file_list
:
try
:
os
.
remove
(
f
)
except
IOError
:
logger
.
info
(
"Error while deleting: {}"
.
format
(
f
))
if
__name__
==
'__main__'
:
test_basic
()
test_pipeline
()
tests/ut/python/dataset/test_sync_wait.py
0 → 100644
浏览文件 @
dc0491ca
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
import
time
import
numpy
as
np
def
gen
():
for
i
in
range
(
100
):
yield
np
.
array
(
i
),
class
Augment
:
def
__init__
(
self
,
loss
):
self
.
loss
=
loss
def
preprocess
(
self
,
input
):
return
input
def
update
(
self
,
data
):
self
.
loss
=
data
[
"loss"
]
def
test_simple_sync_wait
():
"""
Test simple sync wait: test sync in dataset pipeline
"""
logger
.
info
(
"test_simple_sync_wait"
)
batch_size
=
4
dataset
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"input"
])
aug
=
Augment
(
0
)
dataset
=
dataset
.
sync_wait
(
condition_name
=
"policy"
,
callback
=
aug
.
update
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
dataset
=
dataset
.
batch
(
batch_size
)
count
=
0
for
data
in
dataset
.
create_dict_iterator
():
assert
(
data
[
"input"
][
0
]
==
count
)
count
+=
batch_size
data
=
{
"loss"
:
count
}
dataset
.
sync_update
(
condition_name
=
"policy"
,
data
=
data
)
def
test_simple_shuffle_sync
():
"""
Test simple shuffle sync: test shuffle before sync
"""
logger
.
info
(
"test_simple_shuffle_sync"
)
shuffle_size
=
4
batch_size
=
10
dataset
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"input"
])
aug
=
Augment
(
0
)
dataset
=
dataset
.
shuffle
(
shuffle_size
)
dataset
=
dataset
.
sync_wait
(
condition_name
=
"policy"
,
callback
=
aug
.
update
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
dataset
=
dataset
.
batch
(
batch_size
)
count
=
0
for
data
in
dataset
.
create_dict_iterator
():
count
+=
1
#time.sleep(0.5)
data
=
{
"loss"
:
count
}
dataset
.
sync_update
(
condition_name
=
"policy"
,
data
=
data
)
def
test_two_sync
():
"""
Test two sync: dataset pipeline with with two sync_operators
"""
logger
.
info
(
"test_two_sync"
)
batch_size
=
6
dataset
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"input"
])
aug
=
Augment
(
0
)
# notice that with our design, we need to have step_size = shuffle size
dataset
=
dataset
.
sync_wait
(
condition_name
=
"every batch"
,
callback
=
aug
.
update
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
dataset
=
dataset
.
sync_wait
(
num_batch
=
2
,
condition_name
=
"every 2 batches"
)
dataset
=
dataset
.
batch
(
batch_size
)
count
=
0
for
data
in
dataset
.
create_dict_iterator
():
count
+=
1
data
=
{
"loss"
:
count
}
dataset
.
sync_update
(
condition_name
=
"every batch"
,
data
=
data
)
if
count
%
2
==
0
:
dataset
.
sync_update
(
condition_name
=
"every 2 batches"
)
def
test_sync_epoch
():
"""
Test sync wait with epochs: test sync with epochs in dataset pipeline
"""
logger
.
info
(
"test_sync_epoch"
)
batch_size
=
30
dataset
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"input"
])
aug
=
Augment
(
0
)
dataset
=
dataset
.
sync_wait
(
condition_name
=
"policy"
,
callback
=
aug
.
update
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
for
epochs
in
range
(
3
):
aug
.
update
({
"loss"
:
0
})
count
=
0
for
data
in
dataset
.
create_dict_iterator
():
assert
(
data
[
"input"
][
0
]
==
count
)
count
+=
batch_size
data
=
{
"loss"
:
count
}
dataset
.
sync_update
(
condition_name
=
"policy"
,
data
=
data
)
def
test_sync_exception_01
():
"""
Test sync: with shuffle in sync mode
"""
logger
.
info
(
"test_sync_exception_01"
)
shuffle_size
=
4
batch_size
=
10
dataset
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"input"
])
aug
=
Augment
(
0
)
dataset
=
dataset
.
sync_wait
(
condition_name
=
"policy"
,
callback
=
aug
.
update
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
try
:
dataset
=
dataset
.
shuffle
(
shuffle_size
)
except
BaseException
as
e
:
assert
"shuffle"
in
str
(
e
)
dataset
=
dataset
.
batch
(
batch_size
)
def
test_sync_exception_02
():
"""
Test sync: with duplicated condition name
"""
logger
.
info
(
"test_sync_exception_02"
)
batch_size
=
6
dataset
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"input"
])
aug
=
Augment
(
0
)
# notice that with our design, we need to have step_size = shuffle size
dataset
=
dataset
.
sync_wait
(
condition_name
=
"every batch"
,
callback
=
aug
.
update
)
dataset
=
dataset
.
map
(
input_columns
=
[
"input"
],
operations
=
[
aug
.
preprocess
])
try
:
dataset
=
dataset
.
sync_wait
(
num_batch
=
2
,
condition_name
=
"every batch"
)
except
BaseException
as
e
:
assert
"name"
in
str
(
e
)
dataset
=
dataset
.
batch
(
batch_size
)
if
__name__
==
"__main__"
:
test_simple_sync_wait
()
test_simple_shuffle_sync
()
test_two_sync
()
test_sync_exception_01
()
test_sync_exception_02
()
test_sync_epoch
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录