提交 0c316e52 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5866 clean idle mem at proper time

Merge pull request !5866 from liangzelang/fix_global_step_error
...@@ -1014,6 +1014,7 @@ void AscendSession::AssignStaticMemory(NotNull<KernelGraphPtr> graph, ...@@ -1014,6 +1014,7 @@ void AscendSession::AssignStaticMemory(NotNull<KernelGraphPtr> graph,
// assign static memory for parameters // assign static memory for parameters
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->ClearGlobalIdleMem();
runtime_instance->AssignStaticMemoryInput(graph.get().get()); runtime_instance->AssignStaticMemoryInput(graph.get().get());
runtime_instance->AssignStaticMemoryValueNode(graph.get().get()); runtime_instance->AssignStaticMemoryValueNode(graph.get().get());
for (auto &child_graph : graph->child_graph_order()) { for (auto &child_graph : graph->child_graph_order()) {
......
...@@ -155,6 +155,8 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std ...@@ -155,6 +155,8 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std
} }
} }
void AscendKernelRuntime::ClearGlobalIdleMem() { mem_manager_->ClearGlobalIdleMem(); }
bool AscendKernelRuntime::NeedDestroyHccl() { bool AscendKernelRuntime::NeedDestroyHccl() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
......
...@@ -49,6 +49,7 @@ class AscendKernelRuntime : public KernelRuntime { ...@@ -49,6 +49,7 @@ class AscendKernelRuntime : public KernelRuntime {
void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes, const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override; const std::vector<CNodePtr> &execution_order) override;
void ClearGlobalIdleMem() override;
bool SyncStream() override; bool SyncStream() override;
protected: protected:
......
...@@ -77,6 +77,8 @@ void AscendMemoryManager::ResetDynamicMemory() { ...@@ -77,6 +77,8 @@ void AscendMemoryManager::ResetDynamicMemory() {
AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
} }
void AscendMemoryManager::ClearGlobalIdleMem() { AscendMemoryPool::GetInstance().ResetIdleMemBuf(); }
void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { void *AscendMemoryManager::MallocMemFromMemPool(size_t size) {
auto align_size = GetCommonAlignSize(size); auto align_size = GetCommonAlignSize(size);
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size); return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
......
...@@ -28,6 +28,7 @@ class AscendMemoryManager : public MemoryManager { ...@@ -28,6 +28,7 @@ class AscendMemoryManager : public MemoryManager {
void MallocDeviceMemory() override; void MallocDeviceMemory() override;
void FreeDeviceMemory() override; void FreeDeviceMemory() override;
void ResetDynamicMemory() override; void ResetDynamicMemory() override;
void ClearGlobalIdleMem() override;
void *MallocMemFromMemPool(size_t size) override; void *MallocMemFromMemPool(size_t size) override;
protected: protected:
......
...@@ -74,6 +74,7 @@ class KernelRuntime { ...@@ -74,6 +74,7 @@ class KernelRuntime {
const std::unordered_set<ValueNodePtr> &value_nodes, const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order); const std::vector<CNodePtr> &execution_order);
virtual bool SyncStream() = 0; virtual bool SyncStream() = 0;
virtual void ClearGlobalIdleMem() {}
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
DumpConfPtr GetDumpConf(); DumpConfPtr GetDumpConf();
......
...@@ -39,6 +39,7 @@ class MemoryManager { ...@@ -39,6 +39,7 @@ class MemoryManager {
total_dynamic_size_ = 0; total_dynamic_size_ = 0;
dynamic_mem_offset_ = 0; dynamic_mem_offset_ = 0;
} }
virtual void ClearGlobalIdleMem() {}
void MallocReusedDynamicMem(const session::KernelGraph *graph); void MallocReusedDynamicMem(const session::KernelGraph *graph);
uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size, uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册