未验证 提交 a46e30aa 编写于 作者: D dzhwinter 提交者: GitHub

enhance isinf/isnan in tensor util, avoid copy back to cpu (#12688)

* "avoid copy back to cpu"

* "add infinity support"

* "fix ci"

* "add cpu macro"

* rerun ci; test=develop

* "fix api"
test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop
上级 9ff5184f
......@@ -198,6 +198,9 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None
paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.has_inf ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.has_nan ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.isfinite ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <typeindex>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......
......@@ -165,10 +165,12 @@ inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
}
template <typename Predicate>
struct AnyVisitor : public boost::static_visitor<bool> {
class AnyVisitor : public boost::static_visitor<bool> {
private:
const framework::Tensor& tensor_;
Predicate predicate_;
public:
AnyVisitor(const framework::Tensor& tensor, Predicate predicate)
: tensor_(tensor), predicate_(std::move(predicate)) {}
......@@ -206,6 +208,27 @@ struct AnyVisitor : public boost::static_visitor<bool> {
}
};
template <typename Predicate>
class AnyOutVisitor : public boost::static_visitor<> {
private:
const framework::Tensor& tensor_;
mutable framework::Tensor* out_;
Predicate predicate_;
public:
AnyOutVisitor(const framework::Tensor& tensor, Predicate predicate,
framework::Tensor* out)
: tensor_(tensor), out_(out), predicate_(std::move(predicate)) {}
template <typename Place>
void operator()(const Place& place) const {
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
out_->Resize({1});
out_->mutable_data<bool>(place);
AnyImpl(predicate_, tensor_, *ctx, out_);
}
};
template <typename Predicate>
inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
AnyVisitor<Predicate> visitor(tensor, predicate);
......@@ -213,6 +236,14 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
return platform::VisitPlace(place, visitor);
}
template <typename Predicate>
inline void Any(const framework::Tensor& tensor, Predicate predicate,
framework::Tensor* out) {
AnyOutVisitor<Predicate> visitor(tensor, predicate, out);
auto place = tensor.place();
platform::VisitPlace(place, visitor);
}
struct ContainsNANPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
......@@ -227,6 +258,12 @@ bool TensorContainsNAN(const framework::Tensor& tensor) {
return Any(tensor, predicate);
}
void TensorContainsNAN(const framework::Tensor& tensor,
framework::Tensor* out) {
ContainsNANPredicate predicate;
Any(tensor, predicate, out);
}
struct ContainsInfPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
......@@ -241,6 +278,71 @@ bool TensorContainsInf(const framework::Tensor& tensor) {
return Any(tensor, predicate);
}
void TensorContainsInf(const framework::Tensor& tensor,
framework::Tensor* out) {
ContainsInfPredicate predicate;
Any(tensor, predicate, out);
}
// NOTE(dzhwinter):
// Isfinite need a AllVisitor to loop through all the elements.
// We choose two cuda call instead of one allvisitor. The AllVisitor
// should be implemented if the performance hurts.
bool TensorIsfinite(const framework::Tensor& tensor) {
ContainsInfPredicate pred_inf;
ContainsNANPredicate pred_nan;
return !Any(tensor, pred_inf) && !Any(tensor, pred_nan);
}
#ifdef PADDLE_WITH_CUDA
template <typename T>
static inline void __global__ BothFalse(const T* cmp, T* out) {
out[0] = (!cmp[0]) && (!out[0]);
}
#endif
struct BothFalseVisitor : public boost::static_visitor<> {
const framework::Tensor& in_;
mutable framework::Tensor* out_;
BothFalseVisitor(const framework::Tensor& in, framework::Tensor* out)
: in_(in), out_(out) {}
template <typename Place>
void operator()(const Place& place) const {
VisitorImpl(place);
}
void VisitorImpl(const platform::CUDAPlace& gpu) const {
#ifdef PADDLE_WITH_CUDA
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(gpu);
BothFalse<bool><<<1, 1, 0, ctx->stream()>>>(in_.data<bool>(),
out_->mutable_data<bool>(gpu));
#endif
}
void VisitorImpl(const platform::CPUPlace& cpu) const {
bool lhs = !in_.data<bool>()[0];
bool rhs = !out_->mutable_data<bool>(cpu)[0];
out_->mutable_data<bool>(cpu)[0] = lhs && rhs;
}
void VisitorImpl(
const platform::CUDAPinnedPlace& cpu /* equals to cpu*/) const {
bool lhs = !in_.data<bool>()[0];
bool rhs = !out_->mutable_data<bool>(cpu)[0];
out_->mutable_data<bool>(cpu)[0] = lhs && rhs;
}
};
void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out) {
framework::Tensor tmp;
TensorContainsInf(tensor, &tmp);
TensorContainsNAN(tensor, out);
BothFalseVisitor visitor(tmp, out);
auto place = tensor.place();
platform::VisitPlace(place, visitor);
}
void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version
......
......@@ -57,8 +57,15 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
template <typename T>
void TesnorToVector(const Tensor& src, std::vector<T>* dst);
// copy the result bool to cpu
bool TensorContainsNAN(const framework::Tensor& tensor);
bool TensorContainsInf(const framework::Tensor& tensor);
bool TensorIsfinite(const framework::Tensor& tensor);
// store the result bool in gpu tensor, async operation. Faster than above ones.
void TensorContainsNAN(const framework::Tensor& tensor, framework::Tensor* out);
void TensorContainsInf(const framework::Tensor& tensor, framework::Tensor* out);
void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out);
void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx);
......
......@@ -36,7 +36,7 @@ TEST(TensorCopy, Tensor) {
TensorCopy(src_tensor, *cpu_place, &dst_tensor);
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
EXPECT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
......@@ -47,7 +47,7 @@ TEST(TensorCopy, Tensor) {
TensorCopy(slice_tensor, *cpu_place, &dst_tensor);
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
EXPECT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
......@@ -77,7 +77,7 @@ TEST(TensorCopy, Tensor) {
// Sync before Compare Tensors
gpu_ctx.Wait();
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
EXPECT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
......@@ -94,7 +94,7 @@ TEST(TensorCopy, Tensor) {
gpu_ctx.Wait();
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
EXPECT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
......@@ -117,7 +117,7 @@ TEST(TensorFromVector, Tensor) {
// Compare Tensors
const int* cpu_ptr = cpu_tensor.data<int>();
const int* src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
EXPECT_NE(src_ptr, cpu_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
}
......@@ -127,7 +127,7 @@ TEST(TensorFromVector, Tensor) {
paddle::framework::TensorFromVector<int>(src_vec, &cpu_tensor);
cpu_ptr = cpu_tensor.data<int>();
src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
EXPECT_NE(src_ptr, cpu_ptr);
for (size_t i = 0; i < 5; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
}
......@@ -161,8 +161,8 @@ TEST(TensorFromVector, Tensor) {
const int* src_ptr = src_vec.data();
const int* cpu_ptr = cpu_tensor.data<int>();
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, cpu_ptr);
ASSERT_NE(src_ptr, dst_ptr);
EXPECT_NE(src_ptr, cpu_ptr);
EXPECT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
......@@ -181,8 +181,8 @@ TEST(TensorFromVector, Tensor) {
src_ptr = src_vec.data();
cpu_ptr = cpu_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, cpu_ptr);
ASSERT_NE(src_ptr, dst_ptr);
EXPECT_NE(src_ptr, cpu_ptr);
EXPECT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 5; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
......@@ -235,9 +235,9 @@ TEST(TensorContainsNAN, CPU) {
buf[0] = 0.0;
buf[1] = NAN;
buf[2] = 0.0;
ASSERT_TRUE(paddle::framework::TensorContainsNAN(src));
EXPECT_TRUE(paddle::framework::TensorContainsNAN(src));
buf[1] = 0.0;
ASSERT_FALSE(paddle::framework::TensorContainsNAN(src));
EXPECT_FALSE(paddle::framework::TensorContainsNAN(src));
}
{
......@@ -248,9 +248,9 @@ TEST(TensorContainsNAN, CPU) {
buf[0] = 0.0;
buf[1].x = 0x7fff;
buf[2] = 0.0;
ASSERT_TRUE(paddle::framework::TensorContainsNAN(src));
EXPECT_TRUE(paddle::framework::TensorContainsNAN(src));
buf[1] = 0.0;
ASSERT_FALSE(paddle::framework::TensorContainsNAN(src));
EXPECT_FALSE(paddle::framework::TensorContainsNAN(src));
}
}
......@@ -261,9 +261,9 @@ TEST(TensorContainsInf, CPU) {
buf[0] = 1.0;
buf[1] = INFINITY;
buf[2] = 0.0;
ASSERT_TRUE(paddle::framework::TensorContainsInf(src));
EXPECT_TRUE(paddle::framework::TensorContainsInf(src));
buf[1] = 1.0;
ASSERT_FALSE(paddle::framework::TensorContainsInf(src));
EXPECT_FALSE(paddle::framework::TensorContainsInf(src));
}
{
......@@ -274,9 +274,55 @@ TEST(TensorContainsInf, CPU) {
buf[0] = 1.0;
buf[1].x = 0x7c00;
buf[2] = 0.0;
ASSERT_TRUE(paddle::framework::TensorContainsInf(src));
EXPECT_TRUE(paddle::framework::TensorContainsInf(src));
buf[1] = 1.0;
ASSERT_FALSE(paddle::framework::TensorContainsInf(src));
EXPECT_FALSE(paddle::framework::TensorContainsInf(src));
}
}
TEST(TensorIsfinite, CPU) {
{
paddle::framework::Tensor src, out;
double* buf = src.mutable_data<double>({3}, paddle::platform::CPUPlace());
buf[0] = 1.0;
buf[1] = INFINITY;
buf[2] = 0.0;
paddle::framework::TensorIsfinite(src, &out);
EXPECT_EQ(out.data<bool>()[0], false);
buf[1] = 1.0;
paddle::framework::TensorIsfinite(src, &out);
EXPECT_EQ(out.data<bool>()[0], true);
}
{
paddle::framework::Tensor src, out;
double* buf = src.mutable_data<double>({3}, paddle::platform::CPUPlace());
buf[0] = 1.0;
buf[1] = NAN;
buf[2] = 0.0;
paddle::framework::TensorIsfinite(src, &out);
EXPECT_EQ(out.data<bool>()[0], false);
buf[1] = 1.0;
paddle::framework::TensorIsfinite(src, &out);
EXPECT_EQ(out.data<bool>()[0], true);
}
{
paddle::framework::Tensor src, out;
paddle::platform::float16* buf =
src.mutable_data<paddle::platform::float16>(
{3}, paddle::platform::CPUPlace());
buf[0] = 1.0;
buf[1].x = 0x7c00;
buf[2] = 0.0;
paddle::framework::TensorIsfinite(src, &out);
EXPECT_EQ(out.data<bool>()[0], false);
buf[1] = 1.0;
paddle::framework::TensorIsfinite(src, &out);
EXPECT_EQ(out.data<bool>()[0], true);
buf[1].x = 0x7fff;
paddle::framework::TensorIsfinite(src, &out);
EXPECT_EQ(out.data<bool>()[0], false);
}
}
......@@ -299,9 +345,9 @@ TEST(Tensor, FromAndToStream) {
TensorFromStream(iss, &dst_tensor, cpu_ctx);
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 5; ++i) {
ASSERT_EQ(dst_ptr[i], array[i]);
EXPECT_EQ(dst_ptr[i], array[i]);
}
ASSERT_EQ(dst_tensor.dims(), src_tensor.dims());
EXPECT_EQ(dst_tensor.dims(), src_tensor.dims());
delete place;
}
#ifdef PADDLE_WITH_CUDA
......@@ -323,7 +369,7 @@ TEST(Tensor, FromAndToStream) {
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 6; ++i) {
ASSERT_EQ(dst_ptr[i], array[i]);
EXPECT_EQ(dst_ptr[i], array[i]);
}
delete gpu_place;
}
......
......@@ -27,9 +27,9 @@ static __global__ void FillNAN(float* buf) {
}
static __global__ void FillInf(float* buf) {
buf[0] = 0.0;
buf[1] = INFINITY;
buf[2] = 0.5;
buf[0] = INFINITY;
buf[1] = 0.1;
buf[2] = 0.2;
}
static __global__ void FillNAN(platform::float16* buf) {
......@@ -44,6 +44,18 @@ static __global__ void FillInf(platform::float16* buf) {
buf[2] = 0.5;
}
static __global__ void FillFinite(float* buf) {
buf[0] = 0.0;
buf[1] = 0.1;
buf[2] = 0.2;
}
static __global__ void FillFinite(platform::float16* buf) {
buf[0] = 0.0;
buf[1] = 0.1;
buf[2] = 0.2;
}
TEST(TensorContainsNAN, GPU) {
paddle::platform::CUDAPlace gpu(0);
auto& pool = paddle::platform::DeviceContextPool::Instance();
......@@ -86,5 +98,163 @@ TEST(TensorContainsInf, GPU) {
}
}
TEST(TensorIsfinite, GPU) {
paddle::platform::CUDAPlace gpu(0);
using paddle::platform::float16;
auto& pool = paddle::platform::DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu);
// contains inf
{
Tensor tensor;
float* buf = tensor.mutable_data<float>({3}, gpu);
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
EXPECT_TRUE(!TensorIsfinite(tensor));
}
{
Tensor tensor;
float16* buf = tensor.mutable_data<float16>({3}, gpu);
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
EXPECT_TRUE(!TensorIsfinite(tensor));
}
// contains nan
{
Tensor tensor;
float* buf = tensor.mutable_data<float>({3}, gpu);
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
EXPECT_TRUE(!TensorIsfinite(tensor));
}
{
Tensor tensor;
float16* buf = tensor.mutable_data<float16>({3}, gpu);
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
EXPECT_TRUE(!TensorIsfinite(tensor));
}
// all element are finite
{
Tensor tensor;
float* buf = tensor.mutable_data<float>({3}, gpu);
FillFinite<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
EXPECT_TRUE(TensorIsfinite(tensor));
}
{
Tensor tensor;
float16* buf = tensor.mutable_data<float16>({3}, gpu);
FillFinite<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
EXPECT_TRUE(TensorIsfinite(tensor));
}
}
TEST(TensorContainsInf, GPUWithoutWait) {
paddle::platform::CUDAPlace gpu(0);
auto& pool = paddle::platform::DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu);
{
Tensor tensor, out;
float* buf = tensor.mutable_data<float>({3}, gpu);
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
TensorContainsInf(tensor, &out);
platform::CPUPlace cpu;
Tensor tmp;
TensorCopy(out, cpu, *cuda_ctx, &tmp);
cuda_ctx->Wait();
ASSERT_EQ(tmp.data<bool>()[0], true);
}
{
Tensor tensor, out;
paddle::platform::float16* buf =
tensor.mutable_data<paddle::platform::float16>({3}, gpu);
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
TensorContainsInf(tensor, &out);
platform::CPUPlace cpu;
Tensor tmp;
TensorCopy(out, cpu, *cuda_ctx, &tmp);
cuda_ctx->Wait();
ASSERT_EQ(tmp.data<bool>()[0], true);
}
}
TEST(TensorContainsNAN, GPUWithoutWait) {
paddle::platform::CUDAPlace gpu(0);
auto& pool = paddle::platform::DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu);
{
Tensor tensor, out;
float* buf = tensor.mutable_data<float>({3}, gpu);
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
TensorContainsNAN(tensor, &out);
platform::CPUPlace cpu;
Tensor tmp;
TensorCopy(out, cpu, *cuda_ctx, &tmp);
cuda_ctx->Wait();
ASSERT_EQ(tmp.data<bool>()[0], true);
}
{
Tensor tensor, out;
paddle::platform::float16* buf =
tensor.mutable_data<paddle::platform::float16>({3}, gpu);
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
TensorContainsNAN(tensor, &out);
platform::CPUPlace cpu;
Tensor tmp;
TensorCopy(out, cpu, *cuda_ctx, &tmp);
cuda_ctx->Wait();
ASSERT_EQ(tmp.data<bool>()[0], true);
}
}
TEST(TensorIsfinite, GPUWithoutWait) {
paddle::platform::CUDAPlace gpu(0);
auto& pool = paddle::platform::DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu);
{
Tensor tensor, out;
float* buf = tensor.mutable_data<float>({3}, gpu);
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
TensorIsfinite(tensor, &out);
platform::CPUPlace cpu;
Tensor tmp;
TensorCopy(out, cpu, *cuda_ctx, &tmp);
cuda_ctx->Wait();
EXPECT_EQ(tmp.data<bool>()[0], false);
}
{
Tensor tensor, out;
float* buf = tensor.mutable_data<float>({3}, gpu);
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
TensorIsfinite(tensor, &out);
platform::CPUPlace cpu;
Tensor tmp;
TensorCopy(out, cpu, *cuda_ctx, &tmp);
cuda_ctx->Wait();
EXPECT_EQ(tmp.data<bool>()[0], false);
}
{
Tensor tensor, out;
float* buf = tensor.mutable_data<float>({3}, gpu);
FillFinite<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
TensorIsfinite(tensor, &out);
platform::CPUPlace cpu;
Tensor tmp;
TensorCopy(out, cpu, *cuda_ctx, &tmp);
cuda_ctx->Wait();
EXPECT_EQ(tmp.data<bool>()[0], true);
}
}
} // namespace framework
} // namespace paddle
......@@ -22,8 +22,8 @@ limitations under the License. */
#include <algorithm>
#include <memory>
#include <thread> //NOLINT
#include "paddle/fluid/inference/paddle_inference_api.h"
#include "paddle/fluid/platform/enforce.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_bool(use_gpu, false, "Whether use gpu.");
......@@ -62,17 +62,17 @@ void Main(bool use_gpu) {
CHECK(predictor->Run(slots, &outputs));
//# 4. Get output.
PADDLE_ENFORCE(outputs.size(), 1UL);
CHECK_EQ(outputs.size(), 1UL);
// Check the output buffer size and result of each tid.
PADDLE_ENFORCE(outputs.front().data.length(), 33168UL);
CHECK_EQ(outputs.front().data.length(), 33168UL);
float result[5] = {0.00129761, 0.00151112, 0.000423564, 0.00108815,
0.000932706};
const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(static_cast<size_t>(5), num_elements);
i++) {
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
result[i]);
CHECK_NEAR(static_cast<float*>(outputs.front().data.data())[i], result[i],
0.001);
}
}
}
......@@ -108,9 +108,9 @@ void MainThreads(int num_threads, bool use_gpu) {
CHECK(predictor->Run(inputs, &outputs));
// 4. Get output.
PADDLE_ENFORCE(outputs.size(), 1UL);
CHECK_EQ(outputs.size(), 1UL);
// Check the output buffer size and result of each tid.
PADDLE_ENFORCE(outputs.front().data.length(), 33168UL);
CHECK_EQ(outputs.front().data.length(), 33168UL);
float result[5] = {0.00129761, 0.00151112, 0.000423564, 0.00108815,
0.000932706};
const size_t num_elements =
......@@ -118,8 +118,8 @@ void MainThreads(int num_threads, bool use_gpu) {
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(static_cast<size_t>(5), num_elements);
i++) {
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
result[i]);
CHECK_NEAR(static_cast<float*>(outputs.front().data.data())[i],
result[i], 0.001);
}
}
});
......
......@@ -17,11 +17,12 @@ limitations under the License. */
*/
#include <gflags/gflags.h>
#include <glog/logging.h> // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files.
#include <glog/logging.h> // use glog instead of CHECK to avoid importing other paddle header files.
#include <fstream>
#include <iostream>
// #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/inference/demo_ci/utils.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use);
......@@ -78,18 +79,17 @@ void CheckOutput(const std::string& referfile, const PaddleTensor& output) {
size_t numel = output.data.length() / PaddleDtypeSize(output.dtype);
VLOG(3) << "predictor output numel " << numel;
VLOG(3) << "reference output numel " << refer.data.size();
PADDLE_ENFORCE_EQ(numel, refer.data.size());
CHECK_EQ(numel, refer.data.size());
switch (output.dtype) {
case PaddleDType::INT64: {
for (size_t i = 0; i < numel; ++i) {
PADDLE_ENFORCE_EQ(static_cast<int64_t*>(output.data.data())[i],
refer.data[i]);
CHECK_EQ(static_cast<int64_t*>(output.data.data())[i], refer.data[i]);
}
break;
}
case PaddleDType::FLOAT32:
for (size_t i = 0; i < numel; ++i) {
PADDLE_ENFORCE_LT(
CHECK_LT(
fabs(static_cast<float*>(output.data.data())[i] - refer.data[i]),
1e-5);
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/isfinite_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
class OverflowOp : public framework::OperatorWithKernel {
public:
OverflowOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of OverflowOp should not be null.");
ctx->SetOutputDim("Out", {1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
int dtype = -1;
auto *x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) {
dtype = framework::ToDataType(x_var->Get<framework::LoDTensor>().type());
} else if (x_var->IsType<framework::SelectedRows>()) {
dtype = framework::ToDataType(
x_var->Get<framework::SelectedRows>().value().type());
} else {
PADDLE_THROW("Cannot find the input data type by all input data");
}
return framework::OpKernelType(framework::proto::VarType::Type(dtype),
ctx.GetPlace());
}
};
class OverflowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input tensors of overflow operator.");
AddOutput("Out",
"(Tensor) 1-dim tensor, contains a bool scalar. The output "
"tensor of overflow operator.");
AddComment(string::Sprintf(R"DOC(
Overflow operator.
$$Out = any(X)$$
If any X contains Inf or Nan, the Out will generate a indicator.
Out = Inf if any X contains Inf,
Out = Nan if any X contains Nan,
Out = 0 if no Inf/Nan detected.
If X contains both Inf/Nan, it will return the first indicator it meeted.
)DOC",
GetName(), GetComments()));
}
protected:
virtual std::string GetName() const = 0;
virtual std::string GetComments() const = 0;
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#define REGISTER_OP_MAKER(op_type, comment) \
namespace paddle { \
namespace operators { \
class _##op_type##OverflowOpMaker \
: public ::paddle::operators::OverflowOpMaker { \
protected: \
std::string GetName() const { return #op_type; } \
std::string GetComments() const { return comment; } \
}; \
} \
} \
REGISTER_OPERATOR(op_type, ops::OverflowOp, \
ops::_##op_type##OverflowOpMaker, \
paddle::framework::EmptyGradOpMaker)
#define REGISTER_OVERFLOW_CPU_KERNEL(op_type, functor) \
REGISTER_OP_CPU_KERNEL( \
op_type, ops::OverflowKernel<paddle::platform::CPUDeviceContext, int, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, float, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CPUDeviceContext, double, \
ops::functor>);
REGISTER_OP_MAKER(isinf, "isinf(X)");
REGISTER_OP_MAKER(isnan, "isnan(X)");
REGISTER_OP_MAKER(isfinite, "isfinite(X)");
FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CPU_KERNEL);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/isfinite_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_OVERFLOW_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16, \
ops::functor>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CUDA_KERNEL);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
struct InfinityFunctor {
void operator()(const framework::Tensor& tensor, framework::Tensor* out) {
framework::TensorContainsInf(tensor, out);
}
};
struct NANFunctor {
void operator()(const framework::Tensor& tensor, framework::Tensor* out) {
framework::TensorContainsNAN(tensor, out);
}
};
struct IsfiniteFunctor {
void operator()(const framework::Tensor& tensor, framework::Tensor* out) {
framework::TensorIsfinite(tensor, out);
}
};
template <typename DeviceContext, typename T, typename Functor>
class OverflowKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* x = ctx.InputVar("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
Functor functor;
if (x->IsType<framework::LoDTensor>()) {
auto* in = ctx.Input<framework::Tensor>("X");
functor(*in, out);
} else if (x->IsType<framework::SelectedRows>()) {
auto& in = ctx.Input<framework::SelectedRows>("X")->value();
functor(in, out);
} else {
PADDLE_THROW("Unsupported input type.");
}
}
};
} // namespace operators
} // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(isinf, InfinityFunctor); \
__macro(isnan, NANFunctor); \
__macro(isfinite, IsfiniteFunctor);
......@@ -24,21 +24,10 @@ from .layer_function_generator import templatedoc
import numpy
__all__ = [
'create_tensor',
'create_parameter',
'create_global_var',
'cast',
'concat',
'sums',
'assign',
'fill_constant_batch_size_like',
'fill_constant',
'argmin',
'argmax',
'argsort',
'ones',
'zeros',
'reverse',
'create_tensor', 'create_parameter', 'create_global_var', 'cast', 'concat',
'sums', 'assign', 'fill_constant_batch_size_like', 'fill_constant',
'argmin', 'argmax', 'argsort', 'ones', 'zeros', 'reverse', 'has_inf',
'has_nan', 'isfinite'
]
......@@ -652,3 +641,52 @@ def load_combine(out, file_path):
inputs={},
output={"Out": out},
args={"file_path": file_path})
def has_inf(x):
"""
Test if any of x contains an infinity number
Args:
x(variable): The Tensor/LoDTensor to be checked.
Returns:
Variable: The tensor variable storing the output, only a bool value.
"""
helper = LayerHelper("isinf", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(type="isinf", inputs={"X": x}, outputs={"Out": out})
return out
def has_nan(x):
"""
Test if any of x contains a NAN
Args:
x(variable): The Tensor/LoDTensor to be checked.
Returns:
Variable: The tensor variable storing the output, only a bool value.
"""
helper = LayerHelper("isnan", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(type="isnan", inputs={"X": x}, outputs={"Out": out})
return out
def isfinite(x):
"""
Test if any of x contains an infinity/NAN number. If all the elements are finite,
returns true, else false.
Args:
x(variable): The Tensor/LoDTensor to be checked.
Returns:
Variable: The tensor variable storing the output, contains a bool value.
"""
helper = LayerHelper("isfinite", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out})
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestInf(OpTest):
def setUp(self):
self.op_type = "isinf"
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
x[0] = np.inf
x[-1] = np.inf
self.inputs = {'X': x}
self.outputs = {'Out': np.array(True).astype(self.dtype)}
def init_dtype(self):
pass
def test_output(self):
self.check_output()
class TestFP16Inf(TestInf):
def init_dtype(self):
self.dtype = np.float16
class TestNAN(OpTest):
def setUp(self):
self.op_type = "isnan"
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
x[0] = np.nan
x[-1] = np.nan
self.inputs = {'X': x}
self.outputs = {'Out': np.array(True).astype(self.dtype)}
def init_dtype(self):
pass
def test_output(self):
self.check_output()
class TestFP16NAN(TestNAN):
def init_dtype(self):
self.dtype = np.float16
class TestIsfinite(OpTest):
def setUp(self):
self.op_type = "isfinite"
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
x[0] = np.inf
x[-1] = np.nan
out = np.isinf(x) | np.isnan(x)
self.inputs = {'X': x}
self.outputs = {'Out': np.array(False).astype(self.dtype)}
def init_dtype(self):
pass
def test_output(self):
self.check_output()
class TestFP16Isfinite(TestIsfinite):
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册