提交 1b4a7cde 编写于 作者: L lizhenyu

fix mem swap bug

上级 0cd9e4cc
...@@ -565,7 +565,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, ...@@ -565,7 +565,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
auto cnode = kernel->cast<CNodePtr>(); auto cnode = kernel->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(kernel) == kAllReduceOpName) { if (AnfAlgo::IsCommunicationOp(kernel)) {
return; return;
} }
// Free the input of kernel by reference count. // Free the input of kernel by reference count.
......
...@@ -24,7 +24,15 @@ namespace device { ...@@ -24,7 +24,15 @@ namespace device {
namespace memswap { namespace memswap {
void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
execution_order_ = kernel_graph->execution_order(); graph_manager_ = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(graph_manager_);
auto &kernels = kernel_graph->execution_order();
for (const auto &kernel : kernels) {
if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) {
execution_order_.push_back(kernel);
}
}
size_t kernel_index = 0; size_t kernel_index = 0;
for (const auto &kernel : execution_order_) { for (const auto &kernel : execution_order_) {
// parse topo order of kernel // parse topo order of kernel
...@@ -41,7 +49,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { ...@@ -41,7 +49,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
} }
// parse topo order of user kernel // parse topo order of user kernel
SaveUserKernelTopoOrder(kernel_graph); SaveUserKernelTopoOrder();
sort(ordered_tensors_.begin(), ordered_tensors_.end(), sort(ordered_tensors_.begin(), ordered_tensors_.end(),
[](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; }); [](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; });
...@@ -62,11 +70,22 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { ...@@ -62,11 +70,22 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
mem_copy_manager_->Init(); mem_copy_manager_->Init();
} }
void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) { bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel);
FuncGraphManagerPtr manager = kernel_graph->manager(); NodeUsersMap &user_map = graph_manager_->node_users();
MS_EXCEPTION_IF_NULL(manager); auto iter = user_map.find(kernel);
NodeUsersMap user_map = manager->node_users(); bool adjacent_with_communication_op = false;
if (iter != user_map.end()) {
AnfNodeIndexSet node_set = iter->second;
adjacent_with_communication_op = std::any_of(
node_set.begin(), node_set.end(),
[](const std::pair<AnfNodePtr, int> &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); });
}
return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op;
}
void MemSwapManager::SaveUserKernelTopoOrder() {
NodeUsersMap &user_map = graph_manager_->node_users();
for (const auto &kernel : execution_order_) { for (const auto &kernel : execution_order_) {
auto iter = user_map.find(kernel); auto iter = user_map.find(kernel);
if (iter == user_map.end()) { if (iter == user_map.end()) {
...@@ -76,13 +95,16 @@ void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGra ...@@ -76,13 +95,16 @@ void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGra
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
for (auto &node_pair : node_set) { for (auto &node_pair : node_set) {
auto user_kernel = node_pair.first; auto user_kernel = node_pair.first;
if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) { if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) {
continue; continue;
} }
size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_; size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_;
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1); auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1);
auto &output_idx = kernel_with_index.second; auto &output_idx = kernel_with_index.second;
if (kernel_with_index.first.get() != kernel.get()) {
MS_LOG(EXCEPTION) << "Save user kernel topo order failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]";
}
kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort); kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort);
} }
for (auto &node_user_pair : kernel_exec_info.node_users_map_) { for (auto &node_user_pair : kernel_exec_info.node_users_map_) {
...@@ -100,6 +122,9 @@ void MemSwapManager::AddSwapInfo() { ...@@ -100,6 +122,9 @@ void MemSwapManager::AddSwapInfo() {
size_t output_idx = tensor.output_idx_; size_t output_idx = tensor.output_idx_;
const AnfNodePtr &kernel = tensor.kernel_; const AnfNodePtr &kernel = tensor.kernel_;
if (IsCommunicationRelevantOp(kernel)) {
continue;
}
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
auto &node_users_map = kernel_exec_info.node_users_map_; auto &node_users_map = kernel_exec_info.node_users_map_;
...@@ -178,7 +203,7 @@ bool MemSwapManager::RetreatSwapInfo() { ...@@ -178,7 +203,7 @@ bool MemSwapManager::RetreatSwapInfo() {
while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) {
++tensor_size_threshold_idx_; ++tensor_size_threshold_idx_;
if (tensor_size_threshold_idx_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) {
tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_;
break; break;
} }
......
...@@ -91,7 +91,7 @@ class MemSwapManager { ...@@ -91,7 +91,7 @@ class MemSwapManager {
void ResetSwapInfo(); void ResetSwapInfo();
void SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph); void SaveUserKernelTopoOrder();
void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap);
...@@ -99,6 +99,8 @@ class MemSwapManager { ...@@ -99,6 +99,8 @@ class MemSwapManager {
void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info);
bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const;
std::vector<CNodePtr> execution_order_; std::vector<CNodePtr> execution_order_;
std::vector<TensorInfo> ordered_tensors_; std::vector<TensorInfo> ordered_tensors_;
std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_; std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_;
...@@ -113,7 +115,8 @@ class MemSwapManager { ...@@ -113,7 +115,8 @@ class MemSwapManager {
size_t tensor_size_num_; size_t tensor_size_num_;
size_t distance_threshold_; size_t distance_threshold_;
MemCopyManagerPtr mem_copy_manager_; MemCopyManagerPtr mem_copy_manager_{nullptr};
FuncGraphManagerPtr graph_manager_{nullptr};
bool mem_swap_initialized_{false}; bool mem_swap_initialized_{false};
bool swap_info_already_set_{false}; bool swap_info_already_set_{false};
bool trigger_swap_{false}; bool trigger_swap_{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册