From ef61df301d79e206e296a7fe7afc9ae1e0f97976 Mon Sep 17 00:00:00 2001 From: Rayman Date: Mon, 10 Oct 2022 10:17:04 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20No.36=E3=80=91=E4=BC=98?= =?UTF-8?q?=E5=8C=96=20lerp=5Fgrad=20op=20=E5=9C=A8=20GPU=20=E4=B8=8A?= =?UTF-8?q?=E7=9A=84=E8=AE=A1=E7=AE=97=E6=80=A7=E8=83=BD=20(#45946)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 243 +++++++++++++++++- .../impl/broadcast_tensors_kernel_impl.h | 3 +- .../fluid/tests/unittests/test_lerp_op.py | 30 +++ 3 files changed, 274 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index 39abe1e055d..b097e4ce4d0 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -15,8 +15,249 @@ #include "paddle/phi/kernels/lerp_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/lerp_grad_kernel_impl.h" + +#include "paddle/phi/kernels/broadcast_tensors_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/gpu/reduce.h" + +namespace phi { + +template +__global__ void LerpGradKernelImpl(const T* weight, + const T* dout, + T* dx, + T* dy, + const int out_size, + const int x_size, + const int y_size) { + CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) { + T temp_dx = weight[idx] * dout[idx]; + if (dx) { + if (idx < x_size) { + dx[idx] = dout[idx] - temp_dx; + } + } + if (dy) { + if (idx < y_size) { + dy[idx] = temp_dx; + } + } + } +} + +template +__global__ void LerpGradScalarKernelImpl(const T* weight, + const T* dout, + T* dx, + T* dy, + const int out_size, + const int x_size, + const int y_size) { + T weight_scalar = weight[0]; + CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) { + T temp_dx = weight_scalar * dout[idx]; + if (dx) { + if (idx < x_size) { + dx[idx] = dout[idx] - temp_dx; + } + } + if (dy) { + if (idx < y_size) { + dy[idx] = temp_dx; + } + } + } +} + +bool XYNeedReduce(const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto out_dims = out.dims(); + int x_rank = x_dims.size(); + int y_rank = y_dims.size(); + int out_rank = out_dims.size(); + int smaller_rank = std::min(x_rank, y_rank); + if (std::max(x_rank, y_rank) < out_rank) { + return true; + } + for (int i = 1; i <= smaller_rank; ++i) { + int x_idx = x_rank - i; + int y_idx = y_rank - i; + int out_idx = out_rank - i; + if (x_dims[x_idx] != y_dims[y_idx]) { + return true; + } + if (x_dims[x_idx] == 1 && y_dims[y_idx] == 1 && out_dims[out_idx] != 1) { + return true; + } + } + return false; +} + +template +void SwitchKernel(const Context& ctx, + const DenseTensor& weight, + const DenseTensor& out_grad, + const int x_grad_size, + const int y_grad_size, + T* x_grad_data, + T* y_grad_data) { + if (weight.numel() == 1) { + // condition when weight is a scalar + const T* weight_data = weight.data(); + const T* out_grad_data = out_grad.data(); + const int64_t out_size = out_grad.numel(); + const int64_t weight_size = weight.numel(); + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, out_size); + LerpGradScalarKernelImpl<<>>(weight_data, + out_grad_data, + x_grad_data, + y_grad_data, + out_size, + x_grad_size, + y_grad_size); + } else { + // broadcast weight with out_grad's dimensions + const std::vector in_tensors = {&weight, &out_grad}; + DenseTensor b_weight = phi::EmptyLike(ctx, out_grad); + DenseTensor b_out = phi::EmptyLike(ctx, out_grad); + std::vector out_tensors = {&b_weight, &b_out}; + + phi::BroadcastTensorsKernel(ctx, in_tensors, out_tensors); + + const T* weight_data = b_weight.data(); + const T* out_grad_data = b_out.data(); + const int out_size = out_grad.numel(); + const int weight_size = weight.numel(); + + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, out_size); + LerpGradKernelImpl<<>>(weight_data, + out_grad_data, + x_grad_data, + y_grad_data, + out_size, + x_grad_size, + y_grad_size); + } +} + +template +void LerpGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + const int rank = out.dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1, + phi::errors::InvalidArgument( + "The number of dimensions for LerpGradOp must be " + "greater than or equal to 1, but the value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, + 6, + phi::errors::InvalidArgument( + "The number of dimensions for LerpGradOp must be " + "less than or equal to 6, but the value received is %d.", + rank)); + + // check if x_grad and y_grad need to be reduced + // if x has a different dimension with y or weight in the middle axis, then + // they need to be broadcast and then reduced. + bool reduce_flag = XYNeedReduce(x, y, out); + if (!reduce_flag) { + int x_grad_size = 0, y_grad_size = 0; + T* x_grad_data = NULL; + T* y_grad_data = NULL; + + if (x_grad) { + x_grad_data = ctx.template Alloc(x_grad); + x_grad_size = x.numel(); + } + + if (y_grad) { + y_grad_data = ctx.template Alloc(y_grad); + y_grad_size = y.numel(); + } + + SwitchKernel(ctx, + weight, + out_grad, + x_grad_size, + y_grad_size, + x_grad_data, + y_grad_data); + + } else { + int x_grad_size = 0, y_grad_size = 0; + DenseTensor b_xgrad = phi::EmptyLike(ctx, out_grad); + DenseTensor b_ygrad = phi::EmptyLike(ctx, out_grad); + T* x_grad_data = NULL; + T* y_grad_data = NULL; + + if (x_grad) { + x_grad_data = ctx.template Alloc(&b_xgrad); + x_grad_size = out.numel(); + } + + if (y_grad) { + y_grad_data = ctx.template Alloc(&b_ygrad); + y_grad_size = out.numel(); + } + + SwitchKernel(ctx, + weight, + out_grad, + x_grad_size, + y_grad_size, + x_grad_data, + y_grad_data); + + if (x_grad) { + std::vector reduce_axis_x = + funcs::GetReduceDim(x_grad->dims(), b_xgrad.dims(), -1); + if (!reduce_axis_x.empty()) { + phi::funcs:: + ReduceKernel>( + ctx, b_xgrad, x_grad, kps::IdentityFunctor(), reduce_axis_x); + } else { + x_grad->ShareDataWith(b_xgrad); + } + } + + if (y_grad) { + std::vector reduce_axis_y = + funcs::GetReduceDim(y_grad->dims(), b_ygrad.dims(), -1); + if (!reduce_axis_y.empty()) { + phi::funcs:: + ReduceKernel>( + ctx, b_ygrad, y_grad, kps::IdentityFunctor(), reduce_axis_y); + } else { + y_grad->ShareDataWith(b_ygrad); + } + } + } +} + +} // namespace phi PD_REGISTER_KERNEL( lerp_grad, GPU, ALL_LAYOUT, phi::LerpGradKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h b/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h index 0e8b4c216fa..652f0b2eee9 100644 --- a/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h +++ b/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h @@ -106,10 +106,11 @@ void BroadcastTensorsKernel(const Context& ctx, SWITCH_OUT_RANK_CASE(3) SWITCH_OUT_RANK_CASE(4) SWITCH_OUT_RANK_CASE(5) + SWITCH_OUT_RANK_CASE(6) default: { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "Target tensor rank out of range" - "Maximum supported rank for broadcast is: 5")); + "Maximum supported rank for broadcast is: 6")); } } } diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index 6e73fb7b7e5..f7f096a8d5a 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import unittest import numpy as np from op_test import OpTest @@ -78,6 +80,34 @@ class TestLerpWithDim6(TestLerp): self.shape = [2, 1, 2, 5, 1, 5] +class TestLerpBroadXY(TestLerp): + + def setUp(self): + self.op_type = "lerp" + self.python_api = paddle.lerp + self.init_dtype() + self.init_shape() + x = np.arange(1., 201.).astype(self.dtype).reshape([2, 1, 2, 50]) + y = np.full(200, 10.).astype(self.dtype).reshape([2, 2, 1, 50]) + w = np.asarray([0.5]).astype(self.dtype) + self.inputs = {'X': x, 'Y': y, 'Weight': w} + self.outputs = {'Out': x + w * (y - x)} + + +class TestLerpBroadWToXY(TestLerp): + + def setUp(self): + self.op_type = "lerp" + self.python_api = paddle.lerp + self.init_dtype() + self.init_shape() + x = np.full(600, 2.5).astype(self.dtype).reshape([50, 2, 2, 3]) + y = np.full(600, 1.).astype(self.dtype).reshape([50, 2, 2, 3]) + w = np.random.random([3]).astype(self.dtype) + self.inputs = {'X': x, 'Y': y, 'Weight': w} + self.outputs = {'Out': x + w * (y - x)} + + class TestLerpAPI(unittest.TestCase): def init_dtype(self): -- GitLab