diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 570b37ff1189b664581b59b6dd5ca7ae1ccd12fb..4864e04fa05164655c28fef565c12c4d9c80566b 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -198,6 +198,9 @@ copy(inference_lib_dist copy(inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) +copy(inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/float16.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) # CAPI inference library for only inference set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING diff --git a/paddle/fluid/extension/include/ext_dispatch.h b/paddle/fluid/extension/include/ext_dispatch.h index 7b3893e2839c194c4ab8ae4553a113897f98335a..9b3e199708adc93356c214df3be217f67d2e8949 100644 --- a/paddle/fluid/extension/include/ext_dispatch.h +++ b/paddle/fluid/extension/include/ext_dispatch.h @@ -47,6 +47,22 @@ namespace paddle { } \ }() +#define PD_DISPATCH_FLOATING_AND_HALF_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT16, paddle::float16, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + ::paddle::ToString(__dtype__), "`"); \ + } \ + }() + ///////// Integral Dispatch Marco /////////// #define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ diff --git a/paddle/fluid/extension/include/ext_dtype.h b/paddle/fluid/extension/include/ext_dtype.h index a1e58fbacdff0c1e5670fb83c13779b79778381f..3890631a6f8a9e99948e32cdd3cb8c1e00c2de75 100644 --- a/paddle/fluid/extension/include/ext_dtype.h +++ b/paddle/fluid/extension/include/ext_dtype.h @@ -19,11 +19,13 @@ limitations under the License. */ #include "complex128.h" // NOLINT #include "complex64.h" // NOLINT #include "ext_exception.h" // NOLINT +#include "float16.h" // NOLINT namespace paddle { using complex64 = paddle::platform::complex64; using complex128 = paddle::platform::complex128; +using float16 = paddle::platform::float16; enum class DataType { BOOL, @@ -32,6 +34,7 @@ enum class DataType { INT16, INT32, INT64, + FLOAT16, FLOAT32, FLOAT64, COMPLEX64, @@ -53,6 +56,8 @@ inline std::string ToString(DataType dtype) { return "int32_t"; case DataType::INT64: return "int64_t"; + case DataType::FLOAT16: + return "float16"; case DataType::FLOAT32: return "float"; case DataType::FLOAT64: @@ -73,6 +78,7 @@ inline std::string ToString(DataType dtype) { _(int16_t, DataType::INT16) \ _(int, DataType::INT32) \ _(int64_t, DataType::INT64) \ + _(float16, DataType::FLOAT16) \ _(float, DataType::FLOAT32) \ _(double, DataType::FLOAT64) \ _(complex64, DataType::COMPLEX64) \ diff --git a/paddle/fluid/extension/src/ext_tensor.cc b/paddle/fluid/extension/src/ext_tensor.cc index cb37bf180c3798b38ef4b0e0a24e1b4ee8b9739a..0cae8f4af7b97de19d4daaad5422fd866ff0124a 100644 --- a/paddle/fluid/extension/src/ext_tensor.cc +++ b/paddle/fluid/extension/src/ext_tensor.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/transform.h" namespace paddle { @@ -170,6 +171,8 @@ DataType Tensor::type() const { return DataType::COMPLEX64; } else if (type == framework::proto::VarType::COMPLEX128) { return DataType::COMPLEX128; + } else if (type == framework::proto::VarType::FP16) { + return DataType::FLOAT16; } // TODO(JiabinYang) Support more dtype here return DataType::FLOAT32; @@ -229,6 +232,8 @@ template PD_DLL_DECL Tensor Tensor::copy_to( const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to( const PlaceType &target_place) const; +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL float *Tensor::data() const; template PD_DLL_DECL double *Tensor::data() const; @@ -242,6 +247,8 @@ template PD_DLL_DECL paddle::platform::complex64 * Tensor::data() const; template PD_DLL_DECL paddle::platform::complex128 * Tensor::data() const; +template PD_DLL_DECL paddle::platform::float16 * +Tensor::data() const; template PD_DLL_DECL float *Tensor::mutable_data(); template PD_DLL_DECL double *Tensor::mutable_data(); @@ -255,6 +262,8 @@ template PD_DLL_DECL paddle::platform::complex64 * Tensor::mutable_data(); template PD_DLL_DECL paddle::platform::complex128 * Tensor::mutable_data(); +template PD_DLL_DECL paddle::platform::float16 * +Tensor::mutable_data(); template PD_DLL_DECL float *Tensor::mutable_data(const PlaceType &place); template PD_DLL_DECL double *Tensor::mutable_data( @@ -274,6 +283,8 @@ template PD_DLL_DECL paddle::platform::complex64 * Tensor::mutable_data(const PlaceType &place); template PD_DLL_DECL paddle::platform::complex128 * Tensor::mutable_data(const PlaceType &place); +template PD_DLL_DECL paddle::platform::float16 * +Tensor::mutable_data(const PlaceType &place); std::vector Tensor::shape() const { GET_CASTED_TENSOR @@ -344,6 +355,11 @@ Tensor Tensor::cast(const DataType &target_type) const { CastDataType( *tensor, rlt_tensor_, ctx)); break; + case framework::proto::VarType::FP16: + framework::VisitDataType( + dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; // TODO(JiabinYang) Support more dtype here default: PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/framework/custom_tensor_test.cc b/paddle/fluid/framework/custom_tensor_test.cc index 7da565886008b042d975e67282af088de65361ab..8d6fd4efd5ae3d6dd5e5335fdb8d20c595b9b0a2 100644 --- a/paddle/fluid/framework/custom_tensor_test.cc +++ b/paddle/fluid/framework/custom_tensor_test.cc @@ -113,6 +113,8 @@ void GroupTestCopy() { TestCopyTensor(); VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); + VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); } void GroupTestCast() { @@ -134,6 +136,8 @@ void GroupTestCast() { TestCast(paddle::DataType::FLOAT32); VLOG(2) << "complex128 cast"; TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "float16 cast"; + TestCast(paddle::DataType::FLOAT16); } void GroupTestDtype() { @@ -146,6 +150,7 @@ void GroupTestDtype() { CHECK(TestDtype() == paddle::DataType::UINT8); CHECK(TestDtype() == paddle::DataType::COMPLEX64); CHECK(TestDtype() == paddle::DataType::COMPLEX128); + CHECK(TestDtype() == paddle::DataType::FLOAT16); } void GroupTestDtypeConvert() { @@ -178,6 +183,9 @@ void GroupTestDtypeConvert() { CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( paddle::DataType::COMPLEX128) == paddle::framework::proto::VarType::COMPLEX128); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::FLOAT16) == + paddle::framework::proto::VarType::FP16); // proto -> enum CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( paddle::framework::proto::VarType::FP64) == @@ -207,6 +215,9 @@ void GroupTestDtypeConvert() { CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( paddle::framework::proto::VarType::COMPLEX128) == paddle::DataType::COMPLEX128); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::FP16) == + paddle::DataType::FLOAT16); } TEST(CustomTensor, copyTest) { diff --git a/paddle/fluid/framework/custom_tensor_utils.h b/paddle/fluid/framework/custom_tensor_utils.h index a252d6aef4ef4734e64c226144a58e1a44e7f2e9..fad1e3ee3496cd410e4fca77c09f61c6ff53a402 100644 --- a/paddle/fluid/framework/custom_tensor_utils.h +++ b/paddle/fluid/framework/custom_tensor_utils.h @@ -60,6 +60,8 @@ class CustomTensorUtils { return framework::proto::VarType::COMPLEX64; case paddle::DataType::COMPLEX128: return framework::proto::VarType::COMPLEX128; + case paddle::DataType::FLOAT16: + return framework::proto::VarType::FP16; case paddle::DataType::BOOL: return framework::proto::VarType::BOOL; default: @@ -91,6 +93,8 @@ class CustomTensorUtils { return paddle::DataType::COMPLEX64; case framework::proto::VarType::COMPLEX128: return paddle::DataType::COMPLEX128; + case framework::proto::VarType::FP16: + return paddle::DataType::FLOAT16; case framework::proto::VarType::BOOL: return paddle::DataType::BOOL; default: diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index 4ba537930cef5251f852ac7fd1936e58e5d927c9..36496ec499fd99ac3c65ac1ef7598a75316d6274 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -13,24 +13,15 @@ endif() py_test(test_sysconfig SRCS test_sysconfig.py) -# 'test_dispatch' compile .cc file +# CPU custom op tests: only compile .cc file py_test(test_dispatch_jit SRCS test_dispatch_jit.py) -set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 120) - py_test(test_multi_out_jit SRCS test_multi_out_jit.py) -set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120) - py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) -set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120) - py_test(test_custom_concat SRCS test_custom_concat.py) -set_tests_properties(test_custom_concat PROPERTIES TIMEOUT 120) - py_test(test_custom_conj SRCS test_custom_conj.py) -set_tests_properties(test_custom_conj PROPERTIES TIMEOUT 120) +# other tests py_test(test_check_abi SRCS test_check_abi.py) - cc_test(test_check_error SRCS test_check_error.cc DEPS gtest) if(NOT LINUX) diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu index be3309d84f57d6f4f000f920339b06dc370c85a8..4ec7d0884582e7c4970865523111279412e027e7 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -20,7 +20,7 @@ __global__ void relu_cuda_forward_kernel(const data_t* x, const int num) { int gid = blockIdx.x * blockDim.x + threadIdx.x; for (int i = gid; i < num; i += blockDim.x * gridDim.x) { - y[i] = max(x[i], static_cast(0.)); + y[i] = x[i] > static_cast(0.) ? x[i] : static_cast(0.); } } @@ -31,7 +31,8 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy, const int num) { int gid = blockIdx.x * blockDim.x + threadIdx.x; for (int i = gid; i < num; i += blockDim.x * gridDim.x) { - dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.); + dx[i] = dy[i] * (y[i] > static_cast(0.) ? static_cast(1.) + : static_cast(0.)); } } @@ -42,7 +43,7 @@ std::vector relu_cuda_forward(const paddle::Tensor& x) { int numel = x.size(); int block = 512; int grid = (numel + block - 1) / block; - PD_DISPATCH_FLOATING_TYPES( + PD_DISPATCH_FLOATING_AND_HALF_TYPES( x.type(), "relu_cuda_forward_kernel", ([&] { relu_cuda_forward_kernel<<>>( x.data(), out.mutable_data(x.place()), numel); @@ -60,7 +61,7 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, int numel = out.size(); int block = 512; int grid = (numel + block - 1) / block; - PD_DISPATCH_FLOATING_TYPES( + PD_DISPATCH_FLOATING_AND_HALF_TYPES( out.type(), "relu_cuda_backward_kernel", ([&] { relu_cuda_backward_kernel<<>>( grad_out.data(), diff --git a/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc b/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc index fbf5442ac026a6e3d66796e063f91ef9b049afd2..0435f50b7c701e997c9e73ba4bfd9a8c5a998471 100644 --- a/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc +++ b/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc @@ -118,3 +118,21 @@ PD_BUILD_OP(dispatch_test_float_and_integer_and_complex) .Inputs({"X"}) .Outputs({"Out"}) .SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex)); + +std::vector DispatchTestFloatAndHalf(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + return {out}; +} + +PD_BUILD_OP(dispatch_test_float_and_half) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(DispatchTestFloatAndHalf)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py index 1a96fc5f0aeed318f7b5b4ec024327579917b82a..23733d20841b3afa2347fb32fcc4335491b7f8cc 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -50,11 +50,17 @@ class TestJITLoad(unittest.TestCase): custom_module.custom_relu, custom_module.custom_relu_dup ] self.dtypes = ['float32', 'float64'] - self.devices = ['cpu', 'gpu'] + if paddle.is_compiled_with_cuda(): + self.dtypes.append('float16') + self.devices = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.devices.append('gpu') def test_static(self): for device in self.devices: for dtype in self.dtypes: + if device == 'cpu' and dtype == 'float16': + continue x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: out = custom_relu_static(custom_op, device, dtype, x) @@ -68,6 +74,8 @@ class TestJITLoad(unittest.TestCase): def test_dynamic(self): for device in self.devices: for dtype in self.dtypes: + if device == 'cpu' and dtype == 'float16': + continue x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: out, x_grad = custom_relu_dynamic(custom_op, device, dtype, @@ -87,7 +95,7 @@ class TestJITLoad(unittest.TestCase): caught_exception = False try: x = np.random.uniform(-1, 1, [4, 8]).astype('int32') - custom_relu_dynamic(custom_module.custom_relu, 'cpu', 'float32', x) + custom_relu_dynamic(custom_module.custom_relu, 'cpu', 'int32', x) except OSError as e: caught_exception = True self.assertTrue( @@ -105,15 +113,15 @@ class TestJITLoad(unittest.TestCase): caught_exception = False try: - x = np.random.uniform(-1, 1, [4, 8]).astype('int64') - custom_relu_dynamic(custom_module.custom_relu, 'gpu', 'float32', x) + x = np.random.uniform(-1, 1, [4, 8]).astype('int32') + custom_relu_dynamic(custom_module.custom_relu, 'gpu', 'int32', x) except OSError as e: caught_exception = True self.assertTrue( - "function \"relu_cuda_forward_kernel\" is not implemented for data type `int64_t`" + "function \"relu_cuda_forward_kernel\" is not implemented for data type `int32_t`" in str(e)) self.assertTrue( - "python/paddle/fluid/tests/custom_op/custom_relu_op.cu:49" in + "python/paddle/fluid/tests/custom_op/custom_relu_op.cu:50" in str(e)) self.assertTrue(caught_exception) diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py index 6781915e021c92f4c0f6a25e9f42ab940a3035d2..5c5c2d65a59574a47e6b9d3cfa7e3be67731dfb0 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py @@ -26,7 +26,7 @@ from paddle.utils.cpp_extension.extension_utils import run_cmd def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): paddle.set_device(device) - t = paddle.to_tensor(np_x) + t = paddle.to_tensor(np_x, dtype=dtype) t.stop_gradient = False out = func(t) if use_func else paddle.nn.functional.relu(t) @@ -171,7 +171,11 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ] self.dtypes = ['float32', 'float64'] - self.devices = ['cpu', 'gpu'] + if paddle.is_compiled_with_cuda(): + self.dtypes.append('float16') + self.devices = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.devices.append('gpu') # config seed SEED = 2021 @@ -181,6 +185,8 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): def test_static(self): for device in self.devices: for dtype in self.dtypes: + if device == 'cpu' and dtype == 'float16': + continue x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: out = custom_relu_static(custom_op, device, dtype, x) @@ -194,6 +200,8 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): def test_static_pe(self): for device in self.devices: for dtype in self.dtypes: + if device == 'cpu' and dtype == 'float16': + continue x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: out = custom_relu_static_pe(custom_op, device, dtype, x) @@ -207,6 +215,8 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): def test_dynamic(self): for device in self.devices: for dtype in self.dtypes: + if device == 'cpu' and dtype == 'float16': + continue x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: out, x_grad = custom_relu_dynamic(custom_op, device, dtype, diff --git a/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py index bc36372c6a7945cee866f486496af1cd0c80397d..12e9f50a5e4092a067c533bcdb6bcb03011d35fa 100644 --- a/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py @@ -83,6 +83,12 @@ class TestJitDispatch(unittest.TestCase): self.run_dispatch_test( dispatch_op.dispatch_test_float_and_integer_and_complex, dtype) + def test_dispatch_float_and_half(self): + dtypes = ["float32", "float64", "float16"] + for dtype in dtypes: + self.run_dispatch_test(dispatch_op.dispatch_test_float_and_half, + dtype) + if __name__ == '__main__': unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 0afc3956a01e1a14fcc80fd366ee0cb5f61e7ccc..71d4afdb283c76a845a69cde1d0d09c95b1b323c 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -453,15 +453,12 @@ class InstallHeaders(Command): def copy_data_type_headers(self, header): if os.name == 'nt': - data_type_headers = ['platform\\complex64.h', 'platform\\complex128.h'] + data_type_headers = ['platform\\complex64.h', 'platform\\complex128.h', 'platform\\float16.h'] else: - data_type_headers = ['platform/complex64.h', 'platform/complex128.h'] + data_type_headers = ['platform/complex64.h', 'platform/complex128.h', 'platform/float16.h'] for dtype_header in data_type_headers: if dtype_header in header: - if os.name == 'nt': - install_dir = os.path.join(self.install_dir, "paddle\\fluid\\extension\\include") - else: - install_dir = os.path.join(self.install_dir, "paddle/fluid/extension/include") + install_dir = os.path.join(self.install_dir, "paddle/fluid/extension/include") if not os.path.exists(install_dir): self.mkpath(install_dir) return self.copy_file(header, install_dir)