diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 4d4ad4ad06435e2fbbc8541b27f0b944bf9917ef..d89c053299d4654c0eed530a38ed5c8d88589b2d 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -178,10 +178,8 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { #endif // alloc mem MemoryAlloc(root_graph.get()); - // task generate - GenerateTaskInfo(root_graph); - // load task into device - LoadTask(root_graph); + // generate and load task into device + Load(root_graph); DumpAllGraphs(all_graphs); // return the root_graph id to backend auto graph_id = root_graph->graph_id(); @@ -258,10 +256,8 @@ void AscendSession::BuildGraph(GraphId graph_id) { } else { // alloc memory, including static memory and dynamic memory MemoryAlloc(graph.get()); - // generate task info for task sink mode - GenerateTaskInfo(graph); - // load task info to device if it is sink mode - LoadTask(graph); + // generate and load task info to device if it is sink mode + Load(graph); } // sync the inital const tensor to device SyncInitialTenosrToDevice(); @@ -322,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vectorRunOpClearMemory(kernel_graph); } -void AscendSession::GenerateTaskInfo(const std::shared_ptr &kernel_graph) const { +void AscendSession::Load(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "Start!"; (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->GenTask(kernel_graph.get()); - if (!ret_ok) { - MS_LOG(EXCEPTION) << "Generate task error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::LoadTask(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->LoadTask(kernel_graph.get()); + bool ret_ok = runtime_instance->Load(kernel_graph.get()); if (!ret_ok) { MS_LOG(EXCEPTION) << "Load task error!"; } MS_LOG(INFO) << "Finish!"; } -void AscendSession::ExecTask(const std::shared_ptr &kernel_graph) const { +void AscendSession::Execute(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "Start!"; auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index 5ddf77354f86951c8a9c6fe1dc777730be3f8dfb..4b2f2c232decdf05bcdc56961261d05770480878 100755 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -81,9 +81,8 @@ class AscendSession : public SessionBasic { void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector &input_tensors, KernelGraph *kernel_graph) const; void RunOpMemoryClear(const KernelGraph *kernel_graph) const; - void GenerateTaskInfo(const std::shared_ptr &kernel_graph) const; - void LoadTask(const std::shared_ptr &kernel_graph) const; - void ExecTask(const std::shared_ptr &kernel_graph) const; + void Load(const std::shared_ptr &kernel_graph) const; + void Execute(const std::shared_ptr &kernel_graph) const; void Dump(const std::shared_ptr &kernel_graph) const; void DumpAllGraphs(const std::vector &all_graphs); void LoadTensor(const std::shared_ptr &kernel_graph) const; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index a58a1c08628e6d95cb421d6efc1e3154e00f9761..6511e227a2a631dab27b76e0548bd61e7e7dc931 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -454,19 +454,31 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size return std::make_shared(device_ptr, device_size, format, type_id); } +bool AscendKernelRuntime::Load(session::KernelGraph *graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); + if (!is_task_sink) { + return true; + } + if (!GenTask(graph)) { + return false; + } + if (!LoadTask(graph)) { + return false; + } + return true; +} + bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { SetContext(); if (graph == nullptr) { MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; } MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id(); +#ifdef MEM_REUSE_DEBUG auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); - if (!is_task_sink) { - return true; - } -#ifdef MEM_REUSE_DEBUG if (!context_ptr->get_param(MS_CTX_ENABLE_MEM_REUSE)) { // Get normal graph ir for memreuse mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); @@ -517,13 +529,6 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; } MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); - if (!is_task_sink) { - return true; - } - if (GraphWithEmptyTaskList(graph)) { MS_LOG(WARNING) << "LoadTask end, task list is empty"; return true; @@ -604,6 +609,36 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) { } } +bool AscendKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { + bool ret = false; + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); +#endif + bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); + if (is_task_sink) { + ret = RunTask(graph); + } else { + ret = LaunchKernel(graph); + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Call MS Run Success in " << cost << " us"; +#endif + return ret; +} + bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { SetContext(); MS_EXCEPTION_IF_NULL(graph); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 6b7ccc085cdcdaed6f8a2a498e63038526639dab..f68a9c36c0f4688dfeeeac3a3ad2548d9c1fa791 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -40,10 +40,12 @@ class AscendKernelRuntime : public KernelRuntime { ~AscendKernelRuntime() override; bool Init() override; bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override; - bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; - bool GenTask(const session::KernelGraph *graph) override; - bool RunTask(const session::KernelGraph *graph) override; - bool LoadTask(const session::KernelGraph *graph) override; + bool LoadData(session::KernelGraph *graph, Debugger *debugger); + bool GenTask(const session::KernelGraph *graph); + bool LoadTask(const session::KernelGraph *graph); + bool RunTask(const session::KernelGraph *graph); + bool Load(session::KernelGraph *graph) override; + bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &inputs, const std::unordered_set &value_nodes, const std::vector &execution_order) override; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index a82d248d9ef90f6373ce43065cd0d0316a1d77a2..ec213d4189d5b1fbafcb4207c569941c133e5303 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -40,37 +40,8 @@ KernelRuntime::~KernelRuntime() { #endif } -bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { - bool ret = false; - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); - if (is_task_sink) { - ret = RunTask(graph); - } else { - ret = LaunchKernel(graph); - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Call MS Run Success in " << cost << " us"; -#endif - return ret; -} +bool KernelRuntime::Load(session::KernelGraph *graph) { return true; } -// for D to impl bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { if (graph != nullptr) { return true; @@ -78,37 +49,6 @@ bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *d return false; } -// for D to impl -bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { - if (graph != nullptr) { - return true; - } - return false; -} - -// for D to impl -bool KernelRuntime::GenTask(const session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -bool KernelRuntime::LoadTask(const session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -// for D to impl -bool KernelRuntime::RunTask(const session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { MS_EXCEPTION_IF_NULL(kernel); if (AnfAlgo::OutputAddrExist(kernel, index)) { diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index b81954a557eafa887a1283dd600c59b1b8de3c4b..5265f4666d57cce220901b31a5ae517eed3d578b 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -58,11 +58,9 @@ class KernelRuntime { void RunOpClearMemory(const session::KernelGraph *graph); bool DumpDataEnabled(); bool DumpDataEnabledIteration(); - virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr); virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr); - virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger); - virtual bool RunTask(const session::KernelGraph *graph); - virtual bool GenTask(const session::KernelGraph *graph); + virtual bool Load(session::KernelGraph *graph); + virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) = 0; bool LaunchKernel(const session::KernelGraph *graph); bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs, const AddressPtrList &kernel_outputs, @@ -80,7 +78,6 @@ class KernelRuntime { #ifdef ENABLE_DUMP_E2E DumpConfPtr GetDumpConf(); #endif - virtual bool LoadTask(const session::KernelGraph *graph); // for GPU and D to impl virtual void ReleaseDeviceRes() {} void set_device_id(uint32_t device_id) { device_id_ = device_id; }