logical_graph.cpp 4.6 KB
Newer Older
W
willzhang4a58 已提交
1
#include "graph/logical_graph.h"
W
willzhang4a58 已提交
2
#include "glog/logging.h"
W
willzhang4a58 已提交
3
#include "operator/operator_factory.h"
W
willzhang4a58 已提交
4 5 6

namespace oneflow {

7 8
LogicalGraph::LogicalGraph(const DLNetConf& dl_net_conf,
                           const Strategy& strategy_conf) {
W
willzhang4a58 已提交
9
  HashMap<LogicalEdge*, std::string> edge2lbn;
W
willzhang4a58 已提交
10 11
  HashMap<LogicalEdge*, std::string> edge2ibn;
  NaiveBuildGraphStruct(dl_net_conf, &edge2lbn, &edge2ibn);
W
willzhang4a58 已提交
12
  FillNodeWithParallelDesc(strategy_conf);
W
willzhang4a58 已提交
13
  AddCloneNodes(edge2lbn, edge2ibn);
W
willzhang4a58 已提交
14 15
}

W
willzhang4a58 已提交
16 17
void LogicalGraph::NaiveBuildGraphStruct(
    const DLNetConf& dl_net_conf,
W
willzhang4a58 已提交
18 19
    HashMap<LogicalEdge*, std::string>* edge2lbn,
    HashMap<LogicalEdge*, std::string>* edge2ibn) {
W
willzhang4a58 已提交
20
  HashMap<std::string, LogicalNode*> lbn2producer;
W
willzhang4a58 已提交
21 22 23
  // Process Op
  for (int op_i = 0; op_i < dl_net_conf.op_conf_size(); ++op_i) {
    const OperatorConf& cur_op_conf = dl_net_conf.op_conf(op_i);
W
willzhang4a58 已提交
24
    // Construct cur node
W
willzhang4a58 已提交
25
    LogicalNode* cur_node = NewFinalNode();
W
willzhang4a58 已提交
26
    cur_node->mut_op() = ConstructOpFromPbConf(cur_op_conf);
W
willzhang4a58 已提交
27
    // Connect input node
W
willzhang4a58 已提交
28
    for (const std::string& ibn : cur_node->op()->input_bns()) {
W
willzhang4a58 已提交
29
      std::string lbn = cur_node->op()->ibn2lbn(ibn);
W
willzhang4a58 已提交
30 31 32
      LogicalNode* pred_node = lbn2producer.at(lbn);
      LogicalEdge* edge = NewFinalEdge();
      CHECK(edge2lbn->emplace(edge, lbn).second);
W
willzhang4a58 已提交
33
      CHECK(edge2ibn->emplace(edge, ibn).second);
W
willzhang4a58 已提交
34
      Connect(pred_node, edge, cur_node);
W
willzhang4a58 已提交
35
    }
W
willzhang4a58 已提交
36
    // Construct output
W
willzhang4a58 已提交
37
    for (const std::string& obn : cur_node->op()->output_bns()) {
W
willzhang4a58 已提交
38
      std::string lbn = cur_node->op()->obn2lbn(obn);
W
willzhang4a58 已提交
39
      CHECK(lbn2producer.emplace(lbn, cur_node).second);
W
willzhang4a58 已提交
40
    }
W
willzhang4a58 已提交
41
  }
W
willzhang4a58 已提交
42
  lbn2producer.clear();
W
willzhang4a58 已提交
43
  // Post Processing
W
willzhang4a58 已提交
44
  UpdateSourceAndSink();
W
willzhang4a58 已提交
45 46
}

W
willzhang4a58 已提交
47
void LogicalGraph::FillNodeWithParallelDesc(const Strategy& strategy_conf) {
W
willzhang4a58 已提交
48
  HashMap<std::string, LogicalNode*> op_name2node;
W
willzhang4a58 已提交
49
  for (const std::unique_ptr<LogicalNode>& logical_node : nodes()) {
W
willzhang4a58 已提交
50 51
    const std::string& op_name = logical_node->op()->op_name();
    CHECK(op_name2node.emplace(op_name, logical_node.get()).second);
W
willzhang4a58 已提交
52
  }
W
willzhang4a58 已提交
53 54 55 56
  for (int gid = 0; gid < strategy_conf.placement_groups_size(); ++gid) {
    const PlacementGroup& cur_group = strategy_conf.placement_groups(gid);
    for (int li = 0; li < cur_group.op_names_size(); ++li) {
      const std::string& op_name = cur_group.op_names(li);
W
willzhang4a58 已提交
57 58
      auto it = op_name2node.find(op_name);
      CHECK(it != op_name2node.end());
59
      auto parallel_desc_raw_ptr = new ParallelDesc(cur_group.parallel_conf());
W
willzhang4a58 已提交
60
      it->second->mut_parallel_desc().reset(parallel_desc_raw_ptr);
W
willzhang4a58 已提交
61 62
    }
  }
W
willzhang4a58 已提交
63 64
}

W
willzhang4a58 已提交
65
void LogicalGraph::AddCloneNodes(
W
willzhang4a58 已提交
66 67
    const HashMap<LogicalEdge*, std::string>& edge2lbn,
    const HashMap<LogicalEdge*, std::string>& edge2ibn) {
W
willzhang4a58 已提交
68 69 70
  std::vector<CloneInfo> clone_infos;
  CollectCloneInfos(&clone_infos, edge2lbn);
  for (const CloneInfo& clone_info : clone_infos) {
W
willzhang4a58 已提交
71
    AddOneCloneNode(clone_info, edge2ibn);
W
willzhang4a58 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
  }
}

void LogicalGraph::CollectCloneInfos(
    std::vector<CloneInfo>* clone_infos,
    const HashMap<LogicalEdge*, std::string>& edge2lbn) {
  for (const std::unique_ptr<LogicalNode>& cur_node : nodes()) {
    HashMap<std::string, std::vector<LogicalEdge*>> lbn2edges;
    for (LogicalEdge* edge : cur_node->out_edges()) {
      lbn2edges[edge2lbn.at(edge)].push_back(edge);
    }
    for (auto& pair : lbn2edges) {
      const std::string& lbn = pair.first;
      std::vector<LogicalEdge*>& edges = pair.second;
      if (edges.size() <= 1) { continue; }
      // Construct clone op
      OperatorConf pb_op_conf;
      pb_op_conf.set_name("clone_" + lbn + "_" + cur_node->node_id_str());
W
willzhang4a58 已提交
90 91
      pb_op_conf.mutable_clone_conf()->set_out_num(edges.size());
      pb_op_conf.mutable_clone_conf()->set_lbn(lbn);
W
willzhang4a58 已提交
92 93 94 95 96 97 98 99 100 101 102
      auto clone_op = ConstructOpFromPbConf(pb_op_conf);
      // Set clone_info
      CloneInfo clone_info;
      clone_info.clone_op = clone_op;
      clone_info.pred_node = cur_node.get();
      clone_info.edges = std::move(edges);
      clone_infos->push_back(clone_info);
    }
  }
}

W
willzhang4a58 已提交
103 104 105
void LogicalGraph::AddOneCloneNode(
    const CloneInfo& clone_info,
    const HashMap<LogicalEdge*, std::string>& edge2ibn) {
W
willzhang4a58 已提交
106 107 108 109
  LogicalNode* clone_node = NewFinalNode();
  clone_node->mut_op() = clone_info.clone_op;
  Connect(clone_info.pred_node, NewFinalEdge(), clone_node);
  CHECK_EQ(clone_node->op()->output_bns().size(), clone_info.edges.size());
W
willzhang4a58 已提交
110 111 112 113 114
  for (size_t i = 0; i < clone_info.edges.size(); ++i) {
    const std::string& obn = clone_node->op()->output_bns().at(i);
    std::string lbn = clone_node->op()->obn2lbn(obn);
    LogicalEdge* edge = clone_info.edges.at(i);
    const std::string& ibn = edge2ibn.at(edge);
W
willzhang4a58 已提交
115
    LogicalNode* dst_node = edge->dst_node();
W
willzhang4a58 已提交
116
    dst_node->mut_op()->AddSpecialIbn2Lbn(ibn, lbn);
W
willzhang4a58 已提交
117 118 119 120 121
    DisConnect(edge);
    Connect(clone_node, edge, dst_node);
  }
}

W
willzhang4a58 已提交
122
} // namespace oneflow