提交 3d282ec4 编写于 作者: Y Yang Yu

Add is_nan/is_inf

上级 e54bb6cc
......@@ -277,21 +277,23 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
return platform::VisitPlace(place, visitor);
}
struct HasNanPredicate {
struct HasNANPredicate {
template <typename T>
auto operator()(T eigen_vec) const -> decltype(std::declval<T>().isnan()) {
auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isnan()) {
return eigen_vec.isnan();
}
};
inline bool HasNan(const framework::Tensor& tensor) {
HasNanPredicate predicate;
inline bool HasNAN(const framework::Tensor& tensor) {
HasNANPredicate predicate;
return Any(tensor, predicate);
}
struct HasInfPredicate {
template <typename T>
auto operator()(T eigen_vec) const -> decltype(std::declval<T>().isinf()) {
auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isinf()) {
return eigen_vec.isinf();
}
};
......
......@@ -13,6 +13,7 @@
#include "paddle/framework/tensor_util.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
namespace paddle {
......@@ -230,5 +231,28 @@ TEST(CopyToVector, Tensor) {
#endif
}
TEST(IsNAN, CPU) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
float* buf = src.mutable_data<float>({3}, CPUPlace());
buf[0] = 0.0;
buf[1] = NAN;
buf[2] = 0.0;
ASSERT_TRUE(HasNAN(src));
}
TEST(IsInf, CPU) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
double* buf = src.mutable_data<double>({3}, CPUPlace());
buf[0] = 1.0;
buf[1] = INFINITY;
buf[2] = 0.0;
ASSERT_TRUE(HasInf(src));
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册