From b2a1c9e8b7f0fab5f81282783acadc2d35f3e207 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Fri, 6 Apr 2018 13:25:29 -0700 Subject: [PATCH] Add float16 support to non-cudnn softmax op on GPU (#9686) * initial commit * fix error * fix typo and order --- paddle/fluid/operators/math/softmax.cu | 3 + paddle/fluid/operators/math/softmax_impl.h | 2 +- paddle/fluid/operators/softmax_op.cc | 9 +- paddle/fluid/operators/softmax_op.cu.cc | 11 +- paddle/fluid/platform/float16.h | 227 +++++++++++++----- .../fluid/tests/unittests/test_softmax_op.py | 11 + 6 files changed, 189 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 5518ebed3f7..a579182ec1b 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -14,6 +14,8 @@ limitations under the License. */ #define EIGEN_USE_GPU +#include + #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax_impl.h" @@ -95,6 +97,7 @@ template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; +template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 3e123f7bf55..dd9971ba091 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -27,7 +27,7 @@ using EigenMatrix = framework::EigenMatrix; template struct ValueClip { HOSTDEVICE T operator()(const T& x) const { - const T kThreshold = -64.; + const T kThreshold = static_cast(-64.); return x < kThreshold ? kThreshold : x; } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index e2c0f915d96..6bdefc0f239 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/softmax_op.h" + +#include + #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" #endif @@ -20,6 +23,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif + namespace paddle { namespace operators { @@ -60,8 +64,8 @@ class SoftmaxOp : public framework::OperatorWithKernel { auto input_data_type = framework::ToDataType(ctx.Input("X")->type()); if (input_data_type == framework::proto::VarType::FP16) { - PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, - "float16 can only be used when CUDNN is used"); + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "float16 can only be used on GPU place"); } std::string data_format = ctx.Attr("data_format"); @@ -70,6 +74,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { library_); } }; + class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { public: SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker) diff --git a/paddle/fluid/operators/softmax_op.cu.cc b/paddle/fluid/operators/softmax_op.cu.cc index dbd13fd38a3..0c1f7cef7ab 100644 --- a/paddle/fluid/operators/softmax_op.cu.cc +++ b/paddle/fluid/operators/softmax_op.cu.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - softmax, ops::SoftmaxKernel); +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - softmax_grad, - ops::SoftmaxGradKernel); + softmax, ops::SoftmaxKernel, + ops::SoftmaxKernel); +REGISTER_OP_CUDA_KERNEL(softmax_grad, + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 2cf311c7e56..e77f768bf9f 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #ifdef PADDLE_WITH_CUDA #include @@ -293,39 +294,39 @@ struct PADDLE_ALIGN(2) float16 { HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; } HOSTDEVICE inline explicit operator int8_t() const { - return static_cast(float(*this)); + return static_cast(static_cast(*this)); } HOSTDEVICE inline explicit operator uint8_t() const { - return static_cast(float(*this)); + return static_cast(static_cast(*this)); } HOSTDEVICE inline explicit operator int16_t() const { - return static_cast(float(*this)); + return static_cast(static_cast(*this)); } HOSTDEVICE inline explicit operator uint16_t() const { - return static_cast(float(*this)); + return static_cast(static_cast(*this)); } HOSTDEVICE inline explicit operator int32_t() const { - return static_cast(float(*this)); + return static_cast(static_cast(*this)); } HOSTDEVICE inline explicit operator uint32_t() const { - return static_cast(float(*this)); + return static_cast(static_cast(*this)); } HOSTDEVICE inline explicit operator int64_t() const { - return static_cast(float(*this)); + return static_cast(static_cast(*this)); } HOSTDEVICE inline explicit operator uint64_t() const { - return static_cast(float(*this)); + return static_cast(static_cast(*this)); } HOSTDEVICE inline explicit operator double() const { - return static_cast(float(*this)); + return static_cast(static_cast(*this)); } private: @@ -370,7 +371,7 @@ DEVICE inline half operator+(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hadd(a, b); #else - float res = float(float16(a)) + float(float16(b)); + float res = static_cast(float16(a)) + static_cast(float16(b)); return half(float16(res)); #endif } @@ -379,7 +380,7 @@ DEVICE inline half operator-(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hsub(a, b); #else - float res = float(float16(a)) - float(float16(b)); + float res = static_cast(float16(a)) - static_cast(float16(b)); return half(float16(res)); #endif } @@ -388,7 +389,7 @@ DEVICE inline half operator*(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hmul(a, b); #else - float res = float(float16(a)) * float(float16(b)); + float res = static_cast(float16(a)) * static_cast(float16(b)); return half(float16(res)); #endif } @@ -399,7 +400,7 @@ DEVICE inline half operator/(const half& a, const half& b) { float denom = __half2float(b); return __float2half(num / denom); #else - float res = float(float16(a)) / float(float16(b)); + float res = static_cast(float16(a)) / static_cast(float16(b)); return half(float16(res)); #endif } @@ -408,27 +409,27 @@ DEVICE inline half operator-(const half& a) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hneg(a); #else - float res = -float(float16(a)); + float res = -static_cast(float16(a)); return half(float16(res)); #endif } -DEVICE inline half& operator+=(half& a, const half& b) { +DEVICE inline half& operator+=(half& a, const half& b) { // NOLINT a = a + b; return a; } -DEVICE inline half& operator-=(half& a, const half& b) { +DEVICE inline half& operator-=(half& a, const half& b) { // NOLINT a = a - b; return a; } -DEVICE inline half& operator*=(half& a, const half& b) { +DEVICE inline half& operator*=(half& a, const half& b) { // NOLINT a = a * b; return a; } -DEVICE inline half& operator/=(half& a, const half& b) { +DEVICE inline half& operator/=(half& a, const half& b) { // NOLINT a = a / b; return a; } @@ -437,7 +438,7 @@ DEVICE inline bool operator==(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __heq(a, b); #else - return float(float16(a)) == float(float16(b)); + return static_cast(float16(a)) == static_cast(float16(b)); #endif } @@ -445,7 +446,7 @@ DEVICE inline bool operator!=(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hne(a, b); #else - return float(float16(a)) != float(float16(b)); + return static_cast(float16(a)) != static_cast(float16(b)); #endif } @@ -453,7 +454,7 @@ DEVICE inline bool operator<(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hlt(a, b); #else - return float(float16(a)) < float(float16(b)); + return static_cast(float16(a)) < static_cast(float16(b)); #endif } @@ -461,7 +462,7 @@ DEVICE inline bool operator<=(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hle(a, b); #else - return float(float16(a)) <= float(float16(b)); + return static_cast(float16(a)) <= static_cast(float16(b)); #endif } @@ -469,7 +470,7 @@ DEVICE inline bool operator>(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hgt(a, b); #else - return float(float16(a)) > float(float16(b)); + return static_cast(float16(a)) > static_cast(float16(b)); #endif } @@ -477,7 +478,7 @@ DEVICE inline bool operator>=(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hge(a, b); #else - return float(float16(a)) >= float(float16(b)); + return static_cast(float16(a)) >= static_cast(float16(b)); #endif } @@ -489,7 +490,7 @@ HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hadd(half(a), half(b))); #else - return float16(float(a) + float(b)); + return float16(static_cast(a) + static_cast(b)); #endif } @@ -497,7 +498,7 @@ HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hsub(half(a), half(b))); #else - return float16(float(a) - float(b)); + return float16(static_cast(a) - static_cast(b)); #endif } @@ -505,7 +506,7 @@ HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hmul(half(a), half(b))); #else - return float16(float(a) * float(b)); + return float16(static_cast(a) * static_cast(b)); #endif } @@ -516,7 +517,7 @@ HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { float denom = __half2float(half(b)); return float16(num / denom); #else - return float16(float(a) / float(b)); + return float16(static_cast(a) / static_cast(b)); #endif } @@ -530,22 +531,22 @@ HOSTDEVICE inline float16 operator-(const float16& a) { #endif } -HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { +HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { // NOLINT a = a + b; return a; } -HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { +HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { // NOLINT a = a - b; return a; } -HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { +HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { // NOLINT a = a * b; return a; } -HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { +HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { // NOLINT a = a / b; return a; } @@ -554,7 +555,7 @@ HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __heq(half(a), half(b)); #else - return float(a) == float(b); + return static_cast(a) == static_cast(b); #endif } @@ -562,7 +563,7 @@ HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hne(half(a), half(b)); #else - return float(a) != float(b); + return static_cast(a) != static_cast(b); #endif } @@ -570,7 +571,7 @@ HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hlt(half(a), half(b)); #else - return float(a) < float(b); + return static_cast(a) < static_cast(b); #endif } @@ -578,7 +579,7 @@ HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hle(half(a), half(b)); #else - return float(a) <= float(b); + return static_cast(a) <= static_cast(b); #endif } @@ -586,7 +587,7 @@ HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hgt(half(a), half(b)); #else - return float(a) > float(b); + return static_cast(a) > static_cast(b); #endif } @@ -594,7 +595,7 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hge(half(a), half(b)); #else - return float(a) >= float(b); + return static_cast(a) >= static_cast(b); #endif } @@ -679,22 +680,22 @@ inline float16 operator-(const float16& a) { return res; } -inline float16& operator+=(float16& a, const float16& b) { +inline float16& operator+=(float16& a, const float16& b) { // NOLINT a = a + b; return a; } -inline float16& operator-=(float16& a, const float16& b) { +inline float16& operator-=(float16& a, const float16& b) { // NOLINT a = a - b; return a; } -inline float16& operator*=(float16& a, const float16& b) { +inline float16& operator*=(float16& a, const float16& b) { // NOLINT a = a * b; return a; } -inline float16& operator/=(float16& a, const float16& b) { +inline float16& operator/=(float16& a, const float16& b) { // NOLINT a = a / b; return a; } @@ -784,19 +785,19 @@ inline bool operator>=(const float16& a, const float16& b) { // Arithmetic operators for float16, software emulated on other CPU #else inline float16 operator+(const float16& a, const float16& b) { - return float16(float(a) + float(b)); + return float16(static_cast(a) + static_cast(b)); } inline float16 operator-(const float16& a, const float16& b) { - return float16(float(a) - float(b)); + return float16(static_cast(a) - static_cast(b)); } inline float16 operator*(const float16& a, const float16& b) { - return float16(float(a) * float(b)); + return float16(static_cast(a) * static_cast(b)); } inline float16 operator/(const float16& a, const float16& b) { - return float16(float(a) / float(b)); + return float16(static_cast(a) / static_cast(b)); } inline float16 operator-(const float16& a) { @@ -805,51 +806,57 @@ inline float16 operator-(const float16& a) { return res; } -inline float16& operator+=(float16& a, const float16& b) { - a = float16(float(a) + float(b)); +inline float16& operator+=(float16& a, const float16& b) { // NOLINT + a = float16(static_cast(a) + static_cast(b)); return a; } -inline float16& operator-=(float16& a, const float16& b) { - a = float16(float(a) - float(b)); +inline float16& operator-=(float16& a, const float16& b) { // NOLINT + a = float16(static_cast(a) - static_cast(b)); return a; } -inline float16& operator*=(float16& a, const float16& b) { - a = float16(float(a) * float(b)); +inline float16& operator*=(float16& a, const float16& b) { // NOLINT + a = float16(static_cast(a) * static_cast(b)); return a; } -inline float16& operator/=(float16& a, const float16& b) { - a = float16(float(a) / float(b)); +inline float16& operator/=(float16& a, const float16& b) { // NOLINT + a = float16(static_cast(a) / static_cast(b)); return a; } inline bool operator==(const float16& a, const float16& b) { - return float(a) == float(b); + return static_cast(a) == static_cast(b); } inline bool operator!=(const float16& a, const float16& b) { - return float(a) != float(b); + return static_cast(a) != static_cast(b); } inline bool operator<(const float16& a, const float16& b) { - return float(a) < float(b); + return static_cast(a) < static_cast(b); } inline bool operator<=(const float16& a, const float16& b) { - return float(a) <= float(b); + return static_cast(a) <= static_cast(b); } inline bool operator>(const float16& a, const float16& b) { - return float(a) > float(b); + return static_cast(a) > static_cast(b); } inline bool operator>=(const float16& a, const float16& b) { - return float(a) >= float(b); + return static_cast(a) >= static_cast(b); } #endif +HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) { + float16 res; + res.x = a; + return res; +} + HOSTDEVICE inline bool(isnan)(const float16& a) { #if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hisnan(half(a)); @@ -886,28 +893,116 @@ struct is_pod { is_standard_layout::value; }; +template <> +struct numeric_limits { + static const bool is_specialized = true; + static const bool is_signed = true; + static const bool is_integer = false; + static const bool is_exact = false; + static const bool has_infinity = true; + static const bool has_quiet_NaN = true; + static const bool has_signaling_NaN = true; + static const float_denorm_style has_denorm = denorm_present; + static const bool has_denorm_loss = false; + static const std::float_round_style round_style = std::round_to_nearest; + static const bool is_iec559 = false; + static const bool is_bounded = false; + static const bool is_modulo = false; + static const int digits = 11; + static const int digits10 = 3; + static const int max_digits10 = 5; + static const int radix = 2; + static const int min_exponent = -13; + static const int min_exponent10 = -4; + static const int max_exponent = 16; + static const int max_exponent10 = 4; + static const bool traps = true; + static const bool tinyness_before = false; + + static paddle::platform::float16(min)() { + return paddle::platform::raw_uint16_to_float16(0x400); + } + static paddle::platform::float16 lowest() { + return paddle::platform::raw_uint16_to_float16(0xfbff); + } + static paddle::platform::float16(max)() { + return paddle::platform::raw_uint16_to_float16(0x7bff); + } + static paddle::platform::float16 epsilon() { + return paddle::platform::raw_uint16_to_float16(0x0800); + } + static paddle::platform::float16 round_error() { + return paddle::platform::float16(0.5); + } + static paddle::platform::float16 infinity() { + return paddle::platform::raw_uint16_to_float16(0x7c00); + } + static paddle::platform::float16 quiet_NaN() { + return paddle::platform::raw_uint16_to_float16(0x7e00); + } + static paddle::platform::float16 signaling_NaN() { + return paddle::platform::raw_uint16_to_float16(0x7e00); + } + static paddle::platform::float16 denorm_min() { + return paddle::platform::raw_uint16_to_float16(0x1); + } +}; + } // namespace std namespace Eigen { + +using float16 = paddle::platform::float16; + +template <> +struct NumTraits : GenericNumTraits { + enum { + IsSigned = true, + IsInteger = false, + IsComplex = false, + RequireInitialization = false + }; + + HOSTDEVICE static inline float16 epsilon() { + return paddle::platform::raw_uint16_to_float16(0x0800); + } + HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); } + HOSTDEVICE static inline float16 highest() { + return paddle::platform::raw_uint16_to_float16(0x7bff); + } + HOSTDEVICE static inline float16 lowest() { + return paddle::platform::raw_uint16_to_float16(0xfbff); + } + HOSTDEVICE static inline float16 infinity() { + return paddle::platform::raw_uint16_to_float16(0x7c00); + } + HOSTDEVICE static inline float16 quiet_NaN() { + return paddle::platform::raw_uint16_to_float16(0x7c01); + } +}; + namespace numext { template <> -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)( - const paddle::platform::float16& a) { +HOSTDEVICE inline bool(isnan)(const float16& a) { return (paddle::platform::isnan)(a); } template <> -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)( - const paddle::platform::float16& a) { +HOSTDEVICE inline bool(isinf)(const float16& a) { return (paddle::platform::isinf)(a); } template <> -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)( - const paddle::platform::float16& a) { +HOSTDEVICE inline bool(isfinite)(const float16& a) { return (paddle::platform::isfinite)(a); } +template <> +HOSTDEVICE inline float16 exp(const float16& a) { + return float16(::expf(static_cast(a))); +} + } // namespace numext + } // namespace Eigen diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 33d60c7e31c..279f3073f73 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -68,6 +68,17 @@ class TestSoftmaxCUDNNOp(TestSoftmaxOp): self.use_cudnn = True +class TestSoftmaxFP16Op(TestSoftmaxOp): + def init_kernel_type(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp): def init_kernel_type(self): self.use_cudnn = True -- GitLab