提交 e895f19e 编写于 作者: Z zhousiyi

fix transform bug which high order pritimive is not convert to graph

上级 afb0e76d
......@@ -499,15 +499,20 @@ void TraverseGraphMap(
for (auto &use : users) {
CNodePtr node = use.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
if (node->func_graph() != fg) {
continue;
}
int key = use.second;
if (key != 0) {
MS_EXCEPTION_IF_NULL(node->input(0));
bool key_is_const = node->input(0)->isa<ValueNode>();
PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
continue;
if (value != nullptr) {
bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
continue;
}
}
FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
......@@ -554,6 +559,7 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
FuncGraphTransaction tr = manager_ptr->Transact();
auto &fgs = manager_ptr->func_graphs();
TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
tr.Commit();
return graph;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册