From 0e776965e5a821364204bfeddc85c6fa093c6dc4 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <72954905+Asthestarsfalll@users.noreply.github.com> Date: Mon, 10 Apr 2023 11:23:49 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=204=20No.44?= =?UTF-8?q?=E3=80=91=E4=B8=BA=20Paddle=20=E4=BC=98=E5=8C=96=20logsumexp=20?= =?UTF-8?q?op=20=E5=9C=A8=20GPU=20=E4=B8=8A=E7=9A=84=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E6=80=A7=E8=83=BD=20(#52509)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Optimize the performance of logsumexp * Support zero-dim tensor --- paddle/phi/kernels/gpu/logsumexp_kernel.cu | 97 +++++++++++++++++++++- 1 file changed, 94 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/logsumexp_kernel.cu b/paddle/phi/kernels/gpu/logsumexp_kernel.cu index 7963808476d..4806593469b 100644 --- a/paddle/phi/kernels/gpu/logsumexp_kernel.cu +++ b/paddle/phi/kernels/gpu/logsumexp_kernel.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,12 +14,103 @@ #include "paddle/phi/kernels/logsumexp_kernel.h" -#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/gpu/reduce.h" using float16 = phi::dtype::float16; +namespace phi { + +template +struct LogCUDAFunctor { + HOSTDEVICE inline T operator()(const T x) const { return std::log(x); } +}; + +template <> +struct LogCUDAFunctor { + HOSTDEVICE inline float16 operator()(const float16 x) const { + auto x_ = static_cast(x); + return static_cast(std::log(x_)); + } +}; + +template +void LogsumexpKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + bool keepdim, + bool reduce_all, + DenseTensor* out) { + auto* in_x = &x; + auto* out_y = out; + auto xdim = in_x->dims(); + for (size_t i = 0; i < xdim.size(); i++) + PADDLE_ENFORCE_LT(0, + xdim[i], + errors::InvalidArgument( + "The dims of Input(X) should be greater than 0.")); + + reduce_all = recompute_reduce_all(x, axis, reduce_all); + std::vector outdim_vec, keeped_outdim_vec; + std::vector axis_vec; + for (auto i : axis) { + auto v = i >= 0 ? i : i + xdim.size(); + axis_vec.push_back(v); + } + if (axis.size() == 0 || reduce_all) { + for (size_t i = 0; i < xdim.size(); i++) { + axis_vec.push_back(i); + } + } + for (size_t i = 0; i < xdim.size(); i++) { + bool flag = false; + for (auto v : axis_vec) { + if (v == i) { + flag = true; + break; + } + } + if (flag) { + keeped_outdim_vec.push_back(1); + if (keepdim) outdim_vec.push_back(1); + } else { + outdim_vec.push_back(xdim[i]); + keeped_outdim_vec.push_back(xdim[i]); + } + } + + auto outdim = phi::make_ddim(outdim_vec); + auto keeped_outdim = phi::make_ddim(keeped_outdim_vec); + out->Resize(outdim); + dev_ctx.template Alloc(out_y); + + DenseTensor max_x; + max_x.Resize(outdim); + dev_ctx.template Alloc(&max_x); + + phi::funcs::ReduceKernel>( + dev_ctx, *in_x, &max_x, kps::IdentityFunctor(), axis_vec); + + max_x.Resize(keeped_outdim); + DenseTensor temp_x = Subtract(dev_ctx, *in_x, max_x); + phi::funcs::ReduceKernel>( + dev_ctx, temp_x, out_y, kps::ExpFunctor(), axis_vec); + + const std::vector inputs = {out_y}; + std::vector outputs = {&temp_x}; + phi::funcs::ElementwiseKernel( + dev_ctx, inputs, &outputs, LogCUDAFunctor()); + temp_x.Resize(outdim); + out->Resize(outdim); + phi::AddKernel(dev_ctx, temp_x, max_x, out); +} + +} // namespace phi + PD_REGISTER_KERNEL( logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double, float16) {} -- GitLab