From 374e757fd6388761f55e61ce7aca1e31221c0406 Mon Sep 17 00:00:00 2001 From: mayang002 <77949147+mayang002@users.noreply.github.com> Date: Fri, 10 Mar 2023 09:31:49 +0800 Subject: [PATCH] Xpu ernie3: support fp16 for xpu kernels: full_like/stack/where (#51271) * [xpu-ernie3] support fp16 for full_like/stack/where xpu kernels * [xpu-ernie3] support fp16 for full_like/stack/where xpu kernels --- paddle/phi/backends/xpu/xpu2_op_list.cc | 17 ++++++++-- paddle/phi/kernels/xpu/full_kernel.cc | 41 ++++++++++++++++++++++--- paddle/phi/kernels/xpu/stack_kernel.cc | 10 ++++-- paddle/phi/kernels/xpu/where_kernel.cc | 19 ++++++++---- 4 files changed, 72 insertions(+), 15 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 8bc17379e02..f48a91c8285 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -319,6 +319,16 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32, phi::DataType::INT8, phi::DataType::FLOAT32})}, + {"full_batch_size_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"fill_constant_batch_size_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, {"unfold", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"unfold_grad", @@ -595,6 +605,7 @@ XPUOpMap& get_kl2_ops() { {"shape", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64, + phi::DataType::INT32, phi::DataType::FLOAT16})}, {"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, {"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -668,7 +679,8 @@ XPUOpMap& get_kl2_ops() { {"stack", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64, - phi::DataType::INT32})}, + phi::DataType::INT32, + phi::DataType::FLOAT16})}, {"stack_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, {"strided_slice", @@ -810,7 +822,8 @@ XPUOpMap& get_kl2_ops() { {"where", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, {"sin", XPUKernelSet({phi::DataType::FLOAT32})}, {"cos", XPUKernelSet({phi::DataType::FLOAT32})}, {"linspace", diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index 45e5e4893b4..241894a0437 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -88,13 +88,32 @@ void FullLikeKernel(const Context& dev_ctx, phi::errors::InvalidArgument("The filled value is Inf.")); auto out_data = reinterpret_cast(out->data()); - int r = xpu::constant(dev_ctx.x_context(), - out_data, - out->numel(), - static_cast(value)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + if (out->numel() > 0) { + int r = xpu::constant(dev_ctx.x_context(), + out_data, + out->numel(), + static_cast(value)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + } } +template +void FullBatchSizeLikeKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + const Scalar& val, + DataType dtype, + int x_batch_size_dim, + int out_batch_size_dim, + DenseTensor* out) { + if (x.lod().size() && x_batch_size_dim == 0) { + // set the correct batch size for the LoDTensor. + auto odims = out->dims(); + odims[out_batch_size_dim] = static_cast(x.lod().back().size()) - 1; + FullKernel(dev_ctx, phi::vectorize(odims), val, dtype, out); + } + FullLikeKernel(dev_ctx, x, val, dtype, out); +} } // namespace phi PD_REGISTER_KERNEL(full, @@ -122,3 +141,15 @@ PD_REGISTER_KERNEL(full_like, phi::dtype::float16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } + +PD_REGISTER_KERNEL(full_batch_size_like, + XPU, + ALL_LAYOUT, + phi::FullBatchSizeLikeKernel, + float, + int, + int64_t, + bool, + phi::dtype::float16) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/xpu/stack_kernel.cc b/paddle/phi/kernels/xpu/stack_kernel.cc index b908a6a080e..454422739ab 100644 --- a/paddle/phi/kernels/xpu/stack_kernel.cc +++ b/paddle/phi/kernels/xpu/stack_kernel.cc @@ -55,5 +55,11 @@ void StackKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL( - stack, XPU, ALL_LAYOUT, phi::StackKernel, float, int, int64_t) {} +PD_REGISTER_KERNEL(stack, + XPU, + ALL_LAYOUT, + phi::StackKernel, + float, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/where_kernel.cc b/paddle/phi/kernels/xpu/where_kernel.cc index e322fece53a..8b644d0cf7f 100644 --- a/paddle/phi/kernels/xpu/where_kernel.cc +++ b/paddle/phi/kernels/xpu/where_kernel.cc @@ -25,10 +25,11 @@ void WhereKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; const bool* cond_data = condition.data(); - const T* x_data = x.data(); - const T* y_data = y.data(); - T* out_data = ctx.template Alloc(out); + const XPUType* x_data = reinterpret_cast(x.data()); + const XPUType* y_data = reinterpret_cast(y.data()); + XPUType* out_data = reinterpret_cast(ctx.template Alloc(out)); auto cond_dims = phi::vectorize(condition.dims()); auto x_dims = phi::vectorize(x.dims()); @@ -44,10 +45,16 @@ void WhereKernel(const Context& ctx, int ret = xpu::select( ctx.x_context(), cond_data, x_data, y_data, out_data, cond_dims, x_dims); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "select"); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu::select"); } } // namespace phi -PD_REGISTER_KERNEL( - where, XPU, ALL_LAYOUT, phi::WhereKernel, float, int, int64_t) {} +PD_REGISTER_KERNEL(where, + XPU, + ALL_LAYOUT, + phi::WhereKernel, + float, + int, + int64_t, + phi::dtype::float16) {} -- GitLab