diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index fd240d41cbb5a34fa25d7480bd091cd2b89005ce..32e2d51c252d512a92d7b97ba8c1232d626c4c84 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -1014,6 +1014,7 @@ void AscendSession::AssignStaticMemory(NotNull graph, // assign static memory for parameters auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->ClearGlobalIdleMem(); runtime_instance->AssignStaticMemoryInput(graph.get().get()); runtime_instance->AssignStaticMemoryValueNode(graph.get().get()); for (auto &child_graph : graph->child_graph_order()) { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 8b9f267bded039767b8e01a466b75e0d4a6575e6..7b0f2621cf89b64172f604cceee0d843c36a7993 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -155,6 +155,8 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std } } +void AscendKernelRuntime::ClearGlobalIdleMem() { mem_manager_->ClearGlobalIdleMem(); } + bool AscendKernelRuntime::NeedDestroyHccl() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index f23da565ff43af4ef2daa378c082397868777c60..8afe6a39ca2c12213d807f5138990b9527b6cea6 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -49,6 +49,7 @@ class AscendKernelRuntime : public KernelRuntime { void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &inputs, const std::unordered_set &value_nodes, const std::vector &execution_order) override; + void ClearGlobalIdleMem() override; bool SyncStream() override; protected: diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc index 93e38278f14f3879a69e30c7eefee0e8a40413c0..bf2f0316f867762bb58a18a00e13354551b402df 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -77,6 +77,8 @@ void AscendMemoryManager::ResetDynamicMemory() { AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); } +void AscendMemoryManager::ClearGlobalIdleMem() { AscendMemoryPool::GetInstance().ResetIdleMemBuf(); } + void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { auto align_size = GetCommonAlignSize(size); return AscendMemoryPool::GetInstance().AllocTensorMem(align_size); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h index fc684f3fd857ac2a7792e5cce44f3be503bb7521..77812b489cbdeda569e1b7fdf08fb7079787babd 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h @@ -28,6 +28,7 @@ class AscendMemoryManager : public MemoryManager { void MallocDeviceMemory() override; void FreeDeviceMemory() override; void ResetDynamicMemory() override; + void ClearGlobalIdleMem() override; void *MallocMemFromMemPool(size_t size) override; protected: diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 636b5c888404e3666f0fa38607c68b390d614d26..1906e778f372bc616d99fc660789edfa4a85a21e 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -74,6 +74,7 @@ class KernelRuntime { const std::unordered_set &value_nodes, const std::vector &execution_order); virtual bool SyncStream() = 0; + virtual void ClearGlobalIdleMem() {} #ifdef ENABLE_DUMP_E2E DumpConfPtr GetDumpConf(); diff --git a/mindspore/ccsrc/runtime/device/memory_manager.h b/mindspore/ccsrc/runtime/device/memory_manager.h index cb045f8d2742ed50a46b1831703447862698a21d..ad2142d1a22ea7a0f4ad6a926c7d5b96bb362a6a 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.h +++ b/mindspore/ccsrc/runtime/device/memory_manager.h @@ -39,6 +39,7 @@ class MemoryManager { total_dynamic_size_ = 0; dynamic_mem_offset_ = 0; } + virtual void ClearGlobalIdleMem() {} void MallocReusedDynamicMem(const session::KernelGraph *graph); uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,