// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/gelu_kernel.h" #include #include #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas_impl.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" namespace phi { template struct GeluFunctor { template void operator()(Device d, X x, Out out, bool approximate) const { if (approximate) { // gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) if (std::is_same::value) { VLOG(4) << "cast from float16 to float before computing"; auto casted_x = x.template cast(); auto temp = (static_cast(M_2_SQRTPI * M_SQRT1_2) * (casted_x + static_cast(GELU_CONSTANT) * casted_x.cube())) .tanh(); out.device(d) = (casted_x * static_cast(0.5) * (static_cast(1) + temp)) .template cast(); } else { auto temp = (static_cast(M_2_SQRTPI * M_SQRT1_2) * (x + static_cast(GELU_CONSTANT) * x.cube())) .tanh(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); } } else { #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) && \ !defined(PADDLE_WITH_HIP) auto x_data = x.data(); auto out_data = out.data(); int n = std::min(x.size(), out.size()); std::memset(out_data, 0, n * sizeof(T)); phi::funcs::CBlas::AXPY( n, static_cast(M_SQRT1_2), x_data, 1, out_data, 1); phi::funcs::CBlas::VMERF(n, out_data, out_data, VML_LA); for (int i = 0; i < n; i++) { out_data[i] += static_cast(1); } phi::funcs::CBlas::VMUL(n, x_data, out_data, out_data); for (int i = 0; i < n; i++) { out_data[i] *= static_cast(0.5); } #else // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) if (std::is_same::value) { VLOG(4) << "cast from float16 to float before computing"; auto casted_x = x.template cast(); auto temp = (casted_x * static_cast(M_SQRT1_2)).erf(); out.device(d) = (casted_x * static_cast(0.5) * (static_cast(1) + temp)) .template cast(); } else { auto temp = (x * static_cast(M_SQRT1_2)).erf(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); } #endif } } }; template void GeluKernel(const Context& dev_ctx, const DenseTensor& x, bool approximate, DenseTensor* out) { dev_ctx.template Alloc(out); auto eigen_out = EigenVector::Flatten(*out); auto eigen_x = EigenVector::Flatten(x); auto& dev = *dev_ctx.eigen_device(); GeluFunctor functor; functor(dev, eigen_x, eigen_out, approximate); } } // namespace phi PD_REGISTER_KERNEL(gelu, CPU, ALL_LAYOUT, phi::GeluKernel, float, double) {}