model_save_comp_task_node.cpp 1.4 KB
Newer Older
J
jiyuan 已提交
1 2
#include "oneflow/core/graph/model_save_comp_task_node.h"
#include "oneflow/core/graph/model_save_task_graph.h"
W
willzhang4a58 已提交
3
#include "oneflow/core/graph/model_update_comp_task_node.h"
4 5 6 7 8

namespace oneflow {

void MdSaveCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
  CHECK(IsFwNode());
W
willzhang4a58 已提交
9
  auto md_save_gph = static_cast<MdSaveTaskGraph*>(gph);
L
LeGend-AI 已提交
10
  CompTaskNode* updt_task = md_save_gph->update_task();
11 12 13 14
  if (in_edges().empty()) {
    BindProducedRegstAndOutEdge(updt_task->GetProducedRegstDesc("model"),
                                SoleOutEdge());
  } else if (out_edges().empty()) {
J
Jinhui Yuan 已提交
15
    ConsumeRegstDesc("model", GetRelatedRegst(SoleInEdge()));
16 17

    OperatorConf op_conf;
L
LeGend-AI 已提交
18
    op_conf.set_name("model_save_op" + updt_task->node_id_str());
19
    op_conf.mutable_model_save_conf();
20 21 22
    GetRelatedRegst(SoleInEdge())->ForEachLbn([&](const std::string& lbn) {
      op_conf.mutable_model_save_conf()->add_lbns(lbn);
    });
23 24

    ExecNode* exec_node = mut_exec_gph().NewNode();
W
willzhang4a58 已提交
25
    exec_node->mut_op() = OpMgr::Singleton()->AddOp(op_conf);
26
    for (const std::string& ibn : exec_node->op()->input_bns()) {
27 28 29
      exec_node->BindBnInOpAndRegst(ibn, GetRelatedRegst(SoleInEdge()));
    }
    mut_exec_gph().UpdateSourceAndSink();
30 31 32 33 34
  } else {
    UNEXPECTED_RUN();
  }
}

W
willzhang4a58 已提交
35
void MdSaveCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
36 37 38
  CHECK(IsFwNode());
}

W
willzhang4a58 已提交
39
}  // namespace oneflow