提交 2ffe7698 编写于 作者: J Jamie Nisbet

added a pre pass for node removals

cpplint
上级 b57d4ea2
......@@ -409,7 +409,7 @@ Status BatchOp::UnpackPadInfo(const PadInfo &pad_info,
// Visitor accept method for NodePass
Status BatchOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<BatchOp>(), modified);
}
} // namespace dataset
......
......@@ -111,6 +111,51 @@ void DatasetOp::RemoveParent(const DatasetOp *parent) {
parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end());
}
// Removes this node from the tree and connects it's parent/child together
Status DatasetOp::Remove() {
if (parent_.size() > 1) {
std::string err_msg("No support for op removal if the operator has more than one parent");
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (child_.size() > 1) {
std::string err_msg("No support for op removal if the operator has more than one child");
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Scenario's when removing node B:
// A -> B -> C
// A -> B
// B -> C
//
// If we remove B, then first take our child A and update it's parent to be C
// It's possible the parent is null if we are the root node being removed.
if (!child_.empty()) {
// If we have a parent, then assign chlid's parent to point to our parent.
if (!parent_.empty()) {
child_[0]->parent_[0] = parent_[0];
} else {
// We don't have a parent, so we are the root node being removed.
// clear the parent list of our child so that it becomes the new root.
child_[0]->parent_.clear();
tree_->AssignRoot(child_[0]);
}
}
// Next, if we had a parent, then set it's child to be our child.
if (!parent_.empty()) {
// if we have a child, then set our parent to point to it
if (!child_.empty()) {
parent_[0]->child_[0] = child_[0];
} else {
// We don't have a child, so clear the child list of the current
// parent because it will be empty once we are removed.
parent_[0]->child_.clear();
}
}
return Status::OK();
}
// Getter function to get a shared pointer to our childAdds a operator to become our child.
std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
MS_ASSERT(child_index < static_cast<int>(child_.size()));
......@@ -289,6 +334,12 @@ Status DatasetOp::ComputeColMap() {
return Status::OK();
}
Status DatasetOp::PreAccept(NodePass *p, bool *modified) {
// DatasetOp is the base class of visitor target pre-visit.
// This method will only be called if its derived class does not implement one.
return p->PreRunOnNode(shared_from_this(), modified);
}
Status DatasetOp::Accept(NodePass *p, bool *modified) {
// DatasetOp is the base class of visitor target.
// This method will only be called if its derived class does not implement one.
......
......@@ -71,6 +71,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// @param child - shared pointer to the child to remove.
Status RemoveChild(std::shared_ptr<DatasetOp> child);
/// \brief Removes this node from the tree and connects it's parent/child together.
/// \return Status eerror code returned
Status Remove();
// Getter function to get a shared pointer to our child
// @param child_index - An operator can have n children. Indicates choose which child to return.
std::shared_ptr<DatasetOp> child(int32_t child_index) const;
......@@ -264,10 +268,20 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// @return Vector of Children
std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }
// Base method for NodePass visit.
// Subclass needs to override this if it requires special node visit access.
// Check "dataset/engine/opt/pass.h" for more details.
// @return Statue of the node visit
/// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up
/// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main
/// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it
/// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details.
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
virtual Status PreAccept(NodePass *p, bool *modified);
/// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access.
/// Check "dataset/engine/opt/pass.h" for more details.
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
virtual Status Accept(NodePass *p, bool *modified);
// Op name getter
......@@ -285,6 +299,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// Computes a CRC value for the operator
static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
/// Similar to shared_from_this, except this one will give you the derived class as shared_ptr
/// \return A shared_ptr casted to the derived class
template <typename Derived>
std::shared_ptr<Derived> shared_from_base() {
return std::static_pointer_cast<Derived>(shared_from_this());
}
protected:
// Adds a parent operator to this operator
// @notes External callers do not have access to this function.
......
......@@ -313,7 +313,7 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
// Visitor accept method for NodePass
Status DeviceQueueOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<DeviceQueueOp>(), modified);
}
} // namespace dataset
......
......@@ -261,7 +261,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
// Visitor accept method for NodePass
Status FilterOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<FilterOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -367,7 +367,7 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name
// Visitor accept method for NodePass
Status MapOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<MapOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -131,7 +131,7 @@ Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); }
// Visitor accept method for NodePass
Status ProjectOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<ProjectOp>(), modified);
}
// Compute the column map and save it into our own column name map
......
......@@ -176,7 +176,7 @@ Status RenameOp::EoeReceived(int32_t) {
// Visitor accept method for NodePass
Status RenameOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<RenameOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -190,7 +190,7 @@ int32_t RepeatOp::num_producers() const {
// Visitor accept method for NodePass
Status RepeatOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<RepeatOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -298,7 +298,7 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) {
// Visitor accept method for NodePass
Status ShuffleOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<ShuffleOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -130,7 +130,7 @@ Status SkipOp::EofReceived(int32_t worker_id) {
// Visitor accept method for NodePass
Status SkipOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<SkipOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -249,7 +249,7 @@ Status GeneratorOp::Reset() {
// Visitor accept method for NodePass
Status GeneratorOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<GeneratorOp>(), modified);
}
Status GeneratorOp::ComputeColMap() {
......
......@@ -411,7 +411,7 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::se
// Visitor accept method for NodePass
Status ImageFolderOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<ImageFolderOp>(), modified);
}
Status ImageFolderOp::ComputeColMap() {
......
......@@ -496,7 +496,7 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path,
// Visitor accept method for NodePass
Status MindRecordOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<MindRecordOp>(), modified);
}
Status MindRecordOp::ComputeColMap() {
......
......@@ -1004,7 +1004,7 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file
// Visitor accept method for NodePass
Status TFReaderOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<TFReaderOp>(), modified);
}
Status TFReaderOp::ComputeColMap() {
......
......@@ -136,7 +136,7 @@ Status TakeOp::PrepareNodePostAction() {
// Visitor accept method for NodePass
Status TakeOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<TakeOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -237,7 +237,7 @@ Status ZipOp::EoeReceived(int32_t) {
// Visitor accept method for NodePass
Status ZipOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified);
return p->RunOnNode(shared_from_base<ZipOp>(), modified);
}
Status ZipOp::ComputeColMap() {
......
......@@ -20,6 +20,7 @@
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/opt/pre/removal_pass.h"
#include "dataset/engine/perf/profiling.h"
#include "dataset/engine/perf/monitor.h"
......@@ -214,7 +215,8 @@ Status ExecutionTree::PrepareTreePreAction() {
bool modified = false;
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
// example: pre_actions.push_back(new SomePass());
MS_LOG(INFO) << "Running pre pass";
pre_actions.push_back(std::make_unique<RemovalPass>(RemovalPass()));
// Apply pre action passes
for (auto &pass : pre_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
......
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-opt OBJECT
pass.cc
util/printer_pass.cc
pass.cc
pre/removal_nodes.cc
pre/removal_pass.cc
util/printer_pass.cc
)
......@@ -61,6 +61,7 @@ Status NodePass::Run(ExecutionTree *tree, bool *modified) {
// Helper function to perform DFS visit
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) {
RETURN_IF_NOT_OK(node->PreAccept(this, modified));
for (const auto &c : node->Children()) {
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
}
......@@ -159,6 +160,5 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified)
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -66,14 +66,16 @@ class Pass : public std::enable_shared_from_this<Pass> {
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
class TreePass : public Pass {
public:
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
/// \brief Run the transformation pass against the execution tree.
/// \param[inout] tree Pointer to the execution tree to be transformed.
/// \param[inout] modified Indicate if the tree was modified
Status Run(ExecutionTree *tree, bool *modified) final;
// Derived classes may implement the runOnTree function to implement tree transformation.
// "modified" flag needs to be set to true if tree is modified during the pass execution.
// @return Status - The error code return
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
};
......@@ -90,14 +92,23 @@ class NodePass : public Pass {
~NodePass() = default;
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
/// \brief Run the transformation pass against the execution tree
/// \param[inout] tree Pointer to the execution tree to be transformed
/// \param[inout] modified Indicator if the tree was changed
Status Run(ExecutionTree *tree, bool *modified) final;
// Derived classes may implement the runOnNode function to implement node level tree transformation.
// "modified" flag needs to be set to true if tree is modified during the pass execution.
// @return Status - The error code return
/// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down
/// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all
/// \return Status The error code return
virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
/// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation
/// "modified" flag needs to be set to true if tree is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all.
/// \return Status The error code return
virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
// Visit methods to be overridden.
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "dataset/engine/opt/pre/removal_nodes.h"
#include "dataset/engine/opt/pre/removal_pass.h"
#include "dataset/engine/datasetops/shuffle_op.h"
namespace mindspore {
namespace dataset {
RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}
// Perform ShuffleOp removal check.
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
*modified = false;
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if (is_caching_) {
MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
if (removal_pass_) {
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!");
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
#include <memory>
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class RemovalPass;
/// \class RemovalNodes removal_nodes.h
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
/// It works in conjunction with the removal_pass.
class RemovalNodes : public NodePass {
public:
/// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
explicit RemovalNodes(RemovalPass *removal_pass);
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
private:
bool is_caching_;
RemovalPass *removal_pass_; // Back pointer to the owning removal pass
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include "dataset/engine/opt/pre/removal_nodes.h"
#include "dataset/engine/opt/pre/removal_pass.h"
#include "dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
// constructor
RemovalPass::RemovalPass() {}
// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
// Create the removal node pass which can identify which nodes need to be removed.
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
// Then, execute the removal of any nodes that were set up for removal
for (auto node : removal_nodes_) {
node->Remove();
}
return Status::OK();
}
// Adds an operator to the list of operators to be removed
void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); }
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
#include <memory>
#include <vector>
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class RemovalPass removal_pass.h
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
/// nodes should be removed, and then removes them.
class RemovalPass : public TreePass {
public:
/// \brief Constructor
RemovalPass();
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
/// \brief Adds an operator to the list of operators to be removed
/// \param[in] dataset_op The operator to add to the removal list
void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op);
private:
std::vector<std::shared_ptr<DatasetOp>> removal_nodes_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册