logical_graph.cpp 2.0 KB
Newer Older
W
willzhang4a58 已提交
1
#include "graph/logical_graph.h"
W
willzhang4a58 已提交
2
#include "glog/logging.h"
W
willzhang4a58 已提交
3
#include "operator/operator_factory.h"
W
willzhang4a58 已提交
4 5 6

namespace oneflow {

7 8
LogicalGraph::LogicalGraph(const DLNetConf& dl_net_conf,
                           const Strategy& strategy_conf) {
W
willzhang4a58 已提交
9
  BuildGraphStruct(dl_net_conf);
W
willzhang4a58 已提交
10
  FillNodeWithParallelDesc(strategy_conf);
W
willzhang4a58 已提交
11 12
}

W
willzhang4a58 已提交
13
void LogicalGraph::BuildGraphStruct(const DLNetConf& dl_net_conf) {
W
willzhang4a58 已提交
14
  HashMap<std::string, LogicalNode*> lbn2node;
W
willzhang4a58 已提交
15 16 17
  // Process Op
  for (int op_i = 0; op_i < dl_net_conf.op_conf_size(); ++op_i) {
    const OperatorConf& cur_op_conf = dl_net_conf.op_conf(op_i);
W
willzhang4a58 已提交
18
    // Construct cur node
W
willzhang4a58 已提交
19
    LogicalNode* cur_node = NewFinalNode();
W
willzhang4a58 已提交
20
    cur_node->mut_op() = ConstructOpFromPbConf(cur_op_conf);
W
willzhang4a58 已提交
21
    // Connect input node
W
willzhang4a58 已提交
22
    for (const std::string& ibn : cur_node->op()->input_bns()) {
W
willzhang4a58 已提交
23
      std::string lbn = cur_node->op()->ibn2lbn(ibn);
W
willzhang4a58 已提交
24 25
      LogicalNode* pred_node = lbn2node.at(lbn);
      Connect(pred_node, NewFinalEdge(), cur_node);
W
willzhang4a58 已提交
26
    }
W
willzhang4a58 已提交
27
    // Construct output
W
willzhang4a58 已提交
28
    for (const std::string& obn : cur_node->op()->output_bns()) {
W
willzhang4a58 已提交
29
      std::string lbn = cur_node->op()->obn2lbn(obn);
W
willzhang4a58 已提交
30
      lbn2node.emplace(lbn, cur_node);
W
willzhang4a58 已提交
31
    }
W
willzhang4a58 已提交
32
  }
W
willzhang4a58 已提交
33
  lbn2node.clear();
W
willzhang4a58 已提交
34
  // Post Processing
W
willzhang4a58 已提交
35
  UpdateSourceAndSink();
W
willzhang4a58 已提交
36 37
}

W
willzhang4a58 已提交
38
void LogicalGraph::FillNodeWithParallelDesc(const Strategy& strategy_conf) {
W
willzhang4a58 已提交
39
  HashMap<std::string, LogicalNode*> op_name2node;
W
willzhang4a58 已提交
40
  for (const std::unique_ptr<LogicalNode>& logical_node : nodes()) {
W
willzhang4a58 已提交
41 42
    const std::string& op_name = logical_node->op()->op_name();
    CHECK(op_name2node.emplace(op_name, logical_node.get()).second);
W
willzhang4a58 已提交
43
  }
W
willzhang4a58 已提交
44 45 46 47
  for (int gid = 0; gid < strategy_conf.placement_groups_size(); ++gid) {
    const PlacementGroup& cur_group = strategy_conf.placement_groups(gid);
    for (int li = 0; li < cur_group.op_names_size(); ++li) {
      const std::string& op_name = cur_group.op_names(li);
W
willzhang4a58 已提交
48 49
      auto it = op_name2node.find(op_name);
      CHECK(it != op_name2node.end());
50
      auto parallel_desc_raw_ptr = new ParallelDesc(cur_group.parallel_conf());
W
willzhang4a58 已提交
51
      it->second->mut_parallel_desc().reset(parallel_desc_raw_ptr);
W
willzhang4a58 已提交
52 53
    }
  }
W
willzhang4a58 已提交
54 55 56
}

} // namespace oneflow