diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h index b286bcbc2c64f4e0189e8ff55c591501538972c1..011b20c4abf5cadbcea479683aaed42e0619fec5 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h @@ -83,6 +83,7 @@ class MemReuseUtil { void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; } uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const; uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const; + bool is_all_nop_node() const { return is_all_nop_node_; } private: int util_index_; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index ddf73841b77c8bb16de3548a6365ce78f2affea9..185df37e4df22a96f97cd4a3a99cfc33a01eb772 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -160,6 +160,12 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { } mem_swap_manager_ = iter->second; MS_EXCEPTION_IF_NULL(mem_swap_manager_); + auto mem_reuse_iter = mem_reuse_util_map_.find(graph_id); + if (mem_reuse_iter == mem_reuse_util_map_.end()) { + MS_LOG(EXCEPTION) << "Find memory reuse map failed."; + } + mem_reuse_util_ = mem_reuse_iter->second; + MS_EXCEPTION_IF_NULL(mem_reuse_util_); while (!LaunchKernelDynamic(graph)) { MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; if (!UpdateMemorySwapInfo(graph)) { @@ -246,18 +252,11 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); - auto graph_id = graph->graph_id(); - auto iter = mem_reuse_util_map_.find(graph_id); - if (iter == mem_reuse_util_map_.end()) { - MS_LOG(EXCEPTION) << "Find memory reuse map failed."; - } - auto mem_reuse_util_ptr = iter->second; - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + MS_EXCEPTION_IF_NULL(mem_reuse_util_); // Reset the reference count. - mem_reuse_util_ptr->ResetDynamicUsedRefCount(); + mem_reuse_util_->ResetDynamicUsedRefCount(); // The inputs and outputs memory of communication kernel need be continuous, so separate processing. AllocCommunicationOpDynamicRes(graph); - auto &kernels = graph->execution_order(); for (const auto &kernel : kernels) { auto kernel_mod = AnfAlgo::GetKernelMod(kernel); @@ -272,7 +271,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { MS_LOG(EXCEPTION) << "Launch kernel failed."; } - FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id); + FreeKernelDynamicRes(kernel, kernel_workspaces); UpdateMemorySwapTask(kernel); } CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); @@ -450,9 +449,16 @@ bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_inputs); + MS_EXCEPTION_IF_NULL(mem_reuse_util_); for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + DeviceAddressPtr device_address; + if (mem_reuse_util_->is_all_nop_node()) { + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + } else { + // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); + } MS_EXCEPTION_IF_NULL(device_address); UpdateHostSwapQueue(device_address); MS_EXCEPTION_IF_NULL(device_address->ptr_); @@ -525,13 +531,21 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(mem_reuse_util_); bool is_need_alloc_memory = false; bool is_need_free_memory = false; size_t total_size = 0; std::vector size_list; DeviceAddressPtrList addr_list; for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + DeviceAddressPtr device_address; + if (mem_reuse_util_->is_all_nop_node()) { + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + } else { + // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); + } MS_EXCEPTION_IF_NULL(device_address); if (device_address->ptr_ == nullptr) { is_need_alloc_memory = true; @@ -593,11 +607,10 @@ void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, boo } void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, - const AddressPtrList &kernel_workspaces, uint32_t graph_id) { + const AddressPtrList &kernel_workspaces) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); - auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + MS_EXCEPTION_IF_NULL(mem_reuse_util_); auto cnode = kernel->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::IsCommunicationOp(kernel)) { @@ -605,7 +618,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } // Free the input of kernel by reference count. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetKernelInputRef(cnode, i); + auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i); if (kernel_ref_count_ptr == nullptr) { continue; } @@ -614,14 +627,21 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; } if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + DeviceAddressPtr device_address; + if (mem_reuse_util_->is_all_nop_node()) { + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + } else { + // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); + } mem_manager_->FreeMemFromMemPool(device_address); device_address->set_status(DeviceAddressStatus::kInDevice); } } // Free the output of kernel, if output has no reference. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) { - auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetRef(cnode, i); + auto kernel_ref_count_ptr = mem_reuse_util_->GetRef(cnode, i); if (kernel_ref_count_ptr == nullptr) { continue; } diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index 2b1f8198ce182a8b2f8d6a216c415a229d4c6420..e1ba34586614047926ceb075cff7b1f30e5bdfee 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -72,8 +72,7 @@ class GPUKernelRuntime : public KernelRuntime { void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, const DeviceAddressPtrList addr_list, size_t total_size, std::vector size_list); - void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, - uint32_t graph_id); + void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces); bool AddMemorySwapTask(const AnfNodePtr &kernel); bool UpdateMemorySwapInfo(const session::KernelGraph *graph); bool UpdateMemorySwapTask(const AnfNodePtr &kernel); @@ -82,6 +81,7 @@ class GPUKernelRuntime : public KernelRuntime { void ClearSwapQueue(); std::unordered_map mem_reuse_util_map_; std::unordered_map mem_swap_map_; + MemReuseUtilPtr mem_reuse_util_{nullptr}; MemSwapManagerPtr mem_swap_manager_{nullptr}; }; MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime);