提交 6d14de5f 编写于 作者: M Margaret_wangrui

handle switch input to partial

上级 b4e37158
...@@ -501,7 +501,50 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K ...@@ -501,7 +501,50 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
return new_cnode; return new_cnode;
} }
static std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node_input);
MS_EXCEPTION_IF_NULL(graph);
// switch input generalizes partial
if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) ||
AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) {
return node_input->cast<CNodePtr>();
}
if (node_input->isa<CNode>()) {
MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call.";
}
std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
partial_inputs.emplace_back(node_input);
auto partial_node = graph->NewCNode(partial_inputs);
return partial_node;
}
KernelGraphPtr kernel_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input));
partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
auto partial_node = graph->NewCNode(partial_inputs);
return partial_node;
}
CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(graph);
auto node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
if (node->inputs().size() < kSwitchInputSize) {
MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize;
}
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimSwitch->name()));
std::vector<AnfNodePtr> switch_inputs = {primitive, node->input(1)};
for (size_t index = 2; index < node->inputs().size(); index++) {
auto input = CreateSwitchInput(node->input(index), graph);
switch_inputs.emplace_back(input);
}
auto switch_node = graph->NewCNode(switch_inputs);
return switch_node;
}
std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
// create primitive of cnode:call(partial or switch) // create primitive of cnode:call(partial or switch)
...@@ -526,7 +569,8 @@ static std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, ...@@ -526,7 +569,8 @@ static std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode,
}); });
return cnode_inputs; return cnode_inputs;
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
cnode_inputs.emplace_back(cnode_input); auto switch_node = HandleSwitchInputs(cnode_input, graph);
cnode_inputs.emplace_back(switch_node);
return cnode_inputs; return cnode_inputs;
} }
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
......
...@@ -87,6 +87,10 @@ class SessionBasic { ...@@ -87,6 +87,10 @@ class SessionBasic {
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph);
CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph);
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
// set parameters of final graph // set parameters of final graph
virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; } virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; }
// set output of final graph // set output of final graph
......
...@@ -246,6 +246,7 @@ constexpr auto kAnfPartialFuncGraphIndex = 1; ...@@ -246,6 +246,7 @@ constexpr auto kAnfPartialFuncGraphIndex = 1;
constexpr auto kRealInputNodeIndexInTupleGetItem = 1; constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
constexpr auto kTupleGetItemInputSize = 3; constexpr auto kTupleGetItemInputSize = 3;
constexpr auto kSwitchInputSize = 4;
// index define of control depend // index define of control depend
constexpr auto kControlDependPriorIndex = 1; constexpr auto kControlDependPriorIndex = 1;
constexpr auto kControlDependBehindIndex = 2; constexpr auto kControlDependBehindIndex = 2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册