handle partial node in CreateNewCNode

上级 ea475637
...@@ -447,6 +447,37 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K ...@@ -447,6 +447,37 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
return new_cnode; return new_cnode;
} }
static std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
// create primitive of cnode:call(partial or switch)
std::vector<AnfNodePtr> cnode_inputs = {
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
auto attr_input = cnode->input(kAnfPrimitiveIndex);
MS_EXCEPTION_IF_NULL(attr_input);
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
if (cnode_input == nullptr) {
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
<< ", but input[0] has not been created.";
}
// if the node is partial, insert the inputs of partial to the call
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
auto partial_node = attr_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_node);
auto partial_inputs = partial_node->inputs();
std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
return graph->GetBackendAnfByFrontAnf(node);
});
return cnode_inputs;
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
cnode_inputs.emplace_back(cnode_input);
return cnode_inputs;
}
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
}
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
...@@ -471,18 +502,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) ...@@ -471,18 +502,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
} }
} }
} else if (attr_input->isa<CNode>()) { } else if (attr_input->isa<CNode>()) {
// create primitive of cnode:call(switch) cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
}
cnode_inputs.emplace_back(cnode_input);
} else {
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
<< ", but input[0] has not been created.";
}
} else { } else {
// get primitive of old node // get primitive of old node
auto prim = AnfAlgo::GetCNodePrimitive(cnode); auto prim = AnfAlgo::GetCNodePrimitive(cnode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部