From e895f19e803a6fb7ca56f65d8a5ee3e111cbe9a8 Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Fri, 22 May 2020 07:36:20 +0000 Subject: [PATCH] fix transform bug which high order pritimive is not convert to graph --- mindspore/ccsrc/vm/transform.cc | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 93d5f33cf..636d36f93 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -499,15 +499,20 @@ void TraverseGraphMap( for (auto &use : users) { CNodePtr node = use.first->cast(); MS_EXCEPTION_IF_NULL(node); + if (node->func_graph() != fg) { + continue; + } int key = use.second; if (key != 0) { MS_EXCEPTION_IF_NULL(node->input(0)); bool key_is_const = node->input(0)->isa(); PrimitivePtr value = GetValueNode(node->input(0)); - bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name())); - bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name())); - if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) { - continue; + if (value != nullptr) { + bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name())); + bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name())); + if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) { + continue; + } } FuncGraphPtr g = get_prim_graph(GetValueNode(const_primitive_node), dyn_cast(const_primitive_node->abstract())); @@ -554,6 +559,7 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { FuncGraphTransaction tr = manager_ptr->Transact(); auto &fgs = manager_ptr->func_graphs(); TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph); + tr.Commit(); return graph; } -- GitLab