提交 8c895ee9 编写于 作者: C cheng cheng 提交者: Jinhui Yuan

rm IsBwClone (#1078)

上级 8d2daef3
......@@ -295,7 +295,6 @@ void LogicalGraph::AddOneBackwardClone(const BackwardCloneInfo& clone_info) {
LogicalNode* clone_node = NewNode<NormalBackwardLogicalNode>();
clone_node->mut_op_vec() = {clone_op};
clone_node->mut_parallel_desc() = clone_info.succ_node->parallel_desc();
CHECK(bw_clone2fw_producer_.emplace(clone_node, nullptr).second);
*(clone_op->MutBnInOp2Lbi(clone_op->SoleIbn())) = clone_info.lbi;
*(clone_op->MutBnInOp2Lbi(clone_op->SoleIdbn())) = clone_info.lbi;
......@@ -567,9 +566,6 @@ void LogicalGraph::ConnectFwToBw() {
if (bw_node->fw_node() == nullptr) { return; }
Connect<LogicalNode>(bw_node->fw_node(), NewEdge(), bw_node);
});
for (auto& pair : bw_clone2fw_producer_) {
if (pair.second) { Connect<LogicalNode>(pair.second, NewEdge(), pair.first); }
}
}
void LogicalGraph::UpdateEdge2Ibn(const LogicalEdge* edge, const std::string& ibn) {
......
......@@ -69,7 +69,6 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
HashMap<const LogicalEdge*, std::string> edge2ibn_;
HashMap<const LogicalEdge*, std::string> edge2obn_;
HashMap<LogicalNode*, LogicalNode*> bw_clone2fw_producer_;
};
} // namespace oneflow
......
......@@ -37,7 +37,7 @@ void NormalBackwardCompTaskNode::ConsumeAllRegsts() {
}
}
CompTaskNode* fw_task = GetRelatedFwTaskNode();
if (fw_task && !IsBwClone()) {
if (fw_task) {
const std::list<std::weak_ptr<RegstDesc>>& in_regst = fw_task->GetConsumedRegst("in");
for (std::weak_ptr<RegstDesc> regst : in_regst) { ConsumeRegst("in", regst.lock()); }
}
......@@ -78,7 +78,7 @@ void NormalBackwardCompTaskNode::BuildExecGphAndBindOutDiffRegst() {
}
});
CompTaskNode* fw_task = GetRelatedFwTaskNode();
if (fw_task && !IsBwClone()) {
if (fw_task) {
const HashSet<LogicalBlobId>& lbi_boxing = fw_task->logical_node()->lbi_boxing();
const HashSet<LogicalBlobId>& lbi_121 = fw_task->logical_node()->lbi_121();
std::shared_ptr<RegstDesc> out_regst_boxing = GetSoleConsumedRegst("boxing_out");
......@@ -106,7 +106,6 @@ void NormalBackwardCompTaskNode::BuildExecGphAndBindOutDiffRegst() {
void NormalBackwardCompTaskNode::LinkFwExecNode() {
CompTaskNode* fw_task = GetRelatedFwTaskNode();
if (fw_task == nullptr) { return; }
if (IsBwClone()) { return; }
HashMap<std::string, ExecNode*> op_name2fw_exec;
fw_task->exec_gph().ForEachNode([&](ExecNode* fw_exec) {
CHECK(op_name2fw_exec.emplace(fw_exec->op()->op_name(), fw_exec).second);
......@@ -139,14 +138,7 @@ void NormalBackwardCompTaskNode::BuildInDiffRegst() {
const LogicalBlobId& lbi = cur_node->op()->BnInOp2Lbi(idbn);
CompTaskNode* fw_task = GetRelatedFwTaskNode();
if (fw_task) {
if (IsBwClone()) {
std::list<std::weak_ptr<RegstDesc>> out_regsts;
out_regsts.push_back(GetSoleConsumedRegst("boxing_out"));
out_regsts.push_back(GetSoleConsumedRegst("121_out"));
cur_node->BindBnWithOneOfTheRegsts(GenUnDiffBn(idbn), out_regsts);
} else {
cur_node->BindBnWithOneOfTheRegsts(GenUnDiffBn(idbn), GetConsumedRegst("in"));
}
cur_node->BindBnWithOneOfTheRegsts(GenUnDiffBn(idbn), GetConsumedRegst("in"));
}
if (TryAddLbiToB121RegstAndBindIt(cur_node, idbn, "in_diff") == false) {
CHECK(found_lbis.empty() || found_lbis.find(lbi) != found_lbis.end());
......@@ -174,27 +166,23 @@ void NormalBackwardCompTaskNode::BindModelDiffRegst() {
void NormalBackwardCompTaskNode::InferBlobDescsInProducedRegsts() {
if (GetRelatedFwTaskNode()) {
if (IsBwClone()) {
mut_exec_gph().SoleNode()->InferDiffBlobDescsWithoutFwNode(parallel_ctx());
} else {
std::shared_ptr<RegstDesc> in_diff_regst_boxing = GetProducedRegst("boxing_in_diff");
for (std::weak_ptr<RegstDesc> regst : GetConsumedRegst("in")) {
in_diff_regst_boxing->CopyBlobDescWithoutAddLbi(regst.lock().get());
}
std::shared_ptr<RegstDesc> in_diff_regst_boxing = GetProducedRegst("boxing_in_diff");
for (std::weak_ptr<RegstDesc> regst : GetConsumedRegst("in")) {
in_diff_regst_boxing->CopyBlobDescWithoutAddLbi(regst.lock().get());
}
std::shared_ptr<RegstDesc> in_diff_regst_121 = GetProducedRegst("121_in_diff");
for (std::weak_ptr<RegstDesc> regst : GetConsumedRegst("in")) {
in_diff_regst_121->CopyBlobDescWithoutAddLbi(regst.lock().get());
}
std::shared_ptr<RegstDesc> in_diff_regst_121 = GetProducedRegst("121_in_diff");
for (std::weak_ptr<RegstDesc> regst : GetConsumedRegst("in")) {
in_diff_regst_121->CopyBlobDescWithoutAddLbi(regst.lock().get());
}
std::shared_ptr<RegstDesc> md_diff_regst = GetProducedRegst("model_diff");
if (md_diff_regst) { md_diff_regst->CopyBlobDescFrom(GetSoleConsumedRegst("model").get()); }
std::shared_ptr<RegstDesc> md_diff_regst = GetProducedRegst("model_diff");
if (md_diff_regst) { md_diff_regst->CopyBlobDescFrom(GetSoleConsumedRegst("model").get()); }
std::shared_ptr<RegstDesc> activation_diff_regst = GetProducedRegst("activation_diff");
activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("activation").get());
activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("boxing_out").get());
activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("121_out").get());
}
std::shared_ptr<RegstDesc> activation_diff_regst = GetProducedRegst("activation_diff");
activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("activation").get());
activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("boxing_out").get());
activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("121_out").get());
} else {
mut_exec_gph().SoleNode()->InferDiffBlobDescsWithoutFwNode(parallel_ctx());
}
......@@ -210,11 +198,4 @@ CompTaskNode* NormalBackwardCompTaskNode::GetRelatedFwTaskNode() {
return nullptr;
}
bool NormalBackwardCompTaskNode::IsBwClone() const {
const BackwardLogicalNode* bw_logical_node =
dynamic_cast<const BackwardLogicalNode*>(logical_node());
CHECK_NOTNULL(bw_logical_node);
return bw_logical_node->fw_node() == nullptr;
}
} // namespace oneflow
......@@ -15,7 +15,6 @@ class NormalBackwardCompTaskNode final : public CompTaskNode {
void ConsumeAllRegsts() override;
void BuildExecGphAndRegst() override;
TaskType GetTaskType() const override { return TaskType::kNormalBackward; }
bool IsBwClone() const;
protected:
void BuildExecGphAndBindOutDiffRegst();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册