未验证 提交 878e117b 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Support float16 in custom op (#31725)

* support float16 in custom op

* fix failed unittests
上级 c9e1d9dc
...@@ -198,6 +198,9 @@ copy(inference_lib_dist ...@@ -198,6 +198,9 @@ copy(inference_lib_dist
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/float16.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
# CAPI inference library for only inference # CAPI inference library for only inference
set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING
......
...@@ -47,6 +47,22 @@ namespace paddle { ...@@ -47,6 +47,22 @@ namespace paddle {
} \ } \
}() }()
#define PD_DISPATCH_FLOATING_AND_HALF_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT16, paddle::float16, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
::paddle::ToString(__dtype__), "`"); \
} \
}()
///////// Integral Dispatch Marco /////////// ///////// Integral Dispatch Marco ///////////
#define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
......
...@@ -19,11 +19,13 @@ limitations under the License. */ ...@@ -19,11 +19,13 @@ limitations under the License. */
#include "complex128.h" // NOLINT #include "complex128.h" // NOLINT
#include "complex64.h" // NOLINT #include "complex64.h" // NOLINT
#include "ext_exception.h" // NOLINT #include "ext_exception.h" // NOLINT
#include "float16.h" // NOLINT
namespace paddle { namespace paddle {
using complex64 = paddle::platform::complex64; using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128; using complex128 = paddle::platform::complex128;
using float16 = paddle::platform::float16;
enum class DataType { enum class DataType {
BOOL, BOOL,
...@@ -32,6 +34,7 @@ enum class DataType { ...@@ -32,6 +34,7 @@ enum class DataType {
INT16, INT16,
INT32, INT32,
INT64, INT64,
FLOAT16,
FLOAT32, FLOAT32,
FLOAT64, FLOAT64,
COMPLEX64, COMPLEX64,
...@@ -53,6 +56,8 @@ inline std::string ToString(DataType dtype) { ...@@ -53,6 +56,8 @@ inline std::string ToString(DataType dtype) {
return "int32_t"; return "int32_t";
case DataType::INT64: case DataType::INT64:
return "int64_t"; return "int64_t";
case DataType::FLOAT16:
return "float16";
case DataType::FLOAT32: case DataType::FLOAT32:
return "float"; return "float";
case DataType::FLOAT64: case DataType::FLOAT64:
...@@ -73,6 +78,7 @@ inline std::string ToString(DataType dtype) { ...@@ -73,6 +78,7 @@ inline std::string ToString(DataType dtype) {
_(int16_t, DataType::INT16) \ _(int16_t, DataType::INT16) \
_(int, DataType::INT32) \ _(int, DataType::INT32) \
_(int64_t, DataType::INT64) \ _(int64_t, DataType::INT64) \
_(float16, DataType::FLOAT16) \
_(float, DataType::FLOAT32) \ _(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64) \ _(double, DataType::FLOAT64) \
_(complex64, DataType::COMPLEX64) \ _(complex64, DataType::COMPLEX64) \
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
...@@ -170,6 +171,8 @@ DataType Tensor::type() const { ...@@ -170,6 +171,8 @@ DataType Tensor::type() const {
return DataType::COMPLEX64; return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) { } else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128; return DataType::COMPLEX128;
} else if (type == framework::proto::VarType::FP16) {
return DataType::FLOAT16;
} }
// TODO(JiabinYang) Support more dtype here // TODO(JiabinYang) Support more dtype here
return DataType::FLOAT32; return DataType::FLOAT32;
...@@ -229,6 +232,8 @@ template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>( ...@@ -229,6 +232,8 @@ template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place) const; const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>( template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place) const; const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
template PD_DLL_DECL float *Tensor::data<float>() const; template PD_DLL_DECL float *Tensor::data<float>() const;
template PD_DLL_DECL double *Tensor::data<double>() const; template PD_DLL_DECL double *Tensor::data<double>() const;
...@@ -242,6 +247,8 @@ template PD_DLL_DECL paddle::platform::complex64 * ...@@ -242,6 +247,8 @@ template PD_DLL_DECL paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const; Tensor::data<paddle::platform::complex64>() const;
template PD_DLL_DECL paddle::platform::complex128 * template PD_DLL_DECL paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const; Tensor::data<paddle::platform::complex128>() const;
template PD_DLL_DECL paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
template PD_DLL_DECL float *Tensor::mutable_data<float>(); template PD_DLL_DECL float *Tensor::mutable_data<float>();
template PD_DLL_DECL double *Tensor::mutable_data<double>(); template PD_DLL_DECL double *Tensor::mutable_data<double>();
...@@ -255,6 +262,8 @@ template PD_DLL_DECL paddle::platform::complex64 * ...@@ -255,6 +262,8 @@ template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(); Tensor::mutable_data<paddle::platform::complex64>();
template PD_DLL_DECL paddle::platform::complex128 * template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(); Tensor::mutable_data<paddle::platform::complex128>();
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
template PD_DLL_DECL float *Tensor::mutable_data<float>(const PlaceType &place); template PD_DLL_DECL float *Tensor::mutable_data<float>(const PlaceType &place);
template PD_DLL_DECL double *Tensor::mutable_data<double>( template PD_DLL_DECL double *Tensor::mutable_data<double>(
...@@ -274,6 +283,8 @@ template PD_DLL_DECL paddle::platform::complex64 * ...@@ -274,6 +283,8 @@ template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place); Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex128 * template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place); Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
std::vector<int64_t> Tensor::shape() const { std::vector<int64_t> Tensor::shape() const {
GET_CASTED_TENSOR GET_CASTED_TENSOR
...@@ -344,6 +355,11 @@ Tensor Tensor::cast(const DataType &target_type) const { ...@@ -344,6 +355,11 @@ Tensor Tensor::cast(const DataType &target_type) const {
CastDataType<paddle::platform::complex128>( CastDataType<paddle::platform::complex128>(
*tensor, rlt_tensor_, ctx)); *tensor, rlt_tensor_, ctx));
break; break;
case framework::proto::VarType::FP16:
framework::VisitDataType(
dst_type,
CastDataType<paddle::platform::float16>(*tensor, rlt_tensor_, ctx));
break;
// TODO(JiabinYang) Support more dtype here // TODO(JiabinYang) Support more dtype here
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -113,6 +113,8 @@ void GroupTestCopy() { ...@@ -113,6 +113,8 @@ void GroupTestCopy() {
TestCopyTensor<paddle::complex64>(); TestCopyTensor<paddle::complex64>();
VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::complex128>(); TestCopyTensor<paddle::complex128>();
VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::float16>();
} }
void GroupTestCast() { void GroupTestCast() {
...@@ -134,6 +136,8 @@ void GroupTestCast() { ...@@ -134,6 +136,8 @@ void GroupTestCast() {
TestCast<paddle::complex64>(paddle::DataType::FLOAT32); TestCast<paddle::complex64>(paddle::DataType::FLOAT32);
VLOG(2) << "complex128 cast"; VLOG(2) << "complex128 cast";
TestCast<paddle::complex128>(paddle::DataType::FLOAT32); TestCast<paddle::complex128>(paddle::DataType::FLOAT32);
VLOG(2) << "float16 cast";
TestCast<paddle::float16>(paddle::DataType::FLOAT16);
} }
void GroupTestDtype() { void GroupTestDtype() {
...@@ -146,6 +150,7 @@ void GroupTestDtype() { ...@@ -146,6 +150,7 @@ void GroupTestDtype() {
CHECK(TestDtype<uint8_t>() == paddle::DataType::UINT8); CHECK(TestDtype<uint8_t>() == paddle::DataType::UINT8);
CHECK(TestDtype<paddle::complex64>() == paddle::DataType::COMPLEX64); CHECK(TestDtype<paddle::complex64>() == paddle::DataType::COMPLEX64);
CHECK(TestDtype<paddle::complex128>() == paddle::DataType::COMPLEX128); CHECK(TestDtype<paddle::complex128>() == paddle::DataType::COMPLEX128);
CHECK(TestDtype<paddle::float16>() == paddle::DataType::FLOAT16);
} }
void GroupTestDtypeConvert() { void GroupTestDtypeConvert() {
...@@ -178,6 +183,9 @@ void GroupTestDtypeConvert() { ...@@ -178,6 +183,9 @@ void GroupTestDtypeConvert() {
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX128) == paddle::DataType::COMPLEX128) ==
paddle::framework::proto::VarType::COMPLEX128); paddle::framework::proto::VarType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT16) ==
paddle::framework::proto::VarType::FP16);
// proto -> enum // proto -> enum
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP64) == paddle::framework::proto::VarType::FP64) ==
...@@ -207,6 +215,9 @@ void GroupTestDtypeConvert() { ...@@ -207,6 +215,9 @@ void GroupTestDtypeConvert() {
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX128) == paddle::framework::proto::VarType::COMPLEX128) ==
paddle::DataType::COMPLEX128); paddle::DataType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP16) ==
paddle::DataType::FLOAT16);
} }
TEST(CustomTensor, copyTest) { TEST(CustomTensor, copyTest) {
......
...@@ -60,6 +60,8 @@ class CustomTensorUtils { ...@@ -60,6 +60,8 @@ class CustomTensorUtils {
return framework::proto::VarType::COMPLEX64; return framework::proto::VarType::COMPLEX64;
case paddle::DataType::COMPLEX128: case paddle::DataType::COMPLEX128:
return framework::proto::VarType::COMPLEX128; return framework::proto::VarType::COMPLEX128;
case paddle::DataType::FLOAT16:
return framework::proto::VarType::FP16;
case paddle::DataType::BOOL: case paddle::DataType::BOOL:
return framework::proto::VarType::BOOL; return framework::proto::VarType::BOOL;
default: default:
...@@ -91,6 +93,8 @@ class CustomTensorUtils { ...@@ -91,6 +93,8 @@ class CustomTensorUtils {
return paddle::DataType::COMPLEX64; return paddle::DataType::COMPLEX64;
case framework::proto::VarType::COMPLEX128: case framework::proto::VarType::COMPLEX128:
return paddle::DataType::COMPLEX128; return paddle::DataType::COMPLEX128;
case framework::proto::VarType::FP16:
return paddle::DataType::FLOAT16;
case framework::proto::VarType::BOOL: case framework::proto::VarType::BOOL:
return paddle::DataType::BOOL; return paddle::DataType::BOOL;
default: default:
......
...@@ -13,24 +13,15 @@ endif() ...@@ -13,24 +13,15 @@ endif()
py_test(test_sysconfig SRCS test_sysconfig.py) py_test(test_sysconfig SRCS test_sysconfig.py)
# 'test_dispatch' compile .cc file # CPU custom op tests: only compile .cc file
py_test(test_dispatch_jit SRCS test_dispatch_jit.py) py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 120)
py_test(test_multi_out_jit SRCS test_multi_out_jit.py) py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120)
py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)
py_test(test_custom_concat SRCS test_custom_concat.py) py_test(test_custom_concat SRCS test_custom_concat.py)
set_tests_properties(test_custom_concat PROPERTIES TIMEOUT 120)
py_test(test_custom_conj SRCS test_custom_conj.py) py_test(test_custom_conj SRCS test_custom_conj.py)
set_tests_properties(test_custom_conj PROPERTIES TIMEOUT 120)
# other tests
py_test(test_check_abi SRCS test_check_abi.py) py_test(test_check_abi SRCS test_check_abi.py)
cc_test(test_check_error SRCS test_check_error.cc DEPS gtest) cc_test(test_check_error SRCS test_check_error.cc DEPS gtest)
if(NOT LINUX) if(NOT LINUX)
......
...@@ -20,7 +20,7 @@ __global__ void relu_cuda_forward_kernel(const data_t* x, ...@@ -20,7 +20,7 @@ __global__ void relu_cuda_forward_kernel(const data_t* x,
const int num) { const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x; int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) { for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = max(x[i], static_cast<data_t>(0.)); y[i] = x[i] > static_cast<data_t>(0.) ? x[i] : static_cast<data_t>(0.);
} }
} }
...@@ -31,7 +31,8 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy, ...@@ -31,7 +31,8 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy,
const int num) { const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x; int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) { for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.); dx[i] = dy[i] * (y[i] > static_cast<data_t>(0.) ? static_cast<data_t>(1.)
: static_cast<data_t>(0.));
} }
} }
...@@ -42,7 +43,7 @@ std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) { ...@@ -42,7 +43,7 @@ std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
int numel = x.size(); int numel = x.size();
int block = 512; int block = 512;
int grid = (numel + block - 1) / block; int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_AND_HALF_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] { x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>( relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel); x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
...@@ -60,7 +61,7 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x, ...@@ -60,7 +61,7 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
int numel = out.size(); int numel = out.size();
int block = 512; int block = 512;
int grid = (numel + block - 1) / block; int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_AND_HALF_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] { out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>( relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
grad_out.data<data_t>(), grad_out.data<data_t>(),
......
...@@ -118,3 +118,21 @@ PD_BUILD_OP(dispatch_test_float_and_integer_and_complex) ...@@ -118,3 +118,21 @@ PD_BUILD_OP(dispatch_test_float_and_integer_and_complex)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex)); .SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex));
std::vector<paddle::Tensor> DispatchTestFloatAndHalf(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_float_and_half)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndHalf));
...@@ -50,11 +50,17 @@ class TestJITLoad(unittest.TestCase): ...@@ -50,11 +50,17 @@ class TestJITLoad(unittest.TestCase):
custom_module.custom_relu, custom_module.custom_relu_dup custom_module.custom_relu, custom_module.custom_relu_dup
] ]
self.dtypes = ['float32', 'float64'] self.dtypes = ['float32', 'float64']
self.devices = ['cpu', 'gpu'] if paddle.is_compiled_with_cuda():
self.dtypes.append('float16')
self.devices = ['cpu']
if paddle.is_compiled_with_cuda():
self.devices.append('gpu')
def test_static(self): def test_static(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops: for custom_op in self.custom_ops:
out = custom_relu_static(custom_op, device, dtype, x) out = custom_relu_static(custom_op, device, dtype, x)
...@@ -68,6 +74,8 @@ class TestJITLoad(unittest.TestCase): ...@@ -68,6 +74,8 @@ class TestJITLoad(unittest.TestCase):
def test_dynamic(self): def test_dynamic(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops: for custom_op in self.custom_ops:
out, x_grad = custom_relu_dynamic(custom_op, device, dtype, out, x_grad = custom_relu_dynamic(custom_op, device, dtype,
...@@ -87,7 +95,7 @@ class TestJITLoad(unittest.TestCase): ...@@ -87,7 +95,7 @@ class TestJITLoad(unittest.TestCase):
caught_exception = False caught_exception = False
try: try:
x = np.random.uniform(-1, 1, [4, 8]).astype('int32') x = np.random.uniform(-1, 1, [4, 8]).astype('int32')
custom_relu_dynamic(custom_module.custom_relu, 'cpu', 'float32', x) custom_relu_dynamic(custom_module.custom_relu, 'cpu', 'int32', x)
except OSError as e: except OSError as e:
caught_exception = True caught_exception = True
self.assertTrue( self.assertTrue(
...@@ -105,15 +113,15 @@ class TestJITLoad(unittest.TestCase): ...@@ -105,15 +113,15 @@ class TestJITLoad(unittest.TestCase):
caught_exception = False caught_exception = False
try: try:
x = np.random.uniform(-1, 1, [4, 8]).astype('int64') x = np.random.uniform(-1, 1, [4, 8]).astype('int32')
custom_relu_dynamic(custom_module.custom_relu, 'gpu', 'float32', x) custom_relu_dynamic(custom_module.custom_relu, 'gpu', 'int32', x)
except OSError as e: except OSError as e:
caught_exception = True caught_exception = True
self.assertTrue( self.assertTrue(
"function \"relu_cuda_forward_kernel\" is not implemented for data type `int64_t`" "function \"relu_cuda_forward_kernel\" is not implemented for data type `int32_t`"
in str(e)) in str(e))
self.assertTrue( self.assertTrue(
"python/paddle/fluid/tests/custom_op/custom_relu_op.cu:49" in "python/paddle/fluid/tests/custom_op/custom_relu_op.cu:50" in
str(e)) str(e))
self.assertTrue(caught_exception) self.assertTrue(caught_exception)
......
...@@ -26,7 +26,7 @@ from paddle.utils.cpp_extension.extension_utils import run_cmd ...@@ -26,7 +26,7 @@ from paddle.utils.cpp_extension.extension_utils import run_cmd
def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): def custom_relu_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device) paddle.set_device(device)
t = paddle.to_tensor(np_x) t = paddle.to_tensor(np_x, dtype=dtype)
t.stop_gradient = False t.stop_gradient = False
out = func(t) if use_func else paddle.nn.functional.relu(t) out = func(t) if use_func else paddle.nn.functional.relu(t)
...@@ -171,7 +171,11 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -171,7 +171,11 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
] ]
self.dtypes = ['float32', 'float64'] self.dtypes = ['float32', 'float64']
self.devices = ['cpu', 'gpu'] if paddle.is_compiled_with_cuda():
self.dtypes.append('float16')
self.devices = ['cpu']
if paddle.is_compiled_with_cuda():
self.devices.append('gpu')
# config seed # config seed
SEED = 2021 SEED = 2021
...@@ -181,6 +185,8 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -181,6 +185,8 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
def test_static(self): def test_static(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops: for custom_op in self.custom_ops:
out = custom_relu_static(custom_op, device, dtype, x) out = custom_relu_static(custom_op, device, dtype, x)
...@@ -194,6 +200,8 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -194,6 +200,8 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
def test_static_pe(self): def test_static_pe(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops: for custom_op in self.custom_ops:
out = custom_relu_static_pe(custom_op, device, dtype, x) out = custom_relu_static_pe(custom_op, device, dtype, x)
...@@ -207,6 +215,8 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -207,6 +215,8 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
def test_dynamic(self): def test_dynamic(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops: for custom_op in self.custom_ops:
out, x_grad = custom_relu_dynamic(custom_op, device, dtype, out, x_grad = custom_relu_dynamic(custom_op, device, dtype,
......
...@@ -83,6 +83,12 @@ class TestJitDispatch(unittest.TestCase): ...@@ -83,6 +83,12 @@ class TestJitDispatch(unittest.TestCase):
self.run_dispatch_test( self.run_dispatch_test(
dispatch_op.dispatch_test_float_and_integer_and_complex, dtype) dispatch_op.dispatch_test_float_and_integer_and_complex, dtype)
def test_dispatch_float_and_half(self):
dtypes = ["float32", "float64", "float16"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_float_and_half,
dtype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -453,14 +453,11 @@ class InstallHeaders(Command): ...@@ -453,14 +453,11 @@ class InstallHeaders(Command):
def copy_data_type_headers(self, header): def copy_data_type_headers(self, header):
if os.name == 'nt': if os.name == 'nt':
data_type_headers = ['platform\\complex64.h', 'platform\\complex128.h'] data_type_headers = ['platform\\complex64.h', 'platform\\complex128.h', 'platform\\float16.h']
else: else:
data_type_headers = ['platform/complex64.h', 'platform/complex128.h'] data_type_headers = ['platform/complex64.h', 'platform/complex128.h', 'platform/float16.h']
for dtype_header in data_type_headers: for dtype_header in data_type_headers:
if dtype_header in header: if dtype_header in header:
if os.name == 'nt':
install_dir = os.path.join(self.install_dir, "paddle\\fluid\\extension\\include")
else:
install_dir = os.path.join(self.install_dir, "paddle/fluid/extension/include") install_dir = os.path.join(self.install_dir, "paddle/fluid/extension/include")
if not os.path.exists(install_dir): if not os.path.exists(install_dir):
self.mkpath(install_dir) self.mkpath(install_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册