提交 a0b87f1b 编写于 作者: J Juncheng 提交者: Li Xinqi

Dev total loss instance num time shape (#2531)

* loss instance num time shape

* refine
上级 a77d1ebe
......@@ -493,18 +493,53 @@ void AddTotalLossInstanceNumOpConf(
const auto& lbi = GenLogicalBlobId(loss_lbn);
CHECK(loss_lbi2op_node.emplace(lbi, LossOpNode4OpName(lbi.op_name())).second);
}
const BlobDesc* blob_desc = nullptr;
const Shape src_time_shape(
{GlobalJobDesc().TotalBatchNum(), GlobalJobDesc().NumOfPiecesInBatch()});
const int64_t source_time_shape_elem_cnt = src_time_shape.elem_cnt();
bool all_loss_time_shape_eq_src = true;
for (const auto& pair : loss_lbi2op_node) {
const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first);
if (blob_desc != nullptr) { CHECK(*blob_desc == *cur_blob_desc); }
blob_desc = cur_blob_desc;
const Shape* time_shape = pair.second->out_blob_time_shape();
const int64_t time_shape_elem_cnt = time_shape->elem_cnt();
if (time_shape_elem_cnt != source_time_shape_elem_cnt) {
CHECK_EQ(time_shape_elem_cnt % source_time_shape_elem_cnt, 0);
all_loss_time_shape_eq_src = false;
}
}
HashMap<ParallelDesc, int32_t> parallel_desc2optimizer_node_cnt;
CalcParallelDesc2OptimizerNodeCnt(op_graph, lbi2diff_lbi, &parallel_desc2optimizer_node_cnt);
if (blob_desc->is_dynamic()) {
AddTotalLossInstanceNumOpConfForDynamicDim0(parallel_desc2optimizer_node_cnt, loss_lbi2op_node,
job_builder, LossInstanceNum4ParallelDesc);
if (all_loss_time_shape_eq_src) {
const BlobDesc* blob_desc = nullptr;
for (const auto& pair : loss_lbi2op_node) {
const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first);
if (blob_desc != nullptr) { CHECK(*blob_desc == *cur_blob_desc); }
blob_desc = cur_blob_desc;
}
if (blob_desc->is_dynamic()) {
AddTotalLossInstanceNumOpConfForDynamicDim0(parallel_desc2optimizer_node_cnt,
loss_lbi2op_node, job_builder,
LossInstanceNum4ParallelDesc);
} else {
BuildConstantOpAsTotalLossInstanceNum(parallel_desc2optimizer_node_cnt, *blob_desc,
job_builder, LossInstanceNum4ParallelDesc);
}
} else {
std::unique_ptr<BlobDesc> blob_desc;
for (const auto& pair : loss_lbi2op_node) {
const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first);
// TODO: support dynamic
CHECK(!cur_blob_desc->is_dynamic());
const DataType loss_data_type = cur_blob_desc->data_type();
const int64_t time_shape_elem_cnt = pair.second->out_blob_time_shape()->elem_cnt();
// TODO: consider batch_axis or sbp
const int64_t loss_elem_cnt =
cur_blob_desc->shape().elem_cnt() * time_shape_elem_cnt / source_time_shape_elem_cnt;
if (blob_desc) {
CHECK_EQ(blob_desc->data_type(), loss_data_type);
CHECK_EQ(blob_desc->shape().elem_cnt(), loss_elem_cnt);
} else {
blob_desc.reset(new BlobDesc(Shape({loss_elem_cnt}), loss_data_type));
}
}
BuildConstantOpAsTotalLossInstanceNum(parallel_desc2optimizer_node_cnt, *blob_desc, job_builder,
LossInstanceNum4ParallelDesc);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册