提交 3d282ec4 编写于 作者: Y Yang Yu

Add is_nan/is_inf

上级 e54bb6cc
...@@ -277,21 +277,23 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) { ...@@ -277,21 +277,23 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
return platform::VisitPlace(place, visitor); return platform::VisitPlace(place, visitor);
} }
struct HasNanPredicate { struct HasNANPredicate {
template <typename T> template <typename T>
auto operator()(T eigen_vec) const -> decltype(std::declval<T>().isnan()) { auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isnan()) {
return eigen_vec.isnan(); return eigen_vec.isnan();
} }
}; };
inline bool HasNan(const framework::Tensor& tensor) { inline bool HasNAN(const framework::Tensor& tensor) {
HasNanPredicate predicate; HasNANPredicate predicate;
return Any(tensor, predicate); return Any(tensor, predicate);
} }
struct HasInfPredicate { struct HasInfPredicate {
template <typename T> template <typename T>
auto operator()(T eigen_vec) const -> decltype(std::declval<T>().isinf()) { auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isinf()) {
return eigen_vec.isinf(); return eigen_vec.isinf();
} }
}; };
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "paddle/framework/tensor_util.h" #include "paddle/framework/tensor_util.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath>
#include <string> #include <string>
namespace paddle { namespace paddle {
...@@ -230,5 +231,28 @@ TEST(CopyToVector, Tensor) { ...@@ -230,5 +231,28 @@ TEST(CopyToVector, Tensor) {
#endif #endif
} }
TEST(IsNAN, CPU) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
float* buf = src.mutable_data<float>({3}, CPUPlace());
buf[0] = 0.0;
buf[1] = NAN;
buf[2] = 0.0;
ASSERT_TRUE(HasNAN(src));
}
TEST(IsInf, CPU) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
double* buf = src.mutable_data<double>({3}, CPUPlace());
buf[0] = 1.0;
buf[1] = INFINITY;
buf[2] = 0.0;
ASSERT_TRUE(HasInf(src));
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册