From a0b87f1bd76ef96dea8176dab664fbc6f0d8568a Mon Sep 17 00:00:00 2001 From: Juncheng Date: Tue, 31 Dec 2019 14:20:11 +0800 Subject: [PATCH] Dev total loss instance num time shape (#2531) * loss instance num time shape * refine --- oneflow/core/job_completer/autograd.cpp | 49 +++++++++++++++++++++---- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/oneflow/core/job_completer/autograd.cpp b/oneflow/core/job_completer/autograd.cpp index df573b9103..e67446a85d 100644 --- a/oneflow/core/job_completer/autograd.cpp +++ b/oneflow/core/job_completer/autograd.cpp @@ -493,18 +493,53 @@ void AddTotalLossInstanceNumOpConf( const auto& lbi = GenLogicalBlobId(loss_lbn); CHECK(loss_lbi2op_node.emplace(lbi, LossOpNode4OpName(lbi.op_name())).second); } - const BlobDesc* blob_desc = nullptr; + const Shape src_time_shape( + {GlobalJobDesc().TotalBatchNum(), GlobalJobDesc().NumOfPiecesInBatch()}); + const int64_t source_time_shape_elem_cnt = src_time_shape.elem_cnt(); + bool all_loss_time_shape_eq_src = true; for (const auto& pair : loss_lbi2op_node) { - const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first); - if (blob_desc != nullptr) { CHECK(*blob_desc == *cur_blob_desc); } - blob_desc = cur_blob_desc; + const Shape* time_shape = pair.second->out_blob_time_shape(); + const int64_t time_shape_elem_cnt = time_shape->elem_cnt(); + if (time_shape_elem_cnt != source_time_shape_elem_cnt) { + CHECK_EQ(time_shape_elem_cnt % source_time_shape_elem_cnt, 0); + all_loss_time_shape_eq_src = false; + } } HashMap parallel_desc2optimizer_node_cnt; CalcParallelDesc2OptimizerNodeCnt(op_graph, lbi2diff_lbi, ¶llel_desc2optimizer_node_cnt); - if (blob_desc->is_dynamic()) { - AddTotalLossInstanceNumOpConfForDynamicDim0(parallel_desc2optimizer_node_cnt, loss_lbi2op_node, - job_builder, LossInstanceNum4ParallelDesc); + if (all_loss_time_shape_eq_src) { + const BlobDesc* blob_desc = nullptr; + for (const auto& pair : loss_lbi2op_node) { + const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first); + if (blob_desc != nullptr) { CHECK(*blob_desc == *cur_blob_desc); } + blob_desc = cur_blob_desc; + } + if (blob_desc->is_dynamic()) { + AddTotalLossInstanceNumOpConfForDynamicDim0(parallel_desc2optimizer_node_cnt, + loss_lbi2op_node, job_builder, + LossInstanceNum4ParallelDesc); + } else { + BuildConstantOpAsTotalLossInstanceNum(parallel_desc2optimizer_node_cnt, *blob_desc, + job_builder, LossInstanceNum4ParallelDesc); + } } else { + std::unique_ptr blob_desc; + for (const auto& pair : loss_lbi2op_node) { + const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first); + // TODO: support dynamic + CHECK(!cur_blob_desc->is_dynamic()); + const DataType loss_data_type = cur_blob_desc->data_type(); + const int64_t time_shape_elem_cnt = pair.second->out_blob_time_shape()->elem_cnt(); + // TODO: consider batch_axis or sbp + const int64_t loss_elem_cnt = + cur_blob_desc->shape().elem_cnt() * time_shape_elem_cnt / source_time_shape_elem_cnt; + if (blob_desc) { + CHECK_EQ(blob_desc->data_type(), loss_data_type); + CHECK_EQ(blob_desc->shape().elem_cnt(), loss_elem_cnt); + } else { + blob_desc.reset(new BlobDesc(Shape({loss_elem_cnt}), loss_data_type)); + } + } BuildConstantOpAsTotalLossInstanceNum(parallel_desc2optimizer_node_cnt, *blob_desc, job_builder, LossInstanceNum4ParallelDesc); } -- GitLab