提交 9a59515a 编写于 作者: X Xinqi

refine code


Former-commit-id: 8ed2fd1609f3c3ee3b6cbe0a5d1f380afea8bd25
上级 e3978c9f
......@@ -420,23 +420,20 @@ void TaskGraph::EnableMemSharingInVariableOp() {
if (variable_op == nullptr) { return; }
std::string model_bn = variable_op->op_conf().variable_conf().model_name();
auto* fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
if (fw_task_node) {
const LogicalBlobId& lbi = variable_op->BnInOp2Lbi(model_bn);
RegstDesc* model_regst = fw_task_node->GetSoleConsumedRegst("model").get();
if (model_regst->enable_mem_sharing() == false) {
model_regst->set_enable_mem_sharing(true);
model_regst->set_mem_shared_id(Global<IDMgr>::Get()->NewMemSharedId());
model_regst->set_mem_shared_offset(0);
}
RegstDesc* out_regst = fw_task_node->GetProducedRegst("out").get();
CHECK_EQ(out_regst->NumOfLbi(), 1);
out_regst->set_enable_mem_sharing(true);
out_regst->set_mem_shared_id(model_regst->mem_shared_id());
out_regst->set_mem_shared_offset(model_regst->mem_shared_offset()
+ model_regst->ByteOffsetInPackedBlobDescBody(lbi));
} else {
// do nothing
if (fw_task_node == nullptr) { return; }
const LogicalBlobId& lbi = variable_op->BnInOp2Lbi(model_bn);
RegstDesc* model_regst = fw_task_node->GetSoleConsumedRegst("model").get();
if (model_regst->enable_mem_sharing() == false) {
model_regst->set_enable_mem_sharing(true);
model_regst->set_mem_shared_id(Global<IDMgr>::Get()->NewMemSharedId());
model_regst->set_mem_shared_offset(0);
}
RegstDesc* out_regst = fw_task_node->GetProducedRegst("out").get();
CHECK_EQ(out_regst->NumOfLbi(), 1);
out_regst->set_enable_mem_sharing(true);
out_regst->set_mem_shared_id(model_regst->mem_shared_id());
out_regst->set_mem_shared_offset(model_regst->mem_shared_offset()
+ model_regst->ByteOffsetInPackedBlobDescBody(lbi));
});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册