提交 b3135ba2 编写于 作者: S ShawnXuan

SetDepth4Nodes


Former-commit-id: ebceff288f3cc978e122f93b473b4c3d49e32094
上级 fb6a9af8
......@@ -112,9 +112,9 @@ void LogicalGraph::ForEachLogicalNode(std::function<void(LogicalNodeType*)> func
void LogicalGraph::BuildFwStruct() {
HashMap<std::string, std::vector<LogicalNode*>> op_name2nodes;
NaiveBuildFwStruct(&op_name2nodes);
SetDepth4Nodes(op_name2nodes);
FixSharedModelNodes(op_name2nodes);
LinkUnpackFw2PackFw(op_name2nodes);
SetDepth4Nodes(op_name2nodes);
total_mbn_num_ = 0;
ForEachNode([&](LogicalNode* node) {
total_mbn_num_ +=
......@@ -174,7 +174,52 @@ void LogicalGraph::NaiveBuildFwStruct(
}
void LogicalGraph::SetDepth4Nodes(
const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes) {}
const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes) {
HashMap<LogicalNode*, bool> has_handled;
std::queue<LogicalNode*> queue;
std::list<LogicalNode*> starts;
ForEachLogicalNode<RecordLoadLogicalNode>(
[&](RecordLoadLogicalNode* node) { starts.push_back(node); });
for (LogicalNode* start : starts) {
queue.push(start);
has_handled[start] = false;
}
// int depth = 0;
auto ForEachInNode = [&](LogicalNode* node, const std::function<void(LogicalNode*)>& Handler) {
node->ForEachNodeOnInEdge([&](LogicalNode* node_on_in_edge) { Handler(node_on_in_edge); });
};
auto ForEachOutNode = [&](LogicalNode* node, const std::function<void(LogicalNode*)>& Handler) {
node->ForEachNodeOnOutEdge([&](LogicalNode* node_on_out_edge) { Handler(node_on_out_edge); });
};
while (!queue.empty()) {
LogicalNode* cur_node = queue.front();
queue.pop();
int max_in_nodes_depth = 0;
bool need_push = false;
ForEachInNode(cur_node, [&](LogicalNode* in_node) {
if (in_node->depth() > max_in_nodes_depth) max_in_nodes_depth = in_node->depth();
if (!has_handled[in_node]) need_push = true;
});
cur_node->set_depth(max_in_nodes_depth + 1);
if (need_push)
queue.push(cur_node);
else
has_handled[cur_node] = true;
if (has_handled[cur_node]) {
ForEachOutNode(cur_node, [&](LogicalNode* out_node) {
queue.push(out_node);
// has_handled[out_node] = false;
});
}
}
}
void LogicalGraph::FixSharedModelNodes(
const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes) {
......
......@@ -97,7 +97,9 @@ class Node {
virtual std::string VisualStr() const { return ""; }
int depth() const { return depth_; }
void set_depth(const int val) { depth_ = val; }
void set_depth(const int val) {
if (val > depth_) depth_ = val;
}
private:
friend void Connect<NodeType, EdgeType>(NodeType* src_node, EdgeType* edge, NodeType* dst_node);
......@@ -107,7 +109,7 @@ class Node {
HashSet<EdgeType*> in_edges_;
HashSet<EdgeType*> out_edges_;
int depth_; // 4 visualization
int depth_ = 0; // 4 visualization
};
} // namespace oneflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册