diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 7e03bf44251d83bea5a42d6d8c9a17b148418b07..a0d93a38ce9c78d20dffc2af2efd2fcc67baca0f 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -162,10 +162,11 @@ inline void throw_on_error(T e) { } \ } while (0) -#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1) \ - PADDLE_ENFORCE((__VAL0) == (__VAL1), "enforce %s == %s failed, %s != %s", \ - #__VAL0, #__VAL1, std::to_string(__VAL0), \ - std::to_string(__VAL1)); +#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \ + PADDLE_ENFORCE((__VAL0) == (__VAL1), \ + "enforce %s == %s failed, %s != %s\n%s", #__VAL0, #__VAL1, \ + std::to_string(__VAL0), std::to_string(__VAL1), \ + paddle::string::Sprintf("" __VA_ARGS__)); } // namespace platform } // namespace paddle diff --git a/paddle/platform/enforce_test.cc b/paddle/platform/enforce_test.cc index 2ac31812a80d8dd57ce82234cb5835e029a46067..c44fb4360de5ad906560ae95d9986406ecf32ff2 100644 --- a/paddle/platform/enforce_test.cc +++ b/paddle/platform/enforce_test.cc @@ -34,3 +34,50 @@ TEST(ENFORCE, FAILED) { } ASSERT_TRUE(in_catch); } + +TEST(ENFORCE, NO_ARG_OK) { + int a = 2; + int b = 2; + PADDLE_ENFORCE_EQ(a, b); + // test enforce with extra message. + PADDLE_ENFORCE_EQ(a, b, "some thing wrong %s", "info"); +} + +TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { + int a = 2; + bool in_catch = false; + + try { + PADDLE_ENFORCE_EQ(a, 1 + 3); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce a == 1 + 3 failed, 2 != 4"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { + int a = 2; + bool in_catch = false; + + try { + PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their"); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = + "enforce a == 1 + 3 failed, 2 != 4\ntheir size not match"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +}