diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index 17817ebebab2e75e4fc7c9091e745a79284a3bef..387977d34503d0e63ebc3dd01c67e373e0408cd2 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -184,6 +184,10 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto graph_id = graph->graph_id(); + auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + // Reset the reference count. + mem_reuse_util_ptr->ResetDynamicUsedRefCount(); // The inputs and outputs memory of communication kernel need be continuous, so separate processing. AllocCommunicationOpDynamicRes(graph); @@ -360,16 +364,13 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, if (kernel_ref_count_ptr == nullptr) { continue; } - // Can't free the output of graph. - if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) { - continue; - } kernel_ref_count_ptr->ref_count_dynamic_use_--; + if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) { + MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; + } if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); mem_manager_->FreeMemFromMemPool(device_address); - // Reset the reference count. - kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; } } // Free the output of kernel, if output has no reference. diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index 952dfe97e4b716f6743641b75164bf0147a519d5..6128f14582930dfa015cfafa73abbebdaef33f9a 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -288,6 +288,14 @@ void MemReuseUtil::SetGraphOutputRefCount() { #endif } +void MemReuseUtil::ResetDynamicUsedRefCount() { + for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) { + for (auto &ref_count : iter->second) { + ref_count->ref_count_dynamic_use_ = ref_count->ref_count_; + } + } +} + void MemReuseUtil::SetAllInfo(KernelGraph *graph) { if (!InitDynamicKernelRef(graph)) { MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h index cae0e4565f81896c5d7ca12ab57005ee62b855a1..20a362e76fd5f81974bfb9f0fbba222376571917 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h @@ -64,6 +64,8 @@ class MemReuseUtil { void SetReuseRefCount(); // Set the reference count of graph output specially. void SetGraphOutputRefCount(); + // Reset the dynamic used reference count by ref_count_. + void ResetDynamicUsedRefCount(); KernelRefCountPtr GetRef(const AnfNodePtr &node, int output_idx); KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc index 6825da24dad9f7d8619a1325b6b97308de59ad3b..1dd276ad63fde8e1962afd37b4cc839cc3b23f00 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc @@ -161,7 +161,8 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li total_ori_value_size_ = CalculOriValue(graph); total_ori_dy_size_ = CalculOriDy(graph); total_ori_wkspace_size_ = CalculOriWk(graph); - std::string filename = "./memreuse.ir"; + std::string graph_id = std::to_string(graph->graph_id()); + std::string filename = "./memreuse_" + graph_id + ".ir"; std::ofstream ofs(filename); if (!ofs.is_open()) { MS_LOG(ERROR) << "Open file [" << filename << "] failed!";