提交 45acfbd0 编写于 作者: M minqiyang

1. Add specific condition for one or no arg in PADDLE_ENFORCE

2. Add unit test for new enforce feature

test=develop
上级 b1d0a14c
...@@ -258,7 +258,12 @@ inline void throw_on_error(T e) { ...@@ -258,7 +258,12 @@ inline void throw_on_error(T e) {
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
throw ::paddle::platform::EnforceNotMet(__FILE__, __LINE__, __VA_ARGS__) throw ::paddle::platform::EnforceNotMet(__FILE__, __LINE__, __VA_ARGS__)
#define __PADDLE_THROW_ERROR(COND, ...) \ #define __PADDLE_THROW_ERROR_I(_, _9, _8, _7, _6, _5, _4, _3, _2, X_, ...) X_;
#define __THROW_ON_ERROR_ONE_ARG(COND, ARG) \
::paddle::platform::throw_on_error(COND, "%s", std::string(ARG));
#define __PADDLE_THROW_ON_ERROR(COND, ...) \
__PADDLE_THROW_ERROR_I( \ __PADDLE_THROW_ERROR_I( \
__VA_ARGS__, ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ __VA_ARGS__, ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
...@@ -268,15 +273,13 @@ inline void throw_on_error(T e) { ...@@ -268,15 +273,13 @@ inline void throw_on_error(T e) {
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \ ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND)) __THROW_ON_ERROR_ONE_ARG(COND, __VA_ARGS__))
#define __PADDLE_THROW_ERROR_I(_, _9, _8, _7, _6, _5, _4, _3, _2, X_, ...) X_;
#define __PADDLE_UNARY_COMPARE(COND, ...) \ #define __PADDLE_UNARY_COMPARE(COND, ...) \
do { \ do { \
auto __cond = COND; \ auto __cond = COND; \
if (UNLIKELY(::paddle::platform::is_error(__cond))) { \ if (UNLIKELY(::paddle::platform::is_error(__cond))) { \
__PADDLE_THROW_ERROR(__cond, __VA_ARGS__); \ __PADDLE_THROW_ON_ERROR(__cond, __VA_ARGS__); \
} \ } \
} while (0) } while (0)
......
...@@ -37,6 +37,25 @@ TEST(ENFORCE, FAILED) { ...@@ -37,6 +37,25 @@ TEST(ENFORCE, FAILED) {
HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all")); HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all"));
} }
EXPECT_TRUE(caught_exception); EXPECT_TRUE(caught_exception);
caught_exception = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok at all");
} catch (paddle::platform::EnforceNotMet error) {
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "Enforce is not ok at all"));
}
EXPECT_TRUE(caught_exception);
caught_exception = false;
try {
PADDLE_ENFORCE(false);
} catch (paddle::platform::EnforceNotMet error) {
caught_exception = true;
EXPECT_NE(std::string(error.what()).find(" at "), 0);
}
EXPECT_TRUE(caught_exception);
} }
TEST(ENFORCE, NO_ARG_OK) { TEST(ENFORCE, NO_ARG_OK) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册