diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index 6d849bd2a5748b512557bfea8877cd47bb5af6be..935e694636f5545baccf318a42cfe32e097904a3 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -54,9 +54,9 @@ static const size_t PRAMATER_OUTPUT_INDEX = 0; AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } void AscendKernelRuntime::ClearGraphModelMap() { - for (auto &iter : graph_model_id_map_) { - MS_LOG(INFO) << "Ge UnloadModel " << iter.second; - auto ret = ge::model_runner::ModelRunner::Instance().UnloadModel(iter.second); + for (auto &iter : graph_model_map_) { + MS_LOG(INFO) << "Ge UnloadModel " << iter.first; + auto ret = ge::model_runner::ModelRunner::Instance().UnloadModel(iter.first); if (!ret) { MS_LOG(ERROR) << "UnloadModel failed"; } @@ -249,6 +249,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size } bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { + if (graph == nullptr) { + MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; + } + MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id(); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool is_task_sink = context_ptr->enable_task_sink(); @@ -261,19 +265,15 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); } #endif - if (graph == nullptr) { - MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; - } vector> task_info_list; auto anf_node_list = graph->execution_order(); TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id()); // Store the task_info_list - auto iter = task_map_.find(graph); - if (iter != task_map_.end()) { - MS_LOG(EXCEPTION) << "graph TaskInfo list already exist"; + auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list)); + if (!insert_ret.second) { + MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; } - task_map_[graph] = task_info_list; // Graph may have no compute node, such TensorAddGrad. if (task_info_list.empty()) { @@ -296,25 +296,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.GetTotalEventNum(), 0); - graph_model_map_[graph] = model; - graph_model_id_map_[graph] = graph->graph_id(); + auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); + if (!ret.second) { + MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; + } MS_LOG(INFO) << "TaskGenerator GetTaskInfo end..."; return true; } -uint32_t AscendKernelRuntime::GetGraphModelId(const session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto iter = graph_model_id_map_.find(kernel_graph); - if (iter == graph_model_id_map_.end()) { - MS_LOG(EXCEPTION) << "graph not in the map"; - } - return iter->second; -} - bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { if (graph == nullptr) { MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; } + MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id(); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool is_task_sink = context_ptr->enable_task_sink(); @@ -327,23 +321,22 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { return true; } - auto task_iter = graph_model_map_.find(graph); - if (task_iter == graph_model_map_.end()) { - MS_LOG(ERROR) << "task not exist"; + auto model_iter = graph_model_map_.find(graph->graph_id()); + if (model_iter == graph_model_map_.end()) { + MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph LoadTask without GenTask."; return false; } - auto model_id = GetGraphModelId(graph); std::shared_ptr listener; - MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_id; - bool status = - ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_id, task_iter->second, listener); + MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first; + bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, + model_iter->second, listener); if (!status) { - MS_LOG(INFO) << "load task failed"; + MS_LOG(ERROR) << "load task failed"; return false; } if (ProfilingManager::GetInstance().IsProfiling()) { - std::vector task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_id); + std::vector task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first); ProfilingUtils::ReportProfilingData(graph->graph_id(), task_ids); } return true; @@ -351,6 +344,8 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id(); + auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); ge::InputData input_tensors = ge::InputData(); @@ -360,8 +355,12 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { return true; } - auto model_id = GetGraphModelId(graph); - bool status = ge::model_runner::ModelRunner::Instance().RunModel(model_id, input_tensors, output_tensors); + if (!CheckGraphIdValid(graph->graph_id())) { + MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph RunTask without GenTask."; + return false; + } + + bool status = ge::model_runner::ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors); if (!status) { MS_LOG(INFO) << "run task failed"; return false; @@ -497,12 +496,16 @@ bool AscendKernelRuntime::DestroyHccl() { } bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const { - auto iter = task_map_.find(graph); + auto iter = task_map_.find(graph->graph_id()); if (iter == task_map_.end()) { MS_LOG(EXCEPTION) << "Unknown graph ptr"; } return iter->second.empty(); } + +bool AscendKernelRuntime::CheckGraphIdValid(GraphId graph_id) const { + return task_map_.find(graph_id) != task_map_.end() && graph_model_map_.find(graph_id) != graph_model_map_.end(); +} } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h index 547228d32fcc487e04e87e0f5690e1f262ad6384..5d0f61d0a670fe02a13ea98db5c31d7bd1ae636b 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h @@ -23,6 +23,7 @@ #include "runtime/context.h" #include "framework/ge_runtime/davinci_model.h" #include "device/kernel_runtime_manager.h" +#include "session/session_basic.h" using ge::model_runner::TaskInfo; using std::unordered_map; @@ -54,14 +55,13 @@ class AscendKernelRuntime : public KernelRuntime { void ClearGraphModelMap(); void ReleaseDeviceRes() override; - uint32_t GetGraphModelId(const session::KernelGraph *kernel_graph); bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const; + bool CheckGraphIdValid(GraphId graph_id) const; rtContext_t rt_context_{nullptr}; bool initialized_{false}; - unordered_map>> task_map_; - unordered_map> graph_model_map_; - unordered_map graph_model_id_map_; + unordered_map>> task_map_; + unordered_map> graph_model_map_; }; MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);