提交 6cdaa084 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!560 Optimize-depend pass enhance

Merge pull request !560 from huanghui/optimize-depend
...@@ -28,8 +28,7 @@ namespace mindspore { ...@@ -28,8 +28,7 @@ namespace mindspore {
namespace opt { namespace opt {
constexpr auto kSingleInputIndex = 1; constexpr auto kSingleInputIndex = 1;
namespace { namespace {
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { AnfNodePtr GetReplaceNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return nullptr; return nullptr;
...@@ -41,15 +40,6 @@ AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node ...@@ -41,15 +40,6 @@ AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node
if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
return nullptr; return nullptr;
} }
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
// Check whether the node has only one output node.
if (manager->node_users().find(cnode) == manager->node_users().end()) {
MS_LOG(EXCEPTION) << "The node should be used by at least another node's input";
}
if (manager->node_users()[cnode].size() > 1) {
return nullptr;
}
CheckCNodeInputSize(cnode, kSingleInputIndex + 1); CheckCNodeInputSize(cnode, kSingleInputIndex + 1);
return cnode->input(kSingleInputIndex); return cnode->input(kSingleInputIndex);
} }
...@@ -63,7 +53,7 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { ...@@ -63,7 +53,7 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
std::vector<AnfNodePtr> new_make_tuple_inputs; std::vector<AnfNodePtr> new_make_tuple_inputs;
bool need_update = false; bool need_update = false;
for (const auto &input : cnode->inputs()) { for (const auto &input : cnode->inputs()) {
AnfNodePtr replace_input = GetReplaceNode(func_graph, input); AnfNodePtr replace_input = GetReplaceNode(input);
// If replace input is not null, it will be the input of the TransData or Cast. // If replace input is not null, it will be the input of the TransData or Cast.
if (replace_input == nullptr) { if (replace_input == nullptr) {
new_make_tuple_inputs.push_back(input); new_make_tuple_inputs.push_back(input);
...@@ -119,7 +109,7 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con ...@@ -119,7 +109,7 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
if (ReplaceMakeTuple(func_graph, replacing_cnode)) { if (ReplaceMakeTuple(func_graph, replacing_cnode)) {
return nullptr; return nullptr;
} }
AnfNodePtr replace_node = GetReplaceNode(func_graph, replacing_cnode); AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
if (replace_node == nullptr) { if (replace_node == nullptr) {
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
return nullptr; return nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册