diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index c2470a0df12ca4ba5aacfd9ae31d53a2cc819c2a..ea42257502fe28afd02cdd038cbd14d9a173f797 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -497,7 +497,50 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K return new_cnode; } -static std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { +CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(node_input); + MS_EXCEPTION_IF_NULL(graph); + // switch input generalizes partial + if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) || + AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) { + return node_input->cast(); + } + if (node_input->isa()) { + MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call."; + } + std::vector partial_inputs = {NewValueNode(std::make_shared(prim::kPrimPartial->name()))}; + if (node_input->isa() && IsValueNode(node_input)) { + partial_inputs.emplace_back(node_input); + auto partial_node = graph->NewCNode(partial_inputs); + return partial_node; + } + KernelGraphPtr kernel_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(kernel_graph); + kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input)); + partial_inputs.emplace_back(std::make_shared(kernel_graph)); + auto partial_node = graph->NewCNode(partial_inputs); + return partial_node; +} + +CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(graph); + auto node = anf_node->cast(); + MS_EXCEPTION_IF_NULL(node); + if (node->inputs().size() < kSwitchInputSize) { + MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize; + } + auto primitive = NewValueNode(std::make_shared(prim::kPrimSwitch->name())); + std::vector switch_inputs = {primitive, node->input(1)}; + for (size_t index = 2; index < node->inputs().size(); index++) { + auto input = CreateSwitchInput(node->input(index), graph); + switch_inputs.emplace_back(input); + } + auto switch_node = graph->NewCNode(switch_inputs); + return switch_node; +} + +std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); // create primitive of cnode:call(partial or switch) @@ -522,7 +565,8 @@ static std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, }); return cnode_inputs; } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { - cnode_inputs.emplace_back(cnode_input); + auto switch_node = HandleSwitchInputs(cnode_input, graph); + cnode_inputs.emplace_back(switch_node); return cnode_inputs; } MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index cf85dd02250d5733655c2ee6c2e56ead6dd193f2..8f8f88e65ac27d977a4c9f1ea87c3f03c6ce3466 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -87,6 +87,10 @@ class SessionBasic { std::unordered_map *other_graph_cnode); CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); + CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); + CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); + std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); + // set parameters of final graph virtual GraphId SetFinalGraphInput(const std::vector &) { return kInvalidGraphId; } // set output of final graph diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index b3538a3d745633e32002ff24a5058cd55ee1347c..5e3b545cb17e23ff7678588ed714ac7a718c0a4e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -257,6 +257,7 @@ constexpr auto kAnfPartialFuncGraphIndex = 1; constexpr auto kRealInputNodeIndexInTupleGetItem = 1; constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; constexpr auto kTupleGetItemInputSize = 3; +constexpr auto kSwitchInputSize = 4; // index define of control depend constexpr auto kControlDependPriorIndex = 1; constexpr auto kControlDependBehindIndex = 2;