diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 59e6269ea04cf4dfeb2dddee1f256acf8b5a742a..638bd0db9d7025199c31a9327b96062512aa5adb 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace framework { @@ -52,7 +53,9 @@ struct SizeOfTypeFunctor { }; static inline size_t SizeOfType(std::type_index type) { - SizeOfTypeFunctor functor; + SizeOfTypeFunctor + functor; size_t size = functor(type); PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); return size; diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index c36bfad4bc155877f734f9faec9f56588206d284..cf6a4b09dbd2d5b7d22081ff4713e3e644f4800e 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -62,6 +62,7 @@ limitations under the License. */ #define PADDLE_ALIGN(x) __attribute__((aligned(x))) namespace paddle { +namespace platform { // Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated // and aligned at least on a 2-byte boundary, which leads to efficient @@ -71,11 +72,21 @@ struct PADDLE_ALIGN(2) float16 { public: uint16_t x; - // Constructors - HOSTDEVICE inline float16() : x(0) {} + // The following defaulted special class member functions + // are added to make float16 pass the std::is_trivial test + HOSTDEVICE inline float16() = default; - HOSTDEVICE inline float16(const float16& h) : x(h.x) {} + HOSTDEVICE inline float16(const float16&) = default; + HOSTDEVICE inline float16& operator=(const float16&) = default; + + HOSTDEVICE inline float16(float16&&) = default; + + HOSTDEVICE inline float16& operator=(float16&&) = default; + + HOSTDEVICE inline ~float16() = default; + +// Constructors #ifdef PADDLE_CUDA_FP16 HOSTDEVICE inline explicit float16(const half& h) { #if CUDA_VERSION >= 9000 @@ -136,11 +147,6 @@ struct PADDLE_ALIGN(2) float16 { HOSTDEVICE inline explicit float16(const T& val) : x(float16(static_cast(val)).x) {} - HOSTDEVICE inline float16& operator=(const float16& rhs) { - x = rhs.x; - return *this; - } - // Assignment operators #ifdef PADDLE_CUDA_FP16 HOSTDEVICE inline float16& operator=(const half& rhs) { @@ -727,4 +733,25 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { return float(a) >= float(b); } #endif + +} // namespace platform } // namespace paddle + +namespace std { + +// Override the std::is_pod::value for float16 +// The reason is that different compilers implemented std::is_pod based on +// different C++ standards. float16 class is a plain old data in C++11 given +// that it is both trivial and standard_layout. +// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is +// more restricted in that you cannot provide any customized +// constructor in float16. Hence, we override is_pod here following C++11 +// so that .cu files can be successfully compiled by nvcc. +template <> +struct is_pod { + static const bool value = + is_trivial::value && + is_standard_layout::value; +}; + +} // namespace std diff --git a/paddle/fluid/platform/float16_test.cc b/paddle/fluid/platform/float16_test.cc index bed29dbfa7ed57ac98ff9ce37945cc74a8968704..b716ad9df41330bd6e22937381d24e33fa3a7914 100644 --- a/paddle/fluid/platform/float16_test.cc +++ b/paddle/fluid/platform/float16_test.cc @@ -10,10 +10,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/framework/init.h" +#include "paddle/fluid/framework/lod_tensor.h" #include namespace paddle { +namespace platform { TEST(float16, conversion_cpu) { // Explicit conversion from Eigen::half @@ -54,13 +57,9 @@ TEST(float16, conversion_cpu) { EXPECT_EQ(float16(true).x, 0x3c00); EXPECT_EQ(float16(false).x, 0x0000); - // Default constructor - float16 v_def; - EXPECT_EQ(v_def.x, 0x0000); - // Assignment operator float16 v_assign; - v_assign = v_def; + v_assign = float16(0); EXPECT_EQ(v_assign.x, 0x0000); v_assign = Eigen::half(1.0f); EXPECT_EQ(v_assign.x, 0x3c00); @@ -116,4 +115,27 @@ TEST(float16, comparison_cpu) { EXPECT_FALSE(float16(-0.0f) > float16(0.0f)); } +TEST(float16, lod_tensor_cpu) { + framework::LoDTensor lod_tensor; + + std::vector input_data = {float16(1.0f), float16(0.5f), + float16(0.33333f), float16(0.0f)}; + EXPECT_EQ(input_data[0].x, 0x3c00); + EXPECT_EQ(input_data[1].x, 0x3800); + EXPECT_EQ(input_data[2].x, 0x3555); + EXPECT_EQ(input_data[3].x, 0x0000); + + lod_tensor.Resize({4, 1}); + lod_tensor.set_lod(framework::LoD({{0, 2, 4}})); + float16* data_ptr = lod_tensor.mutable_data(CPUPlace()); + + EXPECT_NE(data_ptr, nullptr); + EXPECT_EQ(input_data.size(), static_cast(lod_tensor.numel())); + for (size_t i = 0; i < input_data.size(); ++i) { + data_ptr[i] = input_data[i]; + EXPECT_EQ(data_ptr[i].x, input_data[i].x); + } +} + +} // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/float16_test.cu b/paddle/fluid/platform/float16_test.cu index 7e6c9f58aca3a73fa260be375275c8e4886d2133..567209df4edc483bcb5c6264c62034ddff50c413 100644 --- a/paddle/fluid/platform/float16_test.cu +++ b/paddle/fluid/platform/float16_test.cu @@ -13,6 +13,8 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/utils/Logging.h" #define ARITHMETIC_KERNEL(op_type, sign) \ @@ -108,6 +110,7 @@ limitations under the License. */ #ifdef PADDLE_CUDA_FP16 namespace paddle { +namespace platform { #if CUDA_VERSION < 9000 ARITHMETIC_KERNEL(Add, +) @@ -209,5 +212,35 @@ TEST(float16, conversion_on_gpu) { EXPECT_EQ(v_assign.x, 0x3c00); } +TEST(float16, lod_tensor_on_gpu) { + framework::LoDTensor src_tensor; + framework::LoDTensor gpu_tensor; + framework::LoDTensor dst_tensor; + + float16* src_ptr = src_tensor.mutable_data( + framework::make_ddim({2, 2}), CPUPlace()); + + float16 arr[4] = {float16(1.0f), float16(0.5f), float16(0.33333f), + float16(0.0f)}; + memcpy(src_ptr, arr, 4 * sizeof(float16)); + + // CPU LoDTensor to GPU LoDTensor + CUDAPlace gpu_place(0); + CUDADeviceContext gpu_ctx(gpu_place); + framework::TensorCopy(src_tensor, gpu_place, gpu_ctx, &gpu_tensor); + + // GPU LoDTensor to CPU LoDTensor + framework::TensorCopy(gpu_tensor, CPUPlace(), gpu_ctx, &dst_tensor); + + // Sync before comparing LoDTensors + gpu_ctx.Wait(); + const float16* dst_ptr = dst_tensor.data(); + ASSERT_NE(src_ptr, dst_ptr); + for (size_t i = 0; i < 4; ++i) { + EXPECT_EQ(src_ptr[i].x, dst_ptr[i].x); + } +} + +} // namespace platform } // namespace paddle #endif // PADDLE_CUDA_FP16