diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index 584f66eee7e27d630c231919f2e39f47bb69bff7..11b8bdc16288c2540476a669889b307f48c6d944 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -155,7 +155,8 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { mem_reuse_util_ptr->SetReuseRefCount(); // Can't free the device address of graph output, so set the reference count of graph output specially. mem_reuse_util_ptr->SetGraphOutputRefCount(); - mem_reuse_util_ptr_ = mem_reuse_util_ptr; + auto graph_id = graph->graph_id(); + mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr; } void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { @@ -179,6 +180,7 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); + auto graph_id = graph->graph_id(); // The inputs and outputs memory of communication kernel are special, so separate processing. AllocCommunicationOpDynamicRes(graph); @@ -194,7 +196,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_LOG(ERROR) << "Launch kernel failed."; return false; } - FreeKernelDynamicRes(kernel, kernel_workspaces); + FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id); } if (!SyncStream()) { @@ -341,14 +343,16 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf } void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, - const AddressPtrList &kernel_workspaces) { + const AddressPtrList &kernel_workspaces, uint32_t graph_id) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); + auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); auto cnode = kernel->cast(); MS_EXCEPTION_IF_NULL(cnode); // Free the input of kernel by reference count. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto kernel_ref_count_ptr = mem_reuse_util_ptr_->GetKernelInputRef(cnode, i); + auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetKernelInputRef(cnode, i); if (kernel_ref_count_ptr == nullptr) { continue; } @@ -361,7 +365,6 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, // Reset the reference count. kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; bool is_communication_op = false; - // The inputs and outputs memory of communication kernel are special, so separate processing. FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); if (!is_communication_op) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); @@ -369,6 +372,17 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } } } + // Free the output of kernel, if output has no reference. + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) { + auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetRef(cnode, i); + if (kernel_ref_count_ptr == nullptr) { + continue; + } + if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); + mem_manager_->FreeMemFromMemPool(device_address); + } + } // Free the workspace of kernel. for (size_t i = 0; i < kernel_workspaces.size(); ++i) { auto workspace = kernel_workspaces[i]; diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h index 6f761342d36e2b611d6edc0be1a04dde25ef5c57..e0eb2dc3f17fdf076d88ad3a85625aea8fbf4589 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "device/kernel_runtime.h" #include "device/kernel_runtime_manager.h" @@ -57,11 +58,12 @@ class GPUKernelRuntime : public KernelRuntime { void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); - void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces); + void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, + uint32_t graph_id); void FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, bool *is_communication_op); size_t communication_op_input_ref_count_{0}; size_t communication_op_output_ref_count_{0}; - MemReuseUtilPtr mem_reuse_util_ptr_{nullptr}; + std::unordered_map mem_reuse_util_map_; }; MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); } // namespace gpu