提交 ddacdf17 编写于 作者: S Superjom

init enforce eq

上级 5485caf7
......@@ -162,10 +162,11 @@ inline void throw_on_error(T e) {
} \
} while (0)
#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1) \
PADDLE_ENFORCE((__VAL0) == (__VAL1), "enforce %s == %s failed, %s != %s", \
#__VAL0, #__VAL1, std::to_string(__VAL0), \
std::to_string(__VAL1));
#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \
PADDLE_ENFORCE((__VAL0) == (__VAL1), \
"enforce %s == %s failed, %s != %s\n%s", #__VAL0, #__VAL1, \
std::to_string(__VAL0), std::to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__));
} // namespace platform
} // namespace paddle
......@@ -34,3 +34,50 @@ TEST(ENFORCE, FAILED) {
}
ASSERT_TRUE(in_catch);
}
TEST(ENFORCE, NO_ARG_OK) {
int a = 2;
int b = 2;
PADDLE_ENFORCE_EQ(a, b);
// test enforce with extra message.
PADDLE_ENFORCE_EQ(a, b, "some thing wrong %s", "info");
}
TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) {
int a = 2;
bool in_catch = false;
try {
PADDLE_ENFORCE_EQ(a, 1 + 3);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce a == 1 + 3 failed, 2 != 4";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(in_catch);
}
TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) {
int a = 2;
bool in_catch = false;
try {
PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their");
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg =
"enforce a == 1 + 3 failed, 2 != 4\ntheir size not match";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(in_catch);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册