Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9fb1904e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9fb1904e
编写于
7月 21, 2020
作者:
N
Nat Sutyanyong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactoring opt/pre
上级
b13c7a3d
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
407 addition
and
532 deletion
+407
-532
mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc
mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc
+2
-2
mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt
+1
-3
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc
...spore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc
+0
-181
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h
+0
-141
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc
...c/minddata/dataset/engine/opt/pre/cache_transform_pass.cc
+164
-10
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h
...rc/minddata/dataset/engine/opt/pre/cache_transform_pass.h
+117
-8
mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc
...c/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc
+26
-35
mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h
...rc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h
+20
-11
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc
...re/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc
+0
-58
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h
...ore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h
+0
-64
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc
...ore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc
+33
-7
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h
...pore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h
+39
-7
tests/ut/python/dataset/test_minddataset_exception.py
tests/ut/python/dataset/test_minddataset_exception.py
+5
-5
未找到文件。
mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc
浏览文件 @
9fb1904e
...
...
@@ -23,7 +23,7 @@
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "minddata/dataset/engine/opt/pre/
epoch_
injection_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/monitor.h"
...
...
@@ -225,7 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() {
std
::
vector
<
std
::
unique_ptr
<
Pass
>>
pre_actions
;
// Construct pre actions
MS_LOG
(
INFO
)
<<
"Running pre pass loops."
;
pre_actions
.
push_back
(
std
::
make_unique
<
InjectionPass
>
());
pre_actions
.
push_back
(
std
::
make_unique
<
Epoch
InjectionPass
>
());
pre_actions
.
push_back
(
std
::
make_unique
<
RemovalPass
>
());
pre_actions
.
push_back
(
std
::
make_unique
<
CacheTransformPass
>
());
// Apply pre action passes
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt
浏览文件 @
9fb1904e
...
...
@@ -3,10 +3,8 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_library
(
engine-opt OBJECT
pass.cc
post/repeat_pass.cc
pre/cache_pass.cc
pre/cache_transform_pass.cc
pre/injection_pass.cc
pre/removal_nodes.cc
pre/epoch_injection_pass.cc
pre/removal_pass.cc
optional/tensor_op_fusion_pass.cc
util/printer_pass.cc
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc
已删除
100644 → 0
浏览文件 @
b13c7a3d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/opt/pre/cache_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
namespace
mindspore
{
namespace
dataset
{
// Constructor
CachePass
::
CachePass
(
CacheTransformPass
*
transform_pass
)
:
transform_pass_
(
transform_pass
),
is_caching_
(
false
),
leaf_op_
(
nullptr
)
{}
// Identifies the subtree below this node as a cached descendant tree.
Status
CachePass
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
MS_LOG
(
INFO
)
<<
"Cache transform pass: CacheOp found, identified descendant tree."
;
if
(
is_caching_
)
{
RETURN_STATUS_UNEXPECTED
(
"Nested cache operations is not supported!"
);
}
is_caching_
=
true
;
return
Status
::
OK
();
}
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
// transformation
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
is_caching_
=
false
;
// We a no longer in a cache subtree. clear the flag.
if
(
leaf_op_
)
{
MS_LOG
(
INFO
)
<<
"Cache transform pass: Set up transformation nodes for mappable cache."
;
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
// using base class pointers.
transform_pass_
->
AddMappableCacheOperators
(
std
::
move
(
leaf_op_
),
node
);
}
else
{
// If there was no leaf_op set, then this is a non-mappable scenario.
if
(
sampler_
)
{
// Grab the sampler that was saved from the leaf and plug it into the cache op
node
->
SetSampler
(
std
::
move
(
sampler_
));
MS_LOG
(
INFO
)
<<
"Cache transform pass: Set up cache sampler from non-mappable leaf."
;
}
else
{
// We're a cache op but no sampler was saved from leaf, so create a default sampler
const
int64_t
num_samples
=
0
;
const
int64_t
start_index
=
0
;
sampler_
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
);
node
->
SetSampler
(
std
::
move
(
sampler_
));
MS_LOG
(
INFO
)
<<
"Cache transform pass: Creating default sequential sampler for cache op."
;
}
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t
cache_crc
=
DatasetOp
::
GenerateCRC
(
node
);
RETURN_IF_NOT_OK
(
node
->
CreateCache
(
cache_crc
));
}
return
Status
::
OK
();
}
// Common code for mappable leaf setup.
Status
CachePass
::
MappableCacheLeafSetup
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
)
{
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if
(
is_caching_
&&
leaf_op_
)
{
RETURN_STATUS_UNEXPECTED
(
"There is currently no support for multiple leaf nodes under cache."
);
}
// If we are a leaf in the caching path, then save this leaf.
if
(
is_caching_
)
{
MS_LOG
(
DEBUG
)
<<
"Cache transform pass: Mappable leaf in a cache descendant tree detected"
;
leaf_op_
=
std
::
move
(
leaf_op
);
}
return
Status
::
OK
();
}
// Common code for non mappable leaf setup.
Status
CachePass
::
NonMappableCacheLeafSetup
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
)
{
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if
(
is_caching_
&&
leaf_op_
)
{
RETURN_STATUS_UNEXPECTED
(
"There is currently no support for multiple leaf nodes under cache."
);
}
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
// as save it for use by cache op in ascendant tree.
if
(
is_caching_
)
{
RETURN_IF_NOT_OK
(
leaf_op
->
FetchRemoveSampler
(
&
sampler_
));
MS_LOG
(
DEBUG
)
<<
"Cache transform pass: Non mappable leaf in a cache descendant tree detected"
;
}
else
{
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
std
::
shared_ptr
<
Sampler
>
sampler_from_leaf
;
RETURN_IF_NOT_OK
(
leaf_op
->
FetchRemoveSampler
(
&
sampler_from_leaf
));
}
return
Status
::
OK
();
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
TFReaderOp
>
node
,
bool
*
modified
)
{
if
(
is_caching_
)
{
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
node
->
MakeSimpleProducer
();
}
return
NonMappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
RandomDataOp
>
node
,
bool
*
modified
)
{
return
NonMappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
ImageFolderOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
MnistOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
GeneratorOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
ManifestOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
CifarOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
VOCOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
CocoOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
CelebAOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
MindRecordOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h
已删除
100644 → 0
浏览文件 @
b13c7a3d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
class
CacheTransformPass
;
/// \class CachePass cache_pass.h
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
/// transformation. It works in conjunction with the CacheTransformPass
class
CachePass
:
public
NodePass
{
public:
/// \brief Constructor
/// \param[in] transform_pass Raw pointer back to controlling tree pass
explicit
CachePass
(
CacheTransformPass
*
transform_pass
);
/// \brief Destructor
~
CachePass
()
=
default
;
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
/// transformation
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
TFReaderOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
RandomDataOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
ImageFolderOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
MnistOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
GeneratorOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
ManifestOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CifarOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
VOCOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CocoOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CelebAOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
MindRecordOp
>
node
,
bool
*
modified
)
override
;
private:
/// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status
MappableCacheLeafSetup
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
);
/// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status
NonMappableCacheLeafSetup
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
);
bool
is_caching_
;
std
::
shared_ptr
<
DatasetOp
>
leaf_op_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
CacheTransformPass
*
transform_pass_
;
// Back pointer to the owning transform pass
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc
浏览文件 @
9fb1904e
...
...
@@ -15,17 +15,177 @@
*/
#include <vector>
#include "minddata/dataset/engine/opt/pre/cache_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
namespace
mindspore
{
namespace
dataset
{
// Constructor
CacheTransformPass
::
CachePass
::
CachePass
()
:
is_caching_
(
false
),
leaf_op_
(
nullptr
)
{}
// Identifies the subtree below this node as a cached descendant tree.
Status
CacheTransformPass
::
CachePass
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
MS_LOG
(
INFO
)
<<
"Cache transform pass: CacheOp found, identified descendant tree."
;
if
(
is_caching_
)
{
RETURN_STATUS_UNEXPECTED
(
"Nested cache operations is not supported!"
);
}
is_caching_
=
true
;
return
Status
::
OK
();
}
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
// transformation
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
is_caching_
=
false
;
// We a no longer in a cache subtree. clear the flag.
if
(
leaf_op_
)
{
MS_LOG
(
INFO
)
<<
"Cache transform pass: Set up transformation nodes for mappable cache."
;
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
// using base class pointers.
AddMappableCacheOperators
(
std
::
move
(
leaf_op_
),
node
);
}
else
{
// If there was no leaf_op set, then this is a non-mappable scenario.
if
(
sampler_
)
{
// Grab the sampler that was saved from the leaf and plug it into the cache op
node
->
SetSampler
(
std
::
move
(
sampler_
));
MS_LOG
(
INFO
)
<<
"Cache transform pass: Set up cache sampler from non-mappable leaf."
;
}
else
{
// We're a cache op but no sampler was saved from leaf, so create a default sampler
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
sampler_
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
);
node
->
SetSampler
(
std
::
move
(
sampler_
));
MS_LOG
(
INFO
)
<<
"Cache transform pass: Creating default sequential sampler for cache op."
;
}
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t
cache_crc
=
DatasetOp
::
GenerateCRC
(
node
);
RETURN_IF_NOT_OK
(
node
->
CreateCache
(
cache_crc
));
}
return
Status
::
OK
();
}
// Common code for mappable leaf setup.
Status
CacheTransformPass
::
CachePass
::
MappableCacheLeafSetup
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
)
{
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if
(
is_caching_
&&
leaf_op_
)
{
RETURN_STATUS_UNEXPECTED
(
"There is currently no support for multiple leaf nodes under cache."
);
}
// If we are a leaf in the caching path, then save this leaf.
if
(
is_caching_
)
{
MS_LOG
(
DEBUG
)
<<
"Cache transform pass: Mappable leaf in a cache descendant tree detected"
;
leaf_op_
=
std
::
move
(
leaf_op
);
}
return
Status
::
OK
();
}
// Common code for non mappable leaf setup.
Status
CacheTransformPass
::
CachePass
::
NonMappableCacheLeafSetup
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
)
{
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if
(
is_caching_
&&
leaf_op_
)
{
RETURN_STATUS_UNEXPECTED
(
"There is currently no support for multiple leaf nodes under cache."
);
}
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
// as save it for use by cache op in ascendant tree.
if
(
is_caching_
)
{
RETURN_IF_NOT_OK
(
leaf_op
->
FetchRemoveSampler
(
&
sampler_
));
MS_LOG
(
DEBUG
)
<<
"Cache transform pass: Non mappable leaf in a cache descendant tree detected"
;
}
else
{
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
std
::
shared_ptr
<
Sampler
>
sampler_from_leaf
;
RETURN_IF_NOT_OK
(
leaf_op
->
FetchRemoveSampler
(
&
sampler_from_leaf
));
}
return
Status
::
OK
();
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
TFReaderOp
>
node
,
bool
*
modified
)
{
if
(
is_caching_
)
{
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
node
->
MakeSimpleProducer
();
}
return
NonMappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
RandomDataOp
>
node
,
bool
*
modified
)
{
return
NonMappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
ImageFolderOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
MnistOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
GeneratorOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
ManifestOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
CifarOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
VOCOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
CocoOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
CelebAOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Perform leaf node cache tranform identifications
Status
CacheTransformPass
::
CachePass
::
RunOnNode
(
std
::
shared_ptr
<
MindRecordOp
>
node
,
bool
*
modified
)
{
return
MappableCacheLeafSetup
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
// Assigns the leaf and cache operators that are involved in a cache transformation
void
CacheTransformPass
::
CachePass
::
AddMappableCacheOperators
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
,
std
::
shared_ptr
<
CacheOp
>
cache_op
)
{
cache_pairs_
.
push_back
(
std
::
make_pair
(
leaf_op
,
cache_op
));
}
// constructor
CacheTransformPass
::
CacheTransformPass
()
{}
...
...
@@ -34,11 +194,11 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG
(
INFO
)
<<
"Pre pass: Cache transform pass started."
;
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
// use to execute a transform.
std
::
unique_ptr
<
Pass
>
cache_pass
=
std
::
make_unique
<
CachePass
>
(
this
);
RETURN_IF_NOT_OK
(
cache_pass
->
Run
(
tree
,
modified
));
CachePass
cache_pass
=
CachePass
(
);
RETURN_IF_NOT_OK
(
cache_pass
.
Run
(
tree
,
modified
));
// Then, execute the transform for each pair
for
(
auto
cache_pair
:
cache_pa
irs_
)
{
for
(
auto
cache_pair
:
cache_pa
ss
.
cache_pairs
()
)
{
MS_LOG
(
DEBUG
)
<<
"Cache transform pass: Executing a cache op mappable transform."
;
ExecuteCacheTransform
(
tree
,
cache_pair
.
first
,
cache_pair
.
second
,
cache_pair
.
second
->
cache_client
());
}
...
...
@@ -98,11 +258,5 @@ Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::share
return
Status
::
OK
();
}
// Assigns the leaf and cache operators that are involved in a cache transformation
void
CacheTransformPass
::
AddMappableCacheOperators
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
,
std
::
shared_ptr
<
CacheOp
>
cache_op
)
{
cache_pairs_
.
push_back
(
std
::
make_pair
(
leaf_op
,
cache_op
));
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h
浏览文件 @
9fb1904e
...
...
@@ -33,6 +33,123 @@ class CacheClient;
/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching
/// operations
class
CacheTransformPass
:
public
TreePass
{
/// \class CachePass
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
/// transformation. It works in conjunction with the CacheTransformPass
class
CachePass
:
public
NodePass
{
public:
/// \brief Constructor
/// \param[in] transform_pass Raw pointer back to controlling tree pass
CachePass
();
/// \brief Destructor
~
CachePass
()
=
default
;
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Resets the tracking of the cache within the tree and assigns the operators that
/// will be involved in a cache transformation
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
TFReaderOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
RandomDataOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
ImageFolderOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
MnistOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
GeneratorOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
ManifestOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CifarOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
VOCOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CocoOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CelebAOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
MindRecordOp
>
node
,
bool
*
modified
)
override
;
/// \brief Getter
std
::
vector
<
std
::
pair
<
std
::
shared_ptr
<
DatasetOp
>
,
std
::
shared_ptr
<
CacheOp
>>>
cache_pairs
()
{
return
cache_pairs_
;
}
private:
/// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status
MappableCacheLeafSetup
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
);
/// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status
NonMappableCacheLeafSetup
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
);
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
/// \param[in] leaf_op The leaf operator involved in the cache transform
/// \param[in] cache_op The cache operator involved in the cache transform
void
AddMappableCacheOperators
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
,
std
::
shared_ptr
<
CacheOp
>
cache_op
);
bool
is_caching_
;
std
::
shared_ptr
<
DatasetOp
>
leaf_op_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
// The two operators that work together to establish the cache transform
std
::
vector
<
std
::
pair
<
std
::
shared_ptr
<
DatasetOp
>
,
std
::
shared_ptr
<
CacheOp
>>>
cache_pairs_
;
};
public:
/// \brief Constructor
CacheTransformPass
();
...
...
@@ -46,11 +163,6 @@ class CacheTransformPass : public TreePass {
/// \return Status The error code return
Status
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
override
;
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
/// \param[in] leaf_op The leaf operator involved in the cache transform
/// \param[in] cache_op The cache operator involved in the cache transform
void
AddMappableCacheOperators
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
,
std
::
shared_ptr
<
CacheOp
>
cache_op
);
private:
/// \brief Helper function to execute the cache transformation.
///
...
...
@@ -72,9 +184,6 @@ class CacheTransformPass : public TreePass {
/// \return Status The error code return
Status
ExecuteCacheTransform
(
ExecutionTree
*
tree
,
std
::
shared_ptr
<
DatasetOp
>
leaf_op
,
std
::
shared_ptr
<
DatasetOp
>
cache_op
,
std
::
shared_ptr
<
CacheClient
>
cache_client
);
// The two operators that work together to establish the cache transform
std
::
vector
<
std
::
pair
<
std
::
shared_ptr
<
DatasetOp
>
,
std
::
shared_ptr
<
CacheOp
>>>
cache_pairs_
;
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.cc
→
mindspore/ccsrc/minddata/dataset/engine/opt/pre/
epoch_
injection_pass.cc
浏览文件 @
9fb1904e
...
...
@@ -16,7 +16,7 @@
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "minddata/dataset/engine/opt/pre/
epoch_
injection_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
...
...
@@ -25,64 +25,55 @@ namespace mindspore {
namespace
dataset
{
// constructor
InjectionPass
::
InjectionFinder
::
InjectionFinder
(
InjectionPass
*
injection_pass
)
:
injection_pass_
(
injection_pass
)
{}
EpochInjectionPass
::
InjectionFinder
::
InjectionFinder
(
std
::
shared_ptr
<
DatasetOp
>
node
)
:
injection_point_
(
node
)
{}
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
Status
InjectionPass
::
InjectionFinder
::
PreRunOnNode
(
std
::
shared_ptr
<
BuildVocabOp
>
node
,
bool
*
modified
)
{
if
(
injection_pass_
)
{
injection_pass_
->
epoch_ctrl_bypass_
=
true
;
return
Status
::
OK
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Missing outer injection pass object from inside InjectionFinder!"
);
}
Status
EpochInjectionPass
::
InjectionFinder
::
PreRunOnNode
(
std
::
shared_ptr
<
BuildVocabOp
>
node
,
bool
*
modified
)
{
injection_point_
=
nullptr
;
return
Status
::
OK
();
}
// Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection
Status
InjectionPass
::
InjectionFinder
::
PreRunOnNode
(
std
::
shared_ptr
<
BuildSentencePieceVocabOp
>
node
,
bool
*
modified
)
{
if
(
injection_pass_
)
{
injection_pass_
->
epoch_ctrl_bypass_
=
true
;
return
Status
::
OK
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Missing outer injection pass object from inside InjectionFinder!"
);
}
Status
EpochInjectionPass
::
InjectionFinder
::
PreRunOnNode
(
std
::
shared_ptr
<
BuildSentencePieceVocabOp
>
node
,
bool
*
modified
)
{
injection_point_
=
nullptr
;
return
Status
::
OK
();
}
// Temporary code to prevent the injection of epoch control when cache op is present
// Remove this code in cache op phase 2
Status
InjectionPass
::
InjectionFinder
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
if
(
injection_pass_
)
{
injection_pass_
->
epoch_ctrl_bypass_
=
true
;
return
Status
::
OK
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Missing outer injection pass object from inside InjectionFinder!"
);
}
Status
EpochInjectionPass
::
InjectionFinder
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
injection_point_
=
nullptr
;
return
Status
::
OK
();
}
Status
EpochInjectionPass
::
InjectionFinder
::
RunOnNode
(
std
::
shared_ptr
<
DeviceQueueOp
>
node
,
bool
*
modified
)
{
// Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here.
injection_point_
=
node
->
child
(
0
);
return
Status
::
OK
();
}
// constructor
InjectionPass
::
InjectionPass
()
:
epoch_ctrl_bypass_
(
false
)
{}
EpochInjectionPass
::
EpochInjectionPass
(
)
{}
// Runs an injection pass to inject in operators needed at the pre pass stage
Status
InjectionPass
::
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
Status
Epoch
InjectionPass
::
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
MS_LOG
(
INFO
)
<<
"Pre pass: Injection pass started."
;
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
// The finder can make updates to the InjectionPass object.
InjectionPass
::
InjectionFinder
finder
(
this
);
finder
.
Run
(
tree
,
modified
);
// The finder can make updates to the
Epoch
InjectionPass object.
EpochInjectionPass
::
InjectionFinder
finder
(
tree
->
root
()
);
RETURN_IF_NOT_OK
(
finder
.
Run
(
tree
,
modified
)
);
// The first injection logic is to check if we should inject the epoch control op as the root node.
// Do not inject the op if the number of epochs is 1.
int32_t
num_epochs
=
tree
->
num_epochs
();
if
(
num_epochs
!=
1
&&
!
epoch_ctrl_bypass_
)
{
std
::
shared_ptr
<
DatasetOp
>
epoch_inject_node
=
finder
.
injection_point
();
if
(
num_epochs
!=
1
&&
epoch_inject_node
!=
nullptr
)
{
std
::
shared_ptr
<
EpochCtrlOp
>
epoch_ctrl_op
;
RETURN_IF_NOT_OK
(
EpochCtrlOp
::
Builder
(
num_epochs
).
Build
(
&
epoch_ctrl_op
));
RETURN_IF_NOT_OK
(
tree
->
AssociateNode
(
epoch_ctrl_op
));
std
::
shared_ptr
<
DatasetOp
>
node
=
tree
->
root
();
if
(
std
::
dynamic_pointer_cast
<
DeviceQueueOp
>
(
node
)
==
nullptr
)
{
tree
->
root
()
->
InsertAsParent
(
epoch_ctrl_op
);
}
else
{
tree
->
root
()
->
child
(
0
)
->
InsertAsParent
(
epoch_ctrl_op
);
}
epoch_inject_node
->
InsertAsParent
(
epoch_ctrl_op
);
}
MS_LOG
(
INFO
)
<<
"Pre pass: Injection pass complete."
;
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.h
→
mindspore/ccsrc/minddata/dataset/engine/opt/pre/
epoch_
injection_pass.h
浏览文件 @
9fb1904e
...
...
@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#ifndef DATASET_ENGINE_OPT_PASS_PRE_
EPOCH_
INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_
EPOCH_
INJECTION_PASS_H_
#include <memory>
#include <vector>
...
...
@@ -26,10 +26,10 @@ namespace dataset {
class
DatasetOp
;
/// \class
InjectionPass
injection_pass.h
/// \class
EpochInjectionPass epoch_
injection_pass.h
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
/// parsing.
class
InjectionPass
:
public
TreePass
{
class
Epoch
InjectionPass
:
public
TreePass
{
/// \class InjectionFinder
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
...
...
@@ -37,7 +37,10 @@ class InjectionPass : public TreePass {
class
InjectionFinder
:
public
NodePass
{
public:
/// \brief Constructor
explicit
InjectionFinder
(
InjectionPass
*
injection_pass
);
explicit
InjectionFinder
(
std
::
shared_ptr
<
DatasetOp
>
node
);
/// \brief Destructor
~
InjectionFinder
()
=
default
;
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
...
...
@@ -58,24 +61,30 @@ class InjectionPass : public TreePass {
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Register the DeviceQueueOp for further action.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
DeviceQueueOp
>
node
,
bool
*
modified
)
override
;
/// \brief Getter
std
::
shared_ptr
<
DatasetOp
>
injection_point
()
{
return
injection_point_
;
}
private:
InjectionPass
*
injection_pass
_
;
std
::
shared_ptr
<
DatasetOp
>
injection_point
_
;
};
public:
/// \brief Constructor
InjectionPass
();
Epoch
InjectionPass
();
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
override
;
private:
bool
epoch_ctrl_bypass_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#endif // DATASET_ENGINE_OPT_PASS_PRE_
EPOCH_
INJECTION_PASS_H_
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc
已删除
100644 → 0
浏览文件 @
b13c7a3d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/opt/pre/removal_nodes.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
namespace
mindspore
{
namespace
dataset
{
RemovalNodes
::
RemovalNodes
(
RemovalPass
*
removal_pass
)
:
removal_pass_
(
removal_pass
),
is_caching_
(
false
)
{}
// Identifies the subtree below this node as a cached descendant tree.
Status
RemovalNodes
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
MS_LOG
(
INFO
)
<<
"Removal pass: CacheOp found, identified descendant tree."
;
is_caching_
=
true
;
return
Status
::
OK
();
}
// Resets the tracking of the cache within the tree
Status
RemovalNodes
::
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
MS_LOG
(
INFO
)
<<
"Removal pass: cache descendant tree complete."
;
is_caching_
=
false
;
return
Status
::
OK
();
}
// Perform ShuffleOp removal check.
Status
RemovalNodes
::
RunOnNode
(
std
::
shared_ptr
<
ShuffleOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if
(
is_caching_
)
{
MS_LOG
(
INFO
)
<<
"ShuffleOp identified for removal (CacheOp is in ascendant tree)"
;
if
(
removal_pass_
)
{
removal_pass_
->
AddToRemovalList
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
else
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Back reference to removal pass is missing!"
);
}
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h
已删除
100644 → 0
浏览文件 @
b13c7a3d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
#include <memory>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
namespace
mindspore
{
namespace
dataset
{
/// \class RemovalNodes removal_nodes.h
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
/// It works in conjunction with the removal_pass.
class
RemovalNodes
:
public
NodePass
{
public:
/// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
explicit
RemovalNodes
(
RemovalPass
*
removal_pass
);
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Destructor
~
RemovalNodes
()
=
default
;
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
ShuffleOp
>
node
,
bool
*
modified
)
override
;
private:
bool
is_caching_
;
RemovalPass
*
removal_pass_
;
// Back pointer to the owning removal pass
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc
浏览文件 @
9fb1904e
...
...
@@ -16,32 +16,58 @@
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/removal_nodes.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/execution_tree.h"
namespace
mindspore
{
namespace
dataset
{
RemovalPass
::
RemovalNodes
::
RemovalNodes
()
:
is_caching_
(
false
)
{}
// Identifies the subtree below this node as a cached descendant tree.
Status
RemovalPass
::
RemovalNodes
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
MS_LOG
(
INFO
)
<<
"Removal pass: CacheOp found, identified descendant tree."
;
is_caching_
=
true
;
return
Status
::
OK
();
}
// Resets the tracking of the cache within the tree
Status
RemovalPass
::
RemovalNodes
::
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
MS_LOG
(
INFO
)
<<
"Removal pass: cache descendant tree complete."
;
is_caching_
=
false
;
return
Status
::
OK
();
}
// Perform ShuffleOp removal check.
Status
RemovalPass
::
RemovalNodes
::
RunOnNode
(
std
::
shared_ptr
<
ShuffleOp
>
node
,
bool
*
modified
)
{
*
modified
=
false
;
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if
(
is_caching_
)
{
MS_LOG
(
INFO
)
<<
"ShuffleOp identified for removal (CacheOp is in ascendant tree)"
;
nodes_to_remove_
.
push_back
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
return
Status
::
OK
();
}
// constructor
RemovalPass
::
RemovalPass
()
{}
//
Runs a removal_nodes pass first to find out which
nodes to remove, then removes them.
//
Walk the tree to collect the
nodes to remove, then removes them.
Status
RemovalPass
::
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
MS_LOG
(
INFO
)
<<
"Pre pass: removal pass started."
;
// Create the removal node pass which can identify which nodes need to be removed.
std
::
unique_ptr
<
Pass
>
removal_nodes
=
std
::
make_unique
<
RemovalNodes
>
(
this
);
std
::
unique_ptr
<
RemovalPass
::
RemovalNodes
>
removal_nodes
=
std
::
make_unique
<
RemovalPass
::
RemovalNodes
>
(
);
RETURN_IF_NOT_OK
(
removal_nodes
->
Run
(
tree
,
modified
));
// Then, execute the removal of any nodes that were set up for removal
for
(
auto
node
:
removal_nodes
_
)
{
for
(
auto
node
:
removal_nodes
->
nodes_to_remove
()
)
{
node
->
Remove
();
}
MS_LOG
(
INFO
)
<<
"Pre pass: removal pass complete."
;
return
Status
::
OK
();
}
// Adds an operator to the list of operators to be removed
void
RemovalPass
::
AddToRemovalList
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
)
{
removal_nodes_
.
push_back
(
dataset_op
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h
浏览文件 @
9fb1904e
...
...
@@ -30,6 +30,45 @@ class DatasetOp;
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
/// nodes should be removed, and then removes them.
class
RemovalPass
:
public
TreePass
{
/// \class RemovalNodes
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
/// It works in conjunction with the removal_pass.
class
RemovalNodes
:
public
NodePass
{
public:
/// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
RemovalNodes
();
/// \brief Destructor
~
RemovalNodes
()
=
default
;
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
ShuffleOp
>
node
,
bool
*
modified
)
override
;
/// \brief Getter
/// \return All the nodes to be removed
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
nodes_to_remove
()
{
return
nodes_to_remove_
;
}
private:
bool
is_caching_
;
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
nodes_to_remove_
;
};
public:
/// \brief Constructor
RemovalPass
();
...
...
@@ -42,13 +81,6 @@ class RemovalPass : public TreePass {
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
override
;
/// \brief Adds an operator to the list of operators to be removed
/// \param[in] dataset_op The operator to add to the removal list
void
AddToRemovalList
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
);
private:
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
removal_nodes_
;
};
}
// namespace dataset
}
// namespace mindspore
...
...
tests/ut/python/dataset/test_minddataset_exception.py
浏览文件 @
9fb1904e
...
...
@@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards():
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
assert
'Input shard_id is not within the required interval of (0 to 0).'
in
str
(
error_info
)
assert
'Input shard_id is not within the required interval of (0 to 0).'
in
str
(
error_info
.
value
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
...
...
@@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id():
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
assert
'Input shard_id is not within the required interval of (0 to 0).'
in
str
(
error_info
)
assert
'Input shard_id is not within the required interval of (0 to 0).'
in
str
(
error_info
.
value
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
...
...
@@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard():
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
assert
'Input shard_id is not within the required interval of (0 to 1).'
in
str
(
error_info
)
assert
'Input shard_id is not within the required interval of (0 to 1).'
in
str
(
error_info
.
value
)
with
pytest
.
raises
(
Exception
)
as
error_info
:
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
True
,
2
,
5
)
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
assert
'Input shard_id is not within the required interval of (0 to 1).'
in
str
(
error_info
)
assert
'Input shard_id is not within the required interval of (0 to 1).'
in
str
(
error_info
.
value
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
...
...
@@ -245,7 +245,7 @@ def test_cv_minddataset_partition_num_samples_equals_0():
num_iter
+=
1
with
pytest
.
raises
(
Exception
)
as
error_info
:
partitions
(
5
)
assert
'num_samples should be a positive integer value, but got num_samples=0'
in
str
(
error_info
)
assert
'num_samples should be a positive integer value, but got num_samples=0'
in
str
(
error_info
.
value
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录