提交 ecf7b2d0 编写于 作者: W willzhang4a58

start,stop -> source,sink

上级 c3621c61
......@@ -270,7 +270,7 @@ void ChainGraph::Init(const LogicalGraph* logical_graph) {
}
}
// Post processing
UpdateStartAndStop();
UpdateSourceAndSink();
CollectInputAndOutputLbns();
}
......
......@@ -15,7 +15,7 @@ void CompTransfmGraph::FwBuildGraph() {
}
FwAddCloneOp();
FwSetRelatedTaskEdges(lbn2producer, extern_in_lbn2consumers);
UpdateStartAndStop();
UpdateSourceAndSink();
}
void CompTransfmGraph::FwBuildFromUserOps(
......@@ -136,7 +136,7 @@ void CompTransfmGraph::FwSetRelatedTaskEdges(
void CompTransfmGraph::BpBuildGraph() {
const TransfmGraph* fw_graph = task_node()->GetFwNode()->transfm_graph();
const TransfmNode* cp_in_node = fw_graph->start_node().SoleOutEdge()->dst_node();
const TransfmNode* cp_in_node = fw_graph->source_node().SoleOutEdge()->dst_node();
std::unordered_map<const TransfmNode*, TransfmNode*> fw_node2bp_node;
// Copy Nodes
for (const std::unique_ptr<TransfmNode>& fw_node : fw_graph->nodes()) {
......
......@@ -10,10 +10,10 @@ namespace oneflow {
template<typename NodeType, typename EdgeType>
class Graph {
public:
// Topologically ergodic all nodes except start_node_,stop_node_
// Topologically ergodic all nodes except source_node_,sink_node_
class Iterator;
class ConstIterator;
// Reverse Topologically ergodic all nodes except start_node_,stop_node_
// Reverse Topologically ergodic all nodes except source_node_,sink_node_
class ReverseIterator;
class ConstReverseIterator;
......@@ -21,8 +21,8 @@ class Graph {
Graph() = default;
virtual ~Graph() = default;
const NodeType& start_node() const { return start_node_; }
const NodeType& stop_node() const { return stop_node_; }
const NodeType& source_node() const { return source_node_; }
const NodeType& sink_node() const { return sink_node_; }
Iterator begin();
Iterator end();
......@@ -46,14 +46,14 @@ class Graph {
}
bool IsFirstNode(const NodeType* node) const {
return node->SoleInEdge()->src_node() == &start_node_;
return node->SoleInEdge()->src_node() == &source_node_;
}
bool IsLastNode(const NodeType* node) const {
return node->SoleOutEdge()->dst_node() == &stop_node_;
return node->SoleOutEdge()->dst_node() == &sink_node_;
}
protected:
void UpdateStartAndStop();
void UpdateSourceAndSink();
// Register
void RegisterNode(NodeType* new_node) {
......@@ -83,12 +83,12 @@ class Graph {
}
private:
NodeType start_node_;
NodeType stop_node_;
std::vector<std::unique_ptr<EdgeType>> start_edges_;
std::vector<std::unique_ptr<EdgeType>> stop_edges_;
NodeType source_node_;
NodeType sink_node_;
std::vector<std::unique_ptr<EdgeType>> source_edges_;
std::vector<std::unique_ptr<EdgeType>> sink_edges_;
// manage nodes,edges that are not related to start,stop
// manage nodes,edges that are not related to source,sink
std::vector<std::unique_ptr<NodeType>> nodes_;
std::vector<std::unique_ptr<EdgeType>> edges_;
};
......@@ -101,9 +101,9 @@ class Graph<NodeType, EdgeType>::Iterator final {
Iterator() = default;
~Iterator() = default;
void Init(NodeType* start_node) {
void Init(NodeType* source_node) {
bfs_queue_ = std::queue<NodeType*> ();
bfs_queue_.push(start_node);
bfs_queue_.push(source_node);
}
NodeType& operator * ();
......@@ -148,9 +148,9 @@ class Graph<NodeType, EdgeType>::ReverseIterator final {
ReverseIterator() = default;
~ReverseIterator() = default;
void Init(NodeType* stop_node) {
void Init(NodeType* sink_node) {
bfs_queue_ = std::queue<NodeType*> ();
bfs_queue_.push(stop_node);
bfs_queue_.push(sink_node);
}
NodeType& operator * ();
......@@ -189,21 +189,21 @@ class Graph<NodeType, EdgeType>::ConstReverseIterator final {
};
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::UpdateStartAndStop() {
start_node_.DisconnectAllEdges();
stop_node_.DisconnectAllEdges();
start_edges_.clear();
stop_edges_.clear();
void Graph<NodeType, EdgeType>::UpdateSourceAndSink() {
source_node_.DisconnectAllEdges();
sink_node_.DisconnectAllEdges();
source_edges_.clear();
sink_edges_.clear();
for (const std::unique_ptr<NodeType>& node : nodes_) {
if (node->in_edges().empty()) {
EdgeType* start_edge = new EdgeType;
start_edges_.emplace_back(start_edge);
Connect(&start_node_, start_edge, node.get());
EdgeType* source_edge = new EdgeType;
source_edges_.emplace_back(source_edge);
Connect(&source_node_, source_edge, node.get());
}
if (node->out_edges().empty()) {
EdgeType* stop_edge = new EdgeType;
stop_edges_.emplace_back(stop_edge);
Connect(node.get(), stop_edge, &stop_node_);
EdgeType* sink_edge = new EdgeType;
sink_edges_.emplace_back(sink_edge);
Connect(node.get(), sink_edge, &sink_node_);
}
}
}
......@@ -211,13 +211,13 @@ void Graph<NodeType, EdgeType>::UpdateStartAndStop() {
template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::begin() -> Iterator {
Iterator ret;
ret.Init(&start_node_);
ret.Init(&source_node_);
return ++ret;
}
template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::end() -> Iterator {
Iterator ret;
ret.Init(&stop_node_);
ret.Init(&sink_node_);
return ret;
}
......@@ -237,13 +237,13 @@ auto Graph<NodeType, EdgeType>::cend() const -> ConstIterator {
template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::rbegin() -> ReverseIterator {
ReverseIterator ret;
ret.Init(&stop_node_);
ret.Init(&sink_node_);
return ++ret;
}
template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::rend() -> ReverseIterator {
ReverseIterator ret;
ret.Init(&start_node_);
ret.Init(&source_node_);
return ret;
}
......
......@@ -32,7 +32,7 @@ void LogicalGraph::BuildGraphStruct(const DLNetConf& dl_net_conf) {
}
lbn2node.clear();
// Post Processing
UpdateStartAndStop();
UpdateSourceAndSink();
}
void LogicalGraph::FillNodeWithParallelDesc(const Strategy& strategy_conf) {
......
......@@ -74,7 +74,7 @@ void StageGraph::Init(std::unique_ptr<const ChainGraph>&& chain_graph) {
}
}
// Post processing
UpdateStartAndStop();
UpdateSourceAndSink();
}
} // namespace oneflow
......@@ -35,7 +35,7 @@ void TaskGraph::BuildWithoutTransfm(
InitCompTaskNodes(*stage_graph, id_map, &stage2task_nodes);
InitBoxingTaskNodes(*stage_graph, id_map, &stage2task_nodes);
ConnectTaskNodes(*stage_graph, &stage2task_nodes);
UpdateStartAndStop();
UpdateSourceAndSink();
if (job_need_bp) {
BuildBpStruct();
}
......@@ -222,7 +222,7 @@ void TaskGraph::BuildBpStruct() {
std::vector<TaskNode*> turning_node_vec;
GenerateRelatedBpNodes(&turning_node_vec);
BackwardConnect(turning_node_vec);
UpdateStartAndStop();
UpdateSourceAndSink();
}
void TaskGraph::GenerateRelatedBpNodes(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册