提交 ecf7b2d0 编写于 作者: W willzhang4a58

start,stop -> source,sink

上级 c3621c61
...@@ -270,7 +270,7 @@ void ChainGraph::Init(const LogicalGraph* logical_graph) { ...@@ -270,7 +270,7 @@ void ChainGraph::Init(const LogicalGraph* logical_graph) {
} }
} }
// Post processing // Post processing
UpdateStartAndStop(); UpdateSourceAndSink();
CollectInputAndOutputLbns(); CollectInputAndOutputLbns();
} }
......
...@@ -15,7 +15,7 @@ void CompTransfmGraph::FwBuildGraph() { ...@@ -15,7 +15,7 @@ void CompTransfmGraph::FwBuildGraph() {
} }
FwAddCloneOp(); FwAddCloneOp();
FwSetRelatedTaskEdges(lbn2producer, extern_in_lbn2consumers); FwSetRelatedTaskEdges(lbn2producer, extern_in_lbn2consumers);
UpdateStartAndStop(); UpdateSourceAndSink();
} }
void CompTransfmGraph::FwBuildFromUserOps( void CompTransfmGraph::FwBuildFromUserOps(
...@@ -136,7 +136,7 @@ void CompTransfmGraph::FwSetRelatedTaskEdges( ...@@ -136,7 +136,7 @@ void CompTransfmGraph::FwSetRelatedTaskEdges(
void CompTransfmGraph::BpBuildGraph() { void CompTransfmGraph::BpBuildGraph() {
const TransfmGraph* fw_graph = task_node()->GetFwNode()->transfm_graph(); const TransfmGraph* fw_graph = task_node()->GetFwNode()->transfm_graph();
const TransfmNode* cp_in_node = fw_graph->start_node().SoleOutEdge()->dst_node(); const TransfmNode* cp_in_node = fw_graph->source_node().SoleOutEdge()->dst_node();
std::unordered_map<const TransfmNode*, TransfmNode*> fw_node2bp_node; std::unordered_map<const TransfmNode*, TransfmNode*> fw_node2bp_node;
// Copy Nodes // Copy Nodes
for (const std::unique_ptr<TransfmNode>& fw_node : fw_graph->nodes()) { for (const std::unique_ptr<TransfmNode>& fw_node : fw_graph->nodes()) {
......
...@@ -10,10 +10,10 @@ namespace oneflow { ...@@ -10,10 +10,10 @@ namespace oneflow {
template<typename NodeType, typename EdgeType> template<typename NodeType, typename EdgeType>
class Graph { class Graph {
public: public:
// Topologically ergodic all nodes except start_node_,stop_node_ // Topologically ergodic all nodes except source_node_,sink_node_
class Iterator; class Iterator;
class ConstIterator; class ConstIterator;
// Reverse Topologically ergodic all nodes except start_node_,stop_node_ // Reverse Topologically ergodic all nodes except source_node_,sink_node_
class ReverseIterator; class ReverseIterator;
class ConstReverseIterator; class ConstReverseIterator;
...@@ -21,8 +21,8 @@ class Graph { ...@@ -21,8 +21,8 @@ class Graph {
Graph() = default; Graph() = default;
virtual ~Graph() = default; virtual ~Graph() = default;
const NodeType& start_node() const { return start_node_; } const NodeType& source_node() const { return source_node_; }
const NodeType& stop_node() const { return stop_node_; } const NodeType& sink_node() const { return sink_node_; }
Iterator begin(); Iterator begin();
Iterator end(); Iterator end();
...@@ -46,14 +46,14 @@ class Graph { ...@@ -46,14 +46,14 @@ class Graph {
} }
bool IsFirstNode(const NodeType* node) const { bool IsFirstNode(const NodeType* node) const {
return node->SoleInEdge()->src_node() == &start_node_; return node->SoleInEdge()->src_node() == &source_node_;
} }
bool IsLastNode(const NodeType* node) const { bool IsLastNode(const NodeType* node) const {
return node->SoleOutEdge()->dst_node() == &stop_node_; return node->SoleOutEdge()->dst_node() == &sink_node_;
} }
protected: protected:
void UpdateStartAndStop(); void UpdateSourceAndSink();
// Register // Register
void RegisterNode(NodeType* new_node) { void RegisterNode(NodeType* new_node) {
...@@ -83,12 +83,12 @@ class Graph { ...@@ -83,12 +83,12 @@ class Graph {
} }
private: private:
NodeType start_node_; NodeType source_node_;
NodeType stop_node_; NodeType sink_node_;
std::vector<std::unique_ptr<EdgeType>> start_edges_; std::vector<std::unique_ptr<EdgeType>> source_edges_;
std::vector<std::unique_ptr<EdgeType>> stop_edges_; std::vector<std::unique_ptr<EdgeType>> sink_edges_;
// manage nodes,edges that are not related to start,stop // manage nodes,edges that are not related to source,sink
std::vector<std::unique_ptr<NodeType>> nodes_; std::vector<std::unique_ptr<NodeType>> nodes_;
std::vector<std::unique_ptr<EdgeType>> edges_; std::vector<std::unique_ptr<EdgeType>> edges_;
}; };
...@@ -101,9 +101,9 @@ class Graph<NodeType, EdgeType>::Iterator final { ...@@ -101,9 +101,9 @@ class Graph<NodeType, EdgeType>::Iterator final {
Iterator() = default; Iterator() = default;
~Iterator() = default; ~Iterator() = default;
void Init(NodeType* start_node) { void Init(NodeType* source_node) {
bfs_queue_ = std::queue<NodeType*> (); bfs_queue_ = std::queue<NodeType*> ();
bfs_queue_.push(start_node); bfs_queue_.push(source_node);
} }
NodeType& operator * (); NodeType& operator * ();
...@@ -148,9 +148,9 @@ class Graph<NodeType, EdgeType>::ReverseIterator final { ...@@ -148,9 +148,9 @@ class Graph<NodeType, EdgeType>::ReverseIterator final {
ReverseIterator() = default; ReverseIterator() = default;
~ReverseIterator() = default; ~ReverseIterator() = default;
void Init(NodeType* stop_node) { void Init(NodeType* sink_node) {
bfs_queue_ = std::queue<NodeType*> (); bfs_queue_ = std::queue<NodeType*> ();
bfs_queue_.push(stop_node); bfs_queue_.push(sink_node);
} }
NodeType& operator * (); NodeType& operator * ();
...@@ -189,21 +189,21 @@ class Graph<NodeType, EdgeType>::ConstReverseIterator final { ...@@ -189,21 +189,21 @@ class Graph<NodeType, EdgeType>::ConstReverseIterator final {
}; };
template<typename NodeType, typename EdgeType> template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::UpdateStartAndStop() { void Graph<NodeType, EdgeType>::UpdateSourceAndSink() {
start_node_.DisconnectAllEdges(); source_node_.DisconnectAllEdges();
stop_node_.DisconnectAllEdges(); sink_node_.DisconnectAllEdges();
start_edges_.clear(); source_edges_.clear();
stop_edges_.clear(); sink_edges_.clear();
for (const std::unique_ptr<NodeType>& node : nodes_) { for (const std::unique_ptr<NodeType>& node : nodes_) {
if (node->in_edges().empty()) { if (node->in_edges().empty()) {
EdgeType* start_edge = new EdgeType; EdgeType* source_edge = new EdgeType;
start_edges_.emplace_back(start_edge); source_edges_.emplace_back(source_edge);
Connect(&start_node_, start_edge, node.get()); Connect(&source_node_, source_edge, node.get());
} }
if (node->out_edges().empty()) { if (node->out_edges().empty()) {
EdgeType* stop_edge = new EdgeType; EdgeType* sink_edge = new EdgeType;
stop_edges_.emplace_back(stop_edge); sink_edges_.emplace_back(sink_edge);
Connect(node.get(), stop_edge, &stop_node_); Connect(node.get(), sink_edge, &sink_node_);
} }
} }
} }
...@@ -211,13 +211,13 @@ void Graph<NodeType, EdgeType>::UpdateStartAndStop() { ...@@ -211,13 +211,13 @@ void Graph<NodeType, EdgeType>::UpdateStartAndStop() {
template<typename NodeType, typename EdgeType> template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::begin() -> Iterator { auto Graph<NodeType, EdgeType>::begin() -> Iterator {
Iterator ret; Iterator ret;
ret.Init(&start_node_); ret.Init(&source_node_);
return ++ret; return ++ret;
} }
template<typename NodeType, typename EdgeType> template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::end() -> Iterator { auto Graph<NodeType, EdgeType>::end() -> Iterator {
Iterator ret; Iterator ret;
ret.Init(&stop_node_); ret.Init(&sink_node_);
return ret; return ret;
} }
...@@ -237,13 +237,13 @@ auto Graph<NodeType, EdgeType>::cend() const -> ConstIterator { ...@@ -237,13 +237,13 @@ auto Graph<NodeType, EdgeType>::cend() const -> ConstIterator {
template<typename NodeType, typename EdgeType> template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::rbegin() -> ReverseIterator { auto Graph<NodeType, EdgeType>::rbegin() -> ReverseIterator {
ReverseIterator ret; ReverseIterator ret;
ret.Init(&stop_node_); ret.Init(&sink_node_);
return ++ret; return ++ret;
} }
template<typename NodeType, typename EdgeType> template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::rend() -> ReverseIterator { auto Graph<NodeType, EdgeType>::rend() -> ReverseIterator {
ReverseIterator ret; ReverseIterator ret;
ret.Init(&start_node_); ret.Init(&source_node_);
return ret; return ret;
} }
......
...@@ -32,7 +32,7 @@ void LogicalGraph::BuildGraphStruct(const DLNetConf& dl_net_conf) { ...@@ -32,7 +32,7 @@ void LogicalGraph::BuildGraphStruct(const DLNetConf& dl_net_conf) {
} }
lbn2node.clear(); lbn2node.clear();
// Post Processing // Post Processing
UpdateStartAndStop(); UpdateSourceAndSink();
} }
void LogicalGraph::FillNodeWithParallelDesc(const Strategy& strategy_conf) { void LogicalGraph::FillNodeWithParallelDesc(const Strategy& strategy_conf) {
......
...@@ -74,7 +74,7 @@ void StageGraph::Init(std::unique_ptr<const ChainGraph>&& chain_graph) { ...@@ -74,7 +74,7 @@ void StageGraph::Init(std::unique_ptr<const ChainGraph>&& chain_graph) {
} }
} }
// Post processing // Post processing
UpdateStartAndStop(); UpdateSourceAndSink();
} }
} // namespace oneflow } // namespace oneflow
...@@ -35,7 +35,7 @@ void TaskGraph::BuildWithoutTransfm( ...@@ -35,7 +35,7 @@ void TaskGraph::BuildWithoutTransfm(
InitCompTaskNodes(*stage_graph, id_map, &stage2task_nodes); InitCompTaskNodes(*stage_graph, id_map, &stage2task_nodes);
InitBoxingTaskNodes(*stage_graph, id_map, &stage2task_nodes); InitBoxingTaskNodes(*stage_graph, id_map, &stage2task_nodes);
ConnectTaskNodes(*stage_graph, &stage2task_nodes); ConnectTaskNodes(*stage_graph, &stage2task_nodes);
UpdateStartAndStop(); UpdateSourceAndSink();
if (job_need_bp) { if (job_need_bp) {
BuildBpStruct(); BuildBpStruct();
} }
...@@ -222,7 +222,7 @@ void TaskGraph::BuildBpStruct() { ...@@ -222,7 +222,7 @@ void TaskGraph::BuildBpStruct() {
std::vector<TaskNode*> turning_node_vec; std::vector<TaskNode*> turning_node_vec;
GenerateRelatedBpNodes(&turning_node_vec); GenerateRelatedBpNodes(&turning_node_vec);
BackwardConnect(turning_node_vec); BackwardConnect(turning_node_vec);
UpdateStartAndStop(); UpdateSourceAndSink();
} }
void TaskGraph::GenerateRelatedBpNodes( void TaskGraph::GenerateRelatedBpNodes(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册