Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c22eac74
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c22eac74
编写于
6月 23, 2020
作者:
J
Jamie Nisbet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
subtree creation in python apis
updates fix fix cpplint fix
上级
6ba89d43
变更
20
展开全部
隐藏空白更改
内联
并排
Showing
20 changed file
with
378 addition
and
522 deletion
+378
-522
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+281
-89
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+50
-30
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
+4
-5
mindspore/ccsrc/dataset/engine/datasetops/map_op.h
mindspore/ccsrc/dataset/engine/datasetops/map_op.h
+2
-18
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc
+4
-9
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h
+1
-15
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
...re/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
+3
-9
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
...ore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
+1
-15
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
...re/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
+3
-7
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
...ore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
+1
-15
mindspore/ccsrc/dataset/engine/execution_tree.cc
mindspore/ccsrc/dataset/engine/execution_tree.cc
+6
-4
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
+1
-3
mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc
mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc
+0
-98
mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h
mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h
+0
-35
mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc
mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc
+0
-51
mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h
mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h
+0
-35
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+3
-3
tests/ut/cpp/dataset/map_op_test.cc
tests/ut/cpp/dataset/map_op_test.cc
+0
-69
tests/ut/python/dataset/test_opt_pass.py
tests/ut/python/dataset/test_opt_pass.py
+15
-9
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
c22eac74
此差异已折叠。
点击以展开。
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
c22eac74
...
...
@@ -77,7 +77,7 @@ class DEPipeline {
~
DEPipeline
();
// Function to add a Node to the Execution Tree.
Status
AddNodeToTree
(
const
OpName
&
op_name
,
const
py
::
dict
&
args
,
DsOpPtr
*
o
ut
);
Status
AddNodeToTree
(
const
OpName
&
op_name
,
const
py
::
dict
&
args
,
py
::
dict
*
outp
ut
);
// Function to add a child and parent relationship.
static
Status
AddChildToParentNode
(
const
DsOpPtr
&
child_op
,
const
DsOpPtr
&
parent_op
);
...
...
@@ -104,73 +104,74 @@ class DEPipeline {
int
GetRepeatCount
()
const
;
Status
ParseShuffleOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseShuffleOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseMindRecordOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseMindRecordOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
BuildMindrecordSamplerChain
(
const
py
::
handle
&
handle
,
std
::
vector
<
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
*
operators
,
int
num_padded
);
Status
ParseMapOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseMapOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseFilterOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseFilterOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseRepeatOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseRepeatOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseSkipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseSkipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseBatchOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseBatchOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseBucketBatchByLengthOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseBucketBatchByLengthOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseBarrierOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseBarrierOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseGeneratorOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseGeneratorOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseRenameOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseRenameOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseTakeOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseTakeOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseZipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseZipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseConcatOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseConcatOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseDeviceQueueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseDeviceQueueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseTFReaderOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseTFReaderOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseProjectOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseProjectOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseImageFolderOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseImageFolderOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseManifestOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseManifestOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseVOCOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseVOCOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseCocoOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseCocoOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseCifar10Op
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseCifar10Op
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseCifar100Op
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseCifar100Op
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseRandomDataOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseRandomDataOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
void
PrintTree
();
int32_t
GetNumClasses
()
const
;
Status
ParseMnistOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseMnistOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
SetBatchParameters
(
const
py
::
dict
&
args
);
Status
ParseCelebAOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseCelebAOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseTextFileOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseTextFileOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseBuildVocabOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseBuildVocabOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseClueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseClueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
private:
// Execution tree that links the dataset operators.
...
...
@@ -180,6 +181,25 @@ class DEPipeline {
static
Status
ParsePadInfo
(
py
::
handle
value
,
PadInfo
*
pad_info
);
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
/// \param[in] shuffle_size The size to use in the shuffle buffer
/// \param[in] input_op The operator to build shuffle on top of
/// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be
/// the shuffle operator
/// \return Status return code
Status
AddShuffleOp
(
int64_t
shuffle_size
,
std
::
shared_ptr
<
DatasetOp
>
input_op
,
std
::
shared_ptr
<
DatasetOp
>
*
shuffle_op
);
/// \brief Helper function to compute the shuffle size
/// \param[in] num_files The number of files in the dataset
/// \param[in] num_devices The number of devices in the dataset
/// \param[in] num_rows The number of rows in the dataset
/// \param[in] total_rows An upper bound on the total rows in the dataset
/// \param[out] shuffle_size The resultant computed shuffle size
/// \return Status return code
Status
ComputeShuffleSize
(
int64_t
num_files
,
int64_t
num_devices
,
int64_t
num_rows
,
int64_t
total_rows
,
int64_t
*
shuffle_size
);
int
batch_size_
;
int
repeat_num_
;
int
num_rows_
;
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
c22eac74
...
...
@@ -116,9 +116,9 @@ void bindDEPipeline(py::module *m) {
.
def
(
"AddNodeToTree"
,
[](
DEPipeline
&
de
,
const
OpName
&
op_name
,
const
py
::
dict
&
args
)
{
DsOpPtr
op
;
THROW_IF_ERROR
(
de
.
AddNodeToTree
(
op_name
,
args
,
&
o
p
));
return
o
p
;
py
::
dict
out
;
THROW_IF_ERROR
(
de
.
AddNodeToTree
(
op_name
,
args
,
&
o
ut
));
return
o
ut
;
},
py
::
return_value_policy
::
reference
)
.
def_static
(
"AddChildToParentNode"
,
...
...
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
浏览文件 @
c22eac74
...
...
@@ -54,20 +54,19 @@ Status MapOp::Builder::sanityCheck() const {
Status
MapOp
::
Builder
::
Build
(
std
::
shared_ptr
<
MapOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
sanityCheck
());
*
ptr
=
std
::
make_shared
<
MapOp
>
(
std
::
move
(
build_in_col_names_
),
std
::
move
(
build_out_col_names_
),
std
::
move
(
build_tensor_funcs_
),
std
::
move
(
build_col_order_
),
build_num_workers
_
,
build_
op_connector_size_
,
build_
perf_mode_
);
std
::
move
(
build_tensor_funcs_
),
build_num_workers_
,
build_op_connector_size
_
,
build_perf_mode_
);
return
Status
::
OK
();
}
// Constructor of MapOp
MapOp
::
MapOp
(
const
std
::
vector
<
std
::
string
>
&
in_col_names
,
const
std
::
vector
<
std
::
string
>
&
out_col_names
,
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
tensor_funcs
,
const
std
::
vector
<
std
::
string
>
&
columns_order
,
int32_t
num_workers
,
int32_t
op_connector_size
,
bool
perf_mode
)
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
tensor_funcs
,
int32_t
num_workers
,
int32_t
op_connector_size
,
bool
perf_mode
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
tfuncs_
(
std
::
move
(
tensor_funcs
)),
in_columns_
(
in_col_names
),
out_columns_
(
out_col_names
),
columns_order_
(
columns_order
),
perf_mode_
(
perf_mode
)
{
// If caller didn't specify the out_col_names, assume they are same as the in_columns.
if
(
out_columns_
.
empty
()
||
out_columns_
[
0
].
empty
())
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/map_op.h
浏览文件 @
c22eac74
...
...
@@ -93,13 +93,6 @@ class MapOp : public ParallelOp {
return
*
this
;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder
&
SetColOrder
(
const
std
::
vector
<
std
::
string
>
&
col_order_
)
{
build_col_order_
=
col_order_
;
return
*
this
;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder
&
SetNumWorkers
(
int32_t
num_workers
)
{
...
...
@@ -130,7 +123,6 @@ class MapOp : public ParallelOp {
std
::
vector
<
std
::
string
>
build_in_col_names_
;
std
::
vector
<
std
::
string
>
build_out_col_names_
;
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
build_tensor_funcs_
;
std
::
vector
<
std
::
string
>
build_col_order_
;
int32_t
build_num_workers_
;
int32_t
build_op_connector_size_
;
bool
build_perf_mode_
;
// Default true.
...
...
@@ -145,12 +137,11 @@ class MapOp : public ParallelOp {
// @param in_col_names A list of input column names (should match the input/output \p tensorFuncs).
// @param out_col_names A list of output column names (should match the input/output \p tensorFuncs).
// @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data.
// @param columns_order names A full list of column names (should match the whole dataset view post \p tensorFuncs).
// @param num_workers The number of worker threads.
// @param op_connector_size The size of each queue in the connector.
MapOp
(
const
std
::
vector
<
std
::
string
>
&
in_col_names
,
const
std
::
vector
<
std
::
string
>
&
out_col_names
,
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
tensor_funcs
,
const
std
::
vector
<
std
::
string
>
&
columns_order
,
int32_t
num_workers
,
int32_t
op_connector_size
,
bool
perf_mode
);
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
tensor_funcs
,
int32_t
num_workers
,
int32_t
op_connector_size
,
bool
perf_mode
);
// Destructor
~
MapOp
()
=
default
;
...
...
@@ -190,10 +181,6 @@ class MapOp : public ParallelOp {
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"MapOp"
;
}
// Columns order getter
// @return The post map columns order
std
::
vector
<
std
::
string
>
const
&
ColumnsOrder
()
const
{
return
columns_order_
;
}
private:
// Local queues where worker threads can pop from.
// Popping directly from the Connector can block if the previous designated threads haven't pop.
...
...
@@ -215,9 +202,6 @@ class MapOp : public ParallelOp {
// Indices of the columns to process.
std
::
vector
<
size_t
>
to_process_indices_
;
// Variable to store the column_order of all columns post tensorOps
std
::
vector
<
std
::
string
>
columns_order_
;
// Performance mode is when the main thread creates local queues, pulls databuffers from the previous
// op's Connector and distributes them to the local queues. Workers pull from the local queues.
// If this flag is false, each worker pulls directly from the Connector. This use less resources
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc
浏览文件 @
c22eac74
...
...
@@ -31,11 +31,7 @@
namespace
mindspore
{
namespace
dataset
{
ClueOp
::
Builder
::
Builder
()
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_num_samples_
(
0
),
builder_shuffle_files_
(
false
),
builder_shuffle_global_
(
false
)
{
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_num_samples_
(
0
),
builder_shuffle_files_
(
false
)
{
std
::
shared_ptr
<
ConfigManager
>
config_manager
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
config_manager
->
num_parallel_workers
();
builder_op_connector_size_
=
config_manager
->
op_connector_size
();
...
...
@@ -66,8 +62,8 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
std
::
shared_ptr
<
ClueOp
>
clue_op
=
std
::
make_shared
<
ClueOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_num_samples_
,
builder_worker_connector_size_
,
ck_map
,
builder_clue_files_list_
,
builder_op_connector_size_
,
builder_shuffle_files_
,
builder_
shuffle_global
_
,
builder_
num_devices_
,
builder_
device_id_
);
builder_clue_files_list_
,
builder_op_connector_size_
,
builder_shuffle_files_
,
builder_
num_devices
_
,
builder_device_id_
);
RETURN_IF_NOT_OK
(
clue_op
->
Init
());
*
op
=
std
::
move
(
clue_op
);
...
...
@@ -87,7 +83,7 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
ClueOp
::
ClueOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
num_samples
,
int32_t
worker_connector_size
,
ColKeyMap
cols_to_keyword
,
std
::
vector
<
std
::
string
>
clue_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
bool
shuffle_global
,
int32_t
num_device
,
int32_t
device_id
)
bool
shuffle_files
,
int32_t
num_device
,
int32_t
device_id
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
rows_per_buffer_
(
rows_per_buffer
),
num_rows_per_shard_
(
0
),
...
...
@@ -98,7 +94,6 @@ ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples
load_jagged_connector_
(
true
),
cols_to_keyword_
(
cols_to_keyword
),
shuffle_files_
(
shuffle_files
),
shuffle_global_
(
shuffle_global
),
finished_reading_dataset_
(
false
),
num_devices_
(
num_device
),
device_id_
(
device_id
),
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h
浏览文件 @
c22eac74
...
...
@@ -104,13 +104,6 @@ class ClueOp : public ParallelOp {
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetShuffleGlobal
(
bool
shuffle_global
)
{
builder_shuffle_global_
=
shuffle_global
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumSamples
(
int64_t
num_samples
)
{
...
...
@@ -139,15 +132,13 @@ class ClueOp : public ParallelOp {
int32_t
builder_worker_connector_size_
;
std
::
vector
<
std
::
string
>
builder_clue_files_list_
;
bool
builder_shuffle_files_
;
bool
builder_shuffle_global_
;
std
::
map
<
std
::
string
,
std
::
string
>
builder_cols_to_keyword_
;
};
// Constructor of ClueOp
// @param shuffle_global - whether or not to shuffle the entire dataset.
ClueOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
num_samples
,
int32_t
worker_connector_size
,
ColKeyMap
cols_to_keyword
,
std
::
vector
<
std
::
string
>
clue_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
bool
shuffle_global
,
int32_t
num_devices
,
int32_t
device_id
);
bool
shuffle_files
,
int32_t
num_devices
,
int32_t
device_id
);
// Default destructor
~
ClueOp
()
=
default
;
...
...
@@ -182,10 +173,6 @@ class ClueOp : public ParallelOp {
// @return Vector of the input file names
std
::
vector
<
std
::
string
>
FileNames
()
{
return
clue_files_list_
;
}
// Global shuffle flag getter
// @return Bool - whether this Op requires global shuffle
bool
RequireGlobalShuffle
()
{
return
shuffle_global_
;
}
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
...
...
@@ -269,7 +256,6 @@ class ClueOp : public ParallelOp {
int32_t
device_id_
;
bool
shuffle_files_
;
bool
shuffle_global_
;
bool
finished_reading_dataset_
;
int32_t
num_devices_
;
int64_t
rows_per_buffer_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
浏览文件 @
c22eac74
...
...
@@ -33,11 +33,7 @@
namespace
mindspore
{
namespace
dataset
{
TextFileOp
::
Builder
::
Builder
()
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_total_rows_
(
0
),
builder_shuffle_files_
(
false
),
builder_shuffle_global_
(
false
)
{
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_total_rows_
(
0
),
builder_shuffle_files_
(
false
)
{
std
::
shared_ptr
<
ConfigManager
>
config_manager
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
config_manager
->
num_parallel_workers
();
builder_op_connector_size_
=
config_manager
->
op_connector_size
();
...
...
@@ -68,7 +64,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
std
::
shared_ptr
<
TextFileOp
>
text_file_op
=
std
::
make_shared
<
TextFileOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_total_rows_
,
builder_worker_connector_size_
,
std
::
move
(
builder_schema_
),
builder_text_files_list_
,
builder_op_connector_size_
,
builder_shuffle_files_
,
builder_
shuffle_global_
,
builder_
num_devices_
,
builder_device_id_
);
builder_num_devices_
,
builder_device_id_
);
RETURN_IF_NOT_OK
(
text_file_op
->
Init
());
*
op
=
std
::
move
(
text_file_op
);
...
...
@@ -77,8 +73,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
TextFileOp
::
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
total_rows
,
int32_t
worker_connector_size
,
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
vector
<
std
::
string
>
text_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
bool
shuffle_global
,
int32_t
num_device
,
int32_t
device_id
)
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_device
,
int32_t
device_id
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
device_id_
(
device_id
),
num_devices_
(
num_device
),
...
...
@@ -86,7 +81,6 @@ TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t tot
total_rows_
(
total_rows
),
text_files_list_
(
std
::
move
(
text_files_list
)),
shuffle_files_
(
shuffle_files
),
shuffle_global_
(
shuffle_global
),
data_schema_
(
std
::
move
(
schema
)),
all_num_rows_
(
0
),
num_rows_per_shard_
(
0
),
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
浏览文件 @
c22eac74
...
...
@@ -105,13 +105,6 @@ class TextFileOp : public ParallelOp {
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetShuffleGlobal
(
bool
shuffle_global
)
{
builder_shuffle_global_
=
shuffle_global
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetTotalRows
(
int64_t
total_rows
)
{
...
...
@@ -129,7 +122,6 @@ class TextFileOp : public ParallelOp {
int32_t
builder_worker_connector_size_
;
std
::
vector
<
std
::
string
>
builder_text_files_list_
;
bool
builder_shuffle_files_
;
bool
builder_shuffle_global_
;
std
::
unique_ptr
<
DataSchema
>
builder_schema_
;
};
...
...
@@ -143,11 +135,10 @@ class TextFileOp : public ParallelOp {
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param shuffle_global - whether or not to shuffle the entire dataset.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
total_rows
,
int32_t
worker_connector_size
,
std
::
unique_ptr
<
DataSchema
>
,
std
::
vector
<
std
::
string
>
text_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
bool
shuffle_global
,
int32_t
num_devices
,
int32_t
device_id
);
bool
shuffle_files
,
int32_t
num_devices
,
int32_t
device_id
);
// Default destructor
~
TextFileOp
()
=
default
;
...
...
@@ -186,10 +177,6 @@ class TextFileOp : public ParallelOp {
// @return Vector of the input file names
std
::
vector
<
std
::
string
>
FileNames
()
{
return
text_files_list_
;
}
// Global shuffle flag getter
// @return Bool - whether this Op requires global shuffle
bool
RequireGlobalShuffle
()
{
return
shuffle_global_
;
}
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
...
...
@@ -274,7 +261,6 @@ class TextFileOp : public ParallelOp {
int64_t
total_rows_
;
std
::
vector
<
std
::
string
>
text_files_list_
;
bool
shuffle_files_
;
bool
shuffle_global_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
int64_t
all_num_rows_
;
int64_t
num_rows_per_shard_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
浏览文件 @
c22eac74
...
...
@@ -55,7 +55,6 @@ TFReaderOp::Builder::Builder()
builder_op_connector_size_
=
config_manager
->
op_connector_size
();
builder_rows_per_buffer_
=
config_manager
->
rows_per_buffer
();
builder_shuffle_files_
=
false
;
builder_shuffle_global_
=
false
;
builder_data_schema_
=
std
::
make_unique
<
DataSchema
>
();
}
...
...
@@ -126,8 +125,7 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
std
::
shared_ptr
<
TFReaderOp
>
new_tf_reader_op
=
std
::
make_shared
<
TFReaderOp
>
(
builder_num_workers_
,
builder_worker_connector_size_
,
builder_rows_per_buffer_
,
builder_total_rows_
,
builder_dataset_files_list_
,
std
::
move
(
builder_data_schema_
),
builder_op_connector_size_
,
builder_columns_to_load_
,
builder_shuffle_files_
,
builder_shuffle_global_
,
builder_num_devices_
,
builder_device_id_
,
builder_equal_rows_per_shard_
);
builder_shuffle_files_
,
builder_num_devices_
,
builder_device_id_
,
builder_equal_rows_per_shard_
);
RETURN_IF_NOT_OK
(
new_tf_reader_op
->
Init
());
*
out_tf_reader_op
=
std
::
move
(
new_tf_reader_op
);
...
...
@@ -137,8 +135,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
TFReaderOp
::
TFReaderOp
(
int32_t
num_workers
,
int32_t
worker_connector_size
,
int64_t
rows_per_buffer
,
int64_t
total_num_rows
,
std
::
vector
<
std
::
string
>
dataset_files_list
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
int32_t
op_connector_size
,
std
::
vector
<
std
::
string
>
columns_to_load
,
bool
shuffle_files
,
bool
shuffle_global
,
int32_t
num_device
,
int32_t
device_id
,
bool
equal_rows_per_shard
)
std
::
vector
<
std
::
string
>
columns_to_load
,
bool
shuffle_files
,
int32_t
num_device
,
int32_t
device_id
,
bool
equal_rows_per_shard
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
device_id_
(
device_id
),
num_devices_
(
num_device
),
...
...
@@ -148,7 +146,6 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
columns_to_load_
(
std
::
move
(
columns_to_load
)),
finished_reading_dataset_
(
false
),
shuffle_files_
(
shuffle_files
),
shuffle_global_
(
shuffle_global
),
data_schema_
(
std
::
move
(
data_schema
)),
filename_index_
(
std
::
make_unique
<
StringIndex
>
()),
load_io_block_queue_
(
true
),
...
...
@@ -174,7 +171,6 @@ void TFReaderOp::Print(std::ostream &out, bool show_all) const {
// Then show any custom derived-internal stuff
out
<<
"
\n
Rows per buffer: "
<<
rows_per_buffer_
<<
"
\n
Total rows: "
<<
total_rows_
<<
"
\n
Device id: "
<<
device_id_
<<
"
\n
Number of devices: "
<<
num_devices_
<<
"
\n
Shuffle files: "
<<
((
shuffle_files_
)
?
"yes"
:
"no"
)
<<
"
\n
Shuffle global: "
<<
((
shuffle_global_
)
?
"yes"
:
"no"
)
<<
"
\n
Dataset files list: Size: "
<<
dataset_files_list_
.
size
()
<<
"
\n
"
;
for
(
int
i
=
0
;
i
<
dataset_files_list_
.
size
();
++
i
)
{
out
<<
" "
<<
dataset_files_list_
[
i
];
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
浏览文件 @
c22eac74
...
...
@@ -146,13 +146,6 @@ class TFReaderOp : public ParallelOp {
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetShuffleGlobal
(
bool
shuffle_global
)
{
builder_shuffle_global_
=
shuffle_global
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetShardEqualRows
(
bool
shard_equal_rows
)
{
...
...
@@ -172,7 +165,6 @@ class TFReaderOp : public ParallelOp {
std
::
vector
<
std
::
string
>
builder_dataset_files_list_
;
std
::
vector
<
std
::
string
>
builder_columns_to_load_
;
bool
builder_shuffle_files_
;
bool
builder_shuffle_global_
;
bool
builder_equal_rows_per_shard_
;
};
...
...
@@ -187,12 +179,11 @@ class TFReaderOp : public ParallelOp {
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param shuffle_global - whether or not to shuffle the entire dataset.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
TFReaderOp
(
int32_t
num_workers
,
int32_t
worker_connector_size
,
int64_t
rows_per_buffer
,
int64_t
total_num_rows
,
std
::
vector
<
std
::
string
>
dataset_files_list
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
int32_t
op_connector_size
,
std
::
vector
<
std
::
string
>
columns_to_load
,
bool
shuffle_files
,
bool
shuffle_global
,
int32_t
num_devices
,
int32_t
device_id
,
bool
equal_rows_per_shard
);
int32_t
num_devices
,
int32_t
device_id
,
bool
equal_rows_per_shard
);
// Default destructor
~
TFReaderOp
()
=
default
;
...
...
@@ -245,10 +236,6 @@ class TFReaderOp : public ParallelOp {
// @return Vector of the input file names
std
::
vector
<
std
::
string
>
FileNames
()
{
return
dataset_files_list_
;
}
// Global shuffle flag getter
// @return Bool - whether this Op requires global shuffle
bool
RequireGlobalShuffle
()
{
return
shuffle_global_
;
}
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
...
...
@@ -393,7 +380,6 @@ class TFReaderOp : public ParallelOp {
std
::
vector
<
std
::
string
>
columns_to_load_
;
bool
finished_reading_dataset_
;
bool
shuffle_files_
;
bool
shuffle_global_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
std
::
unique_ptr
<
StringIndex
>
filename_index_
;
bool
load_io_block_queue_
;
...
...
mindspore/ccsrc/dataset/engine/execution_tree.cc
浏览文件 @
c22eac74
...
...
@@ -19,8 +19,7 @@
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pre/map_column_reorder.h"
#include "dataset/engine/opt/pre/global_shuffle.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/perf/profiling.h"
#include "dataset/engine/perf/monitor.h"
...
...
@@ -42,6 +41,10 @@ ExecutionTree::~ExecutionTree() { (void)tg_->ServiceStop(); }
// provides it with a link to the tree. A node cannot form any relationships (parent/child) with
// other nodes unless they are associated with the same tree.
Status
ExecutionTree
::
AssociateNode
(
const
std
::
shared_ptr
<
DatasetOp
>
&
op
)
{
// If we are already a part of the tree, no-op
if
(
op
->
tree_
==
this
)
{
return
Status
::
OK
();
}
if
(
tree_state_
!=
kDeTStateInit
&&
tree_state_
!=
kDeTStateBuilding
)
{
std
::
string
err_msg
=
"Invalid tree state for adding a node. Current state: "
+
std
::
to_string
(
static_cast
<
int
>
(
tree_state_
))
+
...
...
@@ -211,8 +214,7 @@ Status ExecutionTree::PrepareTreePreAction() {
bool
modified
=
false
;
std
::
vector
<
std
::
unique_ptr
<
Pass
>>
pre_actions
;
// Construct pre actions
pre_actions
.
push_back
(
std
::
make_unique
<
MapColumnReorder
>
());
pre_actions
.
push_back
(
std
::
make_unique
<
GlobalShufflePass
>
());
// example: pre_actions.push_back(new SomePass());
// Apply pre action passes
for
(
auto
&
pass
:
pre_actions
)
{
RETURN_IF_NOT_OK
(
pass
->
Run
(
this
,
&
modified
));
...
...
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
浏览文件 @
c22eac74
...
...
@@ -2,7 +2,5 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property
(
SOURCE
${
_CURRENT_SRC_FILES
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD
)
add_library
(
engine-opt OBJECT
pass.cc
pre/map_column_reorder.cc
pre/global_shuffle.cc
util/printer_pass.cc
)
\ No newline at end of file
)
mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc
已删除
100644 → 0
浏览文件 @
6ba89d43
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include "dataset/engine/opt/pre/global_shuffle.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
namespace
mindspore
{
namespace
dataset
{
Status
GlobalShufflePass
::
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
std
::
vector
<
std
::
shared_ptr
<
TFReaderOp
>>
tf_readers
;
std
::
vector
<
std
::
shared_ptr
<
TextFileOp
>>
text_files
;
std
::
vector
<
std
::
shared_ptr
<
ClueOp
>>
clues
;
// Pass 1, search for all sources which requires global shuffle
for
(
auto
&
op
:
*
tree
)
{
if
(
auto
ptr
=
std
::
dynamic_pointer_cast
<
TFReaderOp
>
(
op
.
shared_from_this
()))
{
if
(
ptr
->
RequireGlobalShuffle
())
{
tf_readers
.
push_back
(
ptr
);
continue
;
}
}
if
(
auto
ptr
=
std
::
dynamic_pointer_cast
<
TextFileOp
>
(
op
.
shared_from_this
()))
{
if
(
ptr
->
RequireGlobalShuffle
())
{
text_files
.
push_back
(
ptr
);
continue
;
}
}
if
(
auto
ptr
=
std
::
dynamic_pointer_cast
<
ClueOp
>
(
op
.
shared_from_this
()))
{
if
(
ptr
->
RequireGlobalShuffle
())
{
clues
.
push_back
(
ptr
);
continue
;
}
}
}
// Pass 2, insert shuffle nodes
// The following blocks can be implemented with template if we unify the CountTotalRows across all source nodes .
for
(
auto
node
:
tf_readers
)
{
std
::
shared_ptr
<
ShuffleOp
::
Builder
>
builder
=
std
::
make_shared
<
ShuffleOp
::
Builder
>
();
int64_t
total_rows
=
0
;
TFReaderOp
::
CountTotalRows
(
&
total_rows
,
node
->
FileNames
(),
8
,
true
);
int32_t
avg_file_size
=
total_rows
/
(
node
->
FileNames
().
size
());
builder
->
SetShuffleSize
(
std
::
max
(
avg_file_size
*
4
,
10000
));
std
::
shared_ptr
<
ShuffleOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
RETURN_IF_NOT_OK
(
tree
->
AssociateNode
(
op
));
RETURN_IF_NOT_OK
(
node
->
InsertAsParent
(
op
));
}
for
(
auto
node
:
text_files
)
{
std
::
shared_ptr
<
ShuffleOp
::
Builder
>
builder
=
std
::
make_shared
<
ShuffleOp
::
Builder
>
();
int64_t
total_rows
=
0
;
TextFileOp
::
CountAllFileRows
(
node
->
FileNames
(),
&
total_rows
);
int32_t
avg_file_size
=
total_rows
/
(
node
->
FileNames
().
size
());
builder
->
SetShuffleSize
(
std
::
max
(
avg_file_size
*
4
,
10000
));
std
::
shared_ptr
<
ShuffleOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
RETURN_IF_NOT_OK
(
tree
->
AssociateNode
(
op
));
RETURN_IF_NOT_OK
(
node
->
InsertAsParent
(
op
));
}
for
(
auto
node
:
clues
)
{
std
::
shared_ptr
<
ShuffleOp
::
Builder
>
builder
=
std
::
make_shared
<
ShuffleOp
::
Builder
>
();
int64_t
total_rows
=
0
;
ClueOp
::
CountAllFileRows
(
node
->
FileNames
(),
&
total_rows
);
int32_t
avg_file_size
=
total_rows
/
(
node
->
FileNames
().
size
());
builder
->
SetShuffleSize
(
std
::
max
(
avg_file_size
*
4
,
10000
));
std
::
shared_ptr
<
ShuffleOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
RETURN_IF_NOT_OK
(
tree
->
AssociateNode
(
op
));
RETURN_IF_NOT_OK
(
node
->
InsertAsParent
(
op
));
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h
已删除
100644 → 0
浏览文件 @
6ba89d43
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H
#define DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H
#include <memory>
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
// Global Shuffle Pass will insert ShuffleOp when the leaf nodes requires global shuffle.
// Example:
// Input Tree: TFReader(GLOBAL_SHUFFLE) -> Batch
// Output Tree: TFReader -> Shuffle -> Batch
class
GlobalShufflePass
:
public
TreePass
{
Status
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
override
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H
mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc
已删除
100644 → 0
浏览文件 @
6ba89d43
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include "dataset/engine/opt/pre/map_column_reorder.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/datasetops/map_op.h"
#include "dataset/engine/datasetops/project_op.h"
namespace
mindspore
{
namespace
dataset
{
Status
MapColumnReorder
::
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
std
::
vector
<
std
::
shared_ptr
<
MapOp
>>
to_process
;
// Pass 1, search for all MapOp with column orders
for
(
auto
&
op
:
*
tree
)
{
if
(
auto
mapOp
=
std
::
dynamic_pointer_cast
<
MapOp
>
(
op
.
shared_from_this
()))
{
if
(
mapOp
->
ColumnsOrder
().
size
()
!=
0
)
{
to_process
.
push_back
(
mapOp
);
}
}
}
// Pass 2, insert nodes for all MapOp
for
(
auto
node
:
to_process
)
{
std
::
shared_ptr
<
ProjectOp
::
Builder
>
builder
=
std
::
make_shared
<
ProjectOp
::
Builder
>
(
node
->
ColumnsOrder
());
std
::
shared_ptr
<
ProjectOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
RETURN_IF_NOT_OK
(
tree
->
AssociateNode
(
op
));
RETURN_IF_NOT_OK
(
node
->
InsertAsParent
(
op
));
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h
已删除
100644 → 0
浏览文件 @
6ba89d43
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H
#define DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H
#include <memory>
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
// Map Column Recorder Pass will insert ProjectOp when MapOp requires a full output columns reorder.
// Example:
// Input Tree: TFReader -> MapOp(with col_order) -> Batch
// Output Tree: TFReader -> MapOp -> ProjectOp(col_order) -> Batch
class
MapColumnReorder
:
public
TreePass
{
Status
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
override
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H
mindspore/dataset/engine/iterators.py
浏览文件 @
c22eac74
...
...
@@ -172,13 +172,13 @@ class Iterator:
# Convert python node into C node and add to C layer execution tree in postorder traversal.
def
__convert_node_postorder
(
self
,
node
):
op_type
=
self
.
__get_dataset_type
(
node
)
c_node
=
self
.
depipeline
.
AddNodeToTree
(
op_type
,
node
.
get_args
())
c_node
s
=
self
.
depipeline
.
AddNodeToTree
(
op_type
,
node
.
get_args
())
for
py_child
in
node
.
children
:
c_child
=
self
.
__convert_node_postorder
(
py_child
)
self
.
depipeline
.
AddChildToParentNode
(
c_child
,
c_node
)
self
.
depipeline
.
AddChildToParentNode
(
c_child
,
c_node
s
[
"bottom"
]
)
return
c_node
return
c_node
s
[
"top"
]
def
__batch_node
(
self
,
dataset
,
level
):
"""Recursively get batch node in the dataset tree."""
...
...
tests/ut/cpp/dataset/map_op_test.cc
浏览文件 @
c22eac74
...
...
@@ -130,75 +130,6 @@ std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int6
std
::
shared_ptr
<
ExecutionTree
>
Build
(
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ops
);
// TestByPosition scenario:
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
// A TensorOp that does nothing picks the label column and output a column also named label.
// Thus, based on the new MapOp behaviour, the column ordering will be |image|label|A|B|.
// Verify the column ordering based on the Tensor properties matching to that of in the schema file.
TEST_F
(
MindDataTestMapOp
,
TestByPosition
)
{
Status
rc
;
MS_LOG
(
INFO
)
<<
"Doing TestByPosition."
;
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total
// of 10 rows.
auto
my_tfreader_op
=
this
->
CreateTFReaderOp
();
rc
=
my_tree_
->
AssociateNode
(
my_tfreader_op
);
EXPECT_TRUE
(
rc
.
IsOk
());
auto
my_no_op
=
std
::
make_shared
<
mindspore
::
dataset
::
test
::
NoOp
>
();
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
my_func_list
;
my_func_list
.
push_back
(
my_no_op
);
std
::
shared_ptr
<
MapOp
>
my_map_op
;
MapOp
::
Builder
builder
;
builder
.
SetInColNames
({
"label"
})
.
SetOutColNames
({})
.
SetColOrder
({
"image"
,
"label"
,
"A"
,
"B"
})
.
SetTensorFuncs
(
std
::
move
(
my_func_list
))
.
SetNumWorkers
(
100
);
rc
=
builder
.
Build
(
&
my_map_op
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree_
->
AssociateNode
(
my_map_op
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
my_map_op
->
AddChild
(
my_tfreader_op
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree_
->
AssignRoot
(
my_map_op
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree_
->
Prepare
();
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree_
->
Launch
();
EXPECT_TRUE
(
rc
.
IsOk
());
// Based on the schema file, create the golden result to compare with.
std
::
vector
<
DataType
::
Type
>
golden_types
({
DataType
::
Type
::
DE_UINT8
,
DataType
::
Type
::
DE_INT64
,
DataType
::
Type
::
DE_FLOAT32
,
DataType
::
Type
::
DE_INT64
}
);
std
::
vector
<
uint64_t
>
golden_ranks
({
3
,
1
,
4
,
1
});
std
::
vector
<
TensorShape
>
golden_shapes
({
TensorShape
({
3
,
4
,
2
}),
TensorShape
({
7
}),
TensorShape
({
1
,
13
,
14
,
12
}),
TensorShape
({
9
})}
);
// Start the loop of reading tensors from our pipeline
DatasetIterator
di
(
my_tree_
);
TensorRow
tensor_list
;
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
EXPECT_TRUE
(
rc
.
IsOk
());
EXPECT_EQ
(
tensor_list
.
size
(),
4
);
for
(
uint32_t
i
=
0
;
i
<
tensor_list
.
size
();
i
++
)
{
EXPECT_EQ
(
tensor_list
[
i
]
->
type
(),
golden_types
[
i
]);
EXPECT_EQ
(
tensor_list
[
i
]
->
Rank
(),
golden_ranks
[
i
]);
EXPECT_EQ
(
tensor_list
[
i
]
->
shape
(),
golden_shapes
[
i
]);
EXPECT_NE
(
tensor_list
[
i
]
->
GetBuffer
(),
nullptr
);
}
}
// TestAsMap scenario:
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
// A TensorOp that does nothing picks the "image" column and produces a column named "X".
...
...
tests/ut/python/dataset/test_opt_pass.py
浏览文件 @
c22eac74
...
...
@@ -16,8 +16,10 @@ import numpy as np
import
mindspore.dataset
as
ds
def
test_map_reorder_pass_0
():
# tests the construction of multiple ops from a single dataset.
# map dataset with columns order arguments should produce a ProjectOp over MapOp
# This test does not utilize the compiling passes at this time.
def
test_map_reorder0
():
def
generator_mc
(
maxid
=
1
):
for
_
in
range
(
maxid
):
yield
(
np
.
array
([
0
]),
np
.
array
([
1
]))
...
...
@@ -31,8 +33,10 @@ def test_map_reorder_pass_0():
for
item
in
data0
.
create_tuple_iterator
():
# each data is a dictionary
assert
item
==
[
np
.
array
(
1
),
np
.
array
(
0
)]
def
test_map_reorder_pass_1
():
# tests the construction of multiple ops from a single dataset.
# map dataset with columns order arguments should produce a ProjectOp over MapOp
# This test does not utilize the compiling passes at this time.
def
test_map_reorder1
():
def
generator_mc
(
maxid
=
1
):
for
_
in
range
(
maxid
):
yield
(
np
.
array
([
0
]),
np
.
array
([
1
]),
np
.
array
([
2
]))
...
...
@@ -48,8 +52,10 @@ def test_map_reorder_pass_1():
for
item
in
data2
.
create_tuple_iterator
():
assert
item
==
[
np
.
array
(
2
),
np
.
array
(
2
),
np
.
array
(
1
),
np
.
array
(
1
),
np
.
array
(
0
),
np
.
array
(
0
)]
def
test_global_shuffle_pass
():
# tests the construction of multiple ops from a single dataset.
# TFRecordDataset with global shuffle should produce a ShuffleOp over TfReaderOp.
# This test does not utilize the compiling passes at this time.
def
test_shuffle
():
FILES
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
SCHEMA_FILE
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
...
...
@@ -85,6 +91,6 @@ def test_global_shuffle_pass():
if
__name__
==
"__main__"
:
test_map_reorder
_pass_
0
()
test_map_reorder
_pass_
1
()
test_global_shuffle
_pass
()
test_map_reorder0
()
test_map_reorder1
()
test_global_shuffle
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录