From e40cfb10109a3a48fd0fe81a8f766e8c1aa52fdd Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Wed, 15 Jan 2020 10:49:31 +0800 Subject: [PATCH] fix the bug of assert_is_op_output. test=develop (#22262) --- .../fluid/framework/ir/graph_pattern_detector.cc | 14 +++++++++++++- paddle/fluid/framework/ir/graph_pattern_detector.h | 3 +++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 96a37977ee5..be73ccb5f7b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -656,6 +656,17 @@ bool HasInput(Node *op, const std::string &argument) { return true; } +bool HasOutput(Node *op, const std::string &argument) { + PADDLE_ENFORCE_EQ( + op->IsOp(), true, + platform::errors::InvalidArgument( + "First parameter of function HasOuput must be Node::Op")); + auto const &names = op->Op()->OutputNames(); + if (std::find(names.begin(), names.end(), argument) == names.end()) + return false; + return true; +} + bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { PADDLE_ENFORCE_EQ( var->IsVar(), true, @@ -665,7 +676,8 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { op->IsOp(), true, platform::errors::InvalidArgument( "Second parameter of function IsNthOutput must be Node::Op")); - if (op->Op()->Output(argument).size() <= nth) return false; + if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth) + return false; return var->Name() == op->Op()->Output(argument)[nth]; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index d0f65a88b74..735b23f2731 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -318,6 +318,9 @@ bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth); // Check whether the op node has input of given name. bool HasInput(Node* op, const std::string& argument); +// Check whether the op node has output of given name. +bool HasOutput(Node* op, const std::string& argument); + // Tell whether a var node is a op node's nth output. bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); -- GitLab