提交 093c2cae 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!337 optimize execute order sort

Merge pull request !337 from kisnwang/optimize-execute-order-sort
...@@ -23,7 +23,7 @@ namespace ascend { ...@@ -23,7 +23,7 @@ namespace ascend {
class AscendMemoryManager : public MemoryManager { class AscendMemoryManager : public MemoryManager {
public: public:
AscendMemoryManager() = default; AscendMemoryManager() = default;
virtual ~AscendMemoryManager() = default; ~AscendMemoryManager() override = default;
void MallocDeviceMemory() override; void MallocDeviceMemory() override;
void FreeDeviceMemory() override; void FreeDeviceMemory() override;
......
...@@ -26,6 +26,8 @@ namespace ascend { ...@@ -26,6 +26,8 @@ namespace ascend {
class AscendMemoryPool : public DynamicMemPoolBestFit { class AscendMemoryPool : public DynamicMemPoolBestFit {
public: public:
~AscendMemoryPool() override = default; ~AscendMemoryPool() override = default;
AscendMemoryPool(const AscendMemoryPool&) = delete;
AscendMemoryPool& operator=(const AscendMemoryPool&) = delete;
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override; bool FreeDeviceMem(const DeviceMemPtr& addr) override;
...@@ -51,13 +53,11 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { ...@@ -51,13 +53,11 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
private: private:
AscendMemoryPool() = default; AscendMemoryPool() = default;
AscendMemoryPool(const AscendMemoryPool&) = delete;
AscendMemoryPool& operator=(const AscendMemoryPool&) = delete;
bool has_malloc_{false}; bool has_malloc_{false};
uint8_t* device_mem_pool_base_{nullptr}; uint8_t* device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0}; uint64_t device_mem_pool_size_{0};
size_t free_mem_size_; size_t free_mem_size_{0};
size_t total_mem_size_; size_t total_mem_size_{0};
}; };
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
......
...@@ -858,6 +858,14 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { ...@@ -858,6 +858,14 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
return false; return false;
} }
bool AnfRuntimeAlgorithm::IsAllReduceOp(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
return true;
}
return false;
}
bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) { bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
auto kernel_name = AnfAlgo::GetCNodeName(node); auto kernel_name = AnfAlgo::GetCNodeName(node);
return kernel_name == kGetNextOpName; return kernel_name == kGetNextOpName;
......
...@@ -176,6 +176,7 @@ class AnfRuntimeAlgorithm { ...@@ -176,6 +176,7 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl // get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsAllReduceOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node); static bool IsGetNext(const NotNull<AnfNodePtr> &node);
}; };
} // namespace session } // namespace session
......
...@@ -50,90 +50,127 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const { ...@@ -50,90 +50,127 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
} }
void KernelGraph::SetExecOrderByDefault() { void KernelGraph::SetExecOrderByDefault() {
BfsToUpdateNodeOutput(); std::stack<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
execution_order_.clear(); execution_order_.clear();
std::queue<AnfNodePtr> allreduce_nodes;
std::queue<AnfNodePtr> zero_output_nodes;
std::unordered_set<AnfNodePtr> visited_nodes; std::unordered_set<AnfNodePtr> visited_nodes;
auto clear_output = [&zero_output_nodes, &allreduce_nodes, &visited_nodes, this](const AnfNodePtr &input) -> void { std::queue<AnfNodePtr> zero_input_nodes;
if (node_output_num_[input] == 0 && visited_nodes.find(input) == visited_nodes.end()) {
MS_EXCEPTION_IF_NULL(input); auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) {
MS_LOG(DEBUG) << "Clear output num:" << input->DebugString(); auto it = node_output_edges_.find(node);
(void)visited_nodes.insert(input); if (it == node_output_edges_.end()) {
if (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == kAllReduceOpName) { // value node and parameter has no input,no need to print log
allreduce_nodes.push(input); if (node->isa<CNode>()) {
} else { MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
zero_output_nodes.push(input);
} }
return;
}
// visit all reduce node first, then other nodes
std::vector<AnfNodePtr> active_nodes;
for (const auto &output_edge : it->second) {
auto next_node = output_edge.first;
if (node_input_num_.find(next_node) == node_input_num_.end()) {
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
}
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
if (node_input_num_[next_node] < output_edge.second) {
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num"
<< node_input_num_[next_node] << ",depend edge:" << output_edge.second;
}
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
// allreduce first
if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) {
(void)visited_nodes.insert(next_node);
if (AnfAlgo::IsAllReduceOp(next_node)) {
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
visit_queue->push(next_node);
} else {
active_nodes.emplace_back(next_node);
}
}
}
for (auto &node : active_nodes) {
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
visit_queue->push(node);
} }
}; };
zero_output_nodes.emplace(get_return());
while (!zero_output_nodes.empty() || !allreduce_nodes.empty()) { AnfNodePtr last_allreduce_node = nullptr;
AnfNodePtr node; std::queue<AnfNodePtr> allreduce_descendants;
if (!zero_output_nodes.empty()) { while (!seed_nodes.empty() || last_allreduce_node != nullptr) {
node = zero_output_nodes.front(); // seed nodes first, then visit last all reduce node descendant
zero_output_nodes.pop(); if (seed_nodes.empty()) {
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
last_allreduce_node = nullptr;
} else { } else {
node = allreduce_nodes.front(); zero_input_nodes.push(seed_nodes.top());
allreduce_nodes.pop(); seed_nodes.pop();
}
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
execution_order_.push_back(node->cast<CNodePtr>());
} }
auto it = node_input_edges_.find(node);
if (it == node_input_edges_.end()) { // all reduce node descendant first, then common queue
// value node and parameter has no input,no need to print log while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) {
if (node->isa<CNode>()) { AnfNodePtr node = nullptr;
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; bool is_allreduce_descendant = false;
if (allreduce_descendants.empty()) {
node = zero_input_nodes.front();
zero_input_nodes.pop();
} else {
node = allreduce_descendants.front();
allreduce_descendants.pop();
is_allreduce_descendant = true;
} }
continue; // add execute node
} MS_EXCEPTION_IF_NULL(node);
for (const auto &input_edge : it->second) { if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
if (node_output_num_.find(input_edge.first) == node_output_num_.end()) { execution_order_.push_back(node->cast<CNodePtr>());
MS_EXCEPTION_IF_NULL(input_edge.first);
MS_LOG(EXCEPTION) << "Can't find node[" << input_edge.first->DebugString() << "]";
} }
MS_EXCEPTION_IF_NULL(input_edge.first); // for all reduce node, visit last all reduce node descendant
MS_LOG(DEBUG) << "Decrease input:" << input_edge.first->DebugString() << ",node:" << node->DebugString() if (AnfAlgo::IsAllReduceOp(node)) {
<< ",num: " << node_output_num_[input_edge.first] << ",decrease num:" << input_edge.second; if (last_allreduce_node != nullptr) {
if (node_output_num_[input_edge.first] < input_edge.second) { visit_node_descendant(last_allreduce_node, &allreduce_descendants);
MS_LOG(EXCEPTION) << "Input node:" << input_edge.first->DebugString() << ",node_output_num" }
<< node_output_num_[input_edge.first] << "depend edge:" << input_edge.second; last_allreduce_node = node;
} else if (is_allreduce_descendant) {
visit_node_descendant(node, &allreduce_descendants);
} else {
visit_node_descendant(node, &zero_input_nodes);
} }
node_output_num_[input_edge.first] = node_output_num_[input_edge.first] - input_edge.second;
clear_output(input_edge.first);
} }
} }
CheckLoop(); CheckLoop();
std::reverse(execution_order_.begin(), execution_order_.end());
} }
void KernelGraph::CheckLoop() { void KernelGraph::CheckLoop() {
std::map<AnfNodePtr, size_t> none_zero_output; std::map<AnfNodePtr, size_t> none_zero_nodes;
if (node_output_edges_.size() != node_output_num_.size()) { if (node_input_edges_.size() != node_input_num_.size()) {
MS_LOG(EXCEPTION) << "node_output_edges_ size :" << node_output_edges_.size() MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
<< "not equal to node_output_num_ size:" << node_output_num_.size(); << "not equal to node_input_num_ size:" << node_input_num_.size();
} }
for (auto &it : node_output_num_) { for (auto &it : node_input_num_) {
MS_EXCEPTION_IF_NULL(it.first); MS_EXCEPTION_IF_NULL(it.first);
string str; string str;
auto node_output_it = node_output_edges_.find(it.first); auto node_input_it = node_input_edges_.find(it.first);
if (node_output_it == node_output_edges_.end()) { if (node_input_it == node_input_edges_.end()) {
MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]"; MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
} }
for (const auto &output_edge : node_output_edges_[it.first]) { for (const auto &input_edge : node_input_edges_[it.first]) {
MS_EXCEPTION_IF_NULL(output_edge.first); MS_EXCEPTION_IF_NULL(input_edge.first);
str = str.append(output_edge.first->DebugString()).append("|"); str = str.append(input_edge.first->DebugString()).append("|");
} }
if (it.second != 0) { if (it.second != 0) {
MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",outputs:" << str << ",output num:" << it.second; MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
none_zero_output[it.first] = it.second; none_zero_nodes[it.first] = it.second;
} }
} }
// if don't consider control depend and loop exit,a exception will be throw // if don't consider control depend and loop exit,a exception will be throw
if (!none_zero_output.empty()) { if (!none_zero_nodes.empty()) {
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_output.size(); MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
} }
} }
...@@ -346,12 +383,13 @@ void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, ...@@ -346,12 +383,13 @@ void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input,
} else { } else {
input_it->second.push_back(input_depend_edge); input_it->second.push_back(input_depend_edge);
} }
// add the depend sum of node // add node input depend num
auto depend_it = node_output_num_.find(input); auto depend_it = node_input_num_.find(node);
if (depend_it == node_output_num_.end()) { if (depend_it == node_input_num_.end()) {
node_output_num_[input] = 0; node_input_num_[node] = depend_edge_num;
} else {
depend_it->second += depend_edge_num;
} }
node_output_num_[input] += depend_edge_num;
} }
std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) { std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
...@@ -429,9 +467,9 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf ...@@ -429,9 +467,9 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
return true; return true;
} }
void KernelGraph::BfsToUpdateNodeOutput() { void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
node_output_edges_.clear(); node_output_edges_.clear();
node_output_num_.clear(); node_input_num_.clear();
node_input_edges_.clear(); node_input_edges_.clear();
std::vector<AnfNodePtr> control_depends; std::vector<AnfNodePtr> control_depends;
std::unordered_set<AnfNodePtr> visited_nodes; std::unordered_set<AnfNodePtr> visited_nodes;
...@@ -441,6 +479,11 @@ void KernelGraph::BfsToUpdateNodeOutput() { ...@@ -441,6 +479,11 @@ void KernelGraph::BfsToUpdateNodeOutput() {
auto node = que.front(); auto node = que.front();
que.pop(); que.pop();
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (node->isa<Parameter>() || node->isa<ValueNode>()) {
seed_nodes->push(node);
continue;
}
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
} }
...@@ -454,10 +497,6 @@ void KernelGraph::BfsToUpdateNodeOutput() { ...@@ -454,10 +497,6 @@ void KernelGraph::BfsToUpdateNodeOutput() {
control_depends.push_back(input); control_depends.push_back(input);
depend_edge_num = 0; depend_edge_num = 0;
} }
// the 2rd input of depend is no depend edge
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && input == cnode->input(kDependAttachNodeIndex)) {
depend_edge_num = 0;
}
PushNoVisitedNode(input, &que, &visited_nodes); PushNoVisitedNode(input, &que, &visited_nodes);
AddDependEdge(node, input, depend_edge_num); AddDependEdge(node, input, depend_edge_num);
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <queue> #include <queue>
#include <stack>
#include <map> #include <map>
#include <unordered_set> #include <unordered_set>
#include "ir/func_graph.h" #include "ir/func_graph.h"
...@@ -93,8 +94,8 @@ class KernelGraph : public FuncGraph { ...@@ -93,8 +94,8 @@ class KernelGraph : public FuncGraph {
private: private:
// remove value node form graph // remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
// BFS to update all nodes' output // update node edge list
void BfsToUpdateNodeOutput(); void UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes);
// add node depend edge by data edge or control depend // add node depend edge by data edge or control depend
void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
// handle control depend // handle control depend
...@@ -114,7 +115,7 @@ class KernelGraph : public FuncGraph { ...@@ -114,7 +115,7 @@ class KernelGraph : public FuncGraph {
std::unordered_map<tensor::TensorPtr, ValueNodePtr> tensor_to_value_node_map_; std::unordered_map<tensor::TensorPtr, ValueNodePtr> tensor_to_value_node_map_;
// include all value nodes // include all value nodes
std::unordered_set<ValueNodePtr> graph_value_nodes_; std::unordered_set<ValueNodePtr> graph_value_nodes_;
std::unordered_map<AnfNodePtr, size_t> node_output_num_; std::unordered_map<AnfNodePtr, size_t> node_input_num_;
std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_input_edges_; std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_input_edges_;
// record map between ref final output anf with index and ref origin input with index // record map between ref final output anf with index and ref origin input with index
std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_; std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;
......
...@@ -135,4 +135,5 @@ def test_LSTM(): ...@@ -135,4 +135,5 @@ def test_LSTM():
for epoch in range(num_epochs): for epoch in range(num_epochs):
loss = train_network(train_features, train_labels) loss = train_network(train_features, train_labels)
losses.append(loss) losses.append(loss)
print("loss:", loss.asnumpy())
assert(losses[-1].asnumpy() < 0.01) assert(losses[-1].asnumpy() < 0.01)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册