未验证 提交 91a0911c 编写于 作者: Z Zeng Jinle 提交者: GitHub

Make PADDLE_ENFORCE_EQ support types that cannot be converted to std::string (#19243)

* make PADDLE_ENFORCE_EQ support cannot to string types, test=develop

* follow huihuang's comments, test=develop
上级 8a89ca94
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
#include <iomanip> #include <iomanip>
#include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
...@@ -307,7 +308,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); ...@@ -307,7 +308,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
do { \ do { \
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
__LINE__); \ __LINE__); \
} while (false) } while (0)
/* /*
* Some enforce helpers here, usage: * Some enforce helpers here, usage:
...@@ -366,6 +367,45 @@ using CommonType1 = typename std::add_lvalue_reference< ...@@ -366,6 +367,45 @@ using CommonType1 = typename std::add_lvalue_reference<
template <typename T1, typename T2> template <typename T1, typename T2>
using CommonType2 = typename std::add_lvalue_reference< using CommonType2 = typename std::add_lvalue_reference<
typename std::add_const<typename TypeConverter<T1, T2>::Type2>::type>::type; typename std::add_const<typename TypeConverter<T1, T2>::Type2>::type>::type;
// Here, we use SFINAE to check whether T can be converted to std::string
template <typename T>
struct CanToString {
private:
using YesType = uint8_t;
using NoType = uint16_t;
template <typename U>
static YesType Check(decltype(std::cout << std::declval<U>())) {
return 0;
}
template <typename U>
static NoType Check(...) {
return 0;
}
public:
static constexpr bool kValue =
std::is_same<YesType, decltype(Check<T>(std::cout))>::value;
};
template <bool kCanToString /* = true */>
struct BinaryCompareMessageConverter {
template <typename T>
static std::string Convert(const char* expression, const T& value) {
return expression + std::string(":") + string::to_string(value);
}
};
template <>
struct BinaryCompareMessageConverter<false> {
template <typename T>
static const char* Convert(const char* expression, const T& value) {
return expression;
}
};
} // namespace details } // namespace details
#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \ #define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \
...@@ -381,11 +421,16 @@ using CommonType2 = typename std::add_lvalue_reference< ...@@ -381,11 +421,16 @@ using CommonType2 = typename std::add_lvalue_reference<
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \ bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
static_cast<__COMMON_TYPE2__>(__val2)); \ static_cast<__COMMON_TYPE2__>(__val2)); \
if (UNLIKELY(!__is_not_error)) { \ if (UNLIKELY(!__is_not_error)) { \
constexpr bool __kCanToString__ = \
::paddle::platform::details::CanToString<__TYPE1__>::kValue && \
::paddle::platform::details::CanToString<__TYPE2__>::kValue; \
PADDLE_THROW("Enforce failed. Expected %s " #__CMP \ PADDLE_THROW("Enforce failed. Expected %s " #__CMP \
" %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \ " %s, but received %s " #__INV_CMP " %s.\n%s", \
#__VAL1, #__VAL2, #__VAL1, \ #__VAL1, #__VAL2, \
::paddle::string::to_string(__val1), #__VAL2, \ ::paddle::platform::details::BinaryCompareMessageConverter< \
::paddle::string::to_string(__val2), \ __kCanToString__>::Convert(#__VAL1, __val1), \
::paddle::platform::details::BinaryCompareMessageConverter< \
__kCanToString__>::Convert(#__VAL2, __val2), \
::paddle::string::Sprintf(__VA_ARGS__)); \ ::paddle::string::Sprintf(__VA_ARGS__)); \
} \ } \
} while (0) } while (0)
......
...@@ -11,7 +11,9 @@ limitations under the License. */ ...@@ -11,7 +11,9 @@ limitations under the License. */
#include <array> #include <array>
#include <iostream> #include <iostream>
#include <list>
#include <memory> #include <memory>
#include <set>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -296,3 +298,64 @@ TEST(enforce, cuda_success) { ...@@ -296,3 +298,64 @@ TEST(enforce, cuda_success) {
#endif #endif
} }
#endif #endif
struct CannotToStringType {
explicit CannotToStringType(int num) : num_(num) {}
bool operator==(const CannotToStringType& other) const {
return num_ == other.num_;
}
bool operator!=(const CannotToStringType& other) const {
return num_ != other.num_;
}
private:
int num_;
};
TEST(enforce, cannot_to_string_type) {
static_assert(
!paddle::platform::details::CanToString<CannotToStringType>::kValue,
"CannotToStringType must not be converted to string");
static_assert(paddle::platform::details::CanToString<int>::kValue,
"int can be converted to string");
CannotToStringType obj1(3), obj2(4), obj3(3);
PADDLE_ENFORCE_NE(obj1, obj2, "Object 1 is not equal to Object 2");
PADDLE_ENFORCE_EQ(obj1, obj3, "Object 1 is equal to Object 3");
std::string msg = "Compare obj1 with obj2";
try {
PADDLE_ENFORCE_EQ(obj1, obj2, msg);
} catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what();
LOG(INFO) << ex_msg;
EXPECT_TRUE(ex_msg.find(msg) != std::string::npos);
EXPECT_TRUE(
ex_msg.find("Expected obj1 == obj2, but received obj1 != obj2") !=
std::string::npos);
}
msg = "Compare x with y";
try {
int x = 3, y = 2;
PADDLE_ENFORCE_EQ(x, y, msg);
} catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what();
LOG(INFO) << ex_msg;
EXPECT_TRUE(ex_msg.find(msg) != std::string::npos);
EXPECT_TRUE(ex_msg.find("Expected x == y, but received x:3 != y:2") !=
std::string::npos);
}
std::set<int> set;
PADDLE_ENFORCE_EQ(set.begin(), set.end());
set.insert(3);
PADDLE_ENFORCE_NE(set.begin(), set.end());
std::list<float> list;
PADDLE_ENFORCE_EQ(list.begin(), list.end());
list.push_back(4);
PADDLE_ENFORCE_NE(list.begin(), list.end());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册