diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index dd83686b9d27490b6a8426ccc01ec915ea5eb7c9..7eb4be2137e1fe08e27909a885b35cbb3bb4a5e4 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -258,7 +258,12 @@ inline void throw_on_error(T e) { #define PADDLE_THROW(...) \ throw ::paddle::platform::EnforceNotMet(__FILE__, __LINE__, __VA_ARGS__) -#define __PADDLE_THROW_ERROR(COND, ...) \ +#define __PADDLE_THROW_ERROR_I(_, _9, _8, _7, _6, _5, _4, _3, _2, X_, ...) X_; + +#define __THROW_ON_ERROR_ONE_ARG(COND, ARG) \ + ::paddle::platform::throw_on_error(COND, "%s", std::string(ARG)); + +#define __PADDLE_THROW_ON_ERROR(COND, ...) \ __PADDLE_THROW_ERROR_I( \ __VA_ARGS__, ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ @@ -268,15 +273,13 @@ inline void throw_on_error(T e) { ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ - ::paddle::platform::throw_on_error(COND)) - -#define __PADDLE_THROW_ERROR_I(_, _9, _8, _7, _6, _5, _4, _3, _2, X_, ...) X_; + __THROW_ON_ERROR_ONE_ARG(COND, __VA_ARGS__)) #define __PADDLE_UNARY_COMPARE(COND, ...) \ do { \ auto __cond = COND; \ if (UNLIKELY(::paddle::platform::is_error(__cond))) { \ - __PADDLE_THROW_ERROR(__cond, __VA_ARGS__); \ + __PADDLE_THROW_ON_ERROR(__cond, __VA_ARGS__); \ } \ } while (0) diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index d52182965552e9ec945cb7d0b421d8addcb758e9..1091badae54a809c4a9da6d0398bcbb538420af0 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -37,6 +37,25 @@ TEST(ENFORCE, FAILED) { HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all")); } EXPECT_TRUE(caught_exception); + + caught_exception = false; + try { + PADDLE_ENFORCE(false, "Enforce is not ok at all"); + } catch (paddle::platform::EnforceNotMet error) { + caught_exception = true; + EXPECT_TRUE( + HasPrefix(StringPiece(error.what()), "Enforce is not ok at all")); + } + EXPECT_TRUE(caught_exception); + + caught_exception = false; + try { + PADDLE_ENFORCE(false); + } catch (paddle::platform::EnforceNotMet error) { + caught_exception = true; + EXPECT_NE(std::string(error.what()).find(" at "), 0); + } + EXPECT_TRUE(caught_exception); } TEST(ENFORCE, NO_ARG_OK) {