From d9197b591a25834a0d45aad1f7365619541a8949 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Wed, 6 May 2020 10:52:10 +0800 Subject: [PATCH] gpu optimize the use of reference count --- mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc | 13 +++++++------ mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc | 8 ++++++++ mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h | 2 ++ .../pre_activate/mem_reuse/mem_reuse_checker.cc | 3 ++- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index 17817ebeb..387977d34 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -184,6 +184,10 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto graph_id = graph->graph_id(); + auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + // Reset the reference count. + mem_reuse_util_ptr->ResetDynamicUsedRefCount(); // The inputs and outputs memory of communication kernel need be continuous, so separate processing. AllocCommunicationOpDynamicRes(graph); @@ -360,16 +364,13 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, if (kernel_ref_count_ptr == nullptr) { continue; } - // Can't free the output of graph. - if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) { - continue; - } kernel_ref_count_ptr->ref_count_dynamic_use_--; + if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) { + MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; + } if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); mem_manager_->FreeMemFromMemPool(device_address); - // Reset the reference count. - kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; } } // Free the output of kernel, if output has no reference. diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index 952dfe97e..6128f1458 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -288,6 +288,14 @@ void MemReuseUtil::SetGraphOutputRefCount() { #endif } +void MemReuseUtil::ResetDynamicUsedRefCount() { + for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) { + for (auto &ref_count : iter->second) { + ref_count->ref_count_dynamic_use_ = ref_count->ref_count_; + } + } +} + void MemReuseUtil::SetAllInfo(KernelGraph *graph) { if (!InitDynamicKernelRef(graph)) { MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h index cae0e4565..20a362e76 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h @@ -64,6 +64,8 @@ class MemReuseUtil { void SetReuseRefCount(); // Set the reference count of graph output specially. void SetGraphOutputRefCount(); + // Reset the dynamic used reference count by ref_count_. + void ResetDynamicUsedRefCount(); KernelRefCountPtr GetRef(const AnfNodePtr &node, int output_idx); KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc index 6825da24d..1dd276ad6 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc @@ -161,7 +161,8 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li total_ori_value_size_ = CalculOriValue(graph); total_ori_dy_size_ = CalculOriDy(graph); total_ori_wkspace_size_ = CalculOriWk(graph); - std::string filename = "./memreuse.ir"; + std::string graph_id = std::to_string(graph->graph_id()); + std::string filename = "./memreuse_" + graph_id + ".ir"; std::ofstream ofs(filename); if (!ofs.is_open()) { MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; -- GitLab