Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
93e7c97a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
93e7c97a
编写于
5月 21, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1272 [Dataset] MindData Tree Optimizer Infrastructure
Merge pull request !1272 from JunhanHu/minddata_opt
上级
a3b9c238
8f774d61
变更
42
显示空白变更内容
内联
并排
Showing
42 changed file
with
826 addition
and
5 deletion
+826
-5
mindspore/ccsrc/dataset/CMakeLists.txt
mindspore/ccsrc/dataset/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/engine/CMakeLists.txt
mindspore/ccsrc/dataset/engine/CMakeLists.txt
+3
-2
mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc
mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/batch_op.h
mindspore/ccsrc/dataset/engine/datasetops/batch_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
+12
-0
mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc
mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc
+8
-0
mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h
mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc
mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/filter_op.h
mindspore/ccsrc/dataset/engine/datasetops/filter_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/map_op.h
mindspore/ccsrc/dataset/engine/datasetops/map_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/project_op.cc
mindspore/ccsrc/dataset/engine/datasetops/project_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/project_op.h
mindspore/ccsrc/dataset/engine/datasetops/project_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc
mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/rename_op.h
mindspore/ccsrc/dataset/engine/datasetops/rename_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc
...re/ccsrc/dataset/engine/datasetops/source/generator_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h
...ore/ccsrc/dataset/engine/datasetops/source/generator_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
...ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h
.../ccsrc/dataset/engine/datasetops/source/image_folder_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
...e/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
...re/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
...re/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
...ore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc
mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/zip_op.h
mindspore/ccsrc/dataset/engine/datasetops/zip_op.h
+6
-0
mindspore/ccsrc/dataset/engine/execution_tree.cc
mindspore/ccsrc/dataset/engine/execution_tree.cc
+47
-1
mindspore/ccsrc/dataset/engine/execution_tree.h
mindspore/ccsrc/dataset/engine/execution_tree.h
+32
-2
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
+6
-0
mindspore/ccsrc/dataset/engine/opt/pass.cc
mindspore/ccsrc/dataset/engine/opt/pass.cc
+157
-0
mindspore/ccsrc/dataset/engine/opt/pass.h
mindspore/ccsrc/dataset/engine/opt/pass.h
+146
-0
mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc
mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc
+111
-0
mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h
mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h
+62
-0
tests/ut/python/dataset/test_opt.py
tests/ut/python/dataset/test_opt.py
+46
-0
未找到文件。
mindspore/ccsrc/dataset/CMakeLists.txt
浏览文件 @
93e7c97a
...
...
@@ -66,6 +66,7 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine>
)
...
...
mindspore/ccsrc/dataset/engine/CMakeLists.txt
浏览文件 @
93e7c97a
add_subdirectory
(
datasetops
)
add_subdirectory
(
opt
)
if
(
ENABLE_TDTQUE
)
add_subdirectory
(
tdt
)
endif
()
...
...
@@ -14,7 +15,7 @@ add_library(engine OBJECT
target_include_directories
(
engine PRIVATE
${
pybind11_INCLUDE_DIRS
}
)
if
(
ENABLE_TDTQUE
)
add_dependencies
(
engine engine-datasetops engine-datasetops-source engine-tdt
)
add_dependencies
(
engine engine-datasetops engine-datasetops-source engine-tdt
engine-opt
)
else
()
add_dependencies
(
engine engine-datasetops engine-datasetops-source
)
add_dependencies
(
engine engine-datasetops engine-datasetops-source
engine-opt
)
endif
()
mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc
浏览文件 @
93e7c97a
...
...
@@ -22,6 +22,7 @@
#include "dataset/core/pybind_support.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
using
float16
=
Eigen
::
half
;
...
...
@@ -462,5 +463,11 @@ Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> d
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
BatchOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
BatchOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/batch_op.h
浏览文件 @
93e7c97a
...
...
@@ -192,6 +192,12 @@ class BatchOp : public ParallelOp {
Status
PadTensor
(
std
::
shared_ptr
<
Tensor
>
src
,
std
::
shared_ptr
<
Tensor
>
*
dst
,
const
std
::
vector
<
dsize_t
>
&
pad_shape
,
float
pad_val
);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
// recursive helper function. This function could be very expensive if called on a multi-dimensional tensor
// it is only meant to be called by PadTensor.
...
...
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
浏览文件 @
93e7c97a
...
...
@@ -25,6 +25,7 @@
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
...
...
@@ -249,5 +250,11 @@ Status DatasetOp::AssignColMapFromChild() {
}
return
Status
::
OK
();
}
Status
DatasetOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// DatasetOp is the base class of visitor target.
// This method will only be called if its derived class does not implement one.
return
p
->
RunOnNode
(
shared_from_this
(),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
浏览文件 @
93e7c97a
...
...
@@ -32,6 +32,8 @@ class ExecutionTree;
class
DataBuffer
;
class
NodePass
;
// The base class DatasetOp is the main tree node. It is an abstract class, so
// the actual implementation of the operators will be derived from here.
class
DatasetOp
:
public
std
::
enable_shared_from_this
<
DatasetOp
>
{
...
...
@@ -209,6 +211,16 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// @return - the column name map as a string
std
::
string
ColumnNameMapAsString
()
const
;
// Children Getter
// @return Vector or Children
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Children
()
const
{
return
child_
;
}
// Base method for NodePass visit.
// Subclass needs to override this if it requires special node visit access.
// Check "dataset/engine/opt/pass.h" for more details.
// @return Statue of the node visit
virtual
Status
Accept
(
NodePass
*
p
,
bool
*
modified
);
protected:
// Adds a parent operator to this operator
// @notes External callers do not have access to this function.
...
...
mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc
浏览文件 @
93e7c97a
...
...
@@ -24,6 +24,7 @@
#include "dataset/engine/dataset_iterator.h"
#include "dataset/util/status.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
#ifdef ENABLE_TDTQUE
#include "tdt/tsd_client.h"
...
...
@@ -265,5 +266,12 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
out
<<
"
\n
Channel name: "
<<
channel_name_
<<
"
\n
Prefetch size: "
<<
prefetch_size_
<<
"
\n\n
"
;
}
}
// Visitor accept method for NodePass
Status
DeviceQueueOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
DeviceQueueOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h
浏览文件 @
93e7c97a
...
...
@@ -134,6 +134,12 @@ class DeviceQueueOp : public PipelineOp {
Status
operator
()()
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
// Name: checkExceptions(DataBuffer);
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
...
...
mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc
浏览文件 @
93e7c97a
...
...
@@ -27,6 +27,7 @@
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h"
#include "dataset/util/task_manager.h"
...
...
@@ -259,5 +260,11 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
}
return
Status
(
StatusCode
::
kOK
,
"FilterOp predicate func call succeed"
);
}
// Visitor accept method for NodePass
Status
FilterOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
FilterOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/filter_op.h
浏览文件 @
93e7c97a
...
...
@@ -121,6 +121,12 @@ class FilterOp : public ParallelOp {
// @param show_all A bool to control if you want to show all info or just a summary.
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
// predicate_func python callable which returns a boolean value.
py
::
function
predicate_func_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
浏览文件 @
93e7c97a
...
...
@@ -27,6 +27,7 @@
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h"
#include "dataset/util/task_manager.h"
...
...
@@ -370,5 +371,11 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name
column_name_id_map_
=
final_col_name_id_map
;
}
}
// Visitor accept method for NodePass
Status
MapOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
MapOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/map_op.h
浏览文件 @
93e7c97a
...
...
@@ -171,6 +171,12 @@ class MapOp : public ParallelOp {
// @return the number of threads consuming data from previous op's output Connector.
int32_t
num_consumers
()
const
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
// Local queues where worker threads can pop from.
// Popping directly from the Connector can block if the previous designated threads haven't pop.
...
...
mindspore/ccsrc/dataset/engine/datasetops/project_op.cc
浏览文件 @
93e7c97a
...
...
@@ -25,6 +25,7 @@
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
...
...
@@ -144,5 +145,11 @@ Status ProjectOp::EoeReceived(int32_t worker_id) {
}
Status
ProjectOp
::
EofReceived
(
int32_t
worker_id
)
{
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
ProjectOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
ProjectOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/project_op.h
浏览文件 @
93e7c97a
...
...
@@ -101,6 +101,12 @@ class ProjectOp : public PipelineOp {
// @return Status - The error code returned.
Status
EofReceived
(
int32_t
worker_id
)
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
std
::
vector
<
std
::
string
>
columns_to_project_
;
std
::
vector
<
int32_t
>
projected_column_indices_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc
浏览文件 @
93e7c97a
...
...
@@ -24,6 +24,7 @@
#include "dataset/core/global_context.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
...
...
@@ -170,5 +171,11 @@ Status RenameOp::EoeReceived(int32_t) {
state_
=
OpState
::
kDeOpIdle
;
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
RenameOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
RenameOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/rename_op.h
浏览文件 @
93e7c97a
...
...
@@ -110,6 +110,12 @@ class RenameOp : public PipelineOp {
// @return Status - The error code return
Status
operator
()()
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
protected:
// Rename core functionality
Status
RenameColumns
();
...
...
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
浏览文件 @
93e7c97a
...
...
@@ -21,6 +21,7 @@
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
...
...
@@ -187,5 +188,11 @@ int32_t RepeatOp::num_producers() const {
return
child_
[
0
]
->
num_producers
();
}
}
// Visitor accept method for NodePass
Status
RepeatOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
RepeatOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
浏览文件 @
93e7c97a
...
...
@@ -118,6 +118,12 @@ class RepeatOp : public PipelineOp {
// @param workerId - The worker id
int32_t
num_producers
()
const
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
int32_t
max_repeats_
;
// The number of repeats that the user requested
int32_t
repeat_count_
;
// A counter for the current number of executed repeats
...
...
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc
浏览文件 @
93e7c97a
...
...
@@ -30,6 +30,7 @@
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
...
...
@@ -296,5 +297,11 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) {
state_
=
OpState
::
kDeOpIdle
;
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
ShuffleOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
ShuffleOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h
浏览文件 @
93e7c97a
...
...
@@ -155,6 +155,12 @@ class ShuffleOp : public PipelineOp {
// @return Status - The error code return
Status
EoeReceived
(
int32_t
worker_id
)
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
// Private function to add a new row to the shuffle buffer.
// @return Status - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
浏览文件 @
93e7c97a
...
...
@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
...
...
@@ -128,5 +129,11 @@ Status SkipOp::EofReceived(int32_t worker_id) {
MS_LOG
(
DEBUG
)
<<
"Skip operator EOF received, do nothing now."
;
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
SkipOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
SkipOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
浏览文件 @
93e7c97a
...
...
@@ -74,6 +74,12 @@ class SkipOp : public PipelineOp {
// @param worker_id - The worker id
Status
EofReceived
(
int32_t
worker_id
)
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
int32_t
max_skips_
;
// The number of skips that the user requested
int32_t
skip_count_
;
// A counter for the current number of executed skips
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc
浏览文件 @
93e7c97a
...
...
@@ -20,6 +20,7 @@
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -250,5 +251,11 @@ Status GeneratorOp::Reset() {
wp_
.
Set
();
return
Status
(
StatusCode
::
kOK
,
"GeneratorOp Reset Succeed"
);
}
// Visitor accept method for NodePass
Status
GeneratorOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
GeneratorOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h
浏览文件 @
93e7c97a
...
...
@@ -121,6 +121,12 @@ class GeneratorOp : public PipelineOp {
// @return Status - The error code return
Status
Reset
()
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
py
::
function
generator_function_
;
std
::
vector
<
std
::
string
>
column_names_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
浏览文件 @
93e7c97a
...
...
@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -451,5 +452,11 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t
(
*
num_rows
)
=
(
row_cnt
/
num_dev
)
+
(
row_cnt
%
num_dev
==
0
?
0
:
1
);
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
ImageFolderOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
ImageFolderOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h
浏览文件 @
93e7c97a
...
...
@@ -225,6 +225,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
const
std
::
set
<
std
::
string
>
&
exts
,
int64_t
*
num_rows
,
int64_t
*
num_classes
,
int64_t
dev_id
=
0
,
int64_t
num_dev
=
1
);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
93e7c97a
...
...
@@ -29,6 +29,7 @@
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
...
...
@@ -684,5 +685,11 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path,
}
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
MindRecordOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
MindRecordOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
浏览文件 @
93e7c97a
...
...
@@ -195,6 +195,12 @@ class MindRecordOp : public ParallelOp {
Status
SetColumnsBlob
();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
Status
GetBufferFromReader
(
std
::
unique_ptr
<
DataBuffer
>
*
fetched_buffer
,
int64_t
buffer_id
,
int32_t
worker_id
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
浏览文件 @
93e7c97a
...
...
@@ -37,6 +37,7 @@
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/path.h"
#include "dataset/util/queue.h"
#include "dataset/util/random.h"
...
...
@@ -1037,5 +1038,11 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file
return
rows_read
;
}
// Visitor accept method for NodePass
Status
TFReaderOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
TFReaderOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
浏览文件 @
93e7c97a
...
...
@@ -222,6 +222,12 @@ class TFReaderOp : public ParallelOp {
static
Status
CountTotalRows
(
int64_t
*
out_total_rows
,
const
std
::
vector
<
std
::
string
>
&
filenames
,
int64_t
threads
=
1
,
bool
estimate
=
false
);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
...
...
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
浏览文件 @
93e7c97a
...
...
@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -132,5 +133,11 @@ Status TakeOp::PrepareNodePostAction() {
tree_
->
AddToRepeatStack
(
shared_from_this
());
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
TakeOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
TakeOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
浏览文件 @
93e7c97a
...
...
@@ -84,6 +84,12 @@ class TakeOp : public PipelineOp {
// before providing their own implementations.
Status
PrepareNodePostAction
()
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
int32_t
max_takes_
;
// The number of takes that the user requested
int32_t
take_count_
;
// A counter for the current number of executed takes
...
...
mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc
浏览文件 @
93e7c97a
...
...
@@ -19,6 +19,7 @@
#include "dataset/core/constants.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
#include "utils/log_adapter.h"
...
...
@@ -250,5 +251,11 @@ Status ZipOp::EoeReceived(int32_t) {
state_
=
OpState
::
kDeOpIdle
;
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
ZipOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
std
::
static_pointer_cast
<
ZipOp
>
(
shared_from_this
()),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/zip_op.h
浏览文件 @
93e7c97a
...
...
@@ -104,6 +104,12 @@ class ZipOp : public PipelineOp {
// @return Status - The error code return
Status
operator
()()
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
// Handles preprocessing of the main loop, used when starting new epoch
Status
prepare
(
TensorQTable
*
const
table
);
...
...
mindspore/ccsrc/dataset/engine/execution_tree.cc
浏览文件 @
93e7c97a
...
...
@@ -20,6 +20,8 @@
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/util/printer_pass.h"
namespace
mindspore
{
namespace
dataset
{
// Constructor
...
...
@@ -161,10 +163,54 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
return
Status
::
OK
();
}
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PrepareTreePreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PrepareTreePostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
Status
ExecutionTree
::
Prepare
()
{
// Pre optimization compulsory transformation
RETURN_IF_NOT_OK
(
this
->
PrepareTreePreAction
());
// Optimization transformation
RETURN_IF_NOT_OK
(
this
->
Optimize
());
// Post optimization compulsory transformation
RETURN_IF_NOT_OK
(
this
->
PrepareTreePostAction
());
// Existing transformation implementation, will be removed later
RETURN_IF_NOT_OK
(
this
->
PrepareDeprecated
());
return
Status
::
OK
();
}
Status
ExecutionTree
::
PrepareTreePreAction
()
{
return
Status
::
OK
();
}
Status
ExecutionTree
::
PrepareTreePostAction
()
{
return
Status
::
OK
();
}
Status
ExecutionTree
::
Optimize
()
{
// auto pp = new PrinterPass();
// bool modified = false;
// pp->Run(this, &modified);
return
Status
::
OK
();
}
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
Status
ExecutionTree
::
Prepare
()
{
//
// This driver is deprecated.
Status
ExecutionTree
::
PrepareDeprecated
()
{
// Tree must be in pending prepare state before we can assign root to it
if
(
tree_state_
!=
kDeTStatePrepare
)
{
std
::
string
err_msg
=
...
...
mindspore/ccsrc/dataset/engine/execution_tree.h
浏览文件 @
93e7c97a
...
...
@@ -152,11 +152,41 @@ class ExecutionTree {
// @return the prepare flags
uint32_t
PrepareFlags
()
const
{
return
prepare_flags_
;
}
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PrepareTreePreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PrepareTreePostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
Status
Prepare
();
// Compulsory transformation/action pre optimization.
// @return Status - The error code return
Status
PrepareTreePreAction
();
// Compulsory transformation/action post optimization.
// @return Status - The error code return
Status
PrepareTreePostAction
();
// Optimization transformation/action, optional.
// @return Status - The error code return
Status
Optimize
();
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
// @return Status - The error code return
Status
Prepare
();
Status
Prepare
Deprecated
();
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk.
...
...
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
0 → 100644
浏览文件 @
93e7c97a
file
(
GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"*.cc"
)
set_property
(
SOURCE
${
_CURRENT_SRC_FILES
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD
)
add_library
(
engine-opt OBJECT
pass.cc
util/printer_pass.cc
)
\ No newline at end of file
mindspore/ccsrc/dataset/engine/opt/pass.cc
0 → 100644
浏览文件 @
93e7c97a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/map_op.h"
#include "dataset/engine/datasetops/project_op.h"
#include "dataset/engine/datasetops/rename_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/zip_op.h"
namespace
mindspore
{
namespace
dataset
{
// Driver method for TreePass
Status
TreePass
::
Run
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
return
this
->
RunOnTree
(
tree
,
modified
);
}
// Driver method for NodePass
Status
NodePass
::
Run
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
std
::
shared_ptr
<
DatasetOp
>
root
=
tree
->
root
();
if
(
traversalOrder_
==
Order
::
DFS
)
{
// DFS
return
DFSNodeVisit
(
root
,
modified
);
}
else
if
(
traversalOrder_
==
Order
::
BFS
)
{
// BFS
return
BFSNodeVisit
(
root
,
modified
);
}
return
Status
::
OK
();
}
// Helper function to perform DFS visit
Status
NodePass
::
DFSNodeVisit
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
)
{
for
(
const
auto
&
c
:
node
->
Children
())
{
RETURN_IF_NOT_OK
(
this
->
DFSNodeVisit
(
c
,
modified
));
}
return
node
->
Accept
(
this
,
modified
);
}
// Helper function to perform BFS visit
Status
NodePass
::
BFSNodeVisit
(
std
::
shared_ptr
<
DatasetOp
>
root
,
bool
*
modified
)
{
// Initialize bfs queue with root
std
::
queue
<
std
::
shared_ptr
<
DatasetOp
>>
bfsQueue
;
bfsQueue
.
push
(
root
);
// BFS loop
while
(
!
bfsQueue
.
empty
())
{
// Pop the front of the bfs queue
auto
curNode
=
bfsQueue
.
front
();
bfsQueue
.
pop
();
// Run node pass
RETURN_IF_NOT_OK
(
curNode
->
Accept
(
this
,
modified
));
// Push children into bfs queue
for
(
const
auto
&
c
:
curNode
->
Children
())
{
bfsQueue
.
push
(
c
);
}
}
return
Status
::
OK
();
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
BatchOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
MapOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
ProjectOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
RenameOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
FilterOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
SkipOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
ShuffleOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
GeneratorOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
MindRecordOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
TFReaderOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
TakeOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
ZipOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
DeviceQueueOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
ImageFolderOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/opt/pass.h
0 → 100644
浏览文件 @
93e7c97a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_H_
#define DATASET_ENGINE_OPT_PASS_H_
#include <memory>
#include <queue>
#include "dataset/engine/execution_tree.h"
#include "dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
class
BatchOp
;
class
MapOp
;
class
ProjectOp
;
class
RenameOp
;
class
FilterOp
;
class
SkipOp
;
class
ShuffleOp
;
class
GeneratorOp
;
class
MindRecordOp
;
class
TFReaderOp
;
class
TakeOp
;
class
ZipOp
;
class
DeviceQueueOp
;
class
ImageFolderOp
;
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class
Pass
:
public
std
::
enable_shared_from_this
<
Pass
>
{
public:
// Run the transformation pass again the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
virtual
Status
Run
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
return
Status
::
OK
();
}
};
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
class
TreePass
:
public
Pass
{
public:
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
Status
Run
(
ExecutionTree
*
tree
,
bool
*
modified
)
final
;
// Derived classes may implement the runOnTree function to implement tree transformation.
// "modified" flag needs to be set to true if tree is modified during the pass execution.
// @return Status - The error code return
virtual
Status
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
return
Status
::
OK
();
}
};
// NodePass is a basic Pass class which performs transformation on Node visiting.
// NodePass implements Visitor design pattern.
class
NodePass
:
public
Pass
{
public:
// Tree traversal order
enum
Order
{
DFS
,
BFS
};
// Constructor
// Default DFS traversal
explicit
NodePass
(
Order
order
=
Order
::
DFS
)
{
traversalOrder_
=
order
;
}
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
Status
Run
(
ExecutionTree
*
tree
,
bool
*
modified
)
final
;
// Derived classes may implement the runOnNode function to implement node level tree transformation.
// "modified" flag needs to be set to true if tree is modified during the pass execution.
// @return Status - The error code return
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
)
{
return
Status
::
OK
();
}
// Visit methods to be overridden.
// Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode
// of its own type and override "Accept" from DatasetOp.
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
BatchOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
MapOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
ProjectOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
RenameOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
FilterOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
SkipOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
ShuffleOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
GeneratorOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
MindRecordOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
TFReaderOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
TakeOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
ZipOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
DeviceQueueOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
ImageFolderOp
>
node
,
bool
*
modified
);
private:
// Helper function to perform DFS visit
Status
DFSNodeVisit
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
);
// Helper function to perform BFS visit
Status
BFSNodeVisit
(
std
::
shared_ptr
<
DatasetOp
>
root
,
bool
*
modified
);
// Tree traversal order of the NodePass
Order
traversalOrder_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_H_
mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc
0 → 100644
浏览文件 @
93e7c97a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "dataset/engine/opt/util/printer_pass.h"
namespace
mindspore
{
namespace
dataset
{
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting DatasetOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
BatchOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting BatchOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
MapOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting MapOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
ProjectOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting ProjectOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
RenameOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting RenameOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
FilterOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting FilterOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
SkipOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting SkipOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
ShuffleOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting ShuffleOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
GeneratorOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting GeneratorOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
MindRecordOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting MindRecordOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
TFReaderOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting TFReaderOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
TakeOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting TakeOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
ZipOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting ZipOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
DeviceQueueOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting DeviceQueueOp"
<<
'\n'
;
return
Status
::
OK
();
}
Status
PrinterPass
::
RunOnNode
(
std
::
shared_ptr
<
ImageFolderOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
std
::
cout
<<
"Visiting ImageFolderOp"
<<
'\n'
;
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h
0 → 100644
浏览文件 @
93e7c97a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
#include <memory>
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
class
PrinterPass
:
public
NodePass
{
public:
Status
RunOnNode
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
BatchOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
MapOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
ProjectOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
RenameOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
FilterOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
SkipOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
ShuffleOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
GeneratorOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
MindRecordOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
TFReaderOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
TakeOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
ZipOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
DeviceQueueOp
>
node
,
bool
*
modified
)
override
;
Status
RunOnNode
(
std
::
shared_ptr
<
ImageFolderOp
>
node
,
bool
*
modified
)
override
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
tests/ut/python/dataset/test_opt.py
0 → 100644
浏览文件 @
93e7c97a
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
# Generate 1d int numpy array from 0 - 63
def
generator_1d
():
for
i
in
range
(
64
):
yield
(
np
.
array
([
i
]),)
def
test_case_0
():
"""
Test 1D Generator
"""
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
shuffle
(
2
)
data1
=
data1
.
map
([
"data"
],
operations
=
(
lambda
x
:
x
))
data1
=
data1
.
batch
(
2
)
i
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
pass
if
__name__
==
"__main__"
:
test_case_0
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录