未验证 提交 e40cfb10 编写于 作者: Z Zhen Wang 提交者: GitHub

fix the bug of assert_is_op_output. test=develop (#22262)

上级 a46bb2e6
......@@ -656,6 +656,17 @@ bool HasInput(Node *op, const std::string &argument) {
return true;
}
bool HasOutput(Node *op, const std::string &argument) {
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
platform::errors::InvalidArgument(
"First parameter of function HasOuput must be Node::Op"));
auto const &names = op->Op()->OutputNames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
return true;
}
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE_EQ(
var->IsVar(), true,
......@@ -665,7 +676,8 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
op->IsOp(), true,
platform::errors::InvalidArgument(
"Second parameter of function IsNthOutput must be Node::Op"));
if (op->Op()->Output(argument).size() <= nth) return false;
if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth)
return false;
return var->Name() == op->Op()->Output(argument)[nth];
}
......
......@@ -318,6 +318,9 @@ bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth);
// Check whether the op node has input of given name.
bool HasInput(Node* op, const std::string& argument);
// Check whether the op node has output of given name.
bool HasOutput(Node* op, const std::string& argument);
// Tell whether a var node is a op node's nth output.
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册