From 8c895ee941de4897ab51d025b0b4a3f472668262 Mon Sep 17 00:00:00 2001 From: cheng cheng <472491134@qq.com> Date: Thu, 2 Aug 2018 18:57:10 +0800 Subject: [PATCH] rm IsBwClone (#1078) --- oneflow/core/graph/logical_graph.cpp | 4 -- oneflow/core/graph/logical_graph.h | 1 - .../normal_backward_compute_task_node.cpp | 53 ++++++------------- .../graph/normal_backward_compute_task_node.h | 1 - 4 files changed, 17 insertions(+), 42 deletions(-) diff --git a/oneflow/core/graph/logical_graph.cpp b/oneflow/core/graph/logical_graph.cpp index d6d05b9903..f8af4cc6f3 100644 --- a/oneflow/core/graph/logical_graph.cpp +++ b/oneflow/core/graph/logical_graph.cpp @@ -295,7 +295,6 @@ void LogicalGraph::AddOneBackwardClone(const BackwardCloneInfo& clone_info) { LogicalNode* clone_node = NewNode(); clone_node->mut_op_vec() = {clone_op}; clone_node->mut_parallel_desc() = clone_info.succ_node->parallel_desc(); - CHECK(bw_clone2fw_producer_.emplace(clone_node, nullptr).second); *(clone_op->MutBnInOp2Lbi(clone_op->SoleIbn())) = clone_info.lbi; *(clone_op->MutBnInOp2Lbi(clone_op->SoleIdbn())) = clone_info.lbi; @@ -567,9 +566,6 @@ void LogicalGraph::ConnectFwToBw() { if (bw_node->fw_node() == nullptr) { return; } Connect(bw_node->fw_node(), NewEdge(), bw_node); }); - for (auto& pair : bw_clone2fw_producer_) { - if (pair.second) { Connect(pair.second, NewEdge(), pair.first); } - } } void LogicalGraph::UpdateEdge2Ibn(const LogicalEdge* edge, const std::string& ibn) { diff --git a/oneflow/core/graph/logical_graph.h b/oneflow/core/graph/logical_graph.h index 1c5dbce170..e14feb700c 100644 --- a/oneflow/core/graph/logical_graph.h +++ b/oneflow/core/graph/logical_graph.h @@ -69,7 +69,6 @@ class LogicalGraph final : public Graph { HashMap edge2ibn_; HashMap edge2obn_; - HashMap bw_clone2fw_producer_; }; } // namespace oneflow diff --git a/oneflow/core/graph/normal_backward_compute_task_node.cpp b/oneflow/core/graph/normal_backward_compute_task_node.cpp index 7a61ff7e88..a128809733 100644 --- a/oneflow/core/graph/normal_backward_compute_task_node.cpp +++ b/oneflow/core/graph/normal_backward_compute_task_node.cpp @@ -37,7 +37,7 @@ void NormalBackwardCompTaskNode::ConsumeAllRegsts() { } } CompTaskNode* fw_task = GetRelatedFwTaskNode(); - if (fw_task && !IsBwClone()) { + if (fw_task) { const std::list>& in_regst = fw_task->GetConsumedRegst("in"); for (std::weak_ptr regst : in_regst) { ConsumeRegst("in", regst.lock()); } } @@ -78,7 +78,7 @@ void NormalBackwardCompTaskNode::BuildExecGphAndBindOutDiffRegst() { } }); CompTaskNode* fw_task = GetRelatedFwTaskNode(); - if (fw_task && !IsBwClone()) { + if (fw_task) { const HashSet& lbi_boxing = fw_task->logical_node()->lbi_boxing(); const HashSet& lbi_121 = fw_task->logical_node()->lbi_121(); std::shared_ptr out_regst_boxing = GetSoleConsumedRegst("boxing_out"); @@ -106,7 +106,6 @@ void NormalBackwardCompTaskNode::BuildExecGphAndBindOutDiffRegst() { void NormalBackwardCompTaskNode::LinkFwExecNode() { CompTaskNode* fw_task = GetRelatedFwTaskNode(); if (fw_task == nullptr) { return; } - if (IsBwClone()) { return; } HashMap op_name2fw_exec; fw_task->exec_gph().ForEachNode([&](ExecNode* fw_exec) { CHECK(op_name2fw_exec.emplace(fw_exec->op()->op_name(), fw_exec).second); @@ -139,14 +138,7 @@ void NormalBackwardCompTaskNode::BuildInDiffRegst() { const LogicalBlobId& lbi = cur_node->op()->BnInOp2Lbi(idbn); CompTaskNode* fw_task = GetRelatedFwTaskNode(); if (fw_task) { - if (IsBwClone()) { - std::list> out_regsts; - out_regsts.push_back(GetSoleConsumedRegst("boxing_out")); - out_regsts.push_back(GetSoleConsumedRegst("121_out")); - cur_node->BindBnWithOneOfTheRegsts(GenUnDiffBn(idbn), out_regsts); - } else { - cur_node->BindBnWithOneOfTheRegsts(GenUnDiffBn(idbn), GetConsumedRegst("in")); - } + cur_node->BindBnWithOneOfTheRegsts(GenUnDiffBn(idbn), GetConsumedRegst("in")); } if (TryAddLbiToB121RegstAndBindIt(cur_node, idbn, "in_diff") == false) { CHECK(found_lbis.empty() || found_lbis.find(lbi) != found_lbis.end()); @@ -174,27 +166,23 @@ void NormalBackwardCompTaskNode::BindModelDiffRegst() { void NormalBackwardCompTaskNode::InferBlobDescsInProducedRegsts() { if (GetRelatedFwTaskNode()) { - if (IsBwClone()) { - mut_exec_gph().SoleNode()->InferDiffBlobDescsWithoutFwNode(parallel_ctx()); - } else { - std::shared_ptr in_diff_regst_boxing = GetProducedRegst("boxing_in_diff"); - for (std::weak_ptr regst : GetConsumedRegst("in")) { - in_diff_regst_boxing->CopyBlobDescWithoutAddLbi(regst.lock().get()); - } + std::shared_ptr in_diff_regst_boxing = GetProducedRegst("boxing_in_diff"); + for (std::weak_ptr regst : GetConsumedRegst("in")) { + in_diff_regst_boxing->CopyBlobDescWithoutAddLbi(regst.lock().get()); + } - std::shared_ptr in_diff_regst_121 = GetProducedRegst("121_in_diff"); - for (std::weak_ptr regst : GetConsumedRegst("in")) { - in_diff_regst_121->CopyBlobDescWithoutAddLbi(regst.lock().get()); - } + std::shared_ptr in_diff_regst_121 = GetProducedRegst("121_in_diff"); + for (std::weak_ptr regst : GetConsumedRegst("in")) { + in_diff_regst_121->CopyBlobDescWithoutAddLbi(regst.lock().get()); + } - std::shared_ptr md_diff_regst = GetProducedRegst("model_diff"); - if (md_diff_regst) { md_diff_regst->CopyBlobDescFrom(GetSoleConsumedRegst("model").get()); } + std::shared_ptr md_diff_regst = GetProducedRegst("model_diff"); + if (md_diff_regst) { md_diff_regst->CopyBlobDescFrom(GetSoleConsumedRegst("model").get()); } - std::shared_ptr activation_diff_regst = GetProducedRegst("activation_diff"); - activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("activation").get()); - activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("boxing_out").get()); - activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("121_out").get()); - } + std::shared_ptr activation_diff_regst = GetProducedRegst("activation_diff"); + activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("activation").get()); + activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("boxing_out").get()); + activation_diff_regst->CopyBlobDescWithoutAddLbi(GetSoleConsumedRegst("121_out").get()); } else { mut_exec_gph().SoleNode()->InferDiffBlobDescsWithoutFwNode(parallel_ctx()); } @@ -210,11 +198,4 @@ CompTaskNode* NormalBackwardCompTaskNode::GetRelatedFwTaskNode() { return nullptr; } -bool NormalBackwardCompTaskNode::IsBwClone() const { - const BackwardLogicalNode* bw_logical_node = - dynamic_cast(logical_node()); - CHECK_NOTNULL(bw_logical_node); - return bw_logical_node->fw_node() == nullptr; -} - } // namespace oneflow diff --git a/oneflow/core/graph/normal_backward_compute_task_node.h b/oneflow/core/graph/normal_backward_compute_task_node.h index c939ef89e1..f524c116bf 100644 --- a/oneflow/core/graph/normal_backward_compute_task_node.h +++ b/oneflow/core/graph/normal_backward_compute_task_node.h @@ -15,7 +15,6 @@ class NormalBackwardCompTaskNode final : public CompTaskNode { void ConsumeAllRegsts() override; void BuildExecGphAndRegst() override; TaskType GetTaskType() const override { return TaskType::kNormalBackward; } - bool IsBwClone() const; protected: void BuildExecGphAndBindOutDiffRegst(); -- GitLab