From c57e12bec4ce78db3db8165b6e9821d4fddd660c Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sun, 10 Jul 2022 22:58:45 -0500 Subject: [PATCH] refine dist_grad kernel (#44182) * refine dist_grad kernel * fix cpu kernel bug --- paddle/phi/kernels/cpu/dist_grad_kernel.cc | 22 -- paddle/phi/kernels/dist_grad_kernel.cc | 93 ++++++++ paddle/phi/kernels/gpu/dist_grad_kernel.cu | 26 -- .../phi/kernels/impl/dist_grad_kernel_impl.h | 223 ------------------ 4 files changed, 93 insertions(+), 271 deletions(-) delete mode 100644 paddle/phi/kernels/cpu/dist_grad_kernel.cc create mode 100644 paddle/phi/kernels/dist_grad_kernel.cc delete mode 100644 paddle/phi/kernels/gpu/dist_grad_kernel.cu delete mode 100644 paddle/phi/kernels/impl/dist_grad_kernel_impl.h diff --git a/paddle/phi/kernels/cpu/dist_grad_kernel.cc b/paddle/phi/kernels/cpu/dist_grad_kernel.cc deleted file mode 100644 index c1aaa2adf75..00000000000 --- a/paddle/phi/kernels/cpu/dist_grad_kernel.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/dist_grad_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/dist_grad_kernel_impl.h" - -PD_REGISTER_KERNEL( - dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} diff --git a/paddle/phi/kernels/dist_grad_kernel.cc b/paddle/phi/kernels/dist_grad_kernel.cc new file mode 100644 index 00000000000..ba468ad299e --- /dev/null +++ b/paddle/phi/kernels/dist_grad_kernel.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/dist_grad_kernel.h" + +#include +#include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/p_norm_grad_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/scale_kernel.h" + +namespace phi { + +std::pair, std::vector> GetReduceDims( + const DDim& src_dim, const DDim& dst_dim) { + std::vector reduce_dims, new_dims; + auto pre_dims = src_dim.size() - dst_dim.size(); + for (auto i = 0; i < pre_dims; ++i) { + reduce_dims.push_back(i); + } + + for (auto i = pre_dims; i < src_dim.size(); ++i) { + if (dst_dim[i - pre_dims] == 1 && src_dim[i] != 1) { + reduce_dims.push_back(i); + } else { + new_dims.push_back(dst_dim[i - pre_dims]); + } + } + return {reduce_dims, new_dims}; +} + +template +void DistGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& out_grad, + float p, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto t = Subtract(dev_ctx, x, y); + DenseTensor x_grad_tmp; + x_grad_tmp.Resize(t.dims()); + DenseTensor y_grad_tmp; + y_grad_tmp.Resize(t.dims()); + PNormGradKernel( + dev_ctx, t, out, out_grad, p, -1, 1e-12, false, true, &x_grad_tmp); + ScaleKernel(dev_ctx, x_grad_tmp, -1.0, 0.0, false, &y_grad_tmp); + // do reduce, the implemetation of cpu SumKernel has bug, it changes + // the dims of output iternally, so we Resize x/y_grad twice. + auto res_x = GetReduceDims(x_grad_tmp.dims(), x.dims()); + if (!std::get<0>(res_x).empty()) { + x_grad->Resize(phi::make_ddim(std::get<1>(res_x))); + SumKernel( + dev_ctx, x_grad_tmp, std::get<0>(res_x), x.dtype(), false, x_grad); + x_grad->Resize(x.dims()); + } else { + x_grad->ShareBufferWith(x_grad_tmp); + } + auto res_y = GetReduceDims(y_grad_tmp.dims(), y.dims()); + if (!std::get<0>(res_y).empty()) { + y_grad->Resize(phi::make_ddim(std::get<1>(res_y))); + SumKernel( + dev_ctx, y_grad_tmp, std::get<0>(res_y), y.dtype(), false, y_grad); + y_grad->Resize(y.dims()); + } else { + y_grad->ShareBufferWith(y_grad_tmp); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL( + dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} +#endif diff --git a/paddle/phi/kernels/gpu/dist_grad_kernel.cu b/paddle/phi/kernels/gpu/dist_grad_kernel.cu deleted file mode 100644 index df422e8b2da..00000000000 --- a/paddle/phi/kernels/gpu/dist_grad_kernel.cu +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/kernels/dist_grad_kernel.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/dist_grad_kernel_impl.h" - -#ifdef PADDLE_WITH_HIP -PD_REGISTER_KERNEL(dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float) {} -#else -PD_REGISTER_KERNEL( - dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} -#endif diff --git a/paddle/phi/kernels/impl/dist_grad_kernel_impl.h b/paddle/phi/kernels/impl/dist_grad_kernel_impl.h deleted file mode 100644 index fc118a832dc..00000000000 --- a/paddle/phi/kernels/impl/dist_grad_kernel_impl.h +++ /dev/null @@ -1,223 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace phi { - -template -using ETensor = phi::EigenTensor; - -template -static void GetBraodcastDims(const phi::DDim& x_dims, - const phi::DDim& y_dims, - Eigen::DSizes* x_bcast_dims, - Eigen::DSizes* y_bcast_dims) { - int bcast_dims_remainder = 0; - for (int i = 0; i < x_dims.size(); ++i) { - if (x_dims[i] >= y_dims[i]) { - (*x_bcast_dims)[i] = 1; - (*y_bcast_dims)[i] = x_dims[i] / y_dims[i]; - bcast_dims_remainder += x_dims[i] % y_dims[i]; - } else { - (*y_bcast_dims)[i] = 1; - (*x_bcast_dims)[i] = y_dims[i] / x_dims[i]; - bcast_dims_remainder += y_dims[i] % x_dims[i]; - } - } - PADDLE_ENFORCE_EQ(bcast_dims_remainder, - 0, - phi::errors::PreconditionNotMet( - "The input tensor of Op(dist) could not be broadcast, " - "X's shape is [%s], Y's shape is [%s].", - x_dims, - y_dims)); -} - -static phi::DDim GetNewDims(const phi::DDim& in_dims, int rank) { - std::vector new_dims_vec(rank); - if (in_dims.size() < rank) { - for (int i = 0; i < rank - in_dims.size(); ++i) { - new_dims_vec[i] = 1; - } - for (int i = 0; i < in_dims.size(); ++i) { - new_dims_vec[i + rank - in_dims.size()] = in_dims[i]; - } - } else { - new_dims_vec = vectorize(in_dims); - } - return phi::make_ddim(new_dims_vec); -} - -template -static void DistGradFunction(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out, - const DenseTensor& out_grad, - float p, - DenseTensor* x_grad, - DenseTensor* y_grad) { - auto x_dims = x.dims(); - auto y_dims = y.dims(); - auto out_dims = out.dims(); - - phi::DDim x_new_dims = GetNewDims(x_dims, Rank); - phi::DDim y_new_dims = GetNewDims(y_dims, Rank); - phi::DDim out_new_dims = GetNewDims(out_dims, Rank); - auto x_t = ETensor::From(x, x_new_dims); - auto y_t = ETensor::From(y, y_new_dims); - auto out_t = ETensor::From(out, out_new_dims); - - Eigen::DSizes x_bcast_dims; - Eigen::DSizes y_bcast_dims; - Eigen::DSizes out_bcast_dims; - - GetBraodcastDims(x_new_dims, y_new_dims, &x_bcast_dims, &y_bcast_dims); - std::vector new_dims_vec(Rank); - for (int i = 0; i < Rank; ++i) { - new_dims_vec[i] = std::max(x_new_dims[i], y_new_dims[i]); - out_bcast_dims[i] = new_dims_vec[i]; - } - phi::DDim new_dims = phi::make_ddim(new_dims_vec); - - auto& place = *dev_ctx.eigen_device(); - auto out_grad_t = ETensor::From(out_grad, out_new_dims); - DenseTensor grad; - grad.Resize(new_dims); - dev_ctx.template Alloc(&grad); - auto grad_t = ETensor::From(grad); - - auto x_minux_y = x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims); - auto x_minux_y_abs = x_minux_y.abs(); - auto sign = - (x_minux_y > static_cast(0)).template cast() * static_cast(1.0) + - (x_minux_y < static_cast(0)).template cast() * static_cast(-1.0); - T epsilon = static_cast(1.0e-10f); - - // 1: Lp-norm(z), z = x-y, compute dz - if (p == 0) { - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, &grad, static_cast(0)); - } else if (p == INFINITY || p == -INFINITY) { - // p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if - // j!=i, or equals to sign(z_i) * dout if j=i. - if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) { - grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims)) - .template cast() * - sign.eval() * out_grad_t.broadcast(out_bcast_dims); - } else { - grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims)) - .template cast() * - sign * out_grad_t.broadcast(out_bcast_dims); - } - } else { - // dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout - if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) { - grad_t.device(place) = - (x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims)) - .pow(p - 1) * - sign.eval() * out_grad_t.broadcast(out_bcast_dims); - } else { - grad_t.device(place) = - (x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims)) - .pow(p - 1) * - sign * out_grad_t.broadcast(out_bcast_dims); - } - } - - Eigen::DSizes x_reshape_dims; - Eigen::DSizes y_reshape_dims; - Eigen::DSizes reduce_dims; - for (int i = 0; i < x_new_dims.size(); ++i) { - x_reshape_dims[2 * i] = x_bcast_dims[i]; - x_reshape_dims[2 * i + 1] = x_new_dims[i]; - y_reshape_dims[2 * i] = y_bcast_dims[i]; - y_reshape_dims[2 * i + 1] = y_new_dims[i]; - reduce_dims[i] = 2 * i; - } - - // 2: if x or y is broadcasted in forward function, - // the grad need to be sum along the broadcasted dimensions - if (x_grad) { - dev_ctx.template Alloc(x_grad); - auto x_grad_t = ETensor::From(*x_grad, x_new_dims); - x_grad_t.device(place) = grad_t.reshape(x_reshape_dims) - .sum(reduce_dims) - .reshape(x_grad_t.dimensions()); - } - if (y_grad) { - dev_ctx.template Alloc(y_grad); - auto y_grad_t = ETensor::From(*y_grad, y_new_dims); - y_grad_t.device(place) = -grad_t.reshape(y_reshape_dims) - .sum(reduce_dims) - .reshape(y_grad_t.dimensions()); - } -} - -template -void DistGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out, - const DenseTensor& out_grad, - float p, - DenseTensor* x_grad, - DenseTensor* y_grad) { - auto x_rank = x.dims().size(); - auto y_rank = y.dims().size(); - auto rank = std::max(x_rank, y_rank); - PADDLE_ENFORCE_LE(rank, - 6, - phi::errors::Unimplemented( - "Op(dist) only support tensors with no more than 6 " - "dimensions, but X's rank is %d, Y's rank is %d.", - x_rank, - y_rank)); - switch (rank) { - case 1: - DistGradFunction( - dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); - break; - case 2: - DistGradFunction( - dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); - break; - case 3: - DistGradFunction( - dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); - break; - case 4: - DistGradFunction( - dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); - break; - case 5: - DistGradFunction( - dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); - break; - case 6: - DistGradFunction( - dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); - break; - } -} - -} // namespace phi -- GitLab