提交 4dc5918e 编写于 作者: W willzhang4a58

Get Relalted Register/Edge

上级 eb80f68a
......@@ -143,18 +143,18 @@ void CompTaskNode::FwSetRegisterPtrs4ExecNodes(
for (const auto& pair : extern_in_lbn2consumers) {
const std::string& lbn = pair.first;
for (ExecNode* consumer : pair.second) {
consumer->AddConsumedLbnRegiPair(lbn, SoleInEdge()->register_desc());
consumer->AddConsumedLbnRegiPair(lbn, GetRelatedRegister(SoleInEdge()));
}
}
// Out Register Desc
for (const auto& lbn : chain_node()->output_lbns()) {
ExecNode* producer = lbn2producer.at(lbn);
producer->AddProducedLbnRegiPair(lbn, SoleOutEdge()->register_desc());
producer->AddProducedLbnRegiPair(lbn, GetRelatedRegister(SoleOutEdge()));
}
}
void CompTaskNode::FwSetProducedRegisterDescs() {
RegisterDesc* data_register = SoleOutEdge()->register_desc();
RegisterDesc* data_register = GetRelatedRegister(SoleOutEdge());
for (const std::unique_ptr<ExecEdge>& cur_edge : exec_graph().edges()) {
data_register->AddPbn(cur_edge->pbn());
}
......@@ -219,7 +219,7 @@ void CompTaskNode::BpSetRegisterDescPtrs4Nodes(
for (const auto& odbn : bp_node->op()->output_diff_blob_names()) {
std::string lbn = bp_node->op()->odbn2lbn(odbn);
if (found_lbns.find(lbn) == found_lbns.end()) {
bp_node->AddConsumedLbnRegiPair(lbn, SoleInEdge()->register_desc());
bp_node->AddConsumedLbnRegiPair(lbn, GetRelatedRegister(SoleInEdge()));
}
}
}
......@@ -227,14 +227,14 @@ void CompTaskNode::BpSetRegisterDescPtrs4Nodes(
for (ExecEdge* edge : cp_in_node->out_edges()) {
const std::string& lbn = edge->lbn();
ExecNode* bp_node = fw_node2bp_node.at(edge->dst_node());
bp_node->AddProducedLbnRegiPair(lbn, SoleOutEdge()->register_desc());
bp_node->AddProducedLbnRegiPair(lbn, GetRelatedRegister(SoleOutEdge()));
}
}
void CompTaskNode::BpSetProducedRegisterDescs() {
std::unique_ptr<RegisterDesc> model_diff_register(new ContigRegistDesc);
std::unique_ptr<RegisterDesc> model_tmp_register(new DisContigRegistDesc);
RegisterDesc* data_diff_register = SoleOutEdge()->register_desc();
RegisterDesc* data_diff_register = GetRelatedRegister(SoleOutEdge());
for (const std::unique_ptr<ExecEdge>& cur_edge : exec_graph().edges()) {
data_diff_register->AddPbn(cur_edge->pbn());
}
......
......@@ -148,13 +148,13 @@ void InBoxingTaskNode::FwBuildChainSortedEdgesPair(
ExecNode* first_node = mut_exec_graph().NewExecNode();
first_node->mut_op() = op_pair.first;
for (const TaskEdge* edge : sorted_in_edges) {
first_node->AddConsumedLbnRegiPair(lbn, edge->register_desc());
first_node->AddConsumedLbnRegiPair(lbn, GetRelatedRegister(edge));
}
// Second Node
ExecNode* second_node = mut_exec_graph().NewExecNode();
second_node->mut_op() = op_pair.second;
for (const TaskEdge* edge : sorted_out_edges) {
second_node->AddProducedLbnRegiPair(lbn, edge->register_desc());
second_node->AddProducedLbnRegiPair(lbn, GetRelatedRegister(edge));
}
// Connect
Connect(first_node, mut_exec_graph().NewExecEdge(lbn), second_node);
......@@ -179,9 +179,9 @@ void InBoxingTaskNode::SetProducedRegister() {
namespace {
RegisterDesc* GetBpRegisterFromFwRegister(RegisterDesc* fw_register) {
const TaskEdge* fw_edge = GetRelatedTaskEdge4Register(fw_register);
const TaskEdge* fw_edge = GetRelatedTaskEdge(fw_register);
const TaskEdge* bp_edge = fw_edge->related_fwbp_edge();
return bp_edge->register_desc();
return GetRelatedRegister(bp_edge);
}
}
......
......@@ -72,8 +72,9 @@ RegisterDesc* TaskNode::GetProducedRegister4OutEdge(const TaskEdge* edge) const
void TaskNode::SubscribeRegisterDescInnerPath() {
for (const TaskEdge* edge : in_edges()) {
edge->register_desc()->AddSubscriber(this);
subscribed_register_descs_.insert(edge->register_desc());
RegisterDesc* regi = GetRelatedRegister(edge);
regi->AddSubscriber(this);
subscribed_register_descs_.insert(regi);
}
}
......
......@@ -33,6 +33,8 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
//
std::unique_ptr<TaskNode> BuildAndConnectBpNode();
void BuildExecGraphAndSetRegisterDescs();
//
const TaskEdge* GetOutEdge4ProducedRegister(RegisterDesc*) const;
RegisterDesc* GetProducedRegister4OutEdge(const TaskEdge*) const;
......@@ -90,16 +92,16 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> {
related_fwbp_edge_ = new_val;
}
RegisterDesc* register_desc() const {
return src_node()->GetProducedRegister4OutEdge(this);
}
private:
TaskEdge* related_fwbp_edge_;
};
inline const TaskEdge* GetRelatedTaskEdge4Register(RegisterDesc* regi) {
inline RegisterDesc* GetRelatedRegister(const TaskEdge* edge) {
return edge->src_node()->GetProducedRegister4OutEdge(edge);
}
inline const TaskEdge* GetRelatedTaskEdge(RegisterDesc* regi) {
return regi->GetProducer()->GetOutEdge4ProducedRegister(regi);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册