diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 5f111f6b0d57e1ae5174d232e81a02d758fafcb9..006330da5b12bc9655f33d33d6047ecf9cb007a7 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -32,15 +32,9 @@ class MulOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::NotFound("Input(X) of MulOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Y"), true, - platform::errors::NotFound("Input(Y) of MulOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::NotFound("Output(Out) of MulOp should not be null.")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Mul"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Mul"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Mul"); auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index e2bd12a26acd2d4da5af3d78ab04647cd462f822..45aa32e17c32285a1113e9068c9c996b70b7cc22 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -348,6 +348,29 @@ struct EnforceNotMet : public std::exception { #define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) +/** EXTENDED TOOL FUNCTIONS WITH CHECKING **/ + +/* + * Summary: This macro is used to check whether op has specified + * Input or Output Variables. Because op's Input and Output + * checking are written similarly, so abstract this macro. + * + * Parameters: + *     __EXPR: (bool), the bool expression + * __ROLE: (string), Input or Output + * __NAME: (string), Input or Output name + * __OP_TYPE: (string), the op type + * + * Examples: + * OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Mul"); +*/ +#define OP_INOUT_CHECK(__EXPR, __ROLE, __NAME, __OP_TYPE) \ + do { \ + PADDLE_ENFORCE_EQ(__EXPR, true, paddle::platform::errors::NotFound( \ + "No %s(%s) found for %s operator.", \ + __ROLE, __NAME, __OP_TYPE)); \ + } while (0) + /** OTHER EXCEPTION AND ENFORCE **/ struct EOFException : public std::exception { diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index ff0e5d6ff0c49a0e4793fc2eb274e0ef5ef3502d..1215005ad80de9119533e714aa447cf874690dde 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -361,3 +361,17 @@ TEST(enforce, cannot_to_string_type) { list.push_back(4); PADDLE_ENFORCE_NE(list.begin(), list.end()); } + +TEST(OP_INOUT_CHECK_MACRO, SUCCESS) { + OP_INOUT_CHECK(true, "Input", "X", "dummy"); +} + +TEST(OP_INOUT_CHECK_MACRO, FAIL) { + bool caught_exception = false; + try { + OP_INOUT_CHECK(false, "Input", "X", "dummy"); + } catch (paddle::platform::EnforceNotMet& error) { + caught_exception = true; + } + EXPECT_TRUE(caught_exception); +}