未验证 提交 11ab5581 编写于 作者: C cheng cheng 提交者: GitHub

remove TaskGraph::AcyclicTopoForEachNode (#4352)

Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 3cf6a446
......@@ -338,37 +338,6 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node
}
}
void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
const std::function<void(TaskNode* node)>& Handler) const {
auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
if (IsBackEdge(node_on_in_edge, node)) { return; }
Handler(const_cast<TaskNode*>(node_on_in_edge));
});
};
auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
if (IsBackEdge(node, node_on_out_edge)) { return; }
Handler(const_cast<TaskNode*>(node_on_out_edge));
});
};
auto IsSourceNode = [&](TaskNode* node) {
int32_t in_node_num = 0;
ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; });
return in_node_num == 0;
};
std::list<TaskNode*> starts;
ForEachNode([&](TaskNode* node) {
if (IsSourceNode(node) && IsAllowedStartNode(node)) { starts.push_back(node); }
});
// DfsTopo will cause inappropriate chain graph
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler);
}
void TaskGraph::AcyclicTopoForEachNode(const std::function<void(TaskNode* node)>& Handler) const {
return AcyclicTopoForEachNode([](TaskNode*) { return true; }, Handler);
}
void TaskGraph::RemoveEmptyRegsts() {
ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedBlob(); });
ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); });
......@@ -388,7 +357,7 @@ void TaskGraph::SetOrderInGraphForEachNode() {
ordered_task_nodes_.emplace_back(task_node);
++order_in_graph;
};
AcyclicTopoForEachNode(SetOrderInGraph);
TopoForEachNode(SetOrderInGraph);
}
void TaskGraph::MergeChain() {
......@@ -746,6 +715,4 @@ void TaskGraph::ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst) {
}
}
bool IsBackEdge(TaskNode* src, TaskNode* dst) { return false; }
} // namespace oneflow
......@@ -44,8 +44,6 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void EnableInplaceMemSharing(const std::function<bool(const std::string&, const std::string&)>&
IsOpNameDataOrCtrlReachable);
void AcyclicTopoForEachNode(const std::function<void(TaskNode* node)>& Handler) const;
#define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS();
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing);
......@@ -58,9 +56,6 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D);
private:
void AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
const std::function<void(TaskNode* node)>& Handler) const;
void BuildTaskPath(
CompTaskNode* src, CompTaskNode* dst,
std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册