loss_accumulate_comp_task_node.cpp 1.3 KB
Newer Older
W
willzhang4a58 已提交
1 2 3 4 5 6 7 8 9 10 11 12
#include "oneflow/core/graph/loss_accumulate_comp_task_node.h"
#include "oneflow/core/graph/loss_accumulate_task_graph.h"

namespace oneflow {

void LossAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
  if (chain_node()->op_vec().empty()) {
    CompTaskNode* loss_task = static_cast<LossAccTaskGraph*>(gph)->loss_task();
    auto loss_regst = loss_task->GetProducedRegstDesc("loss");
    BindProducedRegstAndOutEdge(loss_regst, SoleOutEdge());
    return;
  }
13
  NewProducedRegstDesc("loss_acc", 1, kMaxRegisterNum);
W
willzhang4a58 已提交
14 15 16 17 18 19 20 21 22 23 24
  auto loss_regst = GetRelatedRegst(SoleInEdge());
  auto loss_acc_regst = GetProducedRegstDesc("loss_acc");
  ExecNode* exec_node = mut_exec_gph().NewNode();
  exec_node->mut_op() = chain_node()->SoleOp();
  exec_node->BindBnInOpAndRegst(exec_node->op()->SoleIbn(), loss_regst);
  exec_node->BindBnInOpAndRegst(exec_node->op()->SoleObn(), loss_acc_regst);
  ConsumeRegstDesc("loss", loss_regst);
  loss_acc_regst->CopyLbnFrom(loss_regst.get());
  mut_exec_gph().UpdateSourceAndSink();
}

W
willzhang4a58 已提交
25
void LossAccCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
W
willzhang4a58 已提交
26 27 28
  if (!chain_node()->op_vec().empty()) {
    auto loss_regst = GetConsumedRegstDesc("loss");
    auto loss_acc_regst = GetProducedRegstDesc("loss_acc");
W
willzhang4a58 已提交
29
    loss_acc_regst->CopyBlobDescFrom(loss_regst.get());
W
willzhang4a58 已提交
30 31 32 33
  }
}

}  // namespace oneflow