未验证 提交 7f1ad510 编写于 作者: C Chen Weihang 提交者: GitHub

Add op inout check macro to simplify error message writing (#23430)

* add op inout check macro, test=develop

* fix enforce_test, test=develop
上级 bc2981e9
...@@ -32,15 +32,9 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -32,15 +32,9 @@ class MulOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ( OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Mul");
ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Mul");
platform::errors::NotFound("Input(X) of MulOp should not be null.")); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Mul");
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MulOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MulOp should not be null."));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
......
...@@ -348,6 +348,29 @@ struct EnforceNotMet : public std::exception { ...@@ -348,6 +348,29 @@ struct EnforceNotMet : public std::exception {
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ #define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
/** EXTENDED TOOL FUNCTIONS WITH CHECKING **/
/*
* Summary: This macro is used to check whether op has specified
* Input or Output Variables. Because op's Input and Output
* checking are written similarly, so abstract this macro.
*
* Parameters:
*     __EXPR: (bool), the bool expression
* __ROLE: (string), Input or Output
* __NAME: (string), Input or Output name
* __OP_TYPE: (string), the op type
*
* Examples:
* OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Mul");
*/
#define OP_INOUT_CHECK(__EXPR, __ROLE, __NAME, __OP_TYPE) \
do { \
PADDLE_ENFORCE_EQ(__EXPR, true, paddle::platform::errors::NotFound( \
"No %s(%s) found for %s operator.", \
__ROLE, __NAME, __OP_TYPE)); \
} while (0)
/** OTHER EXCEPTION AND ENFORCE **/ /** OTHER EXCEPTION AND ENFORCE **/
struct EOFException : public std::exception { struct EOFException : public std::exception {
......
...@@ -361,3 +361,17 @@ TEST(enforce, cannot_to_string_type) { ...@@ -361,3 +361,17 @@ TEST(enforce, cannot_to_string_type) {
list.push_back(4); list.push_back(4);
PADDLE_ENFORCE_NE(list.begin(), list.end()); PADDLE_ENFORCE_NE(list.begin(), list.end());
} }
TEST(OP_INOUT_CHECK_MACRO, SUCCESS) {
OP_INOUT_CHECK(true, "Input", "X", "dummy");
}
TEST(OP_INOUT_CHECK_MACRO, FAIL) {
bool caught_exception = false;
try {
OP_INOUT_CHECK(false, "Input", "X", "dummy");
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
}
EXPECT_TRUE(caught_exception);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册