From 2ffe76981dee627968ec3483ec5790c2a7f12b2a Mon Sep 17 00:00:00 2001 From: Jamie Nisbet Date: Tue, 30 Jun 2020 10:57:25 -0400 Subject: [PATCH] added a pre pass for node removals cpplint --- .../dataset/engine/datasetops/batch_op.cc | 2 +- .../dataset/engine/datasetops/dataset_op.cc | 51 ++++++++++++++++++ .../dataset/engine/datasetops/dataset_op.h | 30 +++++++++-- .../engine/datasetops/device_queue_op.cc | 2 +- .../dataset/engine/datasetops/filter_op.cc | 2 +- .../ccsrc/dataset/engine/datasetops/map_op.cc | 2 +- .../dataset/engine/datasetops/project_op.cc | 2 +- .../dataset/engine/datasetops/rename_op.cc | 2 +- .../dataset/engine/datasetops/repeat_op.cc | 2 +- .../dataset/engine/datasetops/shuffle_op.cc | 2 +- .../dataset/engine/datasetops/skip_op.cc | 2 +- .../engine/datasetops/source/generator_op.cc | 2 +- .../datasetops/source/image_folder_op.cc | 2 +- .../engine/datasetops/source/mindrecord_op.cc | 2 +- .../engine/datasetops/source/tf_reader_op.cc | 2 +- .../dataset/engine/datasetops/take_op.cc | 2 +- .../ccsrc/dataset/engine/datasetops/zip_op.cc | 2 +- .../ccsrc/dataset/engine/execution_tree.cc | 4 +- .../ccsrc/dataset/engine/opt/CMakeLists.txt | 6 ++- mindspore/ccsrc/dataset/engine/opt/pass.cc | 2 +- mindspore/ccsrc/dataset/engine/opt/pass.h | 35 +++++++----- .../dataset/engine/opt/pre/removal_nodes.cc | 42 +++++++++++++++ .../dataset/engine/opt/pre/removal_nodes.h | 51 ++++++++++++++++++ .../dataset/engine/opt/pre/removal_pass.cc | 45 ++++++++++++++++ .../dataset/engine/opt/pre/removal_pass.h | 53 +++++++++++++++++++ 25 files changed, 314 insertions(+), 35 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc create mode 100644 mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h create mode 100644 mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc create mode 100644 mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc index f311c90c3..8bfa8c287 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc @@ -409,7 +409,7 @@ Status BatchOp::UnpackPadInfo(const PadInfo &pad_info, // Visitor accept method for NodePass Status BatchOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc index 91ed7fbc5..170d9a7ce 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc @@ -111,6 +111,51 @@ void DatasetOp::RemoveParent(const DatasetOp *parent) { parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); } +// Removes this node from the tree and connects it's parent/child together +Status DatasetOp::Remove() { + if (parent_.size() > 1) { + std::string err_msg("No support for op removal if the operator has more than one parent"); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (child_.size() > 1) { + std::string err_msg("No support for op removal if the operator has more than one child"); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Scenario's when removing node B: + // A -> B -> C + // A -> B + // B -> C + // + // If we remove B, then first take our child A and update it's parent to be C + // It's possible the parent is null if we are the root node being removed. + if (!child_.empty()) { + // If we have a parent, then assign chlid's parent to point to our parent. + if (!parent_.empty()) { + child_[0]->parent_[0] = parent_[0]; + } else { + // We don't have a parent, so we are the root node being removed. + // clear the parent list of our child so that it becomes the new root. + child_[0]->parent_.clear(); + tree_->AssignRoot(child_[0]); + } + } + + // Next, if we had a parent, then set it's child to be our child. + if (!parent_.empty()) { + // if we have a child, then set our parent to point to it + if (!child_.empty()) { + parent_[0]->child_[0] = child_[0]; + } else { + // We don't have a child, so clear the child list of the current + // parent because it will be empty once we are removed. + parent_[0]->child_.clear(); + } + } + + return Status::OK(); +} + // Getter function to get a shared pointer to our childAdds a operator to become our child. std::shared_ptr DatasetOp::child(int32_t child_index) const { MS_ASSERT(child_index < static_cast(child_.size())); @@ -289,6 +334,12 @@ Status DatasetOp::ComputeColMap() { return Status::OK(); } +Status DatasetOp::PreAccept(NodePass *p, bool *modified) { + // DatasetOp is the base class of visitor target pre-visit. + // This method will only be called if its derived class does not implement one. + return p->PreRunOnNode(shared_from_this(), modified); +} + Status DatasetOp::Accept(NodePass *p, bool *modified) { // DatasetOp is the base class of visitor target. // This method will only be called if its derived class does not implement one. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h index 254cd411c..9ee287d05 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h @@ -71,6 +71,10 @@ class DatasetOp : public std::enable_shared_from_this { // @param child - shared pointer to the child to remove. Status RemoveChild(std::shared_ptr child); + /// \brief Removes this node from the tree and connects it's parent/child together. + /// \return Status eerror code returned + Status Remove(); + // Getter function to get a shared pointer to our child // @param child_index - An operator can have n children. Indicates choose which child to return. std::shared_ptr child(int32_t child_index) const; @@ -264,10 +268,20 @@ class DatasetOp : public std::enable_shared_from_this { // @return Vector of Children std::vector> Children() const { return child_; } - // Base method for NodePass visit. - // Subclass needs to override this if it requires special node visit access. - // Check "dataset/engine/opt/pass.h" for more details. - // @return Statue of the node visit + /// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up + /// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main + /// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it + /// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details. + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + virtual Status PreAccept(NodePass *p, bool *modified); + + /// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access. + /// Check "dataset/engine/opt/pass.h" for more details. + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit virtual Status Accept(NodePass *p, bool *modified); // Op name getter @@ -285,6 +299,14 @@ class DatasetOp : public std::enable_shared_from_this { // Computes a CRC value for the operator static uint32_t GenerateCRC(const std::shared_ptr &op); + /// \brief A helper templated function for casting "this" pointer to shared_ptr + /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr + /// \return A shared_ptr casted to the derived class + template + std::shared_ptr shared_from_base() { + return std::static_pointer_cast(shared_from_this()); + } + protected: // Adds a parent operator to this operator // @notes External callers do not have access to this function. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc index 84bad9db1..0f1fefc0f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc @@ -313,7 +313,7 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { // Visitor accept method for NodePass Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc index a1c5ed007..81c93c6e1 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc @@ -261,7 +261,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate // Visitor accept method for NodePass Status FilterOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc index 020f40d26..05a1ac792 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc @@ -367,7 +367,7 @@ void MapOp::CreateFinalColMap(std::unordered_map *col_name // Visitor accept method for NodePass Status MapOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc index 14b064bab..5ce405602 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc @@ -131,7 +131,7 @@ Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } // Visitor accept method for NodePass Status ProjectOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } // Compute the column map and save it into our own column name map diff --git a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc index bebca780f..23cd29d29 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc @@ -176,7 +176,7 @@ Status RenameOp::EoeReceived(int32_t) { // Visitor accept method for NodePass Status RenameOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc index 86903e540..4999dddd0 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc @@ -190,7 +190,7 @@ int32_t RepeatOp::num_producers() const { // Visitor accept method for NodePass Status RepeatOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc index c16f3f962..f86fcc602 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc @@ -298,7 +298,7 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) { // Visitor accept method for NodePass Status ShuffleOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc index c00fd486b..f6b0fe689 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc @@ -130,7 +130,7 @@ Status SkipOp::EofReceived(int32_t worker_id) { // Visitor accept method for NodePass Status SkipOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc index eb5ba3264..36c221fc1 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc @@ -249,7 +249,7 @@ Status GeneratorOp::Reset() { // Visitor accept method for NodePass Status GeneratorOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } Status GeneratorOp::ComputeColMap() { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc index cb17158bf..837eae1e3 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc @@ -411,7 +411,7 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::se // Visitor accept method for NodePass Status ImageFolderOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } Status ImageFolderOp::ComputeColMap() { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index 3c95b9b05..2b9d010eb 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -496,7 +496,7 @@ Status MindRecordOp::CountTotalRows(const std::vector dataset_path, // Visitor accept method for NodePass Status MindRecordOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } Status MindRecordOp::ComputeColMap() { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index a2b04bcc0..48f13ff76 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -1004,7 +1004,7 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector &file // Visitor accept method for NodePass Status TFReaderOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } Status TFReaderOp::ComputeColMap() { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc index 259ae8e62..8bc449cdc 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc @@ -136,7 +136,7 @@ Status TakeOp::PrepareNodePostAction() { // Visitor accept method for NodePass Status TakeOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc index 55734324f..70bce16a8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc @@ -237,7 +237,7 @@ Status ZipOp::EoeReceived(int32_t) { // Visitor accept method for NodePass Status ZipOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor - return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); + return p->RunOnNode(shared_from_base(), modified); } Status ZipOp::ComputeColMap() { diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.cc b/mindspore/ccsrc/dataset/engine/execution_tree.cc index 2f88ee179..385722e25 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/dataset/engine/execution_tree.cc @@ -20,6 +20,7 @@ #include "dataset/engine/datasetops/shuffle_op.h" #include "dataset/util/task_manager.h" #include "dataset/engine/opt/pass.h" +#include "dataset/engine/opt/pre/removal_pass.h" #include "dataset/engine/perf/profiling.h" #include "dataset/engine/perf/monitor.h" @@ -214,7 +215,8 @@ Status ExecutionTree::PrepareTreePreAction() { bool modified = false; std::vector> pre_actions; // Construct pre actions - // example: pre_actions.push_back(new SomePass()); + MS_LOG(INFO) << "Running pre pass"; + pre_actions.push_back(std::make_unique(RemovalPass())); // Apply pre action passes for (auto &pass : pre_actions) { RETURN_IF_NOT_OK(pass->Run(this, &modified)); diff --git a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt index af0a8918d..080d968cf 100644 --- a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt @@ -1,6 +1,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(engine-opt OBJECT - pass.cc - util/printer_pass.cc + pass.cc + pre/removal_nodes.cc + pre/removal_pass.cc + util/printer_pass.cc ) diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.cc b/mindspore/ccsrc/dataset/engine/opt/pass.cc index a032d46cb..27769f056 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/dataset/engine/opt/pass.cc @@ -61,6 +61,7 @@ Status NodePass::Run(ExecutionTree *tree, bool *modified) { // Helper function to perform DFS visit Status NodePass::DFSNodeVisit(std::shared_ptr node, bool *modified) { + RETURN_IF_NOT_OK(node->PreAccept(this, modified)); for (const auto &c : node->Children()) { RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); } @@ -159,6 +160,5 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); } - } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.h b/mindspore/ccsrc/dataset/engine/opt/pass.h index 39682b22f..129c2fab3 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/dataset/engine/opt/pass.h @@ -66,14 +66,16 @@ class Pass : public std::enable_shared_from_this { // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. class TreePass : public Pass { public: - // Run the transformation pass against the execution tree. - // @param tree - Pointer to the execution tree to be transformed. - // @param modified - Pointer to the modified flag, + /// \brief Run the transformation pass against the execution tree. + /// \param[inout] tree Pointer to the execution tree to be transformed. + /// \param[inout] modified Indicate if the tree was modified Status Run(ExecutionTree *tree, bool *modified) final; - // Derived classes may implement the runOnTree function to implement tree transformation. - // "modified" flag needs to be set to true if tree is modified during the pass execution. - // @return Status - The error code return + /// \brief Derived classes may implement the runOnTree function to implement tree transformation. + /// "modified" flag needs to be set to true if tree is modified during the pass execution. + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } }; @@ -90,14 +92,23 @@ class NodePass : public Pass { ~NodePass() = default; - // Run the transformation pass against the execution tree. - // @param tree - Pointer to the execution tree to be transformed. - // @param modified - Pointer to the modified flag, + /// \brief Run the transformation pass against the execution tree + /// \param[inout] tree Pointer to the execution tree to be transformed + /// \param[inout] modified Indicator if the tree was changed Status Run(ExecutionTree *tree, bool *modified) final; - // Derived classes may implement the runOnNode function to implement node level tree transformation. - // "modified" flag needs to be set to true if tree is modified during the pass execution. - // @return Status - The error code return + /// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down + /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution + /// \param[in] node The node being visited + /// \param[out] modified Indicator if the node was changed at all + /// \return Status The error code return + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } + + /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation + /// "modified" flag needs to be set to true if tree is modified during the pass execution + /// \param[in] node The node being visited + /// \param[out] modified Indicator if the node was changed at all. + /// \return Status The error code return virtual Status RunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } // Visit methods to be overridden. diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc new file mode 100644 index 000000000..831a2a76b --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "dataset/engine/opt/pre/removal_nodes.h" +#include "dataset/engine/opt/pre/removal_pass.h" +#include "dataset/engine/datasetops/shuffle_op.h" + +namespace mindspore { +namespace dataset { + +RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} + +// Perform ShuffleOp removal check. +Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + // If we are in a cache descendant tree, then this shuffle op needs to be removed + if (is_caching_) { + MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; + if (removal_pass_) { + removal_pass_->AddToRemovalList(std::static_pointer_cast(node)); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); + } + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h new file mode 100644 index 000000000..11ef37d80 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ + +#include +#include "dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class RemovalPass; + +/// \class RemovalNodes removal_nodes.h +/// \brief This is a NodePass who's job is to identify which nodes should be removed. +/// It works in conjunction with the removal_pass. +class RemovalNodes : public NodePass { + public: + /// \brief Constructor + /// \param[in] removal_pass Raw pointer back to controlling tree pass + explicit RemovalNodes(RemovalPass *removal_pass); + + /// \brief Perform ShuffleOp removal check + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + private: + bool is_caching_; + RemovalPass *removal_pass_; // Back pointer to the owning removal pass +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc new file mode 100644 index 000000000..31ec31234 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "dataset/engine/opt/pre/removal_nodes.h" +#include "dataset/engine/opt/pre/removal_pass.h" +#include "dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { + +// constructor +RemovalPass::RemovalPass() {} + +// Runs a removal_nodes pass first to find out which nodes to remove, then removes them. +Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { + // Create the removal node pass which can identify which nodes need to be removed. + std::unique_ptr removal_nodes = std::make_unique(this); + RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); + + // Then, execute the removal of any nodes that were set up for removal + for (auto node : removal_nodes_) { + node->Remove(); + } + return Status::OK(); +} + +// Adds an operator to the list of operators to be removed +void RemovalPass::AddToRemovalList(std::shared_ptr dataset_op) { removal_nodes_.push_back(dataset_op); } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h new file mode 100644 index 000000000..6523ca69b --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ + +#include +#include +#include "dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +/// \class RemovalPass removal_pass.h +/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which +/// nodes should be removed, and then removes them. +class RemovalPass : public TreePass { + public: + /// \brief Constructor + RemovalPass(); + + /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + Status RunOnTree(ExecutionTree *tree, bool *modified) override; + + /// \brief Adds an operator to the list of operators to be removed + /// \param[in] dataset_op The operator to add to the removal list + void AddToRemovalList(std::shared_ptr dataset_op); + + private: + std::vector> removal_nodes_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ -- GitLab