From 02dd305bb030d993e754a056c159288b5943d90c Mon Sep 17 00:00:00 2001 From: wuyongkang Date: Thu, 2 Jul 2020 16:02:13 +0800 Subject: [PATCH] Optimization for ApplyTransform function --- mindspore/ccsrc/optimizer/opt.cc | 20 ++++++++++++++------ tests/ut/cpp/common/py_func_graph_fetcher.h | 7 +++++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index 4c2e85157..b5248d7dd 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -96,16 +96,18 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode return result; } -static bool isTraversable(const AnfNodePtr &node) { - if (node == nullptr) { - return false; - } +static bool inline isTraversable(const AnfNodePtr &node, const AnfNodeSet &all_nodes) { if (node->isa() || node->isa()) { - return true; + return false; } + if (IsValueNode(node) || IsValueNode(node)) { + if (!all_nodes.contains(node)) { + return false; + } return true; } + return false; } @@ -128,9 +130,15 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo todo.pop_front(); // check whether this node has been matched. - if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { + if (node == nullptr || node->seen_ == seen) { + continue; + } + + auto fg = node->func_graph(); + if (!(fg != nullptr && fg->manager() != nullptr) && !isTraversable(node, all_nodes)) { continue; } + node->seen_ = seen; // select nodes that this transform can be applied. diff --git a/tests/ut/cpp/common/py_func_graph_fetcher.h b/tests/ut/cpp/common/py_func_graph_fetcher.h index 98552a96b..9d374fcd6 100644 --- a/tests/ut/cpp/common/py_func_graph_fetcher.h +++ b/tests/ut/cpp/common/py_func_graph_fetcher.h @@ -22,6 +22,7 @@ #include "ir/primitive.h" #include "ir/manager.h" #include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" #include "pipeline/parse/parse_base.h" #include "pipeline/parse/parse.h" #include "./common.h" @@ -47,9 +48,10 @@ class PyFuncGraphFetcher { py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...); mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); if (doResolve_) { - std::shared_ptr manager = mindspore::Manage(func_graph, false); + std::shared_ptr manager = mindspore::Manage(func_graph, true); mindspore::parse::python_adapter::set_use_signature_in_resolve(false); mindspore::parse::ResolveAll(manager); + func_graph = BasicClone(func_graph); } return func_graph; } catch (py::error_already_set& e) { @@ -71,8 +73,9 @@ class PyFuncGraphFetcher { py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str()); mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); if (doResolve_) { - std::shared_ptr manager = mindspore::Manage(func_graph, false); + std::shared_ptr manager = mindspore::Manage(func_graph, true); mindspore::parse::ResolveAll(manager); + func_graph = BasicClone(func_graph); } return func_graph; } catch (py::error_already_set& e) { -- GitLab