提交 b2a1c9e8 编写于 作者: K Kexin Zhao 提交者: Yi Wang

Add float16 support to non-cudnn softmax op on GPU (#9686)

* initial commit

* fix error

* fix typo and order
上级 797a7184
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/fluid/operators/math/softmax_impl.h"
...@@ -95,6 +97,7 @@ template class SoftmaxCUDNNFunctor<double>; ...@@ -95,6 +97,7 @@ template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<float>; template class SoftmaxGradCUDNNFunctor<float>;
template class SoftmaxGradCUDNNFunctor<double>; template class SoftmaxGradCUDNNFunctor<double>;
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float>; template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxFunctor<platform::CUDADeviceContext, double>; template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>; template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
......
...@@ -27,7 +27,7 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; ...@@ -27,7 +27,7 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
struct ValueClip { struct ValueClip {
HOSTDEVICE T operator()(const T& x) const { HOSTDEVICE T operator()(const T& x) const {
const T kThreshold = -64.; const T kThreshold = static_cast<T>(-64.);
return x < kThreshold ? kThreshold : x; return x < kThreshold ? kThreshold : x;
} }
}; };
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#include <string>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif #endif
...@@ -20,6 +23,7 @@ limitations under the License. */ ...@@ -20,6 +23,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -60,8 +64,8 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -60,8 +64,8 @@ class SoftmaxOp : public framework::OperatorWithKernel {
auto input_data_type = auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type()); framework::ToDataType(ctx.Input<Tensor>("X")->type());
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used when CUDNN is used"); "float16 can only be used on GPU place");
} }
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
...@@ -70,6 +74,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -70,6 +74,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
library_); library_);
} }
}; };
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker) SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
...@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
softmax_grad, softmax, ops::SoftmaxKernel<plat::CUDADeviceContext, float>,
ops::SoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>); ops::SoftmaxKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(softmax_grad,
ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>);
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include <limits>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <cuda.h> #include <cuda.h>
...@@ -293,39 +294,39 @@ struct PADDLE_ALIGN(2) float16 { ...@@ -293,39 +294,39 @@ struct PADDLE_ALIGN(2) float16 {
HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; } HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }
HOSTDEVICE inline explicit operator int8_t() const { HOSTDEVICE inline explicit operator int8_t() const {
return static_cast<int8_t>(float(*this)); return static_cast<int8_t>(static_cast<float>(*this));
} }
HOSTDEVICE inline explicit operator uint8_t() const { HOSTDEVICE inline explicit operator uint8_t() const {
return static_cast<uint8_t>(float(*this)); return static_cast<uint8_t>(static_cast<float>(*this));
} }
HOSTDEVICE inline explicit operator int16_t() const { HOSTDEVICE inline explicit operator int16_t() const {
return static_cast<int16_t>(float(*this)); return static_cast<int16_t>(static_cast<float>(*this));
} }
HOSTDEVICE inline explicit operator uint16_t() const { HOSTDEVICE inline explicit operator uint16_t() const {
return static_cast<uint16_t>(float(*this)); return static_cast<uint16_t>(static_cast<float>(*this));
} }
HOSTDEVICE inline explicit operator int32_t() const { HOSTDEVICE inline explicit operator int32_t() const {
return static_cast<int32_t>(float(*this)); return static_cast<int32_t>(static_cast<float>(*this));
} }
HOSTDEVICE inline explicit operator uint32_t() const { HOSTDEVICE inline explicit operator uint32_t() const {
return static_cast<uint32_t>(float(*this)); return static_cast<uint32_t>(static_cast<float>(*this));
} }
HOSTDEVICE inline explicit operator int64_t() const { HOSTDEVICE inline explicit operator int64_t() const {
return static_cast<int64_t>(float(*this)); return static_cast<int64_t>(static_cast<float>(*this));
} }
HOSTDEVICE inline explicit operator uint64_t() const { HOSTDEVICE inline explicit operator uint64_t() const {
return static_cast<uint64_t>(float(*this)); return static_cast<uint64_t>(static_cast<float>(*this));
} }
HOSTDEVICE inline explicit operator double() const { HOSTDEVICE inline explicit operator double() const {
return static_cast<double>(float(*this)); return static_cast<double>(static_cast<float>(*this));
} }
private: private:
...@@ -370,7 +371,7 @@ DEVICE inline half operator+(const half& a, const half& b) { ...@@ -370,7 +371,7 @@ DEVICE inline half operator+(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hadd(a, b); return __hadd(a, b);
#else #else
float res = float(float16(a)) + float(float16(b)); float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
return half(float16(res)); return half(float16(res));
#endif #endif
} }
...@@ -379,7 +380,7 @@ DEVICE inline half operator-(const half& a, const half& b) { ...@@ -379,7 +380,7 @@ DEVICE inline half operator-(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hsub(a, b); return __hsub(a, b);
#else #else
float res = float(float16(a)) - float(float16(b)); float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
return half(float16(res)); return half(float16(res));
#endif #endif
} }
...@@ -388,7 +389,7 @@ DEVICE inline half operator*(const half& a, const half& b) { ...@@ -388,7 +389,7 @@ DEVICE inline half operator*(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hmul(a, b); return __hmul(a, b);
#else #else
float res = float(float16(a)) * float(float16(b)); float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
return half(float16(res)); return half(float16(res));
#endif #endif
} }
...@@ -399,7 +400,7 @@ DEVICE inline half operator/(const half& a, const half& b) { ...@@ -399,7 +400,7 @@ DEVICE inline half operator/(const half& a, const half& b) {
float denom = __half2float(b); float denom = __half2float(b);
return __float2half(num / denom); return __float2half(num / denom);
#else #else
float res = float(float16(a)) / float(float16(b)); float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
return half(float16(res)); return half(float16(res));
#endif #endif
} }
...@@ -408,27 +409,27 @@ DEVICE inline half operator-(const half& a) { ...@@ -408,27 +409,27 @@ DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hneg(a); return __hneg(a);
#else #else
float res = -float(float16(a)); float res = -static_cast<float>(float16(a));
return half(float16(res)); return half(float16(res));
#endif #endif
} }
DEVICE inline half& operator+=(half& a, const half& b) { DEVICE inline half& operator+=(half& a, const half& b) { // NOLINT
a = a + b; a = a + b;
return a; return a;
} }
DEVICE inline half& operator-=(half& a, const half& b) { DEVICE inline half& operator-=(half& a, const half& b) { // NOLINT
a = a - b; a = a - b;
return a; return a;
} }
DEVICE inline half& operator*=(half& a, const half& b) { DEVICE inline half& operator*=(half& a, const half& b) { // NOLINT
a = a * b; a = a * b;
return a; return a;
} }
DEVICE inline half& operator/=(half& a, const half& b) { DEVICE inline half& operator/=(half& a, const half& b) { // NOLINT
a = a / b; a = a / b;
return a; return a;
} }
...@@ -437,7 +438,7 @@ DEVICE inline bool operator==(const half& a, const half& b) { ...@@ -437,7 +438,7 @@ DEVICE inline bool operator==(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(a, b); return __heq(a, b);
#else #else
return float(float16(a)) == float(float16(b)); return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
#endif #endif
} }
...@@ -445,7 +446,7 @@ DEVICE inline bool operator!=(const half& a, const half& b) { ...@@ -445,7 +446,7 @@ DEVICE inline bool operator!=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(a, b); return __hne(a, b);
#else #else
return float(float16(a)) != float(float16(b)); return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
#endif #endif
} }
...@@ -453,7 +454,7 @@ DEVICE inline bool operator<(const half& a, const half& b) { ...@@ -453,7 +454,7 @@ DEVICE inline bool operator<(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(a, b); return __hlt(a, b);
#else #else
return float(float16(a)) < float(float16(b)); return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
#endif #endif
} }
...@@ -461,7 +462,7 @@ DEVICE inline bool operator<=(const half& a, const half& b) { ...@@ -461,7 +462,7 @@ DEVICE inline bool operator<=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(a, b); return __hle(a, b);
#else #else
return float(float16(a)) <= float(float16(b)); return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
#endif #endif
} }
...@@ -469,7 +470,7 @@ DEVICE inline bool operator>(const half& a, const half& b) { ...@@ -469,7 +470,7 @@ DEVICE inline bool operator>(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(a, b); return __hgt(a, b);
#else #else
return float(float16(a)) > float(float16(b)); return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
#endif #endif
} }
...@@ -477,7 +478,7 @@ DEVICE inline bool operator>=(const half& a, const half& b) { ...@@ -477,7 +478,7 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(a, b); return __hge(a, b);
#else #else
return float(float16(a)) >= float(float16(b)); return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
#endif #endif
} }
...@@ -489,7 +490,7 @@ HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { ...@@ -489,7 +490,7 @@ HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hadd(half(a), half(b))); return float16(__hadd(half(a), half(b)));
#else #else
return float16(float(a) + float(b)); return float16(static_cast<float>(a) + static_cast<float>(b));
#endif #endif
} }
...@@ -497,7 +498,7 @@ HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { ...@@ -497,7 +498,7 @@ HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hsub(half(a), half(b))); return float16(__hsub(half(a), half(b)));
#else #else
return float16(float(a) - float(b)); return float16(static_cast<float>(a) - static_cast<float>(b));
#endif #endif
} }
...@@ -505,7 +506,7 @@ HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { ...@@ -505,7 +506,7 @@ HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hmul(half(a), half(b))); return float16(__hmul(half(a), half(b)));
#else #else
return float16(float(a) * float(b)); return float16(static_cast<float>(a) * static_cast<float>(b));
#endif #endif
} }
...@@ -516,7 +517,7 @@ HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { ...@@ -516,7 +517,7 @@ HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
float denom = __half2float(half(b)); float denom = __half2float(half(b));
return float16(num / denom); return float16(num / denom);
#else #else
return float16(float(a) / float(b)); return float16(static_cast<float>(a) / static_cast<float>(b));
#endif #endif
} }
...@@ -530,22 +531,22 @@ HOSTDEVICE inline float16 operator-(const float16& a) { ...@@ -530,22 +531,22 @@ HOSTDEVICE inline float16 operator-(const float16& a) {
#endif #endif
} }
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { // NOLINT
a = a + b; a = a + b;
return a; return a;
} }
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { // NOLINT
a = a - b; a = a - b;
return a; return a;
} }
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { // NOLINT
a = a * b; a = a * b;
return a; return a;
} }
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { // NOLINT
a = a / b; a = a / b;
return a; return a;
} }
...@@ -554,7 +555,7 @@ HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { ...@@ -554,7 +555,7 @@ HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(half(a), half(b)); return __heq(half(a), half(b));
#else #else
return float(a) == float(b); return static_cast<float>(a) == static_cast<float>(b);
#endif #endif
} }
...@@ -562,7 +563,7 @@ HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { ...@@ -562,7 +563,7 @@ HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(half(a), half(b)); return __hne(half(a), half(b));
#else #else
return float(a) != float(b); return static_cast<float>(a) != static_cast<float>(b);
#endif #endif
} }
...@@ -570,7 +571,7 @@ HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { ...@@ -570,7 +571,7 @@ HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(half(a), half(b)); return __hlt(half(a), half(b));
#else #else
return float(a) < float(b); return static_cast<float>(a) < static_cast<float>(b);
#endif #endif
} }
...@@ -578,7 +579,7 @@ HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { ...@@ -578,7 +579,7 @@ HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(half(a), half(b)); return __hle(half(a), half(b));
#else #else
return float(a) <= float(b); return static_cast<float>(a) <= static_cast<float>(b);
#endif #endif
} }
...@@ -586,7 +587,7 @@ HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { ...@@ -586,7 +587,7 @@ HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(half(a), half(b)); return __hgt(half(a), half(b));
#else #else
return float(a) > float(b); return static_cast<float>(a) > static_cast<float>(b);
#endif #endif
} }
...@@ -594,7 +595,7 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { ...@@ -594,7 +595,7 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(half(a), half(b)); return __hge(half(a), half(b));
#else #else
return float(a) >= float(b); return static_cast<float>(a) >= static_cast<float>(b);
#endif #endif
} }
...@@ -679,22 +680,22 @@ inline float16 operator-(const float16& a) { ...@@ -679,22 +680,22 @@ inline float16 operator-(const float16& a) {
return res; return res;
} }
inline float16& operator+=(float16& a, const float16& b) { inline float16& operator+=(float16& a, const float16& b) { // NOLINT
a = a + b; a = a + b;
return a; return a;
} }
inline float16& operator-=(float16& a, const float16& b) { inline float16& operator-=(float16& a, const float16& b) { // NOLINT
a = a - b; a = a - b;
return a; return a;
} }
inline float16& operator*=(float16& a, const float16& b) { inline float16& operator*=(float16& a, const float16& b) { // NOLINT
a = a * b; a = a * b;
return a; return a;
} }
inline float16& operator/=(float16& a, const float16& b) { inline float16& operator/=(float16& a, const float16& b) { // NOLINT
a = a / b; a = a / b;
return a; return a;
} }
...@@ -784,19 +785,19 @@ inline bool operator>=(const float16& a, const float16& b) { ...@@ -784,19 +785,19 @@ inline bool operator>=(const float16& a, const float16& b) {
// Arithmetic operators for float16, software emulated on other CPU // Arithmetic operators for float16, software emulated on other CPU
#else #else
inline float16 operator+(const float16& a, const float16& b) { inline float16 operator+(const float16& a, const float16& b) {
return float16(float(a) + float(b)); return float16(static_cast<float>(a) + static_cast<float>(b));
} }
inline float16 operator-(const float16& a, const float16& b) { inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b)); return float16(static_cast<float>(a) - static_cast<float>(b));
} }
inline float16 operator*(const float16& a, const float16& b) { inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b)); return float16(static_cast<float>(a) * static_cast<float>(b));
} }
inline float16 operator/(const float16& a, const float16& b) { inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b)); return float16(static_cast<float>(a) / static_cast<float>(b));
} }
inline float16 operator-(const float16& a) { inline float16 operator-(const float16& a) {
...@@ -805,51 +806,57 @@ inline float16 operator-(const float16& a) { ...@@ -805,51 +806,57 @@ inline float16 operator-(const float16& a) {
return res; return res;
} }
inline float16& operator+=(float16& a, const float16& b) { inline float16& operator+=(float16& a, const float16& b) { // NOLINT
a = float16(float(a) + float(b)); a = float16(static_cast<float>(a) + static_cast<float>(b));
return a; return a;
} }
inline float16& operator-=(float16& a, const float16& b) { inline float16& operator-=(float16& a, const float16& b) { // NOLINT
a = float16(float(a) - float(b)); a = float16(static_cast<float>(a) - static_cast<float>(b));
return a; return a;
} }
inline float16& operator*=(float16& a, const float16& b) { inline float16& operator*=(float16& a, const float16& b) { // NOLINT
a = float16(float(a) * float(b)); a = float16(static_cast<float>(a) * static_cast<float>(b));
return a; return a;
} }
inline float16& operator/=(float16& a, const float16& b) { inline float16& operator/=(float16& a, const float16& b) { // NOLINT
a = float16(float(a) / float(b)); a = float16(static_cast<float>(a) / static_cast<float>(b));
return a; return a;
} }
inline bool operator==(const float16& a, const float16& b) { inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b); return static_cast<float>(a) == static_cast<float>(b);
} }
inline bool operator!=(const float16& a, const float16& b) { inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b); return static_cast<float>(a) != static_cast<float>(b);
} }
inline bool operator<(const float16& a, const float16& b) { inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b); return static_cast<float>(a) < static_cast<float>(b);
} }
inline bool operator<=(const float16& a, const float16& b) { inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b); return static_cast<float>(a) <= static_cast<float>(b);
} }
inline bool operator>(const float16& a, const float16& b) { inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b); return static_cast<float>(a) > static_cast<float>(b);
} }
inline bool operator>=(const float16& a, const float16& b) { inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b); return static_cast<float>(a) >= static_cast<float>(b);
} }
#endif #endif
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
float16 res;
res.x = a;
return res;
}
HOSTDEVICE inline bool(isnan)(const float16& a) { HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hisnan(half(a)); return __hisnan(half(a));
...@@ -886,28 +893,116 @@ struct is_pod<paddle::platform::float16> { ...@@ -886,28 +893,116 @@ struct is_pod<paddle::platform::float16> {
is_standard_layout<paddle::platform::float16>::value; is_standard_layout<paddle::platform::float16>::value;
}; };
template <>
struct numeric_limits<paddle::platform::float16> {
static const bool is_specialized = true;
static const bool is_signed = true;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
static const float_denorm_style has_denorm = denorm_present;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = std::round_to_nearest;
static const bool is_iec559 = false;
static const bool is_bounded = false;
static const bool is_modulo = false;
static const int digits = 11;
static const int digits10 = 3;
static const int max_digits10 = 5;
static const int radix = 2;
static const int min_exponent = -13;
static const int min_exponent10 = -4;
static const int max_exponent = 16;
static const int max_exponent10 = 4;
static const bool traps = true;
static const bool tinyness_before = false;
static paddle::platform::float16(min)() {
return paddle::platform::raw_uint16_to_float16(0x400);
}
static paddle::platform::float16 lowest() {
return paddle::platform::raw_uint16_to_float16(0xfbff);
}
static paddle::platform::float16(max)() {
return paddle::platform::raw_uint16_to_float16(0x7bff);
}
static paddle::platform::float16 epsilon() {
return paddle::platform::raw_uint16_to_float16(0x0800);
}
static paddle::platform::float16 round_error() {
return paddle::platform::float16(0.5);
}
static paddle::platform::float16 infinity() {
return paddle::platform::raw_uint16_to_float16(0x7c00);
}
static paddle::platform::float16 quiet_NaN() {
return paddle::platform::raw_uint16_to_float16(0x7e00);
}
static paddle::platform::float16 signaling_NaN() {
return paddle::platform::raw_uint16_to_float16(0x7e00);
}
static paddle::platform::float16 denorm_min() {
return paddle::platform::raw_uint16_to_float16(0x1);
}
};
} // namespace std } // namespace std
namespace Eigen { namespace Eigen {
using float16 = paddle::platform::float16;
template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
enum {
IsSigned = true,
IsInteger = false,
IsComplex = false,
RequireInitialization = false
};
HOSTDEVICE static inline float16 epsilon() {
return paddle::platform::raw_uint16_to_float16(0x0800);
}
HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); }
HOSTDEVICE static inline float16 highest() {
return paddle::platform::raw_uint16_to_float16(0x7bff);
}
HOSTDEVICE static inline float16 lowest() {
return paddle::platform::raw_uint16_to_float16(0xfbff);
}
HOSTDEVICE static inline float16 infinity() {
return paddle::platform::raw_uint16_to_float16(0x7c00);
}
HOSTDEVICE static inline float16 quiet_NaN() {
return paddle::platform::raw_uint16_to_float16(0x7c01);
}
};
namespace numext { namespace numext {
template <> template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)( HOSTDEVICE inline bool(isnan)(const float16& a) {
const paddle::platform::float16& a) {
return (paddle::platform::isnan)(a); return (paddle::platform::isnan)(a);
} }
template <> template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)( HOSTDEVICE inline bool(isinf)(const float16& a) {
const paddle::platform::float16& a) {
return (paddle::platform::isinf)(a); return (paddle::platform::isinf)(a);
} }
template <> template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)( HOSTDEVICE inline bool(isfinite)(const float16& a) {
const paddle::platform::float16& a) {
return (paddle::platform::isfinite)(a); return (paddle::platform::isfinite)(a);
} }
template <>
HOSTDEVICE inline float16 exp(const float16& a) {
return float16(::expf(static_cast<float>(a)));
}
} // namespace numext } // namespace numext
} // namespace Eigen } // namespace Eigen
...@@ -68,6 +68,17 @@ class TestSoftmaxCUDNNOp(TestSoftmaxOp): ...@@ -68,6 +68,17 @@ class TestSoftmaxCUDNNOp(TestSoftmaxOp):
self.use_cudnn = True self.use_cudnn = True
class TestSoftmaxFP16Op(TestSoftmaxOp):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp): class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp):
def init_kernel_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册