model_diff_accumulate_task_graph.cpp 1.9 KB
Newer Older
L
LeGend-AI 已提交
1
#include "oneflow/core/graph/model_diff_accumulate_task_graph.h"
L
LeGend-AI 已提交
2
#include "oneflow/core/graph/model_diff_accumulate_comp_task_node.h"
L
LeGend-AI 已提交
3 4 5 6

namespace oneflow {

MdDiffAccTaskGraph::MdDiffAccTaskGraph(
W
willzhang4a58 已提交
7
    const std::string& name, const ChainNode* data_chain,
W
willzhang4a58 已提交
8
    const std::vector<CompTaskNode*>& sorted_fw_comptasks4data_chain) {
L
LeGend-AI 已提交
9
  mut_name() = name;
W
willzhang4a58 已提交
10
  BuildTaskGraph(data_chain);
L
LeGend-AI 已提交
11 12 13 14 15 16
  for (CompTaskNode* fw_task : sorted_fw_comptasks4data_chain) {
    CHECK(parallel_id2fw_task_.emplace(fw_task->parallel_id(), fw_task).second);
  }
  BuildExecAndEnrollLbn2Regsts();
}

W
willzhang4a58 已提交
17
void MdDiffAccTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
L
LeGend-AI 已提交
18 19 20
  // Construct ModelDiffAccOp
  OperatorConf op_conf;
  op_conf.set_name("model_diff_acc_" + NewUniqueId());
W
willzhang4a58 已提交
21
  op_conf.mutable_accumulate_conf();
W
willzhang4a58 已提交
22
  auto model_diff_acc_op = OpMgr::Singleton()->AddOp(op_conf);
L
LeGend-AI 已提交
23
  // ModelDiffAccChain
W
willzhang4a58 已提交
24
  auto chain_gph = of_make_unique<ChainGraph>();
L
LeGend-AI 已提交
25 26
  ChainNode* diff_acc_chain = chain_gph->NewNode();
  diff_acc_chain->mut_op_vec() = {model_diff_acc_op};
L
LeGend-AI 已提交
27 28
  auto parallel_desc4diff_acc =
      new ParallelDesc(*(data_chain->parallel_desc()));
L
LeGend-AI 已提交
29 30 31
  parallel_desc4diff_acc->mut_policy() = kModelParallel;
  diff_acc_chain->mut_parallel_desc().reset(parallel_desc4diff_acc);
  // FakerChain
L
LeGend-AI 已提交
32
  if (data_chain->parallel_desc()->policy() == kDataParallel) {
L
LeGend-AI 已提交
33 34 35 36 37
    ChainNode* faker_chain = chain_gph->NewNode();
    faker_chain->mut_op_vec().clear();
    auto parallel_desc4faker = new ParallelDesc(*(data_chain->parallel_desc()));
    parallel_desc4faker->mut_policy() = kFakerMdUpdt;
    faker_chain->mut_parallel_desc().reset(parallel_desc4faker);
W
willzhang4a58 已提交
38 39
    faker_chain->mut_output_lbns() = {kPackedBlobName};
    diff_acc_chain->mut_input_lbns() = {kPackedBlobName};
L
LeGend-AI 已提交
40 41 42 43
    Connect(faker_chain, chain_gph->NewEdge(), diff_acc_chain);
  }
  //
  chain_gph->UpdateSourceAndSink();
W
willzhang4a58 已提交
44 45
  chain_gph->ToDotWithAutoFilePath();
  BuildFromChainGph<MdDiffAccCompTaskNode>(std::move(chain_gph), false);
L
LeGend-AI 已提交
46 47
}

W
willzhang4a58 已提交
48
}  // namespace oneflow