提交 1b4a7cde 编写于 作者: L lizhenyu

fix mem swap bug

上级 0cd9e4cc
......@@ -565,7 +565,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
auto cnode = kernel->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(kernel) == kAllReduceOpName) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
return;
}
// Free the input of kernel by reference count.
......
......@@ -24,7 +24,15 @@ namespace device {
namespace memswap {
void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
execution_order_ = kernel_graph->execution_order();
graph_manager_ = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(graph_manager_);
auto &kernels = kernel_graph->execution_order();
for (const auto &kernel : kernels) {
if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) {
execution_order_.push_back(kernel);
}
}
size_t kernel_index = 0;
for (const auto &kernel : execution_order_) {
// parse topo order of kernel
......@@ -41,7 +49,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
}
// parse topo order of user kernel
SaveUserKernelTopoOrder(kernel_graph);
SaveUserKernelTopoOrder();
sort(ordered_tensors_.begin(), ordered_tensors_.end(),
[](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; });
......@@ -62,11 +70,22 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
mem_copy_manager_->Init();
}
void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
FuncGraphManagerPtr manager = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
NodeUsersMap user_map = manager->node_users();
bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
MS_EXCEPTION_IF_NULL(kernel);
NodeUsersMap &user_map = graph_manager_->node_users();
auto iter = user_map.find(kernel);
bool adjacent_with_communication_op = false;
if (iter != user_map.end()) {
AnfNodeIndexSet node_set = iter->second;
adjacent_with_communication_op = std::any_of(
node_set.begin(), node_set.end(),
[](const std::pair<AnfNodePtr, int> &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); });
}
return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op;
}
void MemSwapManager::SaveUserKernelTopoOrder() {
NodeUsersMap &user_map = graph_manager_->node_users();
for (const auto &kernel : execution_order_) {
auto iter = user_map.find(kernel);
if (iter == user_map.end()) {
......@@ -76,13 +95,16 @@ void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGra
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
for (auto &node_pair : node_set) {
auto user_kernel = node_pair.first;
if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) {
if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) {
continue;
}
size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_;
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1);
auto &output_idx = kernel_with_index.second;
if (kernel_with_index.first.get() != kernel.get()) {
MS_LOG(EXCEPTION) << "Save user kernel topo order failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]";
}
kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort);
}
for (auto &node_user_pair : kernel_exec_info.node_users_map_) {
......@@ -100,6 +122,9 @@ void MemSwapManager::AddSwapInfo() {
size_t output_idx = tensor.output_idx_;
const AnfNodePtr &kernel = tensor.kernel_;
if (IsCommunicationRelevantOp(kernel)) {
continue;
}
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
auto &node_users_map = kernel_exec_info.node_users_map_;
......@@ -178,7 +203,7 @@ bool MemSwapManager::RetreatSwapInfo() {
while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) {
++tensor_size_threshold_idx_;
if (tensor_size_threshold_idx_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) {
if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) {
tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_;
break;
}
......
......@@ -91,7 +91,7 @@ class MemSwapManager {
void ResetSwapInfo();
void SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph);
void SaveUserKernelTopoOrder();
void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap);
......@@ -99,6 +99,8 @@ class MemSwapManager {
void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info);
bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const;
std::vector<CNodePtr> execution_order_;
std::vector<TensorInfo> ordered_tensors_;
std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_;
......@@ -113,7 +115,8 @@ class MemSwapManager {
size_t tensor_size_num_;
size_t distance_threshold_;
MemCopyManagerPtr mem_copy_manager_;
MemCopyManagerPtr mem_copy_manager_{nullptr};
FuncGraphManagerPtr graph_manager_{nullptr};
bool mem_swap_initialized_{false};
bool swap_info_already_set_{false};
bool trigger_swap_{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册