diff --git a/paddle/fluid/framework/ir/graph_traits.cc b/paddle/fluid/framework/ir/graph_traits.cc index abcba32a6492b114193cfab6756ff87247956f6c..4b403c46260c6129451809f276aac67ccc17c4d4 100644 --- a/paddle/fluid/framework/ir/graph_traits.cc +++ b/paddle/fluid/framework/ir/graph_traits.cc @@ -37,12 +37,14 @@ NodesDFSIterator::NodesDFSIterator(const NodesDFSIterator &other) : stack_(other.stack_), visited_(other.visited_) {} Node &NodesDFSIterator::operator*() { - PADDLE_ENFORCE(!stack_.empty()); + PADDLE_ENFORCE_EQ(stack_.empty(), false, platform::errors::OutOfRange( + "The iterator exceeds range.")); return *stack_.top(); } NodesDFSIterator &NodesDFSIterator::operator++() { - PADDLE_ENFORCE(!stack_.empty(), "the iterator exceeds range"); + PADDLE_ENFORCE_EQ(stack_.empty(), false, platform::errors::OutOfRange( + "The iterator exceeds range.")); visited_.insert(stack_.top()); auto *cur = stack_.top(); stack_.pop(); @@ -73,11 +75,18 @@ inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) { } NodesTSIterator::NodesTSIterator(const std::vector &source) { - PADDLE_ENFORCE(!source.empty(), - "Start points of topological sorting should not be empty!"); + PADDLE_ENFORCE_EQ( + source.empty(), false, + platform::errors::InvalidArgument( + "Start points of topological sorting should not be empty!")); // CHECK all the inputs' in-degree is 0 for (auto *node : source) { - PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0)); + PADDLE_ENFORCE_EQ( + CheckNodeIndegreeEquals(*node, 0), true, + platform::errors::InvalidArgument( + "In start points of topological sorting, the indegree of each " + "point should be 0. Node(%s)'s indegree is not 0.", + node->Name())); } std::set to_visit{source.begin(), source.end()}; @@ -106,7 +115,11 @@ NodesTSIterator::NodesTSIterator(const NodesTSIterator &other) : sorted_(other.sorted_), cursor_(other.cursor_) {} Node &NodesTSIterator::operator*() { - PADDLE_ENFORCE_LT(cursor_, sorted_.size()); + PADDLE_ENFORCE_LT( + cursor_, sorted_.size(), + platform::errors::OutOfRange( + "The iterator exceeds range. Container size is %d, but index is %d.", + sorted_.size(), cursor_)); return *sorted_[cursor_]; } @@ -128,7 +141,11 @@ bool NodesTSIterator::operator==(const NodesTSIterator &other) { } Node *NodesTSIterator::operator->() { - PADDLE_ENFORCE_LT(cursor_, sorted_.size()); + PADDLE_ENFORCE_LT( + cursor_, sorted_.size(), + platform::errors::OutOfRange( + "The iterator exceeds range. Container size is %d, but index is %d.", + sorted_.size(), cursor_)); return sorted_[cursor_]; } diff --git a/paddle/fluid/framework/ir/graph_traits.h b/paddle/fluid/framework/ir/graph_traits.h index f6772f9a37567c83c49bd44d551481edda1a74ae..bb4212bcd33d77cfe1c091b18387e18c4c3e5fa7 100644 --- a/paddle/fluid/framework/ir/graph_traits.h +++ b/paddle/fluid/framework/ir/graph_traits.h @@ -15,6 +15,8 @@ #pragma once #include +#include +#include #include #include "paddle/fluid/framework/ir/graph.h" @@ -66,7 +68,7 @@ struct NodesDFSIterator struct NodesTSIterator : public std::iterator { NodesTSIterator() = default; - NodesTSIterator(const std::vector &source); + explicit NodesTSIterator(const std::vector &source); NodesTSIterator(NodesTSIterator &&other) : sorted_(std::move(other.sorted_)), cursor_(other.cursor_) { other.cursor_ = 0; @@ -104,7 +106,10 @@ struct GraphTraits { static iterator_range TS(const Graph &g) { auto start_points = ExtractStartPoints(g); - PADDLE_ENFORCE(!start_points.empty()); + PADDLE_ENFORCE_EQ( + start_points.empty(), false, + platform::errors::InvalidArgument( + "Start points of topological sorting should not be empty!")); NodesTSIterator x(start_points); return iterator_range(NodesTSIterator(start_points), NodesTSIterator()); diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc index 7f4519ad9919d7ad2a13c501e07b7ec92bd1eee1..64f5376a784c29eccadcfcf3021447e4655910c6 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.cc +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -42,7 +42,10 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const { const std::string& graph_viz_path = Get(kGraphvizPath); VLOG(3) << "draw IR graph viz to " << graph_viz_path; std::unique_ptr fout(new std::ofstream(graph_viz_path)); - PADDLE_ENFORCE(fout->good()); + PADDLE_ENFORCE_EQ( + fout->good(), true, + platform::errors::Unavailable( + "Can not open file %s for printing the graph.", graph_viz_path)); std::ostream& sout = *fout; std::unordered_map node2dot; diff --git a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc index a39901e63bf65f7c314595a5fb2cc31d00959bd5..c8dfa02f469a351a8d3495bf19238a723029bb4b 100644 --- a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc +++ b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc @@ -64,7 +64,11 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { for (auto& parameter : *pre_op_desc->Proto()->mutable_outputs()) { auto* arguments = parameter.mutable_arguments(); auto it = std::find(arguments->begin(), arguments->end(), scale_in_name); - PADDLE_ENFORCE(it != arguments->end()); + PADDLE_ENFORCE_NE( + it, arguments->end(), + platform::errors::NotFound( + "Can not find input variable(%s) from scale op(%s).", + scale_in_name, pre_op_desc->Type())); *it = scale_out_name; } diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc index d67f2274ebf1f0b57cf0e9c9fedd2f61eb1d5c9d..456e642ad86ab18d55df2d36650f04c4d6635876 100644 --- a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc @@ -85,7 +85,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { // 1. record op nodes of different roles for (auto node : nodes) { if (!node->IsOp()) continue; - PADDLE_ENFORCE(node->Op(), "must find opdesc"); + PADDLE_ENFORCE_NOT_NULL( + node->Op(), platform::errors::InvalidArgument( + "Node(%s) must hold op description.", node->Name())); int op_role = BOOST_GET_CONST( int, node->Op()->GetAttr( framework::OpProtoAndCheckerMaker::OpRoleAttrName())); @@ -108,7 +110,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { } else if (op_role & static_cast(framework::OpRole::kLRSched)) { lr_ops.push_back(node); } else { // NOLINT - PADDLE_THROW("Invalid op_role: %d", static_cast(op_role)); + PADDLE_THROW(platform::errors::InvalidArgument( + "Invalid op role(%d), in node(%s).", static_cast(op_role), + node->Name())); } }