未验证 提交 91a0911c 编写于 作者: Z Zeng Jinle 提交者: GitHub

Make PADDLE_ENFORCE_EQ support types that cannot be converted to std::string (#19243)

* make PADDLE_ENFORCE_EQ support cannot to string types, test=develop

* follow huihuang's comments, test=develop
上级 8a89ca94
......@@ -27,6 +27,7 @@ limitations under the License. */
#endif // PADDLE_WITH_CUDA
#include <iomanip>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
......@@ -307,7 +308,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
do { \
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
__LINE__); \
} while (false)
} while (0)
/*
* Some enforce helpers here, usage:
......@@ -366,28 +367,72 @@ using CommonType1 = typename std::add_lvalue_reference<
template <typename T1, typename T2>
using CommonType2 = typename std::add_lvalue_reference<
typename std::add_const<typename TypeConverter<T1, T2>::Type2>::type>::type;
// Here, we use SFINAE to check whether T can be converted to std::string
template <typename T>
struct CanToString {
private:
using YesType = uint8_t;
using NoType = uint16_t;
template <typename U>
static YesType Check(decltype(std::cout << std::declval<U>())) {
return 0;
}
template <typename U>
static NoType Check(...) {
return 0;
}
public:
static constexpr bool kValue =
std::is_same<YesType, decltype(Check<T>(std::cout))>::value;
};
template <bool kCanToString /* = true */>
struct BinaryCompareMessageConverter {
template <typename T>
static std::string Convert(const char* expression, const T& value) {
return expression + std::string(":") + string::to_string(value);
}
};
template <>
struct BinaryCompareMessageConverter<false> {
template <typename T>
static const char* Convert(const char* expression, const T& value) {
return expression;
}
};
} // namespace details
#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \
do { \
auto __val1 = (__VAL1); \
auto __val2 = (__VAL2); \
using __TYPE1__ = decltype(__val1); \
using __TYPE2__ = decltype(__val2); \
using __COMMON_TYPE1__ = \
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
using __COMMON_TYPE2__ = \
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
static_cast<__COMMON_TYPE2__>(__val2)); \
if (UNLIKELY(!__is_not_error)) { \
PADDLE_THROW("Enforce failed. Expected %s " #__CMP \
" %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \
#__VAL1, #__VAL2, #__VAL1, \
::paddle::string::to_string(__val1), #__VAL2, \
::paddle::string::to_string(__val2), \
::paddle::string::Sprintf(__VA_ARGS__)); \
} \
#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \
do { \
auto __val1 = (__VAL1); \
auto __val2 = (__VAL2); \
using __TYPE1__ = decltype(__val1); \
using __TYPE2__ = decltype(__val2); \
using __COMMON_TYPE1__ = \
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
using __COMMON_TYPE2__ = \
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
static_cast<__COMMON_TYPE2__>(__val2)); \
if (UNLIKELY(!__is_not_error)) { \
constexpr bool __kCanToString__ = \
::paddle::platform::details::CanToString<__TYPE1__>::kValue && \
::paddle::platform::details::CanToString<__TYPE2__>::kValue; \
PADDLE_THROW("Enforce failed. Expected %s " #__CMP \
" %s, but received %s " #__INV_CMP " %s.\n%s", \
#__VAL1, #__VAL2, \
::paddle::platform::details::BinaryCompareMessageConverter< \
__kCanToString__>::Convert(#__VAL1, __val1), \
::paddle::platform::details::BinaryCompareMessageConverter< \
__kCanToString__>::Convert(#__VAL2, __val2), \
::paddle::string::Sprintf(__VA_ARGS__)); \
} \
} while (0)
#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \
......
......@@ -11,7 +11,9 @@ limitations under the License. */
#include <array>
#include <iostream>
#include <list>
#include <memory>
#include <set>
#include "gtest/gtest.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -296,3 +298,64 @@ TEST(enforce, cuda_success) {
#endif
}
#endif
struct CannotToStringType {
explicit CannotToStringType(int num) : num_(num) {}
bool operator==(const CannotToStringType& other) const {
return num_ == other.num_;
}
bool operator!=(const CannotToStringType& other) const {
return num_ != other.num_;
}
private:
int num_;
};
TEST(enforce, cannot_to_string_type) {
static_assert(
!paddle::platform::details::CanToString<CannotToStringType>::kValue,
"CannotToStringType must not be converted to string");
static_assert(paddle::platform::details::CanToString<int>::kValue,
"int can be converted to string");
CannotToStringType obj1(3), obj2(4), obj3(3);
PADDLE_ENFORCE_NE(obj1, obj2, "Object 1 is not equal to Object 2");
PADDLE_ENFORCE_EQ(obj1, obj3, "Object 1 is equal to Object 3");
std::string msg = "Compare obj1 with obj2";
try {
PADDLE_ENFORCE_EQ(obj1, obj2, msg);
} catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what();
LOG(INFO) << ex_msg;
EXPECT_TRUE(ex_msg.find(msg) != std::string::npos);
EXPECT_TRUE(
ex_msg.find("Expected obj1 == obj2, but received obj1 != obj2") !=
std::string::npos);
}
msg = "Compare x with y";
try {
int x = 3, y = 2;
PADDLE_ENFORCE_EQ(x, y, msg);
} catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what();
LOG(INFO) << ex_msg;
EXPECT_TRUE(ex_msg.find(msg) != std::string::npos);
EXPECT_TRUE(ex_msg.find("Expected x == y, but received x:3 != y:2") !=
std::string::npos);
}
std::set<int> set;
PADDLE_ENFORCE_EQ(set.begin(), set.end());
set.insert(3);
PADDLE_ENFORCE_NE(set.begin(), set.end());
std::list<float> list;
PADDLE_ENFORCE_EQ(list.begin(), list.end());
list.push_back(4);
PADDLE_ENFORCE_NE(list.begin(), list.end());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册