graph.h 5.9 KB
Newer Older
J
jiyuan 已提交
1 2
#ifndef ONEFLOW_CORE_GRAPH_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_GRAPH_H_
W
willzhang4a58 已提交
3

4
#include "oneflow/core/common/str_util.h"
J
jiyuan 已提交
5
#include "oneflow/core/graph/node.h"
W
willzhang4a58 已提交
6
#include "oneflow/core/persistence/persistent_out_stream.h"
W
willzhang4a58 已提交
7 8 9

namespace oneflow {

W
willzhang4a58 已提交
10
template<typename NodeType, typename EdgeType>
W
willzhang4a58 已提交
11 12
class Graph {
 public:
W
willzhang4a58 已提交
13
  OF_DISALLOW_COPY_AND_MOVE(Graph);
W
willzhang4a58 已提交
14 15 16
  Graph() = default;
  virtual ~Graph() = default;

W
Will Zhang 已提交
17 18 19 20 21 22 23
  // For Each
  void ForEachNode(std::function<void(NodeType*)> NodeHandler) const;
  void ForEachNode(std::function<void(NodeType*)> NodeHandler,
                   std::function<bool(NodeType*)> IsNodeReady) const;
  void TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
  void ReverseTopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
  void ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const;
W
willzhang4a58 已提交
24

W
willzhang4a58 已提交
25
  // Getters
W
willzhang4a58 已提交
26 27 28 29 30 31 32
  const std::unordered_set<NodeType*>& source_nodes() const;
  const std::unordered_set<NodeType*>& sink_nodes() const;
  NodeType* SoleSourceNode() const;
  NodeType* SoleSinkNode() const;
  NodeType* SoleNode() const;
  size_t node_num() const { return nodes_.size(); }
  size_t edge_num() const { return edges_.size(); }
W
Will Zhang 已提交
33
  virtual const char* TypeName() const { return ""; }
W
willzhang4a58 已提交
34

W
willzhang4a58 已提交
35
  // Setters
W
Will Zhang 已提交
36 37
  template<typename DerivedNodeType = NodeType>
  DerivedNodeType* NewNode();
W
willzhang4a58 已提交
38
  EdgeType* NewEdge();
W
Will Zhang 已提交
39 40
  void AddAllocatedNode(NodeType*);
  void AddAllocatedEdge(EdgeType*);
W
willzhang4a58 已提交
41

W
willzhang4a58 已提交
42
  // ToDot
W
willzhang4a58 已提交
43
  template<typename StreamT>
W
Will Zhang 已提交
44 45 46
  void ToDotWithStream(StreamT& out_stream);
  void ToDotWithFilePath(const std::string& file_path);
  void ToDotWithAutoFilePath();
W
willzhang4a58 已提交
47

W
willzhang4a58 已提交
48
 private:
W
willzhang4a58 已提交
49 50
  std::vector<std::unique_ptr<NodeType>> nodes_;
  std::vector<std::unique_ptr<EdgeType>> edges_;
W
willzhang4a58 已提交
51
};
W
willzhang4a58 已提交
52

W
willzhang4a58 已提交
53
template<typename NodeType, typename EdgeType>
W
Will Zhang 已提交
54 55 56 57
void Graph<NodeType, EdgeType>::ForEachNode(
    std::function<void(NodeType*)> NodeHandler) const {
  for (auto& x : nodes_) { NodeHandler(x.get()); }
}
W
willzhang4a58 已提交
58

W
willzhang4a58 已提交
59
template<typename NodeType, typename EdgeType>
W
willzhang4a58 已提交
60
void Graph<NodeType, EdgeType>::ForEachNode(
W
Will Zhang 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    std::function<void(NodeType*)> NodeHandler,
    std::function<bool(NodeType*)> IsNodeReady) const {
  std::queue<NodeType*> node_queue;
  HashSet<NodeType*> nodes_pushed;
  for (auto& x : nodes_) {
    if (IsNodeReady(x.get())) {
      node_queue.push(x.get());
      CHECK(nodes_pushed.insert(x.get()).second);
    }
  }
  while (node_queue.empty() == false) {
    NodeType* cur_node = node_queue.front();
    node_queue.pop();
    NodeHandler(cur_node);
    cur_node->ForEachNodeOnInOutEdge([&](NodeType* candidate) {
      if (nodes_pushed.find(candidate) == nodes_pushed.end()
          && IsNodeReady(candidate)) {
        node_queue.push(candidate);
        CHECK(nodes_pushed.insert(candidate).second);
      }
    });
  }
W
willzhang4a58 已提交
83 84 85 86
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNode(
W
Will Zhang 已提交
87 88 89 90 91 92 93 94 95 96
    std::function<void(NodeType*)> NodeHandler) const {
  HashMap<NodeType*, size_t> node2cnt;
  auto IncreaseCnt = [&](NodeType* node) { node2cnt[node] += 1; };
  auto MyNodeHandler = [&](NodeType* node) {
    NodeHandler(node);
    node->ForEachNodeOnOutEdge(IncreaseCnt);
  };
  ForEachNode(MyNodeHandler, [&](NodeType* node) {
    return node->in_edges().size() == node2cnt[node];
  });
W
willzhang4a58 已提交
97 98 99 100
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ReverseTopoForEachNode(
W
Will Zhang 已提交
101 102 103 104 105 106 107 108 109 110
    std::function<void(NodeType*)> NodeHandler) const {
  HashMap<NodeType*, size_t> node2cnt;
  auto IncreaseCnt = [&](NodeType* node) { node2cnt[node] += 1; };
  auto MyNodeHandler = [&](NodeType* node) {
    NodeHandler(node);
    node->ForEachNodeOnInEdge(IncreaseCnt);
  };
  ForEachNode(MyNodeHandler, [&](NodeType* node) {
    return node->out_edges().size() == node2cnt[node];
  });
W
willzhang4a58 已提交
111
}
W
willzhang4a58 已提交
112

W
willzhang4a58 已提交
113
template<typename NodeType, typename EdgeType>
W
willzhang4a58 已提交
114
void Graph<NodeType, EdgeType>::ForEachEdge(
W
Will Zhang 已提交
115 116
    std::function<void(EdgeType*)> EdgeHandler) const {
  for (auto& x : edges_) { EdgeHandler(x.get()); }
W
willzhang4a58 已提交
117
}
W
willzhang4a58 已提交
118

W
willzhang4a58 已提交
119
template<typename NodeType, typename EdgeType>
W
willzhang4a58 已提交
120 121 122
NodeType* Graph<NodeType, EdgeType>::SoleNode() const {
  CHECK_EQ(nodes_.size(), 1);
  return nodes_.front().get();
W
willzhang4a58 已提交
123 124 125
}

template<typename NodeType, typename EdgeType>
W
Will Zhang 已提交
126 127 128 129
template<typename DerivedNodeType>
DerivedNodeType* Graph<NodeType, EdgeType>::NewNode() {
  DerivedNodeType* ret = new DerivedNodeType;
  AddAllocatedNode(ret);
W
willzhang4a58 已提交
130
  return ret;
W
willzhang4a58 已提交
131 132 133
}

template<typename NodeType, typename EdgeType>
W
willzhang4a58 已提交
134 135
EdgeType* Graph<NodeType, EdgeType>::NewEdge() {
  EdgeType* ret = new EdgeType;
W
Will Zhang 已提交
136
  AddAllocatedEdge(ret);
W
willzhang4a58 已提交
137
  return ret;
W
willzhang4a58 已提交
138 139 140
}

template<typename NodeType, typename EdgeType>
W
Will Zhang 已提交
141
void Graph<NodeType, EdgeType>::AddAllocatedNode(NodeType* node) {
W
willzhang4a58 已提交
142 143 144 145
  nodes_.emplace_back(node);
}

template<typename NodeType, typename EdgeType>
W
Will Zhang 已提交
146
void Graph<NodeType, EdgeType>::AddAllocatedEdge(EdgeType* edge) {
W
willzhang4a58 已提交
147 148 149 150
  edges_.emplace_back(edge);
}

template<typename NodeType, typename EdgeType>
W
willzhang4a58 已提交
151
template<typename StreamT>
W
Will Zhang 已提交
152
void Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) {
W
willzhang4a58 已提交
153
  out_stream << "digraph {\n";
W
Will Zhang 已提交
154
  this->ForEachNode([&](NodeType* node) {
W
willzhang4a58 已提交
155 156
    out_stream << "\"" << node->VisualStr() << "\"\n";
  });
W
Will Zhang 已提交
157
  this->ForEachEdge([&](const EdgeType* edge) {
W
willzhang4a58 已提交
158
    out_stream << "\"" << edge->src_node()->VisualStr() << "\" -> "
W
willzhang4a58 已提交
159 160
               << "\"" << edge->dst_node()->VisualStr() << "\""
               << "[label=\"" << edge->VisualStr() << "\"];\n";
W
willzhang4a58 已提交
161 162 163 164 165
  });
  out_stream << "}\n";
}

template<typename NodeType, typename EdgeType>
W
willzhang4a58 已提交
166
void Graph<NodeType, EdgeType>::ToDotWithFilePath(
W
Will Zhang 已提交
167
    const std::string& file_path) {
168
  std::string dir_name = Dirname(file_path);
W
willzhang4a58 已提交
169 170
  if (!LocalFS()->IsDirectory(dir_name)) {
    LocalFS()->RecursivelyCreateDir(dir_name);
W
willzhang4a58 已提交
171
  }
W
willzhang4a58 已提交
172
  PersistentOutStream out_stream(LocalFS(), file_path);
W
willzhang4a58 已提交
173
  ToDotWithStream(out_stream);
W
willzhang4a58 已提交
174 175 176
}

template<typename NodeType, typename EdgeType>
W
Will Zhang 已提交
177
void Graph<NodeType, EdgeType>::ToDotWithAutoFilePath() {
W
willzhang4a58 已提交
178 179
  std::string file_path =
      LogDir() + "/dot/" + TypeName() + "/" + NewUniqueId() + ".dot";
W
willzhang4a58 已提交
180
  ToDotWithFilePath(file_path);
W
willzhang4a58 已提交
181 182
}

W
willzhang4a58 已提交
183
}  // namespace oneflow
W
willzhang4a58 已提交
184

W
willzhang4a58 已提交
185
#endif  // ONEFLOW_CORE_GRAPH_GRAPH_H_