#ifndef ONEFLOW_CORE_GRAPH_GRAPH_H_ #define ONEFLOW_CORE_GRAPH_GRAPH_H_ #include "oneflow/core/common/str_util.h" #include "oneflow/core/graph/node.h" #include "oneflow/core/persistence/persistent_out_stream.h" namespace oneflow { template class Graph { public: OF_DISALLOW_COPY_AND_MOVE(Graph); Graph() = default; virtual ~Graph() = default; // For Each void ForEachNode(std::function NodeHandler) const; void ForEachNode(std::function NodeHandler, std::function IsNodeReady) const; void TopoForEachNode(std::function NodeHandler) const; void ReverseTopoForEachNode(std::function NodeHandler) const; void ForEachEdge(std::function EdgeHandler) const; // Getters const std::unordered_set& source_nodes() const; const std::unordered_set& sink_nodes() const; NodeType* SoleSourceNode() const; NodeType* SoleSinkNode() const; NodeType* SoleNode() const; size_t node_num() const { return nodes_.size(); } size_t edge_num() const { return edges_.size(); } virtual const char* TypeName() const { return ""; } // Setters template DerivedNodeType* NewNode(); EdgeType* NewEdge(); void AddAllocatedNode(NodeType*); void AddAllocatedEdge(EdgeType*); // ToDot template void ToDotWithStream(StreamT& out_stream); void ToDotWithFilePath(const std::string& file_path); void ToDotWithAutoFilePath(); private: std::vector> nodes_; std::vector> edges_; }; template void Graph::ForEachNode( std::function NodeHandler) const { for (auto& x : nodes_) { NodeHandler(x.get()); } } template void Graph::ForEachNode( std::function NodeHandler, std::function IsNodeReady) const { std::queue node_queue; HashSet nodes_pushed; for (auto& x : nodes_) { if (IsNodeReady(x.get())) { node_queue.push(x.get()); CHECK(nodes_pushed.insert(x.get()).second); } } while (node_queue.empty() == false) { NodeType* cur_node = node_queue.front(); node_queue.pop(); NodeHandler(cur_node); cur_node->ForEachNodeOnInOutEdge([&](NodeType* candidate) { if (nodes_pushed.find(candidate) == nodes_pushed.end() && IsNodeReady(candidate)) { node_queue.push(candidate); CHECK(nodes_pushed.insert(candidate).second); } }); } } template void Graph::TopoForEachNode( std::function NodeHandler) const { HashMap node2cnt; auto IncreaseCnt = [&](NodeType* node) { node2cnt[node] += 1; }; auto MyNodeHandler = [&](NodeType* node) { NodeHandler(node); node->ForEachNodeOnOutEdge(IncreaseCnt); }; ForEachNode(MyNodeHandler, [&](NodeType* node) { return node->in_edges().size() == node2cnt[node]; }); } template void Graph::ReverseTopoForEachNode( std::function NodeHandler) const { HashMap node2cnt; auto IncreaseCnt = [&](NodeType* node) { node2cnt[node] += 1; }; auto MyNodeHandler = [&](NodeType* node) { NodeHandler(node); node->ForEachNodeOnInEdge(IncreaseCnt); }; ForEachNode(MyNodeHandler, [&](NodeType* node) { return node->out_edges().size() == node2cnt[node]; }); } template void Graph::ForEachEdge( std::function EdgeHandler) const { for (auto& x : edges_) { EdgeHandler(x.get()); } } template NodeType* Graph::SoleNode() const { CHECK_EQ(nodes_.size(), 1); return nodes_.front().get(); } template template DerivedNodeType* Graph::NewNode() { DerivedNodeType* ret = new DerivedNodeType; AddAllocatedNode(ret); return ret; } template EdgeType* Graph::NewEdge() { EdgeType* ret = new EdgeType; AddAllocatedEdge(ret); return ret; } template void Graph::AddAllocatedNode(NodeType* node) { nodes_.emplace_back(node); } template void Graph::AddAllocatedEdge(EdgeType* edge) { edges_.emplace_back(edge); } template template void Graph::ToDotWithStream(StreamT& out_stream) { out_stream << "digraph {\n"; this->ForEachNode([&](NodeType* node) { out_stream << "\"" << node->VisualStr() << "\"\n"; }); this->ForEachEdge([&](const EdgeType* edge) { out_stream << "\"" << edge->src_node()->VisualStr() << "\" -> " << "\"" << edge->dst_node()->VisualStr() << "\"" << "[label=\"" << edge->VisualStr() << "\"];\n"; }); out_stream << "}\n"; } template void Graph::ToDotWithFilePath( const std::string& file_path) { std::string dir_name = Dirname(file_path); if (!LocalFS()->IsDirectory(dir_name)) { LocalFS()->RecursivelyCreateDir(dir_name); } PersistentOutStream out_stream(LocalFS(), file_path); ToDotWithStream(out_stream); } template void Graph::ToDotWithAutoFilePath() { std::string file_path = LogDir() + "/dot/" + TypeName() + "/" + NewUniqueId() + ".dot"; ToDotWithFilePath(file_path); } } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_GRAPH_H_