未验证 提交 508b40ec 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix Inference CE Error by Topo Order (#34521)

The comment background message is too long, see details at https://github.com/PaddlePaddle/Paddle/pull/34521
上级 393a0b16
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/operator.h"
DEFINE_bool(convert_all_blocks, false,
DEFINE_bool(convert_all_blocks, true,
"Convert all blocks in program into SSAgraphs");
namespace paddle {
......
......@@ -24,15 +24,17 @@ namespace paddle {
namespace framework {
namespace ir {
namespace {
void SortHelper(const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>,
ir::NodeComp> &adj_list,
template <class NodeComparator = ir::NodeComp>
void SortHelper(const std::map<ir::Node *, std::set<ir::Node *, NodeComparator>,
NodeComparator> &adj_list,
ir::Node *node, std::unordered_set<ir::Node *> *visited,
std::vector<ir::Node *> *ret) {
visited->insert(node);
for (auto adj : adj_list.at(node)) {
if (visited->find(adj) == visited->end()) {
SortHelper(adj_list, adj, visited, ret);
SortHelper<NodeComparator>(adj_list, adj, visited, ret);
}
}
......@@ -41,10 +43,11 @@ void SortHelper(const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>,
ret->push_back(node);
}
template <class NodeComparator = ir::NodeComp>
bool HasCircleHelper(
ir::Node *node,
const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
&adj_list,
const std::map<ir::Node *, std::set<ir::Node *, NodeComparator>,
NodeComparator> &adj_list,
std::unordered_set<ir::Node *> *visited,
std::unordered_set<ir::Node *> *in_trace,
std::vector<std::vector<ir::Node *>> *circles) {
......@@ -54,7 +57,8 @@ bool HasCircleHelper(
for (ir::Node *in : adj_list.at(node)) {
if (visited->find(in) == visited->end() &&
HasCircleHelper(in, adj_list, visited, in_trace, circles)) {
HasCircleHelper<NodeComparator>(in, adj_list, visited, in_trace,
circles)) {
return true;
} else if (in_trace->find(in) != in_trace->end()) {
if (circles != nullptr) {
......@@ -77,14 +81,16 @@ bool HasCircleHelper(
return false;
}
template <class NodeComparator = ir::NodeComp>
bool HasCircleInternal(
const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
&adj_list,
const std::map<ir::Node *, std::set<ir::Node *, NodeComparator>,
NodeComparator> &adj_list,
std::vector<std::vector<ir::Node *>> *circles) {
std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> in_trace;
for (auto &adj : adj_list) {
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace, circles)) {
if (HasCircleHelper<NodeComparator>(adj.first, adj_list, &visited,
&in_trace, circles)) {
return true;
}
}
......@@ -136,7 +142,7 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::vector<ir::Node *> ret;
for (auto adj : adj_list) {
if (visited.find(adj.first) == visited.end()) {
SortHelper(adj_list, adj.first, &visited, &ret);
SortHelper<ir::NodeComp>(adj_list, adj.first, &visited, &ret);
}
}
......@@ -169,34 +175,6 @@ bool IsTopologySortOperationsUnique(const Graph &graph) {
return true;
}
// Build operator inlink edge table.
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
BuildOperationAdjList(const Graph &graph) {
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
adj_list;
for (auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
if (adj_list.find(n) == adj_list.end()) {
adj_list[n] = std::set<ir::Node *, ir::NodeComp>();
}
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
adj_list[n].insert(adj_n);
}
}
}
return adj_list;
}
// Build operator outlink edge table.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
const Graph &graph) {
......@@ -424,81 +402,33 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph,
class DescOrderComparator {
public:
bool operator()(const Node *n1, const Node *n2) {
return (n1->DescOrder() > n2->DescOrder()) ||
((n1->DescOrder() == n2->DescOrder()) &&
(n1->ToString() > n2->ToString()));
bool operator()(Node *const &n1, Node *const &n2) const {
if (n1->DescOrder() < n2->DescOrder()) {
return true;
} else if (n1->DescOrder() == n2->DescOrder()) {
return n1->id() < n2->id() ||
(n1->id() == n2->id() && n1->ToString() < n2->ToString());
}
return false;
}
};
std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph) {
std::vector<ir::Node *> sorted_ops;
std::priority_queue<Node *, std::vector<Node *>, DescOrderComparator> q;
std::unordered_map<Node *, std::unordered_set<Node *>> in_ops;
std::unordered_map<Node *, std::unordered_set<Node *>> out_ops;
// ensure all op node in 'in_ops' and 'out_ops'
for (const auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
in_ops.emplace(n, std::unordered_set<Node *>());
out_ops.emplace(n, std::unordered_set<Node *>());
}
// record all op's input op and output op
for (const auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
// traverse all input op
for (const auto &var : n->inputs) {
for (const auto &in : var->inputs) {
// use at instead of [] to prevent no unrecorded op node
in_ops.at(n).insert(in);
out_ops.at(in).insert(n);
}
}
}
// find topology entrance
for (const auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
if (in_ops.at(n).empty()) {
q.push(n);
}
}
// topological sorting
while (!q.empty()) {
// Do not get by reference!!! The element will pop later.
const auto cur_op = q.top();
q.pop();
sorted_ops.push_back(cur_op);
for (const auto &out : out_ops.at(cur_op)) {
PADDLE_ENFORCE_GT(in_ops.at(out).count(cur_op), 0,
platform::errors::InvalidArgument(
"We find %s in %s's output list, "
"but cannot find %s in %s's input list. "
"Please ensure graph completely.",
out->Name().c_str(), cur_op->Name().c_str(),
cur_op->Name().c_str(), out->Name().c_str()));
in_ops.at(out).erase(cur_op);
// push if in-degree is 0
if (in_ops.at(out).empty()) {
q.push(out);
}
std::map<ir::Node *, std::set<ir::Node *, DescOrderComparator>,
DescOrderComparator>
adj_list = BuildOperationAdjList<DescOrderComparator>(graph);
PADDLE_ENFORCE_EQ(HasCircleInternal<DescOrderComparator>(adj_list, nullptr),
false, platform::errors::InvalidArgument(
"Generated graph shouldn't contain cycle."));
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
for (auto adj : adj_list) {
if (visited.find(adj.first) == visited.end()) {
SortHelper<DescOrderComparator>(adj_list, adj.first, &visited, &ret);
}
}
PADDLE_ENFORCE_EQ(
sorted_ops.size(), in_ops.size(),
platform::errors::InvalidArgument("Topological sorting incompletely, "
"only sorted %zd op but total %zd.",
sorted_ops.size(), in_ops.size()));
return sorted_ops;
return ret;
}
} // namespace ir
......
......@@ -77,9 +77,34 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph, SortKind sort_kind);
// Clean the nodes that doesn't connect to others.
void CleanIndividualNodes(Graph *graph);
// Build an adjacency list of operations for the `graph`.
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
BuildOperationAdjList(const Graph &graph);
// Build an in-link adjacency list of operations for the `graph`.
template <class NodeComparator = ir::NodeComp>
std::map<ir::Node *, std::set<ir::Node *, NodeComparator>, NodeComparator>
BuildOperationAdjList(const Graph &graph) {
std::map<ir::Node *, std::set<ir::Node *, NodeComparator>, NodeComparator>
adj_list;
for (auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
if (adj_list.find(n) == adj_list.end()) {
adj_list[n] = std::set<ir::Node *, NodeComparator>();
}
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
adj_list[n].insert(adj_n);
}
}
}
return adj_list;
}
template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
......
......@@ -31,6 +31,7 @@
namespace py = pybind11;
using paddle::framework::ir::Graph;
using paddle::framework::ir::Node;
using paddle::framework::ir::NodeComp;
using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::ir::HasCircle;
using paddle::framework::ir::GraphNum;
......@@ -50,7 +51,7 @@ void BindGraph(py::module *m) {
m->def("graph_num", GraphNum);
m->def("topology_sort", TopologySortOperations,
return_value_policy::reference);
m->def("build_adjacency_list", BuildOperationAdjList,
m->def("build_adjacency_list", BuildOperationAdjList<NodeComp>,
return_value_policy::reference);
py::class_<Graph, std::shared_ptr<Graph>>(
*m, "Graph",
......
......@@ -179,6 +179,7 @@ def __bootstrap__():
sysstr = platform.system()
read_env_flags = [
'check_nan_inf',
'convert_all_blocks',
'benchmark',
'eager_delete_scope',
'fraction_of_cpu_memory_to_use',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册