提交 4b2b4667 编写于 作者: W wuyongkang

Revert "Optimization for ApplyTransform function"

This reverts commit 02dd305b.
上级 b4e37158
...@@ -92,18 +92,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode ...@@ -92,18 +92,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode
return result; return result;
} }
static bool inline isTraversable(const AnfNodePtr &node, const AnfNodeSet &all_nodes) { static bool isTraversable(const AnfNodePtr &node) {
if (node->isa<CNode>() || node->isa<Parameter>()) { if (node == nullptr) {
return false; return false;
} }
if (node->isa<CNode>() || node->isa<Parameter>()) {
if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) { return true;
if (!all_nodes.contains(node)) {
return false;
} }
if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
return true; return true;
} }
return false; return false;
} }
...@@ -126,15 +124,9 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo ...@@ -126,15 +124,9 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
todo.pop_front(); todo.pop_front();
// check whether this node has been matched. // check whether this node has been matched.
if (node == nullptr || node->seen_ == seen) { if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
continue;
}
auto fg = node->func_graph();
if (!(fg != nullptr && fg->manager() != nullptr) && !isTraversable(node, all_nodes)) {
continue; continue;
} }
node->seen_ = seen; node->seen_ = seen;
// select nodes that this transform can be applied. // select nodes that this transform can be applied.
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "ir/primitive.h" #include "ir/primitive.h"
#include "ir/manager.h" #include "ir/manager.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "pipeline/parse/parse_base.h" #include "pipeline/parse/parse_base.h"
#include "pipeline/parse/parse.h" #include "pipeline/parse/parse.h"
#include "./common.h" #include "./common.h"
...@@ -48,10 +47,9 @@ class PyFuncGraphFetcher { ...@@ -48,10 +47,9 @@ class PyFuncGraphFetcher {
py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...); py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...);
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
if (doResolve_) { if (doResolve_) {
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, true); std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
mindspore::parse::python_adapter::set_use_signature_in_resolve(false); mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
mindspore::parse::ResolveAll(manager); mindspore::parse::ResolveAll(manager);
func_graph = BasicClone(func_graph);
} }
return func_graph; return func_graph;
} catch (py::error_already_set& e) { } catch (py::error_already_set& e) {
...@@ -73,9 +71,8 @@ class PyFuncGraphFetcher { ...@@ -73,9 +71,8 @@ class PyFuncGraphFetcher {
py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str()); py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str());
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
if (doResolve_) { if (doResolve_) {
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, true); std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
mindspore::parse::ResolveAll(manager); mindspore::parse::ResolveAll(manager);
func_graph = BasicClone(func_graph);
} }
return func_graph; return func_graph;
} catch (py::error_already_set& e) { } catch (py::error_already_set& e) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册