未验证 提交 e40cfb10 编写于 作者: Z Zhen Wang 提交者: GitHub

fix the bug of assert_is_op_output. test=develop (#22262)

上级 a46bb2e6
...@@ -656,6 +656,17 @@ bool HasInput(Node *op, const std::string &argument) { ...@@ -656,6 +656,17 @@ bool HasInput(Node *op, const std::string &argument) {
return true; return true;
} }
bool HasOutput(Node *op, const std::string &argument) {
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
platform::errors::InvalidArgument(
"First parameter of function HasOuput must be Node::Op"));
auto const &names = op->Op()->OutputNames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
return true;
}
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var->IsVar(), true, var->IsVar(), true,
...@@ -665,7 +676,8 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { ...@@ -665,7 +676,8 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
op->IsOp(), true, op->IsOp(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Second parameter of function IsNthOutput must be Node::Op")); "Second parameter of function IsNthOutput must be Node::Op"));
if (op->Op()->Output(argument).size() <= nth) return false; if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth)
return false;
return var->Name() == op->Op()->Output(argument)[nth]; return var->Name() == op->Op()->Output(argument)[nth];
} }
......
...@@ -318,6 +318,9 @@ bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth); ...@@ -318,6 +318,9 @@ bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth);
// Check whether the op node has input of given name. // Check whether the op node has input of given name.
bool HasInput(Node* op, const std::string& argument); bool HasInput(Node* op, const std::string& argument);
// Check whether the op node has output of given name.
bool HasOutput(Node* op, const std::string& argument);
// Tell whether a var node is a op node's nth output. // Tell whether a var node is a op node's nth output.
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册