diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_manager.h b/mindspore/ccsrc/device/ascend/ascend_memory_manager.h index dea88ac10a93a3b116b8d439e558fb6c403895a1..90c8b2dfca71f9e0105e0c4cff00968c08ff59c7 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_manager.h +++ b/mindspore/ccsrc/device/ascend/ascend_memory_manager.h @@ -23,7 +23,7 @@ namespace ascend { class AscendMemoryManager : public MemoryManager { public: AscendMemoryManager() = default; - virtual ~AscendMemoryManager() = default; + ~AscendMemoryManager() override = default; void MallocDeviceMemory() override; void FreeDeviceMemory() override; diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h index c2a29725f410832abd0e3fc2f6f6e80bb37895d7..a02bd453b2c0782bce2a7ca7333fb7499c55212a 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h +++ b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h @@ -26,6 +26,8 @@ namespace ascend { class AscendMemoryPool : public DynamicMemPoolBestFit { public: ~AscendMemoryPool() override = default; + AscendMemoryPool(const AscendMemoryPool&) = delete; + AscendMemoryPool& operator=(const AscendMemoryPool&) = delete; size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; bool FreeDeviceMem(const DeviceMemPtr& addr) override; @@ -51,13 +53,11 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { private: AscendMemoryPool() = default; - AscendMemoryPool(const AscendMemoryPool&) = delete; - AscendMemoryPool& operator=(const AscendMemoryPool&) = delete; bool has_malloc_{false}; uint8_t* device_mem_pool_base_{nullptr}; uint64_t device_mem_pool_size_{0}; - size_t free_mem_size_; - size_t total_mem_size_; + size_t free_mem_size_{0}; + size_t total_mem_size_{0}; }; } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 0fcb3ce39e805ef403156003dbda1a23460faab4..026a6dd95b5a1322e15329c60587d1334b75d087 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -858,6 +858,14 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { return false; } +bool AnfRuntimeAlgorithm::IsAllReduceOp(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) { + return true; + } + return false; +} + bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { auto kernel_name = AnfAlgo::GetCNodeName(node); return kernel_name == kGetNextOpName; diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index a70a63b6786b12c64320767f497d56520e6e18b2..78359cdd5a44f6a765897240211798df48ae08d2 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -176,6 +176,7 @@ class AnfRuntimeAlgorithm { // get real input index for some tbe ops which input order is different between me and tbe impl static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static bool IsCommunicationOp(const AnfNodePtr &node); + static bool IsAllReduceOp(const AnfNodePtr &node); static bool IsGetNext(const NotNull &node); }; } // namespace session diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 3c647bf21d9537ee822e06629fb4c05917a811f3..139539ccb23df4db503e32652ce8059135fd5e76 100755 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -50,90 +50,127 @@ std::vector KernelGraph::outputs() const { } void KernelGraph::SetExecOrderByDefault() { - BfsToUpdateNodeOutput(); + std::stack seed_nodes; + UpdateNodeEdgeList(&seed_nodes); execution_order_.clear(); - std::queue allreduce_nodes; - std::queue zero_output_nodes; std::unordered_set visited_nodes; - auto clear_output = [&zero_output_nodes, &allreduce_nodes, &visited_nodes, this](const AnfNodePtr &input) -> void { - if (node_output_num_[input] == 0 && visited_nodes.find(input) == visited_nodes.end()) { - MS_EXCEPTION_IF_NULL(input); - MS_LOG(DEBUG) << "Clear output num:" << input->DebugString(); - (void)visited_nodes.insert(input); - if (input->isa() && AnfAlgo::GetCNodeName(input) == kAllReduceOpName) { - allreduce_nodes.push(input); - } else { - zero_output_nodes.push(input); + std::queue zero_input_nodes; + + auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue *visit_queue) { + auto it = node_output_edges_.find(node); + if (it == node_output_edges_.end()) { + // value node and parameter has no input,no need to print log + if (node->isa()) { + MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; } + return; + } + + // visit all reduce node first, then other nodes + std::vector active_nodes; + for (const auto &output_edge : it->second) { + auto next_node = output_edge.first; + if (node_input_num_.find(next_node) == node_input_num_.end()) { + MS_EXCEPTION_IF_NULL(next_node); + MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; + } + MS_EXCEPTION_IF_NULL(next_node); + MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() + << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; + if (node_input_num_[next_node] < output_edge.second) { + MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" + << node_input_num_[next_node] << ",depend edge:" << output_edge.second; + } + node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; + // allreduce first + if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) { + (void)visited_nodes.insert(next_node); + if (AnfAlgo::IsAllReduceOp(next_node)) { + MS_LOG(DEBUG) << "visit node:" << next_node->DebugString(); + visit_queue->push(next_node); + } else { + active_nodes.emplace_back(next_node); + } + } + } + + for (auto &node : active_nodes) { + MS_LOG(DEBUG) << "visit node:" << node->DebugString(); + visit_queue->push(node); } }; - zero_output_nodes.emplace(get_return()); - while (!zero_output_nodes.empty() || !allreduce_nodes.empty()) { - AnfNodePtr node; - if (!zero_output_nodes.empty()) { - node = zero_output_nodes.front(); - zero_output_nodes.pop(); + + AnfNodePtr last_allreduce_node = nullptr; + std::queue allreduce_descendants; + while (!seed_nodes.empty() || last_allreduce_node != nullptr) { + // seed nodes first, then visit last all reduce node descendant + if (seed_nodes.empty()) { + visit_node_descendant(last_allreduce_node, &allreduce_descendants); + last_allreduce_node = nullptr; } else { - node = allreduce_nodes.front(); - allreduce_nodes.pop(); - } - MS_EXCEPTION_IF_NULL(node); - if (node->isa() && AnfAlgo::IsRealKernel(node)) { - execution_order_.push_back(node->cast()); + zero_input_nodes.push(seed_nodes.top()); + seed_nodes.pop(); } - auto it = node_input_edges_.find(node); - if (it == node_input_edges_.end()) { - // value node and parameter has no input,no need to print log - if (node->isa()) { - MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; + + // all reduce node descendant first, then common queue + while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) { + AnfNodePtr node = nullptr; + bool is_allreduce_descendant = false; + if (allreduce_descendants.empty()) { + node = zero_input_nodes.front(); + zero_input_nodes.pop(); + } else { + node = allreduce_descendants.front(); + allreduce_descendants.pop(); + is_allreduce_descendant = true; } - continue; - } - for (const auto &input_edge : it->second) { - if (node_output_num_.find(input_edge.first) == node_output_num_.end()) { - MS_EXCEPTION_IF_NULL(input_edge.first); - MS_LOG(EXCEPTION) << "Can't find node[" << input_edge.first->DebugString() << "]"; + // add execute node + MS_EXCEPTION_IF_NULL(node); + if (node->isa() && AnfAlgo::IsRealKernel(node)) { + execution_order_.push_back(node->cast()); } - MS_EXCEPTION_IF_NULL(input_edge.first); - MS_LOG(DEBUG) << "Decrease input:" << input_edge.first->DebugString() << ",node:" << node->DebugString() - << ",num: " << node_output_num_[input_edge.first] << ",decrease num:" << input_edge.second; - if (node_output_num_[input_edge.first] < input_edge.second) { - MS_LOG(EXCEPTION) << "Input node:" << input_edge.first->DebugString() << ",node_output_num" - << node_output_num_[input_edge.first] << "depend edge:" << input_edge.second; + // for all reduce node, visit last all reduce node descendant + if (AnfAlgo::IsAllReduceOp(node)) { + if (last_allreduce_node != nullptr) { + visit_node_descendant(last_allreduce_node, &allreduce_descendants); + } + last_allreduce_node = node; + } else if (is_allreduce_descendant) { + visit_node_descendant(node, &allreduce_descendants); + } else { + visit_node_descendant(node, &zero_input_nodes); } - node_output_num_[input_edge.first] = node_output_num_[input_edge.first] - input_edge.second; - clear_output(input_edge.first); } } + CheckLoop(); - std::reverse(execution_order_.begin(), execution_order_.end()); } void KernelGraph::CheckLoop() { - std::map none_zero_output; - if (node_output_edges_.size() != node_output_num_.size()) { - MS_LOG(EXCEPTION) << "node_output_edges_ size :" << node_output_edges_.size() - << "not equal to node_output_num_ size:" << node_output_num_.size(); + std::map none_zero_nodes; + if (node_input_edges_.size() != node_input_num_.size()) { + MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size() + << "not equal to node_input_num_ size:" << node_input_num_.size(); } - for (auto &it : node_output_num_) { + for (auto &it : node_input_num_) { MS_EXCEPTION_IF_NULL(it.first); string str; - auto node_output_it = node_output_edges_.find(it.first); - if (node_output_it == node_output_edges_.end()) { + auto node_input_it = node_input_edges_.find(it.first); + if (node_input_it == node_input_edges_.end()) { MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]"; } - for (const auto &output_edge : node_output_edges_[it.first]) { - MS_EXCEPTION_IF_NULL(output_edge.first); - str = str.append(output_edge.first->DebugString()).append("|"); + for (const auto &input_edge : node_input_edges_[it.first]) { + MS_EXCEPTION_IF_NULL(input_edge.first); + str = str.append(input_edge.first->DebugString()).append("|"); } if (it.second != 0) { - MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",outputs:" << str << ",output num:" << it.second; - none_zero_output[it.first] = it.second; + MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second; + none_zero_nodes[it.first] = it.second; } } // if don't consider control depend and loop exit,a exception will be throw - if (!none_zero_output.empty()) { - MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_output.size(); + if (!none_zero_nodes.empty()) { + MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); } } @@ -346,12 +383,13 @@ void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, } else { input_it->second.push_back(input_depend_edge); } - // add the depend sum of node - auto depend_it = node_output_num_.find(input); - if (depend_it == node_output_num_.end()) { - node_output_num_[input] = 0; + // add node input depend num + auto depend_it = node_input_num_.find(node); + if (depend_it == node_input_num_.end()) { + node_input_num_[node] = depend_edge_num; + } else { + depend_it->second += depend_edge_num; } - node_output_num_[input] += depend_edge_num; } std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { @@ -429,9 +467,9 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue *seed_nodes) { node_output_edges_.clear(); - node_output_num_.clear(); + node_input_num_.clear(); node_input_edges_.clear(); std::vector control_depends; std::unordered_set visited_nodes; @@ -441,6 +479,11 @@ void KernelGraph::BfsToUpdateNodeOutput() { auto node = que.front(); que.pop(); MS_EXCEPTION_IF_NULL(node); + if (node->isa() || node->isa()) { + seed_nodes->push(node); + continue; + } + if (!node->isa()) { continue; } @@ -454,10 +497,6 @@ void KernelGraph::BfsToUpdateNodeOutput() { control_depends.push_back(input); depend_edge_num = 0; } - // the 2rd input of depend is no depend edge - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && input == cnode->input(kDependAttachNodeIndex)) { - depend_edge_num = 0; - } PushNoVisitedNode(input, &que, &visited_nodes); AddDependEdge(node, input, depend_edge_num); } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index ff964482bba4454de1b00eea3950b8698e517e33..54b16014a3da0f8a75b6ebcda9019d6c6b70cf62 100755 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include "ir/func_graph.h" @@ -93,8 +94,8 @@ class KernelGraph : public FuncGraph { private: // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); - // BFS to update all nodes' output - void BfsToUpdateNodeOutput(); + // update node edge list + void UpdateNodeEdgeList(std::stack *seed_nodes); // add node depend edge by data edge or control depend void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); // handle control depend @@ -114,7 +115,7 @@ class KernelGraph : public FuncGraph { std::unordered_map tensor_to_value_node_map_; // include all value nodes std::unordered_set graph_value_nodes_; - std::unordered_map node_output_num_; + std::unordered_map node_input_num_; std::unordered_map>> node_input_edges_; // record map between ref final output anf with index and ref origin input with index std::map ref_out_in_map_; diff --git a/tests/st/networks/test_gpu_lstm.py b/tests/st/networks/test_gpu_lstm.py index 43871798127a26238eadef30cbaf5172cf291f78..e5208ff669f080b497399790344632d3595b867c 100644 --- a/tests/st/networks/test_gpu_lstm.py +++ b/tests/st/networks/test_gpu_lstm.py @@ -135,4 +135,5 @@ def test_LSTM(): for epoch in range(num_epochs): loss = train_network(train_features, train_labels) losses.append(loss) + print("loss:", loss.asnumpy()) assert(losses[-1].asnumpy() < 0.01)