diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 96a37977ee53fcc8edec0c296f24171c7be3e384..be73ccb5f7b21cfd887f17d94090005867628a41 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -656,6 +656,17 @@ bool HasInput(Node *op, const std::string &argument) { return true; } +bool HasOutput(Node *op, const std::string &argument) { + PADDLE_ENFORCE_EQ( + op->IsOp(), true, + platform::errors::InvalidArgument( + "First parameter of function HasOuput must be Node::Op")); + auto const &names = op->Op()->OutputNames(); + if (std::find(names.begin(), names.end(), argument) == names.end()) + return false; + return true; +} + bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { PADDLE_ENFORCE_EQ( var->IsVar(), true, @@ -665,7 +676,8 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { op->IsOp(), true, platform::errors::InvalidArgument( "Second parameter of function IsNthOutput must be Node::Op")); - if (op->Op()->Output(argument).size() <= nth) return false; + if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth) + return false; return var->Name() == op->Op()->Output(argument)[nth]; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index d0f65a88b7452125a6684f4ec90e345beef3e844..735b23f27314d57a13b19281aafec3059195e7f8 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -318,6 +318,9 @@ bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth); // Check whether the op node has input of given name. bool HasInput(Node* op, const std::string& argument); +// Check whether the op node has output of given name. +bool HasOutput(Node* op, const std::string& argument); + // Tell whether a var node is a op node's nth output. bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);