未验证 提交 858e4358 编写于 作者: C crystal 提交者: GitHub

use elementwise to optimize gelu backward implementation on GPU (#38263)

* optimize gelu backward

* optimize gelu backward

* optimize code

* Number to expression

* Replacement number
上级 d48d7128
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gelu_op.h" #include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,9 +26,11 @@ struct GeluWithApproximateFunctor { ...@@ -27,9 +26,11 @@ struct GeluWithApproximateFunctor {
// this function is tanh approximation of gelu // this function is tanh approximation of gelu
MPType x = static_cast<MPType>(arg_x); MPType x = static_cast<MPType>(arg_x);
MPType one = static_cast<MPType>(1); MPType one = static_cast<MPType>(1);
MPType out = x * static_cast<MPType>(0.5) * MPType half = static_cast<MPType>(0.5);
(one + tanh(static_cast<MPType>(0.79788456) * x * MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
(one + static_cast<MPType>(0.044715) * x * x))); auto tanh_out =
tanh(kAlpha * x * (one + static_cast<MPType>(GELU_CONSTANT) * x * x));
MPType out = x * half * (one + tanh_out);
return static_cast<T>(out); return static_cast<T>(out);
} }
}; };
...@@ -40,9 +41,10 @@ struct GeluWithoutApproximateFunctor { ...@@ -40,9 +41,10 @@ struct GeluWithoutApproximateFunctor {
inline HOSTDEVICE T operator()(T arg_x) { inline HOSTDEVICE T operator()(T arg_x) {
// actual gelu with approximation = false // actual gelu with approximation = false
MPType x = static_cast<MPType>(arg_x); MPType x = static_cast<MPType>(arg_x);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType erf_out = erf(x * static_cast<MPType>(M_SQRT1_2)); MPType erf_out = erf(x * static_cast<MPType>(M_SQRT1_2));
MPType out = MPType out = x * half * (one + erf_out);
x * static_cast<MPType>(0.5) * (static_cast<MPType>(1) + erf_out);
return static_cast<T>(out); return static_cast<T>(out);
} }
}; };
...@@ -71,6 +73,68 @@ class GeluKernel<platform::CUDADeviceContext, T> ...@@ -71,6 +73,68 @@ class GeluKernel<platform::CUDADeviceContext, T>
} }
}; };
template <typename T>
struct GeluWithApproximateGradFunctor {
using MPType = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
MPType kBeta =
kAlpha * static_cast<MPType>(GELU_CONSTANT) * static_cast<MPType>(3);
auto cube_x = x * x * x;
auto tanh_out =
tanh(kAlpha * ((static_cast<MPType>(GELU_CONSTANT) * cube_x) + x));
auto ans =
half * (one + tanh_out +
(one - tanh_out * tanh_out) * (x * kAlpha + kBeta * cube_x));
return static_cast<T>(ans * dout);
}
};
template <typename T>
struct GeluWithoutApproximateGradFunctor {
using MPType = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
auto ans = half * (one + erf(x * static_cast<MPType>(M_SQRT1_2))) +
half * kAlpha * x * exp(-half * x * x);
return static_cast<T>(ans * dout);
}
};
template <typename T>
class GeluGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto approximate = context.Attr<bool>("approximate");
dx->mutable_data<T>(dout->place());
std::vector<const framework::Tensor*> ins = {x, dout};
std::vector<framework::Tensor*> outs = {dx};
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
if (approximate) {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>());
} else {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor<T>());
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -30,6 +30,8 @@ limitations under the License. */ ...@@ -30,6 +30,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define GELU_CONSTANT 0.044715
template <typename T> template <typename T>
struct GeluFunctor { struct GeluFunctor {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
...@@ -41,14 +43,14 @@ struct GeluFunctor { ...@@ -41,14 +43,14 @@ struct GeluFunctor {
auto casted_x = x.template cast<float>(); auto casted_x = x.template cast<float>();
auto temp = auto temp =
(static_cast<float>(M_2_SQRTPI * M_SQRT1_2) * (static_cast<float>(M_2_SQRTPI * M_SQRT1_2) *
(casted_x + static_cast<float>(0.044715) * casted_x.cube())) (casted_x + static_cast<float>(GELU_CONSTANT) * casted_x.cube()))
.tanh(); .tanh();
out.device(d) = (casted_x * static_cast<float>(0.5) * out.device(d) = (casted_x * static_cast<float>(0.5) *
(static_cast<float>(1) + temp)) (static_cast<float>(1) + temp))
.template cast<T>(); .template cast<T>();
} else { } else {
auto temp = (static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * auto temp = (static_cast<T>(M_2_SQRTPI * M_SQRT1_2) *
(x + static_cast<T>(0.044715) * x.cube())) (x + static_cast<T>(GELU_CONSTANT) * x.cube()))
.tanh(); .tanh();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp); out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
} }
...@@ -101,10 +103,10 @@ struct GeluGradFunctor { ...@@ -101,10 +103,10 @@ struct GeluGradFunctor {
const float kAlpha = static_cast<float>(M_2_SQRTPI * M_SQRT1_2); const float kAlpha = static_cast<float>(M_2_SQRTPI * M_SQRT1_2);
const float kBeta = const float kBeta =
kAlpha * static_cast<float>(0.044715) * static_cast<float>(3); kAlpha * static_cast<float>(GELU_CONSTANT) * static_cast<float>(3);
const auto y = const auto y =
(kAlpha * (kAlpha *
((static_cast<float>(0.044715) * casted_x.cube()) + casted_x)) ((static_cast<float>(GELU_CONSTANT) * casted_x.cube()) + casted_x))
.tanh(); .tanh();
dx.device(d) = (static_cast<float>(0.5) * casted_dout * dx.device(d) = (static_cast<float>(0.5) * casted_dout *
(static_cast<float>(1) + y + (static_cast<float>(1) + y +
...@@ -113,9 +115,10 @@ struct GeluGradFunctor { ...@@ -113,9 +115,10 @@ struct GeluGradFunctor {
.template cast<T>(); .template cast<T>();
} else { } else {
const T kAlpha = static_cast<T>(M_2_SQRTPI * M_SQRT1_2); const T kAlpha = static_cast<T>(M_2_SQRTPI * M_SQRT1_2);
const T kBeta = kAlpha * static_cast<T>(0.044715) * static_cast<T>(3); const T kBeta =
kAlpha * static_cast<T>(GELU_CONSTANT) * static_cast<T>(3);
const auto y = const auto y =
(kAlpha * ((static_cast<T>(0.044715) * x.cube()) + x)).tanh(); (kAlpha * ((static_cast<T>(GELU_CONSTANT) * x.cube()) + x)).tanh();
dx.device(d) = static_cast<T>(0.5) * dout * dx.device(d) = static_cast<T>(0.5) * dout *
(static_cast<T>(1) + y + (static_cast<T>(1) + y +
(x - x * y.square()) * (kAlpha + kBeta * x.square())); (x - x * y.square()) * (kAlpha + kBeta * x.square()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册