提交 59b9db64 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

bugfix: bind in regst in backward task nodes (#1193)



Former-commit-id: 400cf2a6
上级 6c7fb61c
......@@ -48,6 +48,7 @@ void NormalBackwardCompTaskNode::BuildExecGphAndRegst() {
LinkFwExecNode();
BuildActivationDiffRegst();
BuildInDiffRegst();
BindInRegst();
BindModelDiffRegst();
InferBlobDescsInProducedRegsts();
}
......@@ -129,10 +130,6 @@ void NormalBackwardCompTaskNode::BuildInDiffRegst() {
}
for (const std::string& idbn : cur_node->op()->input_diff_bns()) {
const LogicalBlobId& lbi = cur_node->op()->BnInOp2Lbi(idbn);
CompTaskNode* fw_task = GetRelatedFwTaskNode();
if (fw_task) {
cur_node->BindBnWithOneOfTheRegsts(GenUnDiffBn(idbn), GetConsumedRegst("in"));
}
if (logical_node()->IsDataLbiOnOutEdge(lbi)) {
in_diff_regst->AddLbi(lbi);
cur_node->BindBnWithRegst(idbn, in_diff_regst);
......@@ -143,6 +140,16 @@ void NormalBackwardCompTaskNode::BuildInDiffRegst() {
});
}
void NormalBackwardCompTaskNode::BindInRegst() {
mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {
for (const std::string& ibn : cur_node->op()->input_bns()) {
if (GetRelatedFwTaskNode()) {
cur_node->BindBnWithOneOfTheRegsts(ibn, GetConsumedRegst("in"));
}
}
});
}
void NormalBackwardCompTaskNode::BindModelDiffRegst() {
std::shared_ptr<RegstDesc> data_tmp_regst = GetSoleConsumedRegst("data_tmp");
std::shared_ptr<RegstDesc> bw_buf_regst = GetProducedRegst("bw_buf");
......
......@@ -26,6 +26,7 @@ class NormalBackwardCompTaskNode final : public CompTaskNode {
void FixPackedBlobDescOfProducedRegst() override;
void LinkFwExecNode();
void BindModelDiffRegst();
void BindInRegst();
void InferBlobDescsInProducedRegsts();
CompTaskNode* GetRelatedFwTaskNode();
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册