未验证 提交 266ccaa8 编写于 作者: K kexinzhao 提交者: GitHub

Integrate float16 into data_type_transform (#8619)

* test cpu float16 data transform

* add isnan etc

* small fix

* fix containsNAN test error

* add data_type transform GPU test

* add float16 GPU example

* fix error

* fix GPU test error

* add context wait
上级 78c884d7
...@@ -5,14 +5,14 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) ...@@ -5,14 +5,14 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
if (WITH_GPU) if(WITH_GPU)
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto) nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto)
else() else()
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto) cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto)
endif () endif()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
if (WITH_GPU) if(WITH_GPU)
nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor) nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor)
else() else()
cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor)
...@@ -39,8 +39,13 @@ cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) ...@@ -39,8 +39,13 @@ cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry init math_function) DEPS operator op_registry init math_function)
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) if(WITH_GPU)
cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform) nv_library(data_type_transform SRCS data_type_transform.cu DEPS tensor)
nv_test(data_type_transform_test SRCS data_type_transform_test.cc data_type_transform_test.cu DEPS data_type_transform)
else()
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform)
endif()
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform) cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform)
......
...@@ -42,6 +42,7 @@ void DataTransform(const OpKernelType& expected_kernel_type, ...@@ -42,6 +42,7 @@ void DataTransform(const OpKernelType& expected_kernel_type,
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
// do data type transform
if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) { if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) {
TransDataType(kernel_type_for_var, expected_kernel_type, in, &out); TransDataType(kernel_type_for_var, expected_kernel_type, in, &out);
transformed = true; transformed = true;
......
...@@ -16,13 +16,16 @@ limitations under the License. */ ...@@ -16,13 +16,16 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline proto::VarType::Type ToDataType(std::type_index type) { inline proto::VarType::Type ToDataType(std::type_index type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) { if (typeid(platform::float16).hash_code() == type.hash_code()) {
return proto::VarType::FP16;
} else if (typeid(float).hash_code() == type.hash_code()) {
return proto::VarType::FP32; return proto::VarType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) { } else if (typeid(double).hash_code() == type.hash_code()) {
return proto::VarType::FP64; return proto::VarType::FP64;
...@@ -40,6 +43,8 @@ inline proto::VarType::Type ToDataType(std::type_index type) { ...@@ -40,6 +43,8 @@ inline proto::VarType::Type ToDataType(std::type_index type) {
inline std::type_index ToTypeIndex(proto::VarType::Type type) { inline std::type_index ToTypeIndex(proto::VarType::Type type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case proto::VarType::FP16:
return typeid(platform::float16);
case proto::VarType::FP32: case proto::VarType::FP32:
return typeid(float); return typeid(float);
case proto::VarType::FP64: case proto::VarType::FP64:
...@@ -59,6 +64,9 @@ template <typename Visitor> ...@@ -59,6 +64,9 @@ template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case proto::VarType::FP16:
visitor.template operator()<platform::float16>();
break;
case proto::VarType::FP32: case proto::VarType::FP32:
visitor.template operator()<float>(); visitor.template operator()<float>();
break; break;
......
...@@ -47,9 +47,15 @@ struct CastDataType { ...@@ -47,9 +47,15 @@ struct CastDataType {
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_); auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans(*context, in_begin, in_end, out_begin, trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>()); CastDataTypeFunctor<InType, OutType>());
#ifdef __NVCC__
} else if (platform::is_gpu_place(in_.place())) {
platform::Transform<platform::CUDADeviceContext> trans;
auto* context = static_cast<const platform::CUDADeviceContext*>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
#endif
} else { } else {
// TODO(dzhwinter): enhance Copy CPU<->GPU with different data type? PADDLE_THROW("Unsupported place!");
PADDLE_THROW("Unsupport CPU <-> GPU!");
} }
} }
}; };
...@@ -65,6 +71,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var, ...@@ -65,6 +71,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
auto ctx = pool.Get(in.place()); auto ctx = pool.Get(in.place());
switch (src_type) { switch (src_type) {
case proto::VarType::FP16:
framework::VisitDataType(dst_type,
CastDataType<platform::float16>(in, out, ctx));
break;
case proto::VarType::FP32: case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break; break;
......
data_type_transform.cc
\ No newline at end of file
...@@ -22,32 +22,145 @@ TEST(DataTypeTransform, CPUTransform) { ...@@ -22,32 +22,145 @@ TEST(DataTypeTransform, CPUTransform) {
auto place = CPUPlace(); auto place = CPUPlace();
Tensor in; auto kernel_fp16 = OpKernelType(proto::VarType::FP16, place,
Tensor out; DataLayout::kAnyLayout, LibraryType::kPlain);
float* ptr = in.mutable_data<float>(make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i / 3;
}
auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place, auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place, auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::VarType::INT32, place, auto kernel_int32 = OpKernelType(proto::VarType::INT32, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int64 = OpKernelType(proto::VarType::INT64, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_bool = OpKernelType(proto::VarType::BOOL, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
TransDataType(kernel_fp32, kernel_fp64, in, &out); // data type transform from float32
double* out_data_double = out.data<double>(); {
for (int i = 0; i < data_number; ++i) { Tensor in;
ASSERT_EQ(out_data_double[i], static_cast<double>(i / 3)); Tensor out;
float* ptr = in.mutable_data<float>(make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i / 3;
}
TransDataType(kernel_fp32, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_double[i], static_cast<double>(i / 3));
}
TransDataType(kernel_fp32, kernel_int32, in, &out);
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int[i], static_cast<int>(i / 3));
}
} }
TransDataType(kernel_fp32, kernel_int32, in, &out); // data type transform from/to float16
int* out_data_int = out.data<int>(); {
for (int i = 0; i < data_number; ++i) { Tensor in;
ASSERT_EQ(out_data_int[i], static_cast<int>(i / 3)); Tensor out;
float16* ptr = in.mutable_data<float16>(make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i;
}
// transform from float16 to other data types
TransDataType(kernel_fp16, kernel_fp32, in, &out);
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_int32, in, &out);
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_int64, in, &out);
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_bool, in, &out);
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to float16
float* in_data_float = in.mutable_data<float>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
TransDataType(kernel_fp32, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_float[i]).x);
}
// transform double to float16
double* in_data_double = in.mutable_data<double>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
TransDataType(kernel_fp64, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_double[i]).x);
}
// transform int to float16
int* in_data_int = in.mutable_data<int>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int[i] = i;
}
TransDataType(kernel_int32, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int[i]).x);
}
// transform int64 to float16
int64_t* in_data_int64 = in.mutable_data<int64_t>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
TransDataType(kernel_int64, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int64[i]).x);
}
// transform bool to float16
bool* in_data_bool = in.mutable_data<bool>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
TransDataType(kernel_bool, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_bool[i]).x);
}
} }
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "gtest/gtest.h"
TEST(DataTypeTransform, GPUTransform) {
using namespace paddle::framework;
using namespace paddle::platform;
auto cpu_place = CPUPlace();
auto gpu_place = CUDAPlace(0);
CUDADeviceContext context(gpu_place);
auto kernel_fp16 = OpKernelType(proto::VarType::FP16, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp32 = OpKernelType(proto::VarType::FP32, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::VarType::FP64, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::VarType::INT32, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int64 = OpKernelType(proto::VarType::INT64, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_bool = OpKernelType(proto::VarType::BOOL, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
// data type transform from float32
{
Tensor in;
Tensor in_gpu;
Tensor out_gpu;
Tensor out;
float* in_ptr = in.mutable_data<float>(make_ddim({2, 3}), cpu_place);
float arr[6] = {0, 1, 2, 3, 4, 5};
int data_number = sizeof(arr) / sizeof(arr[0]);
memcpy(in_ptr, arr, sizeof(arr));
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_fp32, kernel_fp64, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_double[i], static_cast<double>(arr[i]));
}
TransDataType(kernel_fp32, kernel_int32, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int[i], static_cast<int>(arr[i]));
}
}
// data type transform from/to float16
{
Tensor in;
Tensor in_gpu;
Tensor out_gpu;
Tensor out;
float16* ptr = in.mutable_data<float16>(make_ddim({2, 3}), cpu_place);
float16 arr[6] = {float16(0), float16(1), float16(2),
float16(3), float16(4), float16(5)};
int data_number = sizeof(arr) / sizeof(arr[0]);
memcpy(ptr, arr, sizeof(arr));
TensorCopy(in, gpu_place, context, &in_gpu);
// transform from float16 to other data types
TransDataType(kernel_fp16, kernel_fp32, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_fp64, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_int32, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_int64, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_bool, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to float16
float* in_data_float = in.mutable_data<float>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_fp32, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_float[i]).x);
}
// transform double to float16
double* in_data_double =
in.mutable_data<double>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_fp64, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_double[i]).x);
}
// transform int to float16
int* in_data_int = in.mutable_data<int>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_int[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_int32, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int[i]).x);
}
// transform int64 to float16
int64_t* in_data_int64 =
in.mutable_data<int64_t>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_int64, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int64[i]).x);
}
// transform bool to float16
bool* in_data_bool = in.mutable_data<bool>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_bool, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_bool[i]).x);
}
}
}
...@@ -235,27 +235,53 @@ TEST(TensorToVector, Tensor) { ...@@ -235,27 +235,53 @@ TEST(TensorToVector, Tensor) {
TEST(TensorContainsNAN, CPU) { TEST(TensorContainsNAN, CPU) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor src; {
float* buf = src.mutable_data<float>({3}, CPUPlace()); Tensor src;
buf[0] = 0.0; float* buf = src.mutable_data<float>({3}, CPUPlace());
buf[1] = NAN; buf[0] = 0.0;
buf[2] = 0.0; buf[1] = NAN;
ASSERT_TRUE(TensorContainsNAN(src)); buf[2] = 0.0;
buf[1] = 0.0; ASSERT_TRUE(TensorContainsNAN(src));
ASSERT_FALSE(TensorContainsNAN(src)); buf[1] = 0.0;
ASSERT_FALSE(TensorContainsNAN(src));
}
{
Tensor src;
float16* buf = src.mutable_data<float16>({3}, CPUPlace());
buf[0] = 0.0;
buf[1].x = 0x7fff;
buf[2] = 0.0;
ASSERT_TRUE(TensorContainsNAN(src));
buf[1] = 0.0;
ASSERT_FALSE(TensorContainsNAN(src));
}
} }
TEST(TensorContainsInf, CPU) { TEST(TensorContainsInf, CPU) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor src; {
double* buf = src.mutable_data<double>({3}, CPUPlace()); Tensor src;
buf[0] = 1.0; double* buf = src.mutable_data<double>({3}, CPUPlace());
buf[1] = INFINITY; buf[0] = 1.0;
buf[2] = 0.0; buf[1] = INFINITY;
ASSERT_TRUE(TensorContainsInf(src)); buf[2] = 0.0;
buf[1] = 1.0; ASSERT_TRUE(TensorContainsInf(src));
ASSERT_FALSE(TensorContainsInf(src)); buf[1] = 1.0;
ASSERT_FALSE(TensorContainsInf(src));
}
{
Tensor src;
float16* buf = src.mutable_data<float16>({3}, CPUPlace());
buf[0] = 1.0;
buf[1].x = 0x7c00;
buf[2] = 0.0;
ASSERT_TRUE(TensorContainsInf(src));
buf[1] = 1.0;
ASSERT_FALSE(TensorContainsInf(src));
}
} }
TEST(Tensor, FromAndToStream) { TEST(Tensor, FromAndToStream) {
......
...@@ -25,32 +25,65 @@ static __global__ void FillNAN(float* buf) { ...@@ -25,32 +25,65 @@ static __global__ void FillNAN(float* buf) {
buf[1] = 0.1; buf[1] = 0.1;
buf[2] = NAN; buf[2] = NAN;
} }
static __global__ void FillInf(float* buf) { static __global__ void FillInf(float* buf) {
buf[0] = 0.0; buf[0] = 0.0;
buf[1] = INFINITY; buf[1] = INFINITY;
buf[2] = 0.5; buf[2] = 0.5;
} }
static __global__ void FillNAN(platform::float16* buf) {
buf[0] = 0.0;
buf[1] = 0.1;
buf[2].x = 0x7fff;
}
static __global__ void FillInf(platform::float16* buf) {
buf[0] = 0.0;
buf[1].x = 0x7c00;
buf[2] = 0.5;
}
TEST(TensorContainsNAN, GPU) { TEST(TensorContainsNAN, GPU) {
Tensor tensor; using namespace paddle::platform;
platform::CUDAPlace gpu(0); CUDAPlace gpu(0);
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu); auto* cuda_ctx = pool.GetByPlace(gpu);
float* buf = tensor.mutable_data<float>({3}, gpu); {
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); Tensor tensor;
cuda_ctx->Wait(); float* buf = tensor.mutable_data<float>({3}, gpu);
ASSERT_TRUE(TensorContainsNAN(tensor)); FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(TensorContainsNAN(tensor));
}
{
Tensor tensor;
float16* buf = tensor.mutable_data<float16>({3}, gpu);
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(TensorContainsNAN(tensor));
}
} }
TEST(TensorContainsInf, GPU) { TEST(TensorContainsInf, GPU) {
Tensor tensor; using namespace paddle::platform;
platform::CUDAPlace gpu(0); CUDAPlace gpu(0);
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu); auto* cuda_ctx = pool.GetByPlace(gpu);
float* buf = tensor.mutable_data<float>({3}, gpu); {
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); Tensor tensor;
cuda_ctx->Wait(); float* buf = tensor.mutable_data<float>({3}, gpu);
ASSERT_TRUE(TensorContainsInf(tensor)); FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(TensorContainsInf(tensor));
}
{
Tensor tensor;
float16* buf = tensor.mutable_data<float16>({3}, gpu);
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(TensorContainsInf(tensor));
}
} }
} // namespace framework } // namespace framework
......
...@@ -245,11 +245,13 @@ template struct SetConstant<platform::CPUDeviceContext, int>; ...@@ -245,11 +245,13 @@ template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>; template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>; template struct SetConstant<platform::CPUDeviceContext, bool>;
#define DEFINE_CPU_TRANS(RANK) \ #define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \ template struct Transpose<platform::CPUDeviceContext, platform::float16, \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \ RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \ template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; template struct Transpose<platform::CPUDeviceContext, bool, RANK>;
DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(1);
......
...@@ -20,10 +20,6 @@ limitations under the License. */ ...@@ -20,10 +20,6 @@ limitations under the License. */
#include <cuda.h> #include <cuda.h>
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
#include "unsupported/Eigen/CXX11/Tensor"
#include "paddle/fluid/platform/hostdevice.h"
#ifdef __GNUC__ #ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__) #define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else #else
...@@ -64,6 +60,18 @@ limitations under the License. */ ...@@ -64,6 +60,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
// Forward declare float16 for eigen.h
struct float16;
} // namespace platform
} // namespace paddle
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace platform {
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated // Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient // and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible // memory access of float16 struct and also makes float16 compatible
...@@ -729,6 +737,22 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { ...@@ -729,6 +737,22 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
} }
#endif #endif
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hisnan(half(a));
#else
return (a.x & 0x7fff) > 0x7c00;
#endif
}
HOSTDEVICE inline bool(isinf)(const float16& a) {
return (a.x & 0x7fff) == 0x7c00;
}
HOSTDEVICE inline bool(isfinite)(const float16& a) {
return !((isnan)(a)) && !((isinf)(a));
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -750,3 +774,27 @@ struct is_pod<paddle::platform::float16> { ...@@ -750,3 +774,27 @@ struct is_pod<paddle::platform::float16> {
}; };
} // namespace std } // namespace std
namespace Eigen {
namespace numext {
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(
const paddle::platform::float16& a) {
return (paddle::platform::isnan)(a);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(
const paddle::platform::float16& a) {
return (paddle::platform::isinf)(a);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(
const paddle::platform::float16& a) {
return (paddle::platform::isfinite)(a);
}
} // namespace numext
} // namespace Eigen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册