提交 295baf25 编写于 作者: W willzhang4a58

fix bug: consume in always


Former-commit-id: aa24910e
上级 502e8a5c
......@@ -32,7 +32,7 @@ void BackwardCompTaskNode::ConsumeAllRegsts() {
}
}
if (GetProducedRegst("in_diff")) { ConsumeRegst("in", GetRelatedInRegst()); }
ConsumeRegst("in", GetRelatedInRegst());
}
void BackwardCompTaskNode::BuildExecGphAndRegst() {
......@@ -88,7 +88,6 @@ void BackwardCompTaskNode::BuildActivationDiffRegst() {
void BackwardCompTaskNode::BuildInDiffRegst() {
std::shared_ptr<RegstDesc> in_diff_regst = GetProducedRegst("in_diff");
if (!in_diff_regst) { return; }
std::shared_ptr<RegstDesc> in_regst = GetConsumedRegst("in");
mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {
HashSet<std::string> found_lbns;
......@@ -98,8 +97,10 @@ void BackwardCompTaskNode::BuildInDiffRegst() {
for (const std::string& idbn : cur_node->op()->input_diff_bns()) {
const std::string& lbn = cur_node->op()->Lbn4BnInOp(idbn);
if (found_lbns.find(lbn) != found_lbns.end()) { continue; }
in_diff_regst->AddLbn(lbn);
cur_node->BindBnInOpAndRegst(idbn, in_diff_regst);
if (in_diff_regst) {
in_diff_regst->AddLbn(lbn);
cur_node->BindBnInOpAndRegst(idbn, in_diff_regst);
}
cur_node->BindBnInOpAndRegst(GenUnDiffBn(idbn), in_regst);
}
});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册