提交 d90e1215 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!306 gpu uses dynamic memory pool by default

Merge pull request !306 from limingqi107/master
...@@ -127,9 +127,10 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { ...@@ -127,9 +127,10 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
bool is_enable_pynative_infer = context_ptr->enable_pynative_infer();
struct timeval start_time, end_time; struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr); (void)gettimeofday(&start_time, nullptr);
if (is_enable_dynamic_mem) { if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
ret = LaunchKernelDynamic(graph); ret = LaunchKernelDynamic(graph);
} else { } else {
ret = LaunchKernel(graph); ret = LaunchKernel(graph);
...@@ -152,7 +153,7 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { ...@@ -152,7 +153,7 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
} }
mem_reuse_util_ptr->SetKernelDefMap(); mem_reuse_util_ptr->SetKernelDefMap();
mem_reuse_util_ptr->SetReuseRefCount(); mem_reuse_util_ptr->SetReuseRefCount();
// Can't free the device address of graph output, so set the reference count of graph output specially, // Can't free the device address of graph output, so set the reference count of graph output specially.
mem_reuse_util_ptr->SetGraphOutputRefCount(); mem_reuse_util_ptr->SetGraphOutputRefCount();
mem_reuse_util_ptr_ = mem_reuse_util_ptr; mem_reuse_util_ptr_ = mem_reuse_util_ptr;
} }
...@@ -351,6 +352,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, ...@@ -351,6 +352,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
if (kernel_ref_count_ptr == nullptr) { if (kernel_ref_count_ptr == nullptr) {
continue; continue;
} }
// Can't free the output of graph.
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) {
continue;
}
kernel_ref_count_ptr->ref_count_dynamic_use_--; kernel_ref_count_ptr->ref_count_dynamic_use_--;
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
// Reset the reference count. // Reset the reference count.
...@@ -360,14 +365,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, ...@@ -360,14 +365,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op);
if (!is_communication_op) { if (!is_communication_op) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
MS_EXCEPTION_IF_NULL(device_address); mem_manager_->FreeMemFromMemPool(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr;
} }
} }
} }
// Free the workspace of kernel. // Free the workspace of kernel.
for (size_t i = 0; i < kernel_workspaces.size(); ++i) { for (size_t i = 0; i < kernel_workspaces.size(); ++i) {
auto workspace = kernel_workspaces[i]; auto workspace = kernel_workspaces[i];
...@@ -388,10 +389,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr ...@@ -388,10 +389,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
communication_op_input_ref_count_--; communication_op_input_ref_count_--;
if (communication_op_input_ref_count_ == 0) { if (communication_op_input_ref_count_ == 0) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0);
MS_EXCEPTION_IF_NULL(device_address); mem_manager_->FreeMemFromMemPool(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr;
} }
*is_communication_op = true; *is_communication_op = true;
return; return;
...@@ -410,10 +408,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr ...@@ -410,10 +408,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
communication_op_output_ref_count_--; communication_op_output_ref_count_--;
if (communication_op_output_ref_count_ == 0) { if (communication_op_output_ref_count_ == 0) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0);
MS_EXCEPTION_IF_NULL(device_address); mem_manager_->FreeMemFromMemPool(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr;
} }
*is_communication_op = true; *is_communication_op = true;
} }
......
...@@ -155,6 +155,13 @@ void *MemoryManager::MallocMemFromMemPool(size_t size) { ...@@ -155,6 +155,13 @@ void *MemoryManager::MallocMemFromMemPool(size_t size) {
return nullptr; return nullptr;
} }
void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) {
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(address->ptr_);
FreeMemFromMemPool(address->ptr_);
address->ptr_ = nullptr;
}
void MemoryManager::FreeMemFromMemPool(void *device_ptr) { void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
if (device_ptr == nullptr) { if (device_ptr == nullptr) {
MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";
......
...@@ -47,6 +47,7 @@ class MemoryManager { ...@@ -47,6 +47,7 @@ class MemoryManager {
virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
virtual void *MallocMemFromMemPool(size_t size); virtual void *MallocMemFromMemPool(size_t size);
virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
virtual void FreeMemFromMemPool(void *device_ptr); virtual void FreeMemFromMemPool(void *device_ptr);
size_t GetCommonAlignSize(size_t input_size) const; size_t GetCommonAlignSize(size_t input_size) const;
......
...@@ -273,19 +273,11 @@ void MemReuseUtil::SetReuseRefCount() { ...@@ -273,19 +273,11 @@ void MemReuseUtil::SetReuseRefCount() {
} }
void MemReuseUtil::SetGraphOutputRefCount() { void MemReuseUtil::SetGraphOutputRefCount() {
for (const auto &output : graph_->outputs()) { auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
MS_EXCEPTION_IF_NULL(output); for (const auto &node : nodes) {
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { auto kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
if (!(output->isa<CNode>())) {
continue;
}
auto cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = cnode->input(i + 1);
MS_EXCEPTION_IF_NULL(input_node);
auto kernel_input = AnfAlgo::VisitKernel(input_node, 0);
MS_EXCEPTION_IF_NULL(kernel_input.first); MS_EXCEPTION_IF_NULL(kernel_input.first);
if (!(kernel_input.first->isa<CNode>())) { if (!kernel_input.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_input.first)) {
continue; continue;
} }
auto ak_node = kernel_input.first->cast<CNodePtr>(); auto ak_node = kernel_input.first->cast<CNodePtr>();
...@@ -298,7 +290,6 @@ void MemReuseUtil::SetGraphOutputRefCount() { ...@@ -298,7 +290,6 @@ void MemReuseUtil::SetGraphOutputRefCount() {
kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount; kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount;
} }
} }
}
#ifdef MEM_REUSE_DEBUG #ifdef MEM_REUSE_DEBUG
auto graph = *graph_; auto graph = *graph_;
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph);
......
...@@ -75,7 +75,7 @@ MsContext::MsContext(const std::string& policy, const std::string& target) { ...@@ -75,7 +75,7 @@ MsContext::MsContext(const std::string& policy, const std::string& target) {
precompile_only_ = false; precompile_only_ = false;
auto_mixed_precision_flag_ = true; auto_mixed_precision_flag_ = true;
enable_pynative_infer_ = false; enable_pynative_infer_ = false;
enable_dynamic_mem_pool_ = false; enable_dynamic_mem_pool_ = true;
graph_memory_max_size_ = "0"; graph_memory_max_size_ = "0";
variable_memory_max_size_ = "0"; variable_memory_max_size_ = "0";
MS_LOG(INFO) << "Create context with backend policy:" << policy << ", device target:" << target << "."; MS_LOG(INFO) << "Create context with backend policy:" << policy << ", device target:" << target << ".";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册