提交 9f8a4986 编写于 作者: W willzhang4a58

template pattern for boxing build exec

上级 b1a23d70
...@@ -55,6 +55,12 @@ OpPair FwBuildBoxingOpModelModel() { ...@@ -55,6 +55,12 @@ OpPair FwBuildBoxingOpModelModel() {
} }
void BoxingTaskNode::FwBuildExecGraphAndSetProducedRegisterDescs() {
SetOutEdgeRegisterPtr();
FwBuildExecGraph();
SetProducedRegister();
}
void BoxingTaskNode::SetOutEdgeRegisterPtr() { void BoxingTaskNode::SetOutEdgeRegisterPtr() {
for (TaskEdge* edge : out_edges()) { for (TaskEdge* edge : out_edges()) {
std::string name = "boxing_out_" + std::to_string(edge->edge_id()); std::string name = "boxing_out_" + std::to_string(edge->edge_id());
......
...@@ -20,7 +20,6 @@ class BoxingTaskNode : public TaskNode { ...@@ -20,7 +20,6 @@ class BoxingTaskNode : public TaskNode {
std::pair<const ChainNode*, std::vector<const TaskEdge*>>; std::pair<const ChainNode*, std::vector<const TaskEdge*>>;
using Chain2EdgesMap = using Chain2EdgesMap =
std::unordered_map<const ChainNode*, std::vector<const TaskEdge*>>; std::unordered_map<const ChainNode*, std::vector<const TaskEdge*>>;
void SetOutEdgeRegisterPtr();
void FwInitChain2SortedEdgesMaps( void FwInitChain2SortedEdgesMaps(
Chain2EdgesMap* chain2sorted_edges, Chain2EdgesMap* chain2sorted_edges,
const std::unordered_set<TaskEdge*>& (TaskNode::*in_out_edges)() const, const std::unordered_set<TaskEdge*>& (TaskNode::*in_out_edges)() const,
...@@ -33,10 +32,14 @@ class BoxingTaskNode : public TaskNode { ...@@ -33,10 +32,14 @@ class BoxingTaskNode : public TaskNode {
void FwBuildChainSortedEdgesPair( void FwBuildChainSortedEdgesPair(
const ChainEdgesPair& chain_sorted_in_edges, const ChainEdgesPair& chain_sorted_in_edges,
const ChainEdgesPair& chain_sorted_out_edges); const ChainEdgesPair& chain_sorted_out_edges);
void SetProducedRegister(); virtual void FwBuildExecGraph() = 0;
void BpBuildExecGraphAndSetProducedRegisterDescs() override;
private: private:
void FwBuildExecGraphAndSetProducedRegisterDescs() override;
void BpBuildExecGraphAndSetProducedRegisterDescs() override;
void SetOutEdgeRegisterPtr();
void SetProducedRegister();
}; };
......
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
namespace oneflow { namespace oneflow {
void InBoxingTaskNode::FwBuildExecGraphAndSetProducedRegisterDescs() { void InBoxingTaskNode::FwBuildExecGraph() {
SetOutEdgeRegisterPtr();
Chain2EdgesMap chain2sorted_in_edges; Chain2EdgesMap chain2sorted_in_edges;
FwInitChain2SortedEdgesMaps(&chain2sorted_in_edges, FwInitChain2SortedEdgesMaps(&chain2sorted_in_edges,
&TaskNode::in_edges, &TaskNode::in_edges,
...@@ -21,7 +20,6 @@ void InBoxingTaskNode::FwBuildExecGraphAndSetProducedRegisterDescs() { ...@@ -21,7 +20,6 @@ void InBoxingTaskNode::FwBuildExecGraphAndSetProducedRegisterDescs() {
for (const ChainEdgesPair& chain_sorted_in_edges : chain2sorted_in_edges) { for (const ChainEdgesPair& chain_sorted_in_edges : chain2sorted_in_edges) {
FwBuildChainSortedEdgesPair(chain_sorted_in_edges, chain_sorted_out_edges); FwBuildChainSortedEdgesPair(chain_sorted_in_edges, chain_sorted_out_edges);
} }
SetProducedRegister();
mut_exec_graph().UpdateSourceAndSink(); mut_exec_graph().UpdateSourceAndSink();
} }
......
...@@ -18,8 +18,7 @@ class InBoxingTaskNode final : public BoxingTaskNode { ...@@ -18,8 +18,7 @@ class InBoxingTaskNode final : public BoxingTaskNode {
void InitWithFwNode(TaskNode* fw_node) override { void InitWithFwNode(TaskNode* fw_node) override {
BoxingTaskNode::InitWithFwNode(fw_node); BoxingTaskNode::InitWithFwNode(fw_node);
} }
void FwBuildExecGraph() override;
void FwBuildExecGraphAndSetProducedRegisterDescs() override;
}; };
......
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
namespace oneflow { namespace oneflow {
// In future, we can use template-pattern void OutBoxingTaskNode::FwBuildExecGraph() {
void OutBoxingTaskNode::FwBuildExecGraphAndSetProducedRegisterDescs() {
SetOutEdgeRegisterPtr();
Chain2EdgesMap chain2sorted_out_edges; Chain2EdgesMap chain2sorted_out_edges;
FwInitChain2SortedEdgesMaps(&chain2sorted_out_edges, FwInitChain2SortedEdgesMaps(&chain2sorted_out_edges,
&TaskNode::out_edges, &TaskNode::out_edges,
...@@ -13,13 +11,12 @@ void OutBoxingTaskNode::FwBuildExecGraphAndSetProducedRegisterDescs() { ...@@ -13,13 +11,12 @@ void OutBoxingTaskNode::FwBuildExecGraphAndSetProducedRegisterDescs() {
ChainEdgesPair chain_sorted_in_edges; ChainEdgesPair chain_sorted_in_edges;
chain_sorted_in_edges.first = chain_node(); chain_sorted_in_edges.first = chain_node();
chain_sorted_in_edges.second.assign(in_edges().begin(), in_edges().end()); chain_sorted_in_edges.second.assign(in_edges().begin(), in_edges().end());
FwSortEdgesInnerStage(&chain_sorted_in_edges, FwSortEdgesInnerStage(&chain_sorted_in_edges.second,
&TaskEdge::src_node, &TaskEdge::src_node,
&TaskNode::SoleInEdge); &TaskNode::SoleInEdge);
for (const ChainEdgesPair& chain_sorted_out_edges : chain2sorted_out_edges) { for (const ChainEdgesPair& chain_sorted_out_edges : chain2sorted_out_edges) {
FwBuildChainSortedEdgesPair(chain_sorted_in_edges, chain_sorted_out_edges); FwBuildChainSortedEdgesPair(chain_sorted_in_edges, chain_sorted_out_edges);
} }
SetProducedRegister();
mut_exec_graph().UpdateSourceAndSink(); mut_exec_graph().UpdateSourceAndSink();
} }
......
...@@ -18,7 +18,7 @@ class OutBoxingTaskNode final : public BoxingTaskNode { ...@@ -18,7 +18,7 @@ class OutBoxingTaskNode final : public BoxingTaskNode {
void InitWithFwNode(TaskNode* fw_node) override { void InitWithFwNode(TaskNode* fw_node) override {
BoxingTaskNode::InitWithFwNode(fw_node); BoxingTaskNode::InitWithFwNode(fw_node);
} }
void FwBuildExecGraphAndSetProducedRegisterDescs() override; void FwBuildExecGraph() override;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册