diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index 5c7822814c4468a70110c6bd179b09d86edf739b..7d786ad6141fa41f51ce8ec45ba7a37d513bec35 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -277,21 +277,23 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) { return platform::VisitPlace(place, visitor); } -struct HasNanPredicate { +struct HasNANPredicate { template - auto operator()(T eigen_vec) const -> decltype(std::declval().isnan()) { + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isnan()) { return eigen_vec.isnan(); } }; -inline bool HasNan(const framework::Tensor& tensor) { - HasNanPredicate predicate; +inline bool HasNAN(const framework::Tensor& tensor) { + HasNANPredicate predicate; return Any(tensor, predicate); } struct HasInfPredicate { template - auto operator()(T eigen_vec) const -> decltype(std::declval().isinf()) { + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isinf()) { return eigen_vec.isinf(); } }; diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc index f388c19f28ed28335818733f946d8eaf18464627..01dfd4deb9d126b9e28391b3643156cd1b0dfc9d 100644 --- a/paddle/framework/tensor_util_test.cc +++ b/paddle/framework/tensor_util_test.cc @@ -13,6 +13,7 @@ #include "paddle/framework/tensor_util.h" #include +#include #include namespace paddle { @@ -230,5 +231,28 @@ TEST(CopyToVector, Tensor) { #endif } +TEST(IsNAN, CPU) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + float* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 0.0; + buf[1] = NAN; + buf[2] = 0.0; + + ASSERT_TRUE(HasNAN(src)); +} + +TEST(IsInf, CPU) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + double* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 1.0; + buf[1] = INFINITY; + buf[2] = 0.0; + ASSERT_TRUE(HasInf(src)); +} + } // namespace framework } // namespace paddle