提交 c6d21ccd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!528 optimize execute order for memreuse

Merge pull request !528 from kisnwang/optimize-execute-order-for-memreuse
......@@ -251,9 +251,10 @@ void BestFitMemReuse::ReleaseNodeUnusedOutput(const KernelDef *kernel_def_ptr) {
}
size_t BestFitMemReuse::FindIndx(const std::vector<MembufPtr> &membuf_ptr_list, int fac_idx) const {
size_t membuf_index = 0;
size_t membuf_index = membuf_ptr_list.size();
for (size_t n = 0; n < membuf_ptr_list.size(); ++n) {
auto membuf = membuf_ptr_list[n];
MS_EXCEPTION_IF_NULL(membuf);
if (membuf->index_ == fac_idx) {
membuf_index = n;
break;
......
......@@ -851,17 +851,12 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i
bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_name = AnfAlgo::GetCNodeName(node);
auto kernel_type = AnfAlgo::GetKernelType(node);
if (kernel_name == kAllReduceOpName || kernel_type == HCCL_KERNEL) {
return true;
if (!node->isa<CNode>()) {
return false;
}
return false;
}
bool AnfRuntimeAlgorithm::IsAllReduceOp(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
auto kernel_name = AnfAlgo::GetCNodeName(node);
if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
kernel_name == kReduceScatterOpName) {
return true;
}
return false;
......
......@@ -176,7 +176,6 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsAllReduceOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
};
} // namespace session
......
......@@ -49,80 +49,81 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
return std::vector<AnfNodePtr>();
}
void KernelGraph::SetExecOrderByDefault() {
std::stack<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
execution_order_.clear();
std::unordered_set<AnfNodePtr> visited_nodes;
std::queue<AnfNodePtr> zero_input_nodes;
auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) {
auto it = node_output_edges_.find(node);
if (it == node_output_edges_.end()) {
// value node and parameter has no input,no need to print log
if (node->isa<CNode>()) {
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
}
return;
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(visit_queue);
MS_EXCEPTION_IF_NULL(visited_nodes);
auto it = node_output_edges_.find(node);
if (it == node_output_edges_.end()) {
// value node and parameter has no input,no need to print log
if (node->isa<CNode>()) {
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
}
return;
}
// visit all reduce node first, then other nodes
std::vector<AnfNodePtr> active_nodes;
for (const auto &output_edge : it->second) {
auto next_node = output_edge.first;
if (node_input_num_.find(next_node) == node_input_num_.end()) {
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
}
// visit all reduce node first, then other nodes
std::vector<AnfNodePtr> active_nodes;
for (const auto &output_edge : it->second) {
auto next_node = output_edge.first;
if (node_input_num_.find(next_node) == node_input_num_.end()) {
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
if (node_input_num_[next_node] < output_edge.second) {
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num"
<< node_input_num_[next_node] << ",depend edge:" << output_edge.second;
}
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
// allreduce first
if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) {
(void)visited_nodes.insert(next_node);
if (AnfAlgo::IsAllReduceOp(next_node)) {
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
visit_queue->push(next_node);
} else {
active_nodes.emplace_back(next_node);
}
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
}
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
if (node_input_num_[next_node] < output_edge.second) {
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
<< ",depend edge:" << output_edge.second;
}
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
// allreduce first
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
(void)visited_nodes->insert(next_node);
if (AnfAlgo::IsCommunicationOp(next_node)) {
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
visit_queue->push(next_node);
} else {
active_nodes.emplace_back(next_node);
}
}
}
for (auto &node : active_nodes) {
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
visit_queue->push(node);
}
};
for (auto &node : active_nodes) {
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
visit_queue->push(node);
}
}
AnfNodePtr last_allreduce_node = nullptr;
std::queue<AnfNodePtr> allreduce_descendants;
while (!seed_nodes.empty() || last_allreduce_node != nullptr) {
void KernelGraph::SetExecOrderByDefault() {
std::queue<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
execution_order_.clear();
std::unordered_set<AnfNodePtr> visited_nodes;
std::queue<AnfNodePtr> zero_input_nodes;
AnfNodePtr last_communication_node = nullptr;
std::queue<AnfNodePtr> communication_descendants;
while (!seed_nodes.empty() || last_communication_node != nullptr) {
// seed nodes first, then visit last all reduce node descendant
if (seed_nodes.empty()) {
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
last_allreduce_node = nullptr;
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
last_communication_node = nullptr;
} else {
zero_input_nodes.push(seed_nodes.top());
zero_input_nodes.push(seed_nodes.front());
seed_nodes.pop();
}
// all reduce node descendant first, then common queue
while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) {
while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
AnfNodePtr node = nullptr;
bool is_allreduce_descendant = false;
if (allreduce_descendants.empty()) {
bool is_communication_descendant = false;
if (communication_descendants.empty()) {
node = zero_input_nodes.front();
zero_input_nodes.pop();
} else {
node = allreduce_descendants.front();
allreduce_descendants.pop();
is_allreduce_descendant = true;
node = communication_descendants.front();
communication_descendants.pop();
is_communication_descendant = true;
}
// add execute node
MS_EXCEPTION_IF_NULL(node);
......@@ -130,19 +131,18 @@ void KernelGraph::SetExecOrderByDefault() {
execution_order_.push_back(node->cast<CNodePtr>());
}
// for all reduce node, visit last all reduce node descendant
if (AnfAlgo::IsAllReduceOp(node)) {
if (last_allreduce_node != nullptr) {
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
if (AnfAlgo::IsCommunicationOp(node)) {
if (last_communication_node != nullptr) {
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
}
last_allreduce_node = node;
} else if (is_allreduce_descendant) {
visit_node_descendant(node, &allreduce_descendants);
last_communication_node = node;
} else if (is_communication_descendant) {
VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
} else {
visit_node_descendant(node, &zero_input_nodes);
VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
}
}
}
CheckLoop();
}
......@@ -467,7 +467,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
return true;
}
void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
node_output_edges_.clear();
node_input_num_.clear();
node_input_edges_.clear();
......@@ -483,7 +483,6 @@ void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
seed_nodes->push(node);
continue;
}
if (!node->isa<CNode>()) {
continue;
}
......
......@@ -22,7 +22,6 @@
#include <utility>
#include <string>
#include <queue>
#include <stack>
#include <map>
#include <unordered_set>
#include "ir/func_graph.h"
......@@ -94,8 +93,10 @@ class KernelGraph : public FuncGraph {
private:
// remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes);
// update node edge list
void UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes);
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
// add node depend edge by data edge or control depend
void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
// handle control depend
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册