提交 f5dc6fbe 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2919 gpu kernel runtime code review

Merge pull request !2919 from limingqi107/master
......@@ -137,6 +137,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
if (is_enable_dynamic_mem) {
// Use the dynamic memory pool.
InitKernelRefCount(graph);
InitMemorySwapInfo(graph);
InitKernelOutputAddress(graph);
} else {
AssignDynamicMemory(graph);
......@@ -144,27 +145,24 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
}
bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
bool ret = true;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
bool is_enable_pynative_infer = context_ptr->enable_pynative_infer();
auto iter = mem_swap_map_.find(graph);
if (iter == mem_swap_map_.end()) {
GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared<GPUMemCopyManager>();
iter = mem_swap_map_.emplace(graph, std::make_shared<MemSwapManager>(gpu_mem_copy_manager)).first;
}
mem_swap_manager_ = iter->second;
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
auto graph_id = graph->graph_id();
auto iter = mem_swap_map_.find(graph_id);
if (iter == mem_swap_map_.end()) {
MS_LOG(EXCEPTION) << "Find memory swap map failed.";
}
mem_swap_manager_ = iter->second;
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
while (!LaunchKernelDynamic(graph)) {
ClearKernelOutputAddress(graph);
if (!mem_swap_manager_->mem_swap_init()) {
mem_swap_manager_->Init(graph);
}
if (!mem_swap_manager_->RetreatSwapInfo()) {
MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment.";
if (!UpdateMemorySwapInfo(graph)) {
return false;
}
}
......@@ -197,6 +195,16 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr;
}
void GPUKernelRuntime::InitMemorySwapInfo(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared<GPUMemCopyManager>();
MS_EXCEPTION_IF_NULL(gpu_mem_copy_manager);
MemSwapManagerPtr mem_swap_manager = std::make_shared<MemSwapManager>(gpu_mem_copy_manager);
MS_EXCEPTION_IF_NULL(mem_swap_manager);
auto graph_id = graph->graph_id();
mem_swap_map_[graph_id] = mem_swap_manager;
}
void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto &kernels = graph->execution_order();
......@@ -227,7 +235,6 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap
if (!AnfAlgo::OutputAddrExist(kernel, i)) {
continue;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
if (device_address->ptr_) {
mem_manager_->FreeMemFromMemPool(device_address);
......@@ -239,9 +246,12 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap
bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
auto graph_id = graph->graph_id();
auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id];
auto iter = mem_reuse_util_map_.find(graph_id);
if (iter == mem_reuse_util_map_.end()) {
MS_LOG(EXCEPTION) << "Find memory reuse map failed.";
}
auto mem_reuse_util_ptr = iter->second;
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
// Reset the reference count.
mem_reuse_util_ptr->ResetDynamicUsedRefCount();
......@@ -263,27 +273,14 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
MS_LOG(EXCEPTION) << "Launch kernel failed.";
}
FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id);
if (mem_swap_manager_->trigger_swap() && mem_swap_manager_->QueryKernelTriggerSwap(kernel)) {
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
if (!AddMemSwapTask(kernel)) {
return false;
}
}
if (mem_swap_manager_->trigger_swap()) {
mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
}
UpdateMemorySwapTask(kernel);
}
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
if (mem_swap_manager_->trigger_swap()) {
mem_swap_manager_->ClearSwapQueue();
}
ClearSwapQueue();
return true;
}
bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) {
bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel);
for (auto &mem_swap_info : mem_swap_info_list) {
......@@ -311,14 +308,92 @@ bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) {
return true;
}
bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
ClearKernelOutputAddress(graph);
if (!mem_swap_manager_->mem_swap_init()) {
mem_swap_manager_->Init(graph);
}
return mem_swap_manager_->RetreatSwapInfo();
}
bool GPUKernelRuntime::UpdateMemorySwapTask(const AnfNodePtr &kernel) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
if (!mem_swap_manager_->trigger_swap()) {
return true;
}
if (mem_swap_manager_->QueryKernelTriggerSwap(kernel)) {
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
if (!AddMemorySwapTask(kernel)) {
return false;
}
}
CHECK_OP_RET_WITH_EXCEPT(mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost), "SyncCopyStream failed.");
return true;
}
void GPUKernelRuntime::UpdateHostSwapQueue(const DeviceAddressPtr device_address) {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
if (!mem_swap_manager_->trigger_swap()) {
return;
}
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
}
auto status = device_address->status();
switch (status) {
case DeviceAddressStatus::kInDevice:
break;
case DeviceAddressStatus::kInDeviceToHost: {
mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_);
device_address->set_status(DeviceAddressStatus::kInDevice);
break;
}
case DeviceAddressStatus::kInHostToDevice: {
while (device_address->status() != DeviceAddressStatus::kInDevice) {
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
}
}
break;
}
case DeviceAddressStatus::kInHost:
MS_LOG(ERROR) << "Invaild device address status:" << status;
break;
default:
MS_LOG(EXCEPTION) << "Invaild device address status:" << status;
}
}
void GPUKernelRuntime::UpdateDeviceSwapQueue() {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
if (!mem_swap_manager_->trigger_swap()) {
return;
}
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
}
}
}
void GPUKernelRuntime::ClearSwapQueue() {
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
if (!mem_swap_manager_->trigger_swap()) {
return;
}
mem_swap_manager_->ClearSwapQueue();
}
bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) {
MS_EXCEPTION_IF_NULL(mem_manager_);
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, size);
if (!ret) {
if (!mem_swap_manager_->trigger_swap()) {
return false;
}
mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
......@@ -326,7 +401,6 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address,
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
}
}
ret = mem_manager_->MallocMemFromMemPool(device_address, size);
if (!ret) {
return false;
......@@ -337,12 +411,12 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address,
void *GPUKernelRuntime::AttemptMallocMem(size_t size) {
MS_EXCEPTION_IF_NULL(mem_manager_);
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
if (!device_ptr) {
if (!mem_swap_manager_->trigger_swap()) {
return nullptr;
}
mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
......@@ -350,7 +424,6 @@ void *GPUKernelRuntime::AttemptMallocMem(size_t size) {
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
}
}
device_ptr = mem_manager_->MallocMemFromMemPool(size);
if (!device_ptr) {
return nullptr;
......@@ -377,40 +450,11 @@ bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod
bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_inputs);
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false);
MS_EXCEPTION_IF_NULL(device_address);
if (mem_swap_manager_->trigger_swap()) {
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
}
auto status = device_address->status();
switch (status) {
case DeviceAddressStatus::kInDevice:
break;
case DeviceAddressStatus::kInHost:
break;
case DeviceAddressStatus::kInDeviceToHost: {
mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_);
device_address->set_status(DeviceAddressStatus::kInDevice);
break;
}
case DeviceAddressStatus::kInHostToDevice: {
while (device_address->status() != DeviceAddressStatus::kInDevice) {
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
}
}
break;
}
default:
MS_LOG(ERROR) << "Invaild device address status";
return false;
}
}
UpdateHostSwapQueue(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input);
......@@ -426,16 +470,7 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern
AddressPtrList *kernel_outputs) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_outputs);
MS_EXCEPTION_IF_NULL(mem_manager_);
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
if (mem_swap_manager_->trigger_swap()) {
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
}
}
}
UpdateDeviceSwapQueue();
auto output_sizes = kernel_mod.GetOutputSizeList();
for (size_t i = 0; i < output_sizes.size(); ++i) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
......
......@@ -53,9 +53,9 @@ class GPUKernelRuntime : public KernelRuntime {
// The related functions and members for using dynamic memory pool.
void InitKernelRefCount(const session::KernelGraph *graph);
void InitKernelOutputAddress(const session::KernelGraph *graph);
void InitMemorySwapInfo(const session::KernelGraph *graph);
void ClearKernelOutputAddress(const session::KernelGraph *graph);
bool LaunchKernelDynamic(const session::KernelGraph *graph);
bool AddMemSwapTask(const AnfNodePtr &kernel);
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size);
void *AttemptMallocMem(size_t size);
bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
......@@ -74,8 +74,14 @@ class GPUKernelRuntime : public KernelRuntime {
std::vector<size_t> size_list);
void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces,
uint32_t graph_id);
bool AddMemorySwapTask(const AnfNodePtr &kernel);
bool UpdateMemorySwapInfo(const session::KernelGraph *graph);
bool UpdateMemorySwapTask(const AnfNodePtr &kernel);
void UpdateHostSwapQueue(const DeviceAddressPtr device_address);
void UpdateDeviceSwapQueue();
void ClearSwapQueue();
std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
std::unordered_map<void *, MemSwapManagerPtr> mem_swap_map_;
std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
MemSwapManagerPtr mem_swap_manager_{nullptr};
};
MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime);
......
......@@ -187,8 +187,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
GetSummaryNodes(graph.get());
// Remove NoOp from execution graph
opt::RemoveNopNode(graph.get());
// Alloc memory, including static memory and dynamic memory
AllocateMemory(graph.get());
// Set graph manager.
MS_EXCEPTION_IF_NULL(context_);
FuncGraphManagerPtr manager = MakeManager({graph});
context_->AddManager(manager);
......@@ -196,6 +195,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
manager->AddFuncGraph(graph);
graph->set_manager(manager);
}
// Alloc memory, including static memory and dynamic memory
AllocateMemory(graph.get());
return graph_id;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册