diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index fa5fb6d67401d56bc283fa3d254d43f47c754f11..2e8b7bc5d6d9f6c1e9adfed660a9e25b8d491aba 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -29,6 +29,7 @@ #include "hccl/hcom.h" #include "common/trans.h" #include "runtime/context.h" +#include "device/ascend/ascend_label_assign.h" #include "device/ascend/ascend_stream_assign.h" #include "device/ascend/ascend_memory_pool.h" #include "framework/ge_runtime/model_runner.h" @@ -281,21 +282,24 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { return true; } - AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); + AscendStreamAssign &stream_assign_instance = AscendStreamAssign::GetInstance(); + AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); // the streams' flag not HEAD_STREAM std::vector wait_active_stream_list; - assign_instance.GetWaitStreams(&wait_active_stream_list); - auto force_copy_stream_list = assign_instance.hcom_streams(); + stream_assign_instance.GetWaitStreams(&wait_active_stream_list); + auto force_copy_stream_list = stream_assign_instance.hcom_streams(); - MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() - << ", total event num:" << assign_instance.total_event_num() + MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_assign_instance.GetTotalStreamNum() + << ", total event num:" << stream_assign_instance.total_event_num() + << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) << ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", force_copy_stream_list size:" << force_copy_stream_list.size(); std::vector> empty_list; std::shared_ptr model = std::make_shared( task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0); + 0, 0, 0, 0, 0, stream_assign_instance.GetTotalStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), + stream_assign_instance.total_event_num(), 0); auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); if (!ret.second) { diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc index e4239117c2e3e06dbcda0192b1cb0c4e2e02533d..db68516500b64e6a8149b8b958bf83ead33ab1de 100644 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc @@ -15,6 +15,8 @@ */ #include +#include +#include #include "device/ascend/ascend_label_assign.h" #include "session/anf_runtime_algorithm.h" @@ -36,6 +38,7 @@ static void UpdateLabelGoto(NotNull node) { uint32_t goto_label_id = GetValue(value); AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(goto_label_id), node.get()); MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; + node->set_inputs({node->input(0)}); } static void UpdateLabelSwitch(NotNull node) { @@ -58,29 +61,93 @@ static void UpdateLabelSwitch(NotNull node) { MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; } AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue>(label_list), node.get()); + node->set_inputs({node->input(0), node->input(1)}); } -void AscendLabelAssign::AssignLabel(NotNull &> graph) { - auto cnode_list = graph->execution_order(); - // 1 assign label id to label_set - uint32_t cur_label_id = 0; - for (auto &node : cnode_list) { - if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) { - AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(cur_label_id), node); - MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id; - ++cur_label_id; +static void AssignLabelForLabelSet(NotNull> graph, NotNull label_id, + NotNull> *> memo) { + if (memo->find(graph.get()) != memo->end()) { + return; + } + + MS_LOG(INFO) << "Assign label for " << graph->ToString(); + auto nodes = TopoSort(graph->get_return()); + for (auto &node : nodes) { + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string node_name = AnfAlgo::GetCNodeName(node); + if (node_name == kLabelSetOpName && !AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(*label_id), node); + MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << *label_id; + ++(*label_id); } } - // 2 update label_switch / label_goto - for (auto &node : cnode_list) { - if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) { - UpdateLabelGoto(NOT_NULL(node)); + + for (auto &cg : graph->child_graph_order()) { + AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo); + } +} + +static void AssignLabelForGotoSwitch(NotNull> graph, + NotNull> *> memo) { + if (memo->find(graph.get()) != memo->end()) { + return; + } + + MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); + auto nodes = TopoSort(graph->get_return()); + for (auto &node : nodes) { + if (!node->isa()) { + continue; } - if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) { - UpdateLabelSwitch(NOT_NULL(node)); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string node_name = AnfAlgo::GetCNodeName(node); + if (node_name == kLabelGotoOpName) { + UpdateLabelGoto(NOT_NULL(cnode)); + cnode->set_abstract(nullptr); } + + if (node_name == kLabelSwitchOpName) { + UpdateLabelSwitch(NOT_NULL(cnode)); + } + } + for (auto &cg : graph->child_graph_order()) { + AssignLabelForGotoSwitch(NOT_NULL(cg), memo); + } +} + +void AscendLabelAssign::AssignLabel(NotNull> graph) { + MS_LOG(INFO) << "Assign label start."; + std::set> memo; + uint32_t label_id = 0; + AssignLabelForLabelSet(graph, NOT_NULL(&label_id), NOT_NULL(&memo)); + memo.clear(); + { + std::lock_guard lock(label_num_mutex_); + label_num_[graph.get().get()] = label_id; } + AssignLabelForGotoSwitch(graph, NOT_NULL(&memo)); + MS_LOG(INFO) << "Assign label end."; +} + +uint32_t AscendLabelAssign::GetLabelNum(NotNull graph) { + std::lock_guard lock(label_num_mutex_); + auto iter = label_num_.find(graph.get()); + if (iter == label_num_.end()) { + MS_LOG(WARNING) << "Graph " << graph->ToString() << " has not assigned label."; + return 1; + } + return iter->second; +} + +uint32_t AscendLabelAssign::GetLabelNum(NotNull> graph) { + return GetLabelNum(NOT_NULL(graph.get().get())); } } // namespace ascend diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.h b/mindspore/ccsrc/device/ascend/ascend_label_assign.h index 1cc0351c60efa46c35f52479652bc22f31590aea..743976fba13db459c01d0bc78962616400a052a0 100644 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.h +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.h @@ -18,6 +18,7 @@ #define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ #include +#include #include "session/kernel_graph.h" #include "utils/contract.h" @@ -35,11 +36,16 @@ class AscendLabelAssign { AscendLabelAssign(const AscendLabelAssign &) = delete; AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; - void AssignLabel(NotNull &> graph); + void AssignLabel(NotNull> graph); + uint32_t GetLabelNum(NotNull graph); + uint32_t GetLabelNum(NotNull> graph); private: AscendLabelAssign() = default; ~AscendLabelAssign() = default; + + std::map label_num_; + std::mutex label_num_mutex_; }; } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/kernel/rts/label_switch.cc b/mindspore/ccsrc/kernel/rts/label_switch.cc index 6647ac7eb693d06febb49af825fcfe03b9f48f16..168e1f4844bb6dc4aa538b8c8da99bb5648405f4 100644 --- a/mindspore/ccsrc/kernel/rts/label_switch.cc +++ b/mindspore/ccsrc/kernel/rts/label_switch.cc @@ -17,6 +17,7 @@ #include "kernel/rts/label_switch.h" #include #include +#include #include "runtime/stream.h" #include "framework/ge_runtime/task_info.h" #include "session/anf_runtime_algorithm.h" @@ -66,13 +67,33 @@ std::vector LabelSwitchKernel::GenTask(const std::vector task_info_list; cond_ = inputs[0]->addr; - // std::shared_ptr task_info_ptr = - // std::make_shared(stream_id, label_size_, &label_list_, cond_); - // need updata ge task info define - std::shared_ptr task_info_ptr = std::make_shared(stream_id, label_size_); + // todo: need update ge task info define + auto task_info_ptr = std::make_shared(stream_id, 0); + // auto task_info_ptr = std::make_shared(stream_id, label_size_, label_list_, cond_); MS_EXCEPTION_IF_NULL(task_info_ptr); task_info_list.emplace_back(task_info_ptr); return task_info_list; } + +std::vector> LabelSwitchDesc::GetKernelInfo() { + std::vector> label_switch_build_info{}; + + vector input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT}; + vector input_type{kNumberTypeUInt32, kNumberTypeBool}; + if (input_format.size() != input_type.size()) { + MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " + << input_type.size(); + } + for (size_t i = 0; i < input_format.size(); ++i) { + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat({input_format[i]}); + builder.SetInputsDeviceType({input_type[i]}); + builder.SetProcessor(AICORE); + builder.SetKernelType(RT_KERNEL); + builder.SetFusionType(OPAQUE); + label_switch_build_info.emplace_back(builder.Build()); + } + return label_switch_build_info; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/label_switch.h b/mindspore/ccsrc/kernel/rts/label_switch.h index 0accd26afbf638b33a7e47a9fda7b6f68b68cd8e..858f851b2abe9b80acdc8c80086c4ab644362859 100644 --- a/mindspore/ccsrc/kernel/rts/label_switch.h +++ b/mindspore/ccsrc/kernel/rts/label_switch.h @@ -42,6 +42,14 @@ class LabelSwitchKernel : public RtKernel { void *cond_; }; +class LabelSwitchDesc : public RtKerDesc { + public: + LabelSwitchDesc() = default; + ~LabelSwitchDesc() override = default; + std::vector> GetKernelInfo() override; +}; + +MS_REG_RTKERNEL_DESC(labelswitch, LabelSwitchDesc); MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc b/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc index c841ce4611f70bbba10a0ffb001ddae28a1f8b35..14f5a60a07068f02e8c16b463bfea2f35e346965 100755 --- a/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc +++ b/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc @@ -44,6 +44,12 @@ RtKerDescFactory &RtKerDescFactory::Get() { return _this; } +static bool IsDefaultKernelInfo(const std::string &name) { + static const std::set white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName, + kLabelGotoOpName}; + return white_list.find(name) != white_list.end(); +} + void GetRtKelInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_info_list); @@ -58,7 +64,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node, } // if can't find kernel info in kernel info database, use the default kernel info auto node_name = AnfAlgo::GetCNodeName(kernel_node); - if (node_name == "StreamSwitch" || node_name == "StreamActive") { + if (IsDefaultKernelInfo(node_name)) { auto kernel_build_info_builder = std::make_shared(); // set input infos auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 75f9f76db4749d5b5b8d761ece5b0e9064278d0a..7d6b17992ed0bfd1eb2981774439a4178ca4c441 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -331,12 +331,14 @@ bool ExecuteAction(const ResourcePtr &res) { } auto graph_id = res->results()[kOutput].cast(); - auto bc_ptr = res->results()[kBackend].cast>(); + std::shared_ptr bc_ptr = res->results()[kBackend].cast>(); + std::shared_ptr msbc_ptr = std::dynamic_pointer_cast(bc_ptr); + MS_EXCEPTION_IF_NULL(msbc_ptr); compile::VmEvalFuncPtr run = - std::make_shared([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef { - MS_LOG(INFO) << "Execute args size" << args.size(); - auto outs = bc_ptr->RunGraph(graph_id, args); - MS_LOG(DEBUG) << "out size" << outs.size(); + std::make_shared([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef { + MS_LOG(INFO) << "Execute args size " << args.size(); + auto outs = msbc_ptr->RunGraph(graph_id, args); + MS_LOG(DEBUG) << "out size " << outs.size(); return outs[0]; }); res->results()[kOutput] = run; diff --git a/mindspore/ccsrc/session/CMakeLists.txt b/mindspore/ccsrc/session/CMakeLists.txt index 8143c0a34e4e589594ff9162acaf2a9cfae70da2..2824af8a5d1b3a31b12f99c1a20beea6b804efbb 100644 --- a/mindspore/ccsrc/session/CMakeLists.txt +++ b/mindspore/ccsrc/session/CMakeLists.txt @@ -6,22 +6,23 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ) if (ENABLE_GPU) - file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc" ) list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) endif () if (ENABLE_CPU) - file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu_session.cc" ) list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST}) endif () if (ENABLE_D) - file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend_session.cc" + "ascend_control_parser.cc" ) list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) endif () diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc new file mode 100644 index 0000000000000000000000000000000000000000..665876781446633a36b29cf09972418268441dd2 --- /dev/null +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -0,0 +1,319 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "session/ascend_control_parser.h" +#include "session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace session { + +static VectorRef GetCallArgs(std::vector::iterator iter_begin, std::vector::iterator iter_end) { + VectorRef call_args; + for (auto iter = iter_begin; iter != iter_end; ++iter) { + if (utils::isa(*iter)) { + call_args.push_back(GetValueNode(*iter)); + } else { + call_args.push_back(*iter); + } + } + return call_args; +} + +void AscendControlParser::LinkGraph(NotNull kg) { + std::set memo; + ProcessKernelGraph(kg, nullptr, nullptr, {}, NOT_NULL(&memo)); +} + +NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label, const VectorRef &args, + NotNull *> memo) { + MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); + // 0. recursive condition + if (memo->find(kg) != memo->end()) { + MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); + return NOT_NULL(kg->get_start_label()); + } + + // 2. args replace placeholder + LinkParentGraph(kg, last_node, last_label, args); + // 3. topological sort + std::vector nodes = GetCNodes(TopoSort(kg->get_return())); + if (nodes.empty()) { + MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; + } + // 4. insert first_label + auto start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + for (auto node : nodes) { + if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { + InsertControlDependToGraph(kg, NOT_NULL(start_label), NOT_NULL(node)); + break; + } + } + + kg->set_start_label(start_label); + // 5. traverse + for (size_t i = 0; i < nodes.size(); ++i) { + auto &cnode = nodes[i]; + if (cnode->size() < kCNodePrim + 1) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + AnfNodePtr fn = cnode->input(kCNodePrim); + if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { + MS_LOG(DEBUG) << "continue node " << cnode->DebugString(); + continue; + } + AnfNodePtr arg = cnode->input(kCNodeCallArg); + if (IsValueNode(arg)) { + RecurseCall(kg, NOT_NULL(cnode), (i + 1 < nodes.size() ? nodes[i + 1] : nullptr), memo); + } else if (!arg->isa()) { + MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); + } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitch)) { + auto arg_cnode = arg->cast(); + cnode->set_inputs(cnode->inputs()); + RecurseSwitch(kg, NOT_NULL(cnode), memo); + } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitchLayer)) { + auto arg_cnode = arg->cast(); + cnode->set_inputs(cnode->inputs()); + RecurseSwitchLayer(kg, NOT_NULL(cnode), memo); + } + } + + MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); + return NOT_NULL(start_label); +} + +std::vector AscendControlParser::GetCNodes(const std::vector &in) { + std::vector out; + for (auto &node : in) { + if (node->isa()) { + out.push_back(node->cast()); + } + } + return out; +} + +void AscendControlParser::InsertDependToGraph(NotNull kg, NotNull attch_node) { + std::vector inputs = {NewValueNode(std::make_shared("depend"))}; + auto return_node = kg->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + inputs.push_back(return_node->input(1)); + inputs.push_back(attch_node.get()); + auto depend_node = kg->NewCNode(inputs); + return_node->set_input(1, depend_node); +} + +void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull first_node, + NotNull second_node) { + MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() + << ", the second node is " << second_node->DebugString(); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), + first_node, second_node}; + auto control_depend = kg->NewCNode(inputs); + InsertDependToGraph(kg, NOT_NULL(control_depend)); +} + +void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, + const CNodePtr &last_label, const VectorRef &args) { + if (from_graph_call_node != nullptr) { + SetSubGraphInput(kg, NOT_NULL(from_graph_call_node), args); + } + + auto origin_return = kg->get_return(); + std::vector origin_return_inputs = origin_return->inputs(); + // if entry graph, replace return with make_tuple + if (from_graph_call_node == nullptr || last_label == nullptr) { + MS_LOG(INFO) << kg->ToString() << " is entry graph."; + std::vector make_tuple_inputs = {std::make_shared(prim::kPrimMakeTuple)}; + make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end()); + auto make_tuple = kg->NewCNode(make_tuple_inputs); + origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple}); + } else { + // else replace return with label_goto + auto label_goto = + kg->NewCNode({std::make_shared(std::make_shared(kLabelGotoOpName)), last_label}); + InsertDependToGraph(kg, NOT_NULL(label_goto)); + } +} + +void AscendControlParser::RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + NotNull *> memo) { + MS_LOG(INFO) << "process call func " << cur_node->DebugString(); + + // 1 get kernel graph + auto origin_inputs = cur_node->inputs(); + std::vector new_inputs = {std::make_shared(std::make_shared(kLabelGotoOpName))}; + auto call_args = GetCallArgs(origin_inputs.begin() + 1, origin_inputs.end()); + if (!IsValueNode(origin_inputs[kCNodeCallArg])) { + MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; + return; + } + // 2 return label + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + // 3 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + if (next_node != nullptr && next_node != kg->get_return()) { + InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); + } + auto call_kg = GetValueNode(origin_inputs[kCNodeCallArg]); + // 4 modify call op to goto op + cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); + // 5 recurse sub graph + CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, call_args, memo); + new_inputs.push_back(sub_label); + new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end()); + cur_node->set_inputs(new_inputs); + cur_node->set_abstract(nullptr); + MS_LOG(INFO) << "success process call func " << cur_node->DebugString(); +} + +void AscendControlParser::RecurseSwitch(NotNull kg, NotNull cur_node, + NotNull *> memo) { + MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); + + if (cur_node->size() < kCNodeSwitchLength) { + MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; + } + // 1 return label + auto back_label = kg->NewCNode({std::make_shared(prim::kPrimLabelSet)}); + // 2 recurse sub graph + auto origin_switch_inputs = cur_node->inputs(); + std::vector new_switch_inputs = { + std::make_shared(std::make_shared(kLabelSwitchOpName)), + origin_switch_inputs[kCNodeSwitchCond]}; + for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { + // 2.1 branch kernel graph and args + CNodePtr partial; + KernelGraphPtr branch_fg; + VectorRef call_args; + std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + // 2.2 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + // 2.3 recurse sub graph + CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); + new_switch_inputs.push_back(branch_label); + } + std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); + new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end()); + cur_node->set_inputs(new_switch_inputs); + cur_node->set_abstract(nullptr); + MS_LOG(INFO) << "success process switch func " << cur_node->DebugString(); +} + +void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull cur_node, + NotNull *> memo) { + MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); + + if (cur_node->size() < kCNodeSwitchLayerLength) { + MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; + } + + auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); + MS_EXCEPTION_IF_NULL(branch_tuple); + if (!branch_tuple->isa()) { + MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; + } + auto branch_partial = utils::cast(branch_tuple)->inputs(); + // 1 return label + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSwitchOpName))}); + // 2 recurse sub graph + auto origin_switch_inputs = cur_node->inputs(); + std::vector new_switch_inputs = {std::make_shared(prim::kPrimLabelSwitch), + origin_switch_inputs[kCNodeSwitchCond]}; + for (size_t i = 0; i < branch_partial.size(); ++i) { + // 2.1 branch kernel graph and args + CNodePtr partial; + KernelGraphPtr branch_fg; + VectorRef call_args; + std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + // 2.2 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + // 2.3 recurse sub graph + CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); + new_switch_inputs.push_back(branch_label); + } + new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); + cur_node->set_inputs(new_switch_inputs); + cur_node->set_abstract(nullptr); + MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString(); +} + +std::tuple AscendControlParser::ParsePartial(NotNull node) { + if (!node.get()->isa()) { + MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); + } + // 2.1 branch kernel graph and args + auto partial_cnode = utils::cast(node.get()); + if (partial_cnode->size() < kCNodePartialLength) { + MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; + } + auto partial_inputs = partial_cnode->inputs(); + auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); + auto call_args = GetCallArgs(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end()); + + return {partial_cnode, branch_kg, call_args}; +} + +void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, + NotNull to) { + if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && + AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { + return; + } + if (from.get() == to.get()) { + return; + } + MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " + << to->DebugString(); + // config inputs of assign node + std::vector inputs = {NewValueNode(std::make_shared("Assign")), to, from}; + // generate a new cnode + auto assign_node = kg->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(assign_node); + assign_node->set_abstract(to->abstract()); + // append the assign at the end of from graph + InsertDependToGraph(kg, NOT_NULL(assign_node)); +} + +size_t AscendControlParser::SetChildGraphInput(NotNull kg, NotNull node, + size_t input_index) { + auto output_num = AnfAlgo::GetOutputTensorNum(node); + if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + return input_index + output_num; + } + + auto &graph_inputs = kg->inputs(); + if (input_index >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_index]; + if (node.get()->isa()) { + MS_EXCEPTION_IF_NULL(backend_parameter); + MS_LOG(INFO) << "Reuse node [" << node->DebugString() << "], old node[" << backend_parameter->DebugString() + << "] will be replaced."; + kg->ReplaceNode(backend_parameter, node); + return input_index; + } + InsertAssignToGraph(kg, node, NOT_NULL(backend_parameter)); + return input_index + 1; +} + +void AscendControlParser::SetSubGraphInput(NotNull kg, NotNull from_graph_call_node, + const VectorRef &args) {} + +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..ca215ef0c219c6cadd0ed0b3910ca2cd50d3cac9 --- /dev/null +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H +#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H + +#include +#include +#include +#include "session/kernel_graph.h" +#include "utils/base_ref.h" +#include "utils/contract.h" + +namespace mindspore { +namespace session { + +class AscendControlParser { + public: + static void LinkGraph(NotNull kg); + + static void InsertDependToGraph(NotNull kg, NotNull attch_node); + static void InsertControlDependToGraph(NotNull kg, NotNull first_node, + NotNull second_node); + + private: + static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label, const VectorRef &args, + NotNull *> memo); + static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + NotNull *> memo); + static void RecurseSwitch(NotNull kg, NotNull cur_node, + NotNull *> memo); + static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, + NotNull *> memo); + + static std::vector GetCNodes(const std::vector &in); + static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, + const CNodePtr &last_label, const VectorRef &args); + static void SetSubGraphInput(NotNull kg, NotNull from_graph_call_node, + const VectorRef &args); + static std::tuple ParsePartial(NotNull node); + static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); + static size_t SetChildGraphInput(NotNull kg, NotNull node, size_t input_index); + + static constexpr size_t kCNodePrim = 0; + static constexpr size_t kCNodeCallArg = 1; + static constexpr size_t kCNodeSwitchCond = 1; + static constexpr size_t kCNodeSwitchTrue = 2; + static constexpr size_t kCNodeSwitchFalse = 3; + static constexpr size_t kCNodeSwitchLength = 4; + static constexpr size_t kCNodePartialLength = 2; + static constexpr size_t kCNodePartialFunc = 1; + static constexpr size_t kCNodeSwitchLayerCond = 1; + static constexpr size_t kCNodeSwitchLayerBranch = 2; + static constexpr size_t kCNodeSwitchLayerLength = 3; +}; + +} // namespace session +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 862d66c25dd3bb7e0b842633d849ef0e722ff668..f4f2f4bb5fd43aac68d795ee882e7fb3e6186d77 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -160,14 +160,14 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) { std::vector GetCNodes(const std::vector &anf_nodes) { std::vector cnodes = {}; size_t i = 0; - for (const auto anf : anf_nodes) { + for (auto anf : anf_nodes) { MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString(); MS_EXCEPTION_IF_NULL(anf); if (anf->isa()) { cnodes.push_back(anf->cast()); } } - return std::move(cnodes); + return cnodes; } std::vector> GetChildList(const KernelGraph &cur_graph, const std::vector &cnodes) { @@ -189,7 +189,7 @@ std::vector> GetChildList(const KernelGraph &cur_graph, co ret.push_back(std::vector(cnodes.begin() + after_call_index, cnodes.end())); } } - return std::move(ret); + return ret; } void UpdateRealInput(KernelGraph *graph) { @@ -232,7 +232,7 @@ void UpdateRealInput(KernelGraph *graph) { auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); partial_cnode->set_inputs( std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); - return std::move(ret); + return ret; }; bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); @@ -256,27 +256,28 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { // split switch SplitGraph(graph); // insert goto labels and label_sets - LinkChildGraphs(graph.get()); + LinkChildGraphs(NOT_NULL(graph)); // resource initialize InitRuntimeResource(); - // ir fusion - IRFusion(graph); - // kernel select - SelectKernelGraphKernel(*graph); - // convert model of predict module - ConvertPredictModel(graph); - // hardware optimize - HardwareOptimizeGraphs(graph); + // assign label + AssignLabel(NOT_NULL(graph)); + if (!graph->executable()) { + return graph->graph_id(); + } + for (auto iter : graphs_) { + if (iter.second == graph) { + MS_LOG(INFO) << "Entry graph " << graph->ToString() << " graph id " << graph->graph_id(); + final_graph_id_ = graph->graph_id(); + } + MS_LOG(INFO) << "CompileChildGraph " << iter.second->ToString(); + CompileChildGraph(iter.second); + } // adjust kernel AdjustKernel(graph); // root graph valiate,include genearte execute order and so on RootGraphExecutorValidate(graph.get()); // assign stream AssignStream(graph); - // assign label - AssignLabel(NOT_NULL(graph)); - // build kernel if node is cnode - BuildKernel(graph); // alloc mem MemoryAlloc(graph.get()); // task generate @@ -556,7 +557,7 @@ void AscendSession::AssignStream(const std::shared_ptr &kernel_grap MS_LOG(INFO) << "Finish!"; } -void AscendSession::AssignLabel(NotNull kernel_graph) const { +void AscendSession::AssignLabel(NotNull kernel_graph) const { MS_LOG(INFO) << "Start!"; device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); MS_LOG(INFO) << "Finish!"; @@ -1305,29 +1306,13 @@ void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived } void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { - MS_LOG(INFO) << "Insert depend at the end of graph, the attach node is " << attch_node->DebugString(); - auto graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(graph); - std::vector inputs = {NewValueNode(std::make_shared("depend"))}; - auto return_node = graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - inputs.push_back(return_node->input(1)); - inputs.push_back(attch_node); - auto depend_node = graph->NewCNode(inputs); - return_node->set_input(1, depend_node); + AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node)); } void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node) { - MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() - << ", the second node is " << second_node->DebugString(); - auto graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(graph); - std::vector inputs = {NewValueNode(std::make_shared("ControlDepend"))}; - inputs.push_back(first_node); - inputs.push_back(second_node); - auto control_depend = graph->NewCNode(inputs); - InsertDependToGraph(graph_id, control_depend); + AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node), + NOT_NULL(second_node)); } size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { @@ -1482,5 +1467,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { SplitGraph(child_graph); } } + +void AscendSession::LinkChildGraphs(NotNull graph) { AscendControlParser::LinkGraph(graph); } + } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index a916263e0590dd4720ecb08654ca38c263a121c8..aa1050b61b8510764ee5bd37d5849f239cb9e758 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -28,6 +28,7 @@ #include "session/kernel_graph.h" #include "kernel/kernel.h" #include "session/session_factory.h" +#include "session/ascend_control_parser.h" namespace mindspore { namespace session { @@ -74,7 +75,7 @@ class AscendSession : public SessionBasic { void AdjustKernel(const std::shared_ptr &kernel_graph) const; void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; void AssignStream(const std::shared_ptr &kernel_graph) const; - void AssignLabel(NotNull kernel_graph) const; + void AssignLabel(NotNull kernel_graph) const; void BuildKernel(const std::shared_ptr &kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const; void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; @@ -96,7 +97,8 @@ class AscendSession : public SessionBasic { void SetFinalGraphOutput(const VectorRef &vec_output); void SplitGraph(const KernelGraphPtr &graph); - void LinkChildGraphs(KernelGraph *graph) {} + void LinkChildGraphs(NotNull graph); + void IRFusion(const KernelGraphPtr &graph) {} void SelectKernelGraphKernel(const KernelGraph &graph) {} void ConvertPredictModel(const KernelGraphPtr graph) {} diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 9a3823532ac3e48b4b31f4b4e2bcd65f4bef54c8..de55949c7bb5addccf93681b3a2c7db6d270bee3 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -28,6 +28,7 @@ #include "ir/func_graph.h" #include "ir/anf.h" #include "utils/graph_utils.h" +#include "utils/contract.h" #include "device/kernel_info.h" namespace mindspore { @@ -108,6 +109,7 @@ class KernelGraph : public FuncGraph { std::vector> child_graph_order() const { return child_graph_order_; } // checkout whether current graph is leaf graph bool IsLeafGraph() const; + // set input_tensors pointer of control parameter void set_input_ctrl_tensors(const std::shared_ptr> &input_tensors_ptr) { input_ctrl_tensors_ = input_tensors_ptr; @@ -126,6 +128,9 @@ class KernelGraph : public FuncGraph { // used to dump ir std::string ToString() const override; + void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } + CNodePtr get_start_label() { return start_label_; } + private: // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); @@ -168,12 +173,16 @@ class KernelGraph : public FuncGraph { std::map> node_to_child_graphs_; // child graph execute order in root graph std::vector> child_graph_order_; + // input_tensors of control parameter std::shared_ptr> input_ctrl_tensors_; + // parameter graph std::shared_ptr parent_graph_; // record real parameters,inputs_ is the formal parameters std::map> real_inputs_; + + CNodePtr start_label_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 17d7cb1e359997a111ab4a85e2bb80a622311cbe..9213e41450e969149f536f0df15b6b66334949c9 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/transform/*.cc" "../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc" "../../../mindspore/ccsrc/session/ascend_session.cc" + "../../../mindspore/ccsrc/session/ascend_control_parser.cc" "../../../mindspore/ccsrc/session/kernel_graph.cc" "../../../mindspore/ccsrc/session/session_basic.cc" "../../../mindspore/ccsrc/session/session_factory.cc" diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index 619f2385b4e00586a190fcc0f66ff6d35160c611..5d8e33b25699992dd5942ef65f2305ebdb42dce6 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -22,7 +22,9 @@ namespace mindspore { namespace device { namespace ascend { -void AscendLabelAssign::AssignLabel(NotNull &>) {} +void AscendLabelAssign::AssignLabel(NotNull> graph) {} +uint32_t AscendLabelAssign::GetLabelNum(NotNull graph) { return 1; } +uint32_t AscendLabelAssign::GetLabelNum(NotNull> graph) { return 1; } void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } @@ -39,9 +41,7 @@ bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::ve } // namespace ascend void KernelAdjust::Reorder(const std::shared_ptr &kernel_graph_ptr) { return; } void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { return; } -bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { - return true; -} +bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { return true; } bool KernelAdjust::NeedInsertSwitch() { return true; } void KernelAdjust::Profiling(NotNull kernel_graph_ptr) { return; } } // namespace device