提交 43cfd16e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!444 gpu dynamic memory pool suppoerts multi-graph

Merge pull request !444 from limingqi107/master
...@@ -155,7 +155,8 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { ...@@ -155,7 +155,8 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
mem_reuse_util_ptr->SetReuseRefCount(); mem_reuse_util_ptr->SetReuseRefCount();
// Can't free the device address of graph output, so set the reference count of graph output specially. // Can't free the device address of graph output, so set the reference count of graph output specially.
mem_reuse_util_ptr->SetGraphOutputRefCount(); mem_reuse_util_ptr->SetGraphOutputRefCount();
mem_reuse_util_ptr_ = mem_reuse_util_ptr; auto graph_id = graph->graph_id();
mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr;
} }
void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) {
...@@ -179,6 +180,7 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph ...@@ -179,6 +180,7 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph
bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto graph_id = graph->graph_id();
// The inputs and outputs memory of communication kernel are special, so separate processing. // The inputs and outputs memory of communication kernel are special, so separate processing.
AllocCommunicationOpDynamicRes(graph); AllocCommunicationOpDynamicRes(graph);
...@@ -194,7 +196,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { ...@@ -194,7 +196,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
MS_LOG(ERROR) << "Launch kernel failed."; MS_LOG(ERROR) << "Launch kernel failed.";
return false; return false;
} }
FreeKernelDynamicRes(kernel, kernel_workspaces); FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id);
} }
if (!SyncStream()) { if (!SyncStream()) {
...@@ -341,14 +343,16 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf ...@@ -341,14 +343,16 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf
} }
void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
const AddressPtrList &kernel_workspaces) { const AddressPtrList &kernel_workspaces, uint32_t graph_id) {
MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id];
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
auto cnode = kernel->cast<CNodePtr>(); auto cnode = kernel->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// Free the input of kernel by reference count. // Free the input of kernel by reference count.
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto kernel_ref_count_ptr = mem_reuse_util_ptr_->GetKernelInputRef(cnode, i); auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetKernelInputRef(cnode, i);
if (kernel_ref_count_ptr == nullptr) { if (kernel_ref_count_ptr == nullptr) {
continue; continue;
} }
...@@ -361,7 +365,6 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, ...@@ -361,7 +365,6 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
// Reset the reference count. // Reset the reference count.
kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_;
bool is_communication_op = false; bool is_communication_op = false;
// The inputs and outputs memory of communication kernel are special, so separate processing.
FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op);
if (!is_communication_op) { if (!is_communication_op) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
...@@ -369,6 +372,17 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, ...@@ -369,6 +372,17 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
} }
} }
} }
// Free the output of kernel, if output has no reference.
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) {
auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetRef(cnode, i);
if (kernel_ref_count_ptr == nullptr) {
continue;
}
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
mem_manager_->FreeMemFromMemPool(device_address);
}
}
// Free the workspace of kernel. // Free the workspace of kernel.
for (size_t i = 0; i < kernel_workspaces.size(); ++i) { for (size_t i = 0; i < kernel_workspaces.size(); ++i) {
auto workspace = kernel_workspaces[i]; auto workspace = kernel_workspaces[i];
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <unordered_map>
#include "device/kernel_runtime.h" #include "device/kernel_runtime.h"
#include "device/kernel_runtime_manager.h" #include "device/kernel_runtime_manager.h"
...@@ -57,11 +58,12 @@ class GPUKernelRuntime : public KernelRuntime { ...@@ -57,11 +58,12 @@ class GPUKernelRuntime : public KernelRuntime {
void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph);
void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel);
void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel);
void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces); void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces,
uint32_t graph_id);
void FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, bool *is_communication_op); void FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, bool *is_communication_op);
size_t communication_op_input_ref_count_{0}; size_t communication_op_input_ref_count_{0};
size_t communication_op_output_ref_count_{0}; size_t communication_op_output_ref_count_{0};
MemReuseUtilPtr mem_reuse_util_ptr_{nullptr}; std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
}; };
MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime);
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册