From 4dc5918e1d64cdcd3e5449dfee5905143bd1077c Mon Sep 17 00:00:00 2001 From: willzhang4a58 Date: Fri, 31 Mar 2017 00:49:22 -0400 Subject: [PATCH] Get Relalted Register/Edge --- oneflow/graph/comp_task_node.cpp | 12 ++++++------ oneflow/graph/in_boxing_task_node.cpp | 8 ++++---- oneflow/graph/task_node.cpp | 5 +++-- oneflow/graph/task_node.h | 12 +++++++----- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/oneflow/graph/comp_task_node.cpp b/oneflow/graph/comp_task_node.cpp index 0de3e8faae..b7fef53521 100644 --- a/oneflow/graph/comp_task_node.cpp +++ b/oneflow/graph/comp_task_node.cpp @@ -143,18 +143,18 @@ void CompTaskNode::FwSetRegisterPtrs4ExecNodes( for (const auto& pair : extern_in_lbn2consumers) { const std::string& lbn = pair.first; for (ExecNode* consumer : pair.second) { - consumer->AddConsumedLbnRegiPair(lbn, SoleInEdge()->register_desc()); + consumer->AddConsumedLbnRegiPair(lbn, GetRelatedRegister(SoleInEdge())); } } // Out Register Desc for (const auto& lbn : chain_node()->output_lbns()) { ExecNode* producer = lbn2producer.at(lbn); - producer->AddProducedLbnRegiPair(lbn, SoleOutEdge()->register_desc()); + producer->AddProducedLbnRegiPair(lbn, GetRelatedRegister(SoleOutEdge())); } } void CompTaskNode::FwSetProducedRegisterDescs() { - RegisterDesc* data_register = SoleOutEdge()->register_desc(); + RegisterDesc* data_register = GetRelatedRegister(SoleOutEdge()); for (const std::unique_ptr& cur_edge : exec_graph().edges()) { data_register->AddPbn(cur_edge->pbn()); } @@ -219,7 +219,7 @@ void CompTaskNode::BpSetRegisterDescPtrs4Nodes( for (const auto& odbn : bp_node->op()->output_diff_blob_names()) { std::string lbn = bp_node->op()->odbn2lbn(odbn); if (found_lbns.find(lbn) == found_lbns.end()) { - bp_node->AddConsumedLbnRegiPair(lbn, SoleInEdge()->register_desc()); + bp_node->AddConsumedLbnRegiPair(lbn, GetRelatedRegister(SoleInEdge())); } } } @@ -227,14 +227,14 @@ void CompTaskNode::BpSetRegisterDescPtrs4Nodes( for (ExecEdge* edge : cp_in_node->out_edges()) { const std::string& lbn = edge->lbn(); ExecNode* bp_node = fw_node2bp_node.at(edge->dst_node()); - bp_node->AddProducedLbnRegiPair(lbn, SoleOutEdge()->register_desc()); + bp_node->AddProducedLbnRegiPair(lbn, GetRelatedRegister(SoleOutEdge())); } } void CompTaskNode::BpSetProducedRegisterDescs() { std::unique_ptr model_diff_register(new ContigRegistDesc); std::unique_ptr model_tmp_register(new DisContigRegistDesc); - RegisterDesc* data_diff_register = SoleOutEdge()->register_desc(); + RegisterDesc* data_diff_register = GetRelatedRegister(SoleOutEdge()); for (const std::unique_ptr& cur_edge : exec_graph().edges()) { data_diff_register->AddPbn(cur_edge->pbn()); } diff --git a/oneflow/graph/in_boxing_task_node.cpp b/oneflow/graph/in_boxing_task_node.cpp index cdbeab1cf2..fc11b90dd1 100644 --- a/oneflow/graph/in_boxing_task_node.cpp +++ b/oneflow/graph/in_boxing_task_node.cpp @@ -148,13 +148,13 @@ void InBoxingTaskNode::FwBuildChainSortedEdgesPair( ExecNode* first_node = mut_exec_graph().NewExecNode(); first_node->mut_op() = op_pair.first; for (const TaskEdge* edge : sorted_in_edges) { - first_node->AddConsumedLbnRegiPair(lbn, edge->register_desc()); + first_node->AddConsumedLbnRegiPair(lbn, GetRelatedRegister(edge)); } // Second Node ExecNode* second_node = mut_exec_graph().NewExecNode(); second_node->mut_op() = op_pair.second; for (const TaskEdge* edge : sorted_out_edges) { - second_node->AddProducedLbnRegiPair(lbn, edge->register_desc()); + second_node->AddProducedLbnRegiPair(lbn, GetRelatedRegister(edge)); } // Connect Connect(first_node, mut_exec_graph().NewExecEdge(lbn), second_node); @@ -179,9 +179,9 @@ void InBoxingTaskNode::SetProducedRegister() { namespace { RegisterDesc* GetBpRegisterFromFwRegister(RegisterDesc* fw_register) { - const TaskEdge* fw_edge = GetRelatedTaskEdge4Register(fw_register); + const TaskEdge* fw_edge = GetRelatedTaskEdge(fw_register); const TaskEdge* bp_edge = fw_edge->related_fwbp_edge(); - return bp_edge->register_desc(); + return GetRelatedRegister(bp_edge); } } diff --git a/oneflow/graph/task_node.cpp b/oneflow/graph/task_node.cpp index 5d1e729079..617520c4bc 100644 --- a/oneflow/graph/task_node.cpp +++ b/oneflow/graph/task_node.cpp @@ -72,8 +72,9 @@ RegisterDesc* TaskNode::GetProducedRegister4OutEdge(const TaskEdge* edge) const void TaskNode::SubscribeRegisterDescInnerPath() { for (const TaskEdge* edge : in_edges()) { - edge->register_desc()->AddSubscriber(this); - subscribed_register_descs_.insert(edge->register_desc()); + RegisterDesc* regi = GetRelatedRegister(edge); + regi->AddSubscriber(this); + subscribed_register_descs_.insert(regi); } } diff --git a/oneflow/graph/task_node.h b/oneflow/graph/task_node.h index 5e81eb2f54..a62d793b6f 100644 --- a/oneflow/graph/task_node.h +++ b/oneflow/graph/task_node.h @@ -33,6 +33,8 @@ class TaskNode : public Node { // std::unique_ptr BuildAndConnectBpNode(); void BuildExecGraphAndSetRegisterDescs(); + + // const TaskEdge* GetOutEdge4ProducedRegister(RegisterDesc*) const; RegisterDesc* GetProducedRegister4OutEdge(const TaskEdge*) const; @@ -90,16 +92,16 @@ class TaskEdge final : public Edge { related_fwbp_edge_ = new_val; } - RegisterDesc* register_desc() const { - return src_node()->GetProducedRegister4OutEdge(this); - } - private: TaskEdge* related_fwbp_edge_; }; -inline const TaskEdge* GetRelatedTaskEdge4Register(RegisterDesc* regi) { +inline RegisterDesc* GetRelatedRegister(const TaskEdge* edge) { + return edge->src_node()->GetProducedRegister4OutEdge(edge); +} + +inline const TaskEdge* GetRelatedTaskEdge(RegisterDesc* regi) { return regi->GetProducer()->GetOutEdge4ProducedRegister(regi); } -- GitLab