diff --git a/paddle/fluid/operators/gelu_op.cu b/paddle/fluid/operators/gelu_op.cu index 6a4a322b327a0dab8e8930c552b244de69a51739..d533d79a036fd91050065e926ba7203efa7bb893 100644 --- a/paddle/fluid/operators/gelu_op.cu +++ b/paddle/fluid/operators/gelu_op.cu @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/gelu_op.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -27,9 +26,11 @@ struct GeluWithApproximateFunctor { // this function is tanh approximation of gelu MPType x = static_cast(arg_x); MPType one = static_cast(1); - MPType out = x * static_cast(0.5) * - (one + tanh(static_cast(0.79788456) * x * - (one + static_cast(0.044715) * x * x))); + MPType half = static_cast(0.5); + MPType kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + auto tanh_out = + tanh(kAlpha * x * (one + static_cast(GELU_CONSTANT) * x * x)); + MPType out = x * half * (one + tanh_out); return static_cast(out); } }; @@ -40,9 +41,10 @@ struct GeluWithoutApproximateFunctor { inline HOSTDEVICE T operator()(T arg_x) { // actual gelu with approximation = false MPType x = static_cast(arg_x); + MPType one = static_cast(1); + MPType half = static_cast(0.5); MPType erf_out = erf(x * static_cast(M_SQRT1_2)); - MPType out = - x * static_cast(0.5) * (static_cast(1) + erf_out); + MPType out = x * half * (one + erf_out); return static_cast(out); } }; @@ -71,6 +73,68 @@ class GeluKernel } }; +template +struct GeluWithApproximateGradFunctor { + using MPType = typename details::MPTypeTrait::Type; + inline HOSTDEVICE T operator()(T arg_x, T arg_dout) { + MPType x = static_cast(arg_x); + MPType dout = static_cast(arg_dout); + MPType one = static_cast(1); + MPType half = static_cast(0.5); + MPType kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + MPType kBeta = + kAlpha * static_cast(GELU_CONSTANT) * static_cast(3); + auto cube_x = x * x * x; + auto tanh_out = + tanh(kAlpha * ((static_cast(GELU_CONSTANT) * cube_x) + x)); + auto ans = + half * (one + tanh_out + + (one - tanh_out * tanh_out) * (x * kAlpha + kBeta * cube_x)); + return static_cast(ans * dout); + } +}; + +template +struct GeluWithoutApproximateGradFunctor { + using MPType = typename details::MPTypeTrait::Type; + inline HOSTDEVICE T operator()(T arg_x, T arg_dout) { + MPType x = static_cast(arg_x); + MPType dout = static_cast(arg_dout); + MPType one = static_cast(1); + MPType half = static_cast(0.5); + MPType kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + auto ans = half * (one + erf(x * static_cast(M_SQRT1_2))) + + half * kAlpha * x * exp(-half * x * x); + return static_cast(ans * dout); + } +}; + +template +class GeluGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* dout = + context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + auto approximate = context.Attr("approximate"); + dx->mutable_data(dout->place()); + + std::vector ins = {x, dout}; + std::vector outs = {dx}; + const auto& dev_ctx = + context.template device_context(); + if (approximate) { + LaunchElementwiseCudaKernel( + dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor()); + } else { + LaunchElementwiseCudaKernel( + dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor()); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/gelu_op.h b/paddle/fluid/operators/gelu_op.h index 0446d7d284b2237c033865b1d2280e0c661b1002..a913b8a1112793a835eb6638e8a5d18664f6eb34 100644 --- a/paddle/fluid/operators/gelu_op.h +++ b/paddle/fluid/operators/gelu_op.h @@ -30,6 +30,8 @@ limitations under the License. */ namespace paddle { namespace operators { +#define GELU_CONSTANT 0.044715 + template struct GeluFunctor { template @@ -41,14 +43,14 @@ struct GeluFunctor { auto casted_x = x.template cast(); auto temp = (static_cast(M_2_SQRTPI * M_SQRT1_2) * - (casted_x + static_cast(0.044715) * casted_x.cube())) + (casted_x + static_cast(GELU_CONSTANT) * casted_x.cube())) .tanh(); out.device(d) = (casted_x * static_cast(0.5) * (static_cast(1) + temp)) .template cast(); } else { auto temp = (static_cast(M_2_SQRTPI * M_SQRT1_2) * - (x + static_cast(0.044715) * x.cube())) + (x + static_cast(GELU_CONSTANT) * x.cube())) .tanh(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); } @@ -101,10 +103,10 @@ struct GeluGradFunctor { const float kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); const float kBeta = - kAlpha * static_cast(0.044715) * static_cast(3); + kAlpha * static_cast(GELU_CONSTANT) * static_cast(3); const auto y = (kAlpha * - ((static_cast(0.044715) * casted_x.cube()) + casted_x)) + ((static_cast(GELU_CONSTANT) * casted_x.cube()) + casted_x)) .tanh(); dx.device(d) = (static_cast(0.5) * casted_dout * (static_cast(1) + y + @@ -113,9 +115,10 @@ struct GeluGradFunctor { .template cast(); } else { const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + const T kBeta = + kAlpha * static_cast(GELU_CONSTANT) * static_cast(3); const auto y = - (kAlpha * ((static_cast(0.044715) * x.cube()) + x)).tanh(); + (kAlpha * ((static_cast(GELU_CONSTANT) * x.cube()) + x)).tanh(); dx.device(d) = static_cast(0.5) * dout * (static_cast(1) + y + (x - x * y.square()) * (kAlpha + kBeta * x.square()));