From e592534a5af4204c8cf4f667c824e7cefcc06309 Mon Sep 17 00:00:00 2001 From: Winters Montagne <118546135+WintersMontagne10335@users.noreply.github.com> Date: Tue, 16 May 2023 19:00:01 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=204=20No.34?= =?UTF-8?q?=E3=80=91=E4=B8=BA=20Paddle=20=E4=BC=98=E5=8C=96=20Lerp=20OP=20?= =?UTF-8?q?=E5=9C=A8=20GPU=20=E4=B8=8A=E7=9A=84=E6=80=A7=E8=83=BD=20(#5315?= =?UTF-8?q?4)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * modify lerp_kernel.cu * pre-commit * fix some CI issues * fix some CI issues * fix some CI issues * fix some CI issues * fix some CI issues * fix some CI issues * fix some CI issues * fix some CI issues * Add files via upload fix some CI issues --- paddle/phi/kernels/gpu/lerp_kernel.cu | 109 +++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/lerp_kernel.cu b/paddle/phi/kernels/gpu/lerp_kernel.cu index 25f37bb1704..4304dd38238 100644 --- a/paddle/phi/kernels/gpu/lerp_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_kernel.cu @@ -15,8 +15,115 @@ #include "paddle/phi/kernels/lerp_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/lerp_kernel_impl.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +struct BroadcastMinElementWiseDirectCUDAFunctor { + HOSTDEVICE inline T operator()(const T min) const { return min; } +}; + +template +struct LerpElementWiseDirectCUDAFunctor { + HOSTDEVICE inline T operator()(const T x, const T y, const T weight) const { + return x + weight * (y - x); + } +}; + +template +struct LerpScalarDirectCUDAFunctor { + const T *weight_; + + HOSTDEVICE inline LerpScalarDirectCUDAFunctor(const T *weight) + : weight_(weight) {} + + HOSTDEVICE inline T operator()(const T x, const T y) const { + return x + weight_[0] * (y - x); + } +}; + +template +void LerpKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &weight, + DenseTensor *out) { + int rank = out->dims().size(); + PADDLE_ENFORCE_GE( + rank, + 0, + phi::errors::InvalidArgument( + "The number of dimensions for LerpOp must be " + "greater than or equal to 0, but the value received is %d.", + rank)); + + ctx.template Alloc(out); + std::vector outputs = {out}; + + std::vector inputs; + if (weight.numel() == 1) { + const T *weight_ptr = weight.data(); + inputs.reserve(2); + inputs.emplace_back(&x); + inputs.emplace_back(&y); + auto functor = LerpScalarDirectCUDAFunctor(weight_ptr); + phi::funcs::BroadcastKernel(ctx, inputs, &outputs, functor); + } else { + inputs.reserve(3); + auto functor = LerpElementWiseDirectCUDAFunctor(); + DenseTensor b_min = phi::EmptyLike(ctx, *out); + if (x.dims().size() != y.dims().size() && + weight.dims().size() != y.dims().size()) { + std::vector broadcast_min_inputs; + broadcast_min_inputs.reserve(1); + std::vector broadcast_min_outputs = {&b_min}; + auto broadcast_min_functor = + BroadcastMinElementWiseDirectCUDAFunctor(); + if (x.dims().size() < y.dims().size() && + x.dims().size() < weight.dims().size()) { + broadcast_min_inputs.emplace_back(&x); + phi::funcs::BroadcastKernel(ctx, + broadcast_min_inputs, + &broadcast_min_outputs, + broadcast_min_functor); + inputs.emplace_back(&b_min); + inputs.emplace_back(&y); + inputs.emplace_back(&weight); + } else if (y.dims().size() < weight.dims().size()) { + broadcast_min_inputs.emplace_back(&y); + phi::funcs::BroadcastKernel(ctx, + broadcast_min_inputs, + &broadcast_min_outputs, + broadcast_min_functor); + inputs.emplace_back(&x); + inputs.emplace_back(&b_min); + inputs.emplace_back(&weight); + } else { + broadcast_min_inputs.emplace_back(&weight); + phi::funcs::BroadcastKernel(ctx, + broadcast_min_inputs, + &broadcast_min_outputs, + broadcast_min_functor); + inputs.emplace_back(&x); + inputs.emplace_back(&y); + inputs.emplace_back(&b_min); + } + } else { + inputs.emplace_back(&x); + inputs.emplace_back(&y); + inputs.emplace_back(&weight); + } + phi::funcs::BroadcastKernel(ctx, inputs, &outputs, functor); + } +} + +} // namespace phi PD_REGISTER_KERNEL(lerp, GPU, -- GitLab