diff --git a/paddle/phi/kernels/gpu/erfinv_kernel.cu b/paddle/phi/kernels/gpu/erfinv_kernel.cu index 2316e960396fff4acc66e3f78eb58b47162630fa..087eb46c33fdc78c510cca280b7992176c04273e 100644 --- a/paddle/phi/kernels/gpu/erfinv_kernel.cu +++ b/paddle/phi/kernels/gpu/erfinv_kernel.cu @@ -13,9 +13,25 @@ // limitations under the License. #include "paddle/phi/kernels/erfinv_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/erfinv_kernel_impl.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" + +namespace phi { + +template +struct ErfinvFunctor { + HOSTDEVICE inline T operator()(const T x) const { return erfinv(x); } +}; + +template +void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { + ctx.template Alloc(out); + std::vector ins = {&x}; + std::vector outs = {out}; + phi::funcs::ElementwiseKernel(ctx, ins, &outs, ErfinvFunctor()); +} + +} // namespace phi PD_REGISTER_KERNEL(erfinv, GPU, ALL_LAYOUT, phi::ErfinvKernel, float, double) {}