提交 d28e2ad5 编写于 作者: S ShawnXuan

get longest path


Former-commit-id: 7628e866e8246f41960c5b03e1fa92376fb661c7
上级 0f40d420
......@@ -114,7 +114,7 @@ void LogicalGraph::BuildFwStruct() {
NaiveBuildFwStruct(&op_name2nodes);
FixSharedModelNodes(op_name2nodes);
LinkUnpackFw2PackFw(op_name2nodes);
SetDepth4Nodes();
SetDepthAndBranchId4Nodes();
total_mbn_num_ = 0;
ForEachNode([&](LogicalNode* node) {
total_mbn_num_ +=
......@@ -173,7 +173,7 @@ void LogicalGraph::NaiveBuildFwStruct(
});
}
void LogicalGraph::SetDepth4Nodes() {
void LogicalGraph::SetDepthAndBranchId4Nodes() {
auto ForEachInNode = [&](LogicalNode* node, const std::function<void(LogicalNode*)>& Handler) {
node->ForEachNodeOnInEdge([&](LogicalNode* node_on_in_edge) { Handler(node_on_in_edge); });
};
......@@ -181,21 +181,23 @@ void LogicalGraph::SetDepth4Nodes() {
node->ForEachNodeOnOutEdge([&](LogicalNode* node_on_out_edge) { Handler(node_on_out_edge); });
};
HashMap<LogicalNode*, bool> has_handled;
HashMap<LogicalNode*, bool> has_queued;
std::queue<LogicalNode*> queue;
std::list<LogicalNode*> starts;
ForEachLogicalNode<LogicalNode>([&](LogicalNode* node) {
if (node->in_edges().size() < 1) starts.push_back(node);
});
// set depth for nodes
HashMap<LogicalNode*, bool> has_handled;
HashMap<LogicalNode*, bool> has_queued;
std::queue<LogicalNode*> queue;
for (LogicalNode* start : starts) {
queue.push(start);
has_queued[start] = true;
has_handled[start] = false;
}
int max_depth = 0;
while (!queue.empty()) {
LogicalNode* cur_node = queue.front();
queue.pop();
......@@ -206,6 +208,7 @@ void LogicalGraph::SetDepth4Nodes() {
if (!has_handled[in_node]) need_push = true;
});
cur_node->set_depth(max_in_nodes_depth + 1);
if (max_depth < max_in_nodes_depth + 1) max_depth = max_in_nodes_depth + 1;
if (need_push)
queue.push(cur_node);
else
......@@ -220,11 +223,44 @@ void LogicalGraph::SetDepth4Nodes() {
});
}
}
// ForEachLogicalNode<LogicalNode>([&](LogicalNode* node) {
// LOG(INFO) << node->op_vec().at(0)->op_name() << " " << node->depth();
// });
ForEachLogicalNode<LogicalNode>([&](LogicalNode* node) {
LOG(INFO) << node->op_vec().at(0)->op_name() << " " << node->depth();
});
// set branch id for nodes
// 1. get longest path
std::stack<LogicalNode*> longest_path_stack;
for (LogicalNode* start : starts) {
ASSERT_TRUE(longest_path_stack.empty());
if (GetLongestPath<LogicalNode>(start, max_depth, longest_path_stack, ForEachOutNode)) {
break;
}
}
LOG(INFO) << "longest path size " << max_depth << " " << longest_path_stack.size();
while (!longest_path_stack.empty()) {
LogicalNode* node = longest_path_stack.top();
longest_path_stack.pop();
LOG(INFO) << node->op_vec().at(0)->op_name() << " " << node->depth();
}
}
template<typename NodeType>
bool LogicalGraph::GetLongestPath(
LogicalNode* start, int max_depth, std::stack<NodeType*>& longest_path_stack,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode) {
longest_path_stack.push(start);
if (start->depth() == max_depth) return true;
bool get_it = false;
ForEachOutNode(start, [&](NodeType* out_node) {
if (!get_it) {
if (start->depth() + 1 == out_node->depth())
get_it = GetLongestPath<NodeType>(out_node, max_depth, longest_path_stack, ForEachOutNode);
}
});
if (!get_it) longest_path_stack.pop();
return get_it;
}
void LogicalGraph::FixSharedModelNodes(
const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes) {
HashSet<std::string> all_shared_model_op_names;
......
......@@ -40,8 +40,6 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
void BuildFwStruct();
void NaiveBuildFwStruct(HashMap<std::string, std::vector<LogicalNode*>>* op_name2nodes);
void SetDepth4Nodes();
void SetBranchId4Nodes();
void FixSharedModelNodes(const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes);
void LinkUnpackFw2PackFw(const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes);
void ReConnectToFwClone(LogicalNode* clone_node, const LogicalBlobId& lbi,
......@@ -74,6 +72,12 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
bool MustHaveModelDiffAcc();
void SetDepthAndBranchId4Nodes();
template<typename NodeType>
bool GetLongestPath(
LogicalNode* start, int max_depth, std::stack<NodeType*>& longest_path_stack,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode);
int64_t total_mbn_num_;
std::vector<std::vector<const LogicalNode*>> fw_node_groups_;
......
......@@ -112,8 +112,8 @@ class Node {
HashSet<EdgeType*> in_edges_;
HashSet<EdgeType*> out_edges_;
int depth_ = 0; // 4 visualization
int branch_id_ = 0; // 4 visualization
int depth_ = 0; // 4 visualization
int branch_id_ = 0; // 4 visualization
};
} // namespace oneflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册