提交 99f12f91 编写于 作者: L limingqi107

gpu uses dynamic memory pool by default

上级 a4cf9028
......@@ -127,9 +127,10 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
bool is_enable_pynative_infer = context_ptr->enable_pynative_infer();
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
if (is_enable_dynamic_mem) {
if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
ret = LaunchKernelDynamic(graph);
} else {
ret = LaunchKernel(graph);
......@@ -152,7 +153,7 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
}
mem_reuse_util_ptr->SetKernelDefMap();
mem_reuse_util_ptr->SetReuseRefCount();
// Can't free the device address of graph output, so set the reference count of graph output specially,
// Can't free the device address of graph output, so set the reference count of graph output specially.
mem_reuse_util_ptr->SetGraphOutputRefCount();
mem_reuse_util_ptr_ = mem_reuse_util_ptr;
}
......@@ -351,6 +352,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
if (kernel_ref_count_ptr == nullptr) {
continue;
}
// Can't free the output of graph.
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) {
continue;
}
kernel_ref_count_ptr->ref_count_dynamic_use_--;
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
// Reset the reference count.
......@@ -360,14 +365,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op);
if (!is_communication_op) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr;
mem_manager_->FreeMemFromMemPool(device_address);
}
}
}
// Free the workspace of kernel.
for (size_t i = 0; i < kernel_workspaces.size(); ++i) {
auto workspace = kernel_workspaces[i];
......@@ -388,10 +389,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
communication_op_input_ref_count_--;
if (communication_op_input_ref_count_ == 0) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0);
MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr;
mem_manager_->FreeMemFromMemPool(device_address);
}
*is_communication_op = true;
return;
......@@ -410,10 +408,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
communication_op_output_ref_count_--;
if (communication_op_output_ref_count_ == 0) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0);
MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr;
mem_manager_->FreeMemFromMemPool(device_address);
}
*is_communication_op = true;
}
......
......@@ -155,6 +155,13 @@ void *MemoryManager::MallocMemFromMemPool(size_t size) {
return nullptr;
}
void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) {
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(address->ptr_);
FreeMemFromMemPool(address->ptr_);
address->ptr_ = nullptr;
}
void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
if (device_ptr == nullptr) {
MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";
......
......@@ -47,6 +47,7 @@ class MemoryManager {
virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
virtual void *MallocMemFromMemPool(size_t size);
virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
virtual void FreeMemFromMemPool(void *device_ptr);
size_t GetCommonAlignSize(size_t input_size) const;
......
......@@ -273,30 +273,21 @@ void MemReuseUtil::SetReuseRefCount() {
}
void MemReuseUtil::SetGraphOutputRefCount() {
for (const auto &output : graph_->outputs()) {
MS_EXCEPTION_IF_NULL(output);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) {
if (!(output->isa<CNode>())) {
continue;
}
auto cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = cnode->input(i + 1);
MS_EXCEPTION_IF_NULL(input_node);
auto kernel_input = AnfAlgo::VisitKernel(input_node, 0);
MS_EXCEPTION_IF_NULL(kernel_input.first);
if (!(kernel_input.first->isa<CNode>())) {
continue;
}
auto ak_node = kernel_input.first->cast<CNodePtr>();
auto key = ak_node.get();
auto iter = kernel_output_refs_.find(key);
if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) {
auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second];
MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr);
kernel_ref_count_ptr->ref_count_ = kMaxRefCount;
kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount;
}
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
for (const auto &node : nodes) {
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
MS_EXCEPTION_IF_NULL(kernel_input.first);
if (!kernel_input.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_input.first)) {
continue;
}
auto ak_node = kernel_input.first->cast<CNodePtr>();
auto key = ak_node.get();
auto iter = kernel_output_refs_.find(key);
if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) {
auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second];
MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr);
kernel_ref_count_ptr->ref_count_ = kMaxRefCount;
kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount;
}
}
#ifdef MEM_REUSE_DEBUG
......
......@@ -75,7 +75,7 @@ MsContext::MsContext(const std::string& policy, const std::string& target) {
precompile_only_ = false;
auto_mixed_precision_flag_ = true;
enable_pynative_infer_ = false;
enable_dynamic_mem_pool_ = false;
enable_dynamic_mem_pool_ = true;
graph_memory_max_size_ = "0";
variable_memory_max_size_ = "0";
MS_LOG(INFO) << "Create context with backend policy:" << policy << ", device target:" << target << ".";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册