From aff4368441446622e08a001649ada1bc045d9097 Mon Sep 17 00:00:00 2001 From: crystal <62974595+Zjq9409@users.noreply.github.com> Date: Tue, 21 Dec 2021 19:13:51 +0800 Subject: [PATCH] use elementwise to optimize gelu forward implementation on GPU (#38188) * relu forward opt * add gelu functor * optimize code --- paddle/fluid/operators/gelu_op.cu | 59 +++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/paddle/fluid/operators/gelu_op.cu b/paddle/fluid/operators/gelu_op.cu index 5bb2fd24793..6a4a322b327 100644 --- a/paddle/fluid/operators/gelu_op.cu +++ b/paddle/fluid/operators/gelu_op.cu @@ -12,9 +12,68 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/gelu_op.h" #include "paddle/fluid/platform/float16.h" +namespace paddle { +namespace operators { + +template +struct GeluWithApproximateFunctor { + using MPType = typename details::MPTypeTrait::Type; + inline HOSTDEVICE T operator()(T arg_x) { + // this function is tanh approximation of gelu + MPType x = static_cast(arg_x); + MPType one = static_cast(1); + MPType out = x * static_cast(0.5) * + (one + tanh(static_cast(0.79788456) * x * + (one + static_cast(0.044715) * x * x))); + return static_cast(out); + } +}; + +template +struct GeluWithoutApproximateFunctor { + using MPType = typename details::MPTypeTrait::Type; + inline HOSTDEVICE T operator()(T arg_x) { + // actual gelu with approximation = false + MPType x = static_cast(arg_x); + MPType erf_out = erf(x * static_cast(M_SQRT1_2)); + MPType out = + x * static_cast(0.5) * (static_cast(1) + erf_out); + return static_cast(out); + } +}; + +template +class GeluKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* out = context.Output("Out"); + auto* in = context.Input("X"); + auto approximate = context.Attr("approximate"); + out->mutable_data(in->place()); + + std::vector ins = {in}; + std::vector outs = {out}; + const auto& dev_ctx = + context.template device_context(); + if (approximate) { + LaunchElementwiseCudaKernel( + dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor()); + } else { + LaunchElementwiseCudaKernel( + dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor()); + } + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( gelu, ops::GeluKernel, -- GitLab