diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 82c7d4a2ec648449fc65ca2ae0de397b2f6fa120..48713f2c2ac62a37b7b7a4602f7f6a325aecb0b8 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -5,14 +5,14 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) -if (WITH_GPU) +if(WITH_GPU) nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto) else() cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto) -endif () +endif() cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) -if (WITH_GPU) +if(WITH_GPU) nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor) else() cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) @@ -39,8 +39,13 @@ cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) nv_test(data_device_transform_test SRCS data_device_transform_test.cu DEPS operator op_registry init math_function) -cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) -cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform) +if(WITH_GPU) + nv_library(data_type_transform SRCS data_type_transform.cu DEPS tensor) + nv_test(data_type_transform_test SRCS data_type_transform_test.cc data_type_transform_test.cu DEPS data_type_transform) +else() + cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) + cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform) +endif() cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform) diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index 0475fc1d9aaeda39b3fe845b70461c434bcaafc4..bfad9ac1e9cad1936ed961ad1da55787d2faa23e 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -42,6 +42,7 @@ void DataTransform(const OpKernelType& expected_kernel_type, PassTensorData(&out, &in); } + // do data type transform if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) { TransDataType(kernel_type_for_var, expected_kernel_type, in, &out); transformed = true; diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 1dec766a345d8d6c6fdafc1b50450a9dde91fe5d..4c1b3e7581fe716271c62389c6053a24158913d2 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -16,13 +16,16 @@ limitations under the License. */ #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace framework { inline proto::VarType::Type ToDataType(std::type_index type) { using namespace paddle::framework::proto; - if (typeid(float).hash_code() == type.hash_code()) { + if (typeid(platform::float16).hash_code() == type.hash_code()) { + return proto::VarType::FP16; + } else if (typeid(float).hash_code() == type.hash_code()) { return proto::VarType::FP32; } else if (typeid(double).hash_code() == type.hash_code()) { return proto::VarType::FP64; @@ -40,6 +43,8 @@ inline proto::VarType::Type ToDataType(std::type_index type) { inline std::type_index ToTypeIndex(proto::VarType::Type type) { using namespace paddle::framework::proto; switch (type) { + case proto::VarType::FP16: + return typeid(platform::float16); case proto::VarType::FP32: return typeid(float); case proto::VarType::FP64: @@ -59,6 +64,9 @@ template inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { using namespace paddle::framework::proto; switch (type) { + case proto::VarType::FP16: + visitor.template operator()(); + break; case proto::VarType::FP32: visitor.template operator()(); break; diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 54cc1575d880242c01cd4752f522380a673b7c75..554cd58916c5a1ba09a411b4dc0b3a834ccc486a 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -47,9 +47,15 @@ struct CastDataType { auto* context = static_cast(ctx_); trans(*context, in_begin, in_end, out_begin, CastDataTypeFunctor()); +#ifdef __NVCC__ + } else if (platform::is_gpu_place(in_.place())) { + platform::Transform trans; + auto* context = static_cast(ctx_); + trans(*context, in_begin, in_end, out_begin, + CastDataTypeFunctor()); +#endif } else { - // TODO(dzhwinter): enhance Copy CPU<->GPU with different data type? - PADDLE_THROW("Unsupport CPU <-> GPU!"); + PADDLE_THROW("Unsupported place!"); } } }; @@ -65,6 +71,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var, auto ctx = pool.Get(in.place()); switch (src_type) { + case proto::VarType::FP16: + framework::VisitDataType(dst_type, + CastDataType(in, out, ctx)); + break; case proto::VarType::FP32: framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; diff --git a/paddle/fluid/framework/data_type_transform.cu b/paddle/fluid/framework/data_type_transform.cu new file mode 120000 index 0000000000000000000000000000000000000000..f46491293ef4ad688c1bce9327f5f28011dec809 --- /dev/null +++ b/paddle/fluid/framework/data_type_transform.cu @@ -0,0 +1 @@ +data_type_transform.cc \ No newline at end of file diff --git a/paddle/fluid/framework/data_type_transform_test.cc b/paddle/fluid/framework/data_type_transform_test.cc index 724c8c301f25cca21704de398fc416dff46c330c..c992cba9a3611d50839a8ec056ee6ab954cd88b6 100644 --- a/paddle/fluid/framework/data_type_transform_test.cc +++ b/paddle/fluid/framework/data_type_transform_test.cc @@ -22,32 +22,145 @@ TEST(DataTypeTransform, CPUTransform) { auto place = CPUPlace(); - Tensor in; - Tensor out; - - float* ptr = in.mutable_data(make_ddim({2, 3}), place); - int data_number = 2 * 3; - - for (int i = 0; i < data_number; ++i) { - ptr[i] = i / 3; - } - + auto kernel_fp16 = OpKernelType(proto::VarType::FP16, place, + DataLayout::kAnyLayout, LibraryType::kPlain); auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout, LibraryType::kPlain); auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place, DataLayout::kAnyLayout, LibraryType::kPlain); auto kernel_int32 = OpKernelType(proto::VarType::INT32, place, DataLayout::kAnyLayout, LibraryType::kPlain); + auto kernel_int64 = OpKernelType(proto::VarType::INT64, place, + DataLayout::kAnyLayout, LibraryType::kPlain); + auto kernel_bool = OpKernelType(proto::VarType::BOOL, place, + DataLayout::kAnyLayout, LibraryType::kPlain); - TransDataType(kernel_fp32, kernel_fp64, in, &out); - double* out_data_double = out.data(); - for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_double[i], static_cast(i / 3)); + // data type transform from float32 + { + Tensor in; + Tensor out; + + float* ptr = in.mutable_data(make_ddim({2, 3}), place); + int data_number = 2 * 3; + + for (int i = 0; i < data_number; ++i) { + ptr[i] = i / 3; + } + + TransDataType(kernel_fp32, kernel_fp64, in, &out); + double* out_data_double = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_double[i], static_cast(i / 3)); + } + + TransDataType(kernel_fp32, kernel_int32, in, &out); + int* out_data_int = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_int[i], static_cast(i / 3)); + } } - TransDataType(kernel_fp32, kernel_int32, in, &out); - int* out_data_int = out.data(); - for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int[i], static_cast(i / 3)); + // data type transform from/to float16 + { + Tensor in; + Tensor out; + + float16* ptr = in.mutable_data(make_ddim({2, 3}), place); + int data_number = 2 * 3; + + for (int i = 0; i < data_number; ++i) { + ptr[i] = i; + } + + // transform from float16 to other data types + TransDataType(kernel_fp16, kernel_fp32, in, &out); + float* out_data_float = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_float[i], static_cast(ptr[i])); + } + + TransDataType(kernel_fp16, kernel_fp64, in, &out); + double* out_data_double = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_double[i], static_cast(ptr[i])); + } + + TransDataType(kernel_fp16, kernel_int32, in, &out); + int* out_data_int = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_int[i], static_cast(ptr[i])); + } + + TransDataType(kernel_fp16, kernel_int64, in, &out); + int64_t* out_data_int64 = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_int64[i], static_cast(ptr[i])); + } + + TransDataType(kernel_fp16, kernel_bool, in, &out); + bool* out_data_bool = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_bool[i], static_cast(ptr[i])); + } + + // transform float to float16 + float* in_data_float = in.mutable_data(make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + in_data_float[i] = i; + } + + TransDataType(kernel_fp32, kernel_fp16, in, &out); + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_float[i]).x); + } + + // transform double to float16 + double* in_data_double = in.mutable_data(make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + in_data_double[i] = i; + } + + TransDataType(kernel_fp64, kernel_fp16, in, &out); + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_double[i]).x); + } + + // transform int to float16 + int* in_data_int = in.mutable_data(make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + in_data_int[i] = i; + } + + TransDataType(kernel_int32, kernel_fp16, in, &out); + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_int[i]).x); + } + + // transform int64 to float16 + int64_t* in_data_int64 = in.mutable_data(make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + in_data_int64[i] = i; + } + + TransDataType(kernel_int64, kernel_fp16, in, &out); + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_int64[i]).x); + } + + // transform bool to float16 + bool* in_data_bool = in.mutable_data(make_ddim({2, 3}), place); + for (int i = 0; i < data_number; ++i) { + in_data_bool[i] = i; + } + + TransDataType(kernel_bool, kernel_fp16, in, &out); + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_bool[i]).x); + } } } diff --git a/paddle/fluid/framework/data_type_transform_test.cu b/paddle/fluid/framework/data_type_transform_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..3939bc5e754cd3c2829cfcb3353f83969af055a9 --- /dev/null +++ b/paddle/fluid/framework/data_type_transform_test.cu @@ -0,0 +1,215 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/data_type_transform.h" +#include "paddle/fluid/framework/tensor_util.h" + +#include "gtest/gtest.h" + +TEST(DataTypeTransform, GPUTransform) { + using namespace paddle::framework; + using namespace paddle::platform; + + auto cpu_place = CPUPlace(); + auto gpu_place = CUDAPlace(0); + CUDADeviceContext context(gpu_place); + + auto kernel_fp16 = OpKernelType(proto::VarType::FP16, gpu_place, + DataLayout::kAnyLayout, LibraryType::kPlain); + auto kernel_fp32 = OpKernelType(proto::VarType::FP32, gpu_place, + DataLayout::kAnyLayout, LibraryType::kPlain); + auto kernel_fp64 = OpKernelType(proto::VarType::FP64, gpu_place, + DataLayout::kAnyLayout, LibraryType::kPlain); + auto kernel_int32 = OpKernelType(proto::VarType::INT32, gpu_place, + DataLayout::kAnyLayout, LibraryType::kPlain); + auto kernel_int64 = OpKernelType(proto::VarType::INT64, gpu_place, + DataLayout::kAnyLayout, LibraryType::kPlain); + auto kernel_bool = OpKernelType(proto::VarType::BOOL, gpu_place, + DataLayout::kAnyLayout, LibraryType::kPlain); + + // data type transform from float32 + { + Tensor in; + Tensor in_gpu; + Tensor out_gpu; + Tensor out; + + float* in_ptr = in.mutable_data(make_ddim({2, 3}), cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + int data_number = sizeof(arr) / sizeof(arr[0]); + memcpy(in_ptr, arr, sizeof(arr)); + TensorCopy(in, gpu_place, context, &in_gpu); + + TransDataType(kernel_fp32, kernel_fp64, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + double* out_data_double = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_double[i], static_cast(arr[i])); + } + + TransDataType(kernel_fp32, kernel_int32, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + int* out_data_int = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_int[i], static_cast(arr[i])); + } + } + + // data type transform from/to float16 + { + Tensor in; + Tensor in_gpu; + Tensor out_gpu; + Tensor out; + + float16* ptr = in.mutable_data(make_ddim({2, 3}), cpu_place); + float16 arr[6] = {float16(0), float16(1), float16(2), + float16(3), float16(4), float16(5)}; + int data_number = sizeof(arr) / sizeof(arr[0]); + memcpy(ptr, arr, sizeof(arr)); + TensorCopy(in, gpu_place, context, &in_gpu); + + // transform from float16 to other data types + TransDataType(kernel_fp16, kernel_fp32, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + float* out_data_float = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_float[i], static_cast(ptr[i])); + } + + TransDataType(kernel_fp16, kernel_fp64, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + double* out_data_double = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_double[i], static_cast(ptr[i])); + } + + TransDataType(kernel_fp16, kernel_int32, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + int* out_data_int = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_int[i], static_cast(ptr[i])); + } + + TransDataType(kernel_fp16, kernel_int64, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + int64_t* out_data_int64 = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_int64[i], static_cast(ptr[i])); + } + + TransDataType(kernel_fp16, kernel_bool, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + bool* out_data_bool = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_bool[i], static_cast(ptr[i])); + } + + // transform float to float16 + float* in_data_float = in.mutable_data(make_ddim({2, 3}), cpu_place); + for (int i = 0; i < data_number; ++i) { + in_data_float[i] = i; + } + + TensorCopy(in, gpu_place, context, &in_gpu); + TransDataType(kernel_fp32, kernel_fp16, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_float[i]).x); + } + + // transform double to float16 + double* in_data_double = + in.mutable_data(make_ddim({2, 3}), cpu_place); + for (int i = 0; i < data_number; ++i) { + in_data_double[i] = i; + } + + TensorCopy(in, gpu_place, context, &in_gpu); + TransDataType(kernel_fp64, kernel_fp16, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_double[i]).x); + } + + // transform int to float16 + int* in_data_int = in.mutable_data(make_ddim({2, 3}), cpu_place); + for (int i = 0; i < data_number; ++i) { + in_data_int[i] = i; + } + + TensorCopy(in, gpu_place, context, &in_gpu); + TransDataType(kernel_int32, kernel_fp16, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_int[i]).x); + } + + // transform int64 to float16 + int64_t* in_data_int64 = + in.mutable_data(make_ddim({2, 3}), cpu_place); + for (int i = 0; i < data_number; ++i) { + in_data_int64[i] = i; + } + + TensorCopy(in, gpu_place, context, &in_gpu); + TransDataType(kernel_int64, kernel_fp16, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_int64[i]).x); + } + + // transform bool to float16 + bool* in_data_bool = in.mutable_data(make_ddim({2, 3}), cpu_place); + for (int i = 0; i < data_number; ++i) { + in_data_bool[i] = i; + } + + TensorCopy(in, gpu_place, context, &in_gpu); + TransDataType(kernel_bool, kernel_fp16, in_gpu, &out_gpu); + TensorCopy(out_gpu, cpu_place, context, &out); + context.Wait(); + + ptr = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(ptr[i].x, static_cast(in_data_bool[i]).x); + } + } +} diff --git a/paddle/fluid/framework/tensor_util_test.cc b/paddle/fluid/framework/tensor_util_test.cc index 8aebfcb3b624f49b747e0874e807782b8d4a9951..9687a86ca25be7886e67028a38e54b3065c8e4b5 100644 --- a/paddle/fluid/framework/tensor_util_test.cc +++ b/paddle/fluid/framework/tensor_util_test.cc @@ -235,27 +235,53 @@ TEST(TensorToVector, Tensor) { TEST(TensorContainsNAN, CPU) { using namespace paddle::framework; using namespace paddle::platform; - Tensor src; - float* buf = src.mutable_data({3}, CPUPlace()); - buf[0] = 0.0; - buf[1] = NAN; - buf[2] = 0.0; - ASSERT_TRUE(TensorContainsNAN(src)); - buf[1] = 0.0; - ASSERT_FALSE(TensorContainsNAN(src)); + { + Tensor src; + float* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 0.0; + buf[1] = NAN; + buf[2] = 0.0; + ASSERT_TRUE(TensorContainsNAN(src)); + buf[1] = 0.0; + ASSERT_FALSE(TensorContainsNAN(src)); + } + + { + Tensor src; + float16* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 0.0; + buf[1].x = 0x7fff; + buf[2] = 0.0; + ASSERT_TRUE(TensorContainsNAN(src)); + buf[1] = 0.0; + ASSERT_FALSE(TensorContainsNAN(src)); + } } TEST(TensorContainsInf, CPU) { using namespace paddle::framework; using namespace paddle::platform; - Tensor src; - double* buf = src.mutable_data({3}, CPUPlace()); - buf[0] = 1.0; - buf[1] = INFINITY; - buf[2] = 0.0; - ASSERT_TRUE(TensorContainsInf(src)); - buf[1] = 1.0; - ASSERT_FALSE(TensorContainsInf(src)); + { + Tensor src; + double* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 1.0; + buf[1] = INFINITY; + buf[2] = 0.0; + ASSERT_TRUE(TensorContainsInf(src)); + buf[1] = 1.0; + ASSERT_FALSE(TensorContainsInf(src)); + } + + { + Tensor src; + float16* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 1.0; + buf[1].x = 0x7c00; + buf[2] = 0.0; + ASSERT_TRUE(TensorContainsInf(src)); + buf[1] = 1.0; + ASSERT_FALSE(TensorContainsInf(src)); + } } TEST(Tensor, FromAndToStream) { diff --git a/paddle/fluid/framework/tensor_util_test.cu b/paddle/fluid/framework/tensor_util_test.cu index d630ec44a2aa6f2567eedbb58b9ae181700f5fc9..4766ec28aa3cff6be3259f258f1c9543ae471f5d 100644 --- a/paddle/fluid/framework/tensor_util_test.cu +++ b/paddle/fluid/framework/tensor_util_test.cu @@ -25,32 +25,65 @@ static __global__ void FillNAN(float* buf) { buf[1] = 0.1; buf[2] = NAN; } + static __global__ void FillInf(float* buf) { buf[0] = 0.0; buf[1] = INFINITY; buf[2] = 0.5; } +static __global__ void FillNAN(platform::float16* buf) { + buf[0] = 0.0; + buf[1] = 0.1; + buf[2].x = 0x7fff; +} + +static __global__ void FillInf(platform::float16* buf) { + buf[0] = 0.0; + buf[1].x = 0x7c00; + buf[2] = 0.5; +} + TEST(TensorContainsNAN, GPU) { - Tensor tensor; - platform::CUDAPlace gpu(0); - auto& pool = platform::DeviceContextPool::Instance(); + using namespace paddle::platform; + CUDAPlace gpu(0); + auto& pool = DeviceContextPool::Instance(); auto* cuda_ctx = pool.GetByPlace(gpu); - float* buf = tensor.mutable_data({3}, gpu); - FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); - cuda_ctx->Wait(); - ASSERT_TRUE(TensorContainsNAN(tensor)); + { + Tensor tensor; + float* buf = tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(TensorContainsNAN(tensor)); + } + { + Tensor tensor; + float16* buf = tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(TensorContainsNAN(tensor)); + } } TEST(TensorContainsInf, GPU) { - Tensor tensor; - platform::CUDAPlace gpu(0); - auto& pool = platform::DeviceContextPool::Instance(); + using namespace paddle::platform; + CUDAPlace gpu(0); + auto& pool = DeviceContextPool::Instance(); auto* cuda_ctx = pool.GetByPlace(gpu); - float* buf = tensor.mutable_data({3}, gpu); - FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); - cuda_ctx->Wait(); - ASSERT_TRUE(TensorContainsInf(tensor)); + { + Tensor tensor; + float* buf = tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(TensorContainsInf(tensor)); + } + { + Tensor tensor; + float16* buf = tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(TensorContainsInf(tensor)); + } } } // namespace framework diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 41eab3ade207a592e1a72cd22580f18723167eff..f7f33917d7ef5bbcc7fb5d6e3d0a7f3ae63cde34 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -245,11 +245,13 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; -#define DEFINE_CPU_TRANS(RANK) \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ +#define DEFINE_CPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ template struct Transpose; DEFINE_CPU_TRANS(1); diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 5832bd9ce303d48496e8c15324fd7e4a39484eb4..52fb8c2531357ad7a2b2f8613e5c7fbcef52c6bb 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -20,10 +20,6 @@ limitations under the License. */ #include #endif // PADDLE_WITH_CUDA -#include "unsupported/Eigen/CXX11/Tensor" - -#include "paddle/fluid/platform/hostdevice.h" - #ifdef __GNUC__ #define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__) #else @@ -64,6 +60,18 @@ limitations under the License. */ namespace paddle { namespace platform { +// Forward declare float16 for eigen.h +struct float16; + +} // namespace platform +} // namespace paddle + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace platform { + // Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated // and aligned at least on a 2-byte boundary, which leads to efficient // memory access of float16 struct and also makes float16 compatible @@ -729,6 +737,22 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { } #endif +HOSTDEVICE inline bool(isnan)(const float16& a) { +#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hisnan(half(a)); +#else + return (a.x & 0x7fff) > 0x7c00; +#endif +} + +HOSTDEVICE inline bool(isinf)(const float16& a) { + return (a.x & 0x7fff) == 0x7c00; +} + +HOSTDEVICE inline bool(isfinite)(const float16& a) { + return !((isnan)(a)) && !((isinf)(a)); +} + } // namespace platform } // namespace paddle @@ -750,3 +774,27 @@ struct is_pod { }; } // namespace std + +namespace Eigen { +namespace numext { + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)( + const paddle::platform::float16& a) { + return (paddle::platform::isnan)(a); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)( + const paddle::platform::float16& a) { + return (paddle::platform::isinf)(a); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)( + const paddle::platform::float16& a) { + return (paddle::platform::isfinite)(a); +} + +} // namespace numext +} // namespace Eigen