diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 60a42c777d1c2ebbc22fdb77b1100cc6fcf7ff35..bc0715656a7d61774d53d4a0643ec1c105706085 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -162,5 +162,50 @@ inline void throw_on_error(T e) { } \ } while (0) +/* + * Some enforce helpers here, usage: + * int a = 1; + * int b = 2; + * PADDLE_ENFORCE_EQ(a, b); + * + * will raise an expression described as follows: + * "enforce a == b failed, 1 != 2" with detailed stack infomation. + * + * extra messages is also supported, for example: + * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2) + */ + +#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__) +#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, !=, ==, __VA_ARGS__) +#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >, <=, __VA_ARGS__) +#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >=, <, __VA_ARGS__) +#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__) +#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) + +// if two values have different data types, choose a compatible type for them. +template +struct CompatibleType { + static const bool t1_to_t2 = std::is_convertible::value; + typedef typename std::conditional::type type; +}; + +#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ + PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \ + __CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \ + "enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \ + #__VAL0, #__VAL1, std::to_string(__VAL0), \ + std::to_string(__VAL1), \ + paddle::string::Sprintf("" __VA_ARGS__)); + +#define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \ + typename paddle::platform::CompatibleType::type(__VAL) + } // namespace platform } // namespace paddle diff --git a/paddle/platform/enforce_test.cc b/paddle/platform/enforce_test.cc index 2ac31812a80d8dd57ce82234cb5835e029a46067..7117b49474044af08ae9db79c2fae6693e966af2 100644 --- a/paddle/platform/enforce_test.cc +++ b/paddle/platform/enforce_test.cc @@ -34,3 +34,165 @@ TEST(ENFORCE, FAILED) { } ASSERT_TRUE(in_catch); } + +TEST(ENFORCE, NO_ARG_OK) { + int a = 2; + int b = 2; + PADDLE_ENFORCE_EQ(a, b); + // test enforce with extra message. + PADDLE_ENFORCE_EQ(a, b, "some thing wrong %s", "info"); +} + +TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { + int a = 2; + bool in_catch = false; + + try { + PADDLE_ENFORCE_EQ(a, 1 + 3); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce a == 1 + 3 failed, 2 != 4"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { + int a = 2; + bool in_catch = false; + + try { + PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their"); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = + "enforce a == 1 + 3 failed, 2 != 4\ntheir size not match"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_NE, OK) { + PADDLE_ENFORCE_NE(1, 2); + PADDLE_ENFORCE_NE(1.0, 2UL); +} +TEST(ENFORCE_NE, FAIL) { + bool in_catch = false; + + try { + // 2UL here to check data type compatible + PADDLE_ENFORCE_NE(1.0, 1UL); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1.0 != 1UL failed, 1.000000 == 1"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); } +TEST(ENFORCE_GT, FAIL) { + bool in_catch = false; + + try { + // 2UL here to check data type compatible + PADDLE_ENFORCE_GT(1, 2UL); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1 > 2UL failed, 1 <= 2"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_GE, OK) { + PADDLE_ENFORCE_GE(2, 2UL); + PADDLE_ENFORCE_GE(3, 2UL); + PADDLE_ENFORCE_GE(3, 2); + PADDLE_ENFORCE_GE(3.21, 2UL); +} +TEST(ENFORCE_GE, FAIL) { + bool in_catch = false; + + try { + PADDLE_ENFORCE_GE(1, 2UL); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1 >= 2UL failed, 1 < 2"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_LE, OK) { + PADDLE_ENFORCE_LE(1, 1); + PADDLE_ENFORCE_LE(1, 1UL); + PADDLE_ENFORCE_LE(2, 3UL); + PADDLE_ENFORCE_LE(2UL, 3); + PADDLE_ENFORCE_LE(2UL, 3.2); +} +TEST(ENFORCE_LE, FAIL) { + bool in_catch = false; + + try { + PADDLE_ENFORCE_GT(1, 2UL); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1 > 2UL failed, 1 <= 2"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_LT, OK) { + PADDLE_ENFORCE_LT(3, 10); + PADDLE_ENFORCE_LT(2, 3UL); + PADDLE_ENFORCE_LT(2UL, 3); +} +TEST(ENFORCE_LT, FAIL) { + bool in_catch = false; + + try { + PADDLE_ENFORCE_LT(1UL, 0.12); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1UL < 0.12 failed, 1 >= 0.12"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +}