diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index 8095a503e3b2f1757b0c0459873e36c805e90d84..ad0e093d7ffed891abe23098b93a403dbb5be812 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -565,7 +565,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); auto cnode = kernel->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(kernel) == kAllReduceOpName) { + if (AnfAlgo::IsCommunicationOp(kernel)) { return; } // Free the input of kernel by reference count. diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc index d81364edfb16da7df26c2bceb821fa05b00a48a2..14073bfbc95445dc22a5e5c75185c136f7276cc2 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc @@ -24,7 +24,15 @@ namespace device { namespace memswap { void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); - execution_order_ = kernel_graph->execution_order(); + graph_manager_ = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(graph_manager_); + auto &kernels = kernel_graph->execution_order(); + for (const auto &kernel : kernels) { + if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) { + execution_order_.push_back(kernel); + } + } + size_t kernel_index = 0; for (const auto &kernel : execution_order_) { // parse topo order of kernel @@ -41,7 +49,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { } // parse topo order of user kernel - SaveUserKernelTopoOrder(kernel_graph); + SaveUserKernelTopoOrder(); sort(ordered_tensors_.begin(), ordered_tensors_.end(), [](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; }); @@ -62,11 +70,22 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { mem_copy_manager_->Init(); } -void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - FuncGraphManagerPtr manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - NodeUsersMap user_map = manager->node_users(); +bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { + MS_EXCEPTION_IF_NULL(kernel); + NodeUsersMap &user_map = graph_manager_->node_users(); + auto iter = user_map.find(kernel); + bool adjacent_with_communication_op = false; + if (iter != user_map.end()) { + AnfNodeIndexSet node_set = iter->second; + adjacent_with_communication_op = std::any_of( + node_set.begin(), node_set.end(), + [](const std::pair &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); + } + return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op; +} + +void MemSwapManager::SaveUserKernelTopoOrder() { + NodeUsersMap &user_map = graph_manager_->node_users(); for (const auto &kernel : execution_order_) { auto iter = user_map.find(kernel); if (iter == user_map.end()) { @@ -76,13 +95,16 @@ void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGra auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); for (auto &node_pair : node_set) { auto user_kernel = node_pair.first; - if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) { + if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) { continue; } size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_; auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1); auto &output_idx = kernel_with_index.second; + if (kernel_with_index.first.get() != kernel.get()) { + MS_LOG(EXCEPTION) << "Save user kernel topo order failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort); } for (auto &node_user_pair : kernel_exec_info.node_users_map_) { @@ -100,6 +122,9 @@ void MemSwapManager::AddSwapInfo() { size_t output_idx = tensor.output_idx_; const AnfNodePtr &kernel = tensor.kernel_; + if (IsCommunicationRelevantOp(kernel)) { + continue; + } auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); auto &node_users_map = kernel_exec_info.node_users_map_; @@ -178,7 +203,7 @@ bool MemSwapManager::RetreatSwapInfo() { while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { ++tensor_size_threshold_idx_; - if (tensor_size_threshold_idx_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { + if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; break; } diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h index 7e2823d27c4fd5f49c3eb793c0501f1a683b480d..1969dadb54c3c0ab58ada3aa4e7df7e16ebc344b 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h @@ -91,7 +91,7 @@ class MemSwapManager { void ResetSwapInfo(); - void SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph); + void SaveUserKernelTopoOrder(); void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); @@ -99,6 +99,8 @@ class MemSwapManager { void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); + bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; + std::vector execution_order_; std::vector ordered_tensors_; std::unordered_map kernel_execution_info_; @@ -113,7 +115,8 @@ class MemSwapManager { size_t tensor_size_num_; size_t distance_threshold_; - MemCopyManagerPtr mem_copy_manager_; + MemCopyManagerPtr mem_copy_manager_{nullptr}; + FuncGraphManagerPtr graph_manager_{nullptr}; bool mem_swap_initialized_{false}; bool swap_info_already_set_{false}; bool trigger_swap_{false};