From ba89a3d3148ac085c9d90dedf140eb79e1bf6174 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 21 Jul 2022 19:00:26 +0800 Subject: [PATCH] [ Phi ] svd transfer (#44392) * svd cpu forward * svd gpu forward * transfer the backward of svd * remove cusolver in svd_grad * svd kernel bug fix * fix bugs * fix bugs. * fix bug --- paddle/fluid/operators/svd_helper.h | 65 ----- paddle/fluid/operators/svd_op.cc | 11 +- paddle/fluid/operators/svd_op.cu | 269 ------------------ paddle/fluid/operators/svd_op.h | 165 ----------- paddle/phi/kernels/activation_kernel.h | 11 + paddle/phi/kernels/cpu/svd_grad_kernel.cc | 22 ++ paddle/phi/kernels/cpu/svd_kernel.cc | 132 +++++++++ paddle/phi/kernels/diag_kernel.h | 13 + .../phi/kernels/gpu/matrix_rank_tol_kernel.cu | 22 +- paddle/phi/kernels/gpu/svd_grad_kernel.cu | 22 ++ paddle/phi/kernels/gpu/svd_kernel.cu | 253 ++++++++++++++++ .../phi/kernels/impl/svd_grad_kernel_impl.h | 177 ++++++++++++ paddle/phi/kernels/slice_kernel.h | 18 ++ paddle/phi/kernels/svd_grad_kernel.h | 32 +++ paddle/phi/kernels/svd_kernel.h | 29 ++ paddle/phi/ops/compat/svd_sig.cc | 27 ++ 16 files changed, 748 insertions(+), 520 deletions(-) delete mode 100644 paddle/fluid/operators/svd_op.cu delete mode 100644 paddle/fluid/operators/svd_op.h create mode 100644 paddle/phi/kernels/cpu/svd_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/svd_kernel.cc create mode 100644 paddle/phi/kernels/gpu/svd_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/svd_kernel.cu create mode 100644 paddle/phi/kernels/impl/svd_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/svd_grad_kernel.h create mode 100644 paddle/phi/kernels/svd_kernel.h create mode 100644 paddle/phi/ops/compat/svd_sig.cc diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index ea5625a09a2..a796aa9d544 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -45,71 +45,6 @@ template using EigenVector = framework::EigenVector; -template -void LapackSvd( - const T* X, T* U, T* VH, T* S, int rows, int cols, int full = false) { - char jobz = full ? 'A' : 'S'; - int mx = std::max(rows, cols); - int mn = std::min(rows, cols); - T* a = const_cast(X); - int lda = rows; - int ldu = rows; - int ldvt = full ? cols : mn; - int lwork = full ? (4 * mn * mn + 6 * mn + mx) : (4 * mn * mn + 7 * mn); - std::vector work(lwork); - std::vector iwork(8 * mn); - int info; - phi::funcs::lapackSvd(jobz, - rows, - cols, - a, - lda, - S, - U, - ldu, - VH, - ldvt, - work.data(), - lwork, - iwork.data(), - &info); - if (info < 0) { - PADDLE_THROW(platform::errors::InvalidArgument( - "This %s-th argument has an illegal value", info)); - } - if (info > 0) { - PADDLE_THROW(platform::errors::InvalidArgument( - "DBDSDC/SBDSDC did not converge, updating process failed. May be you " - "passes a invalid matrix.")); - } -} - -template -void BatchSvd(const T* X, - T* U, - T* VH, - T* S, - int rows, - int cols, - int batches, - int full = false) { - // NOTE: this function is row major, because this function called the lapack. - int stride = rows * cols; - int k = std::min(rows, cols); - int stride_u = full ? rows * rows : k * rows; - int stride_v = full ? cols * cols : k * cols; - for (int i = 0; i < batches; ++i) { - LapackSvd(X + i * stride, - U + i * stride_u, - VH + i * stride_v, - S + i * k, - rows, - cols, - full); - } - return; -} - template struct PowFunctor { PowFunctor(const T* input, T* output, int64_t numel, T exp) diff --git a/paddle/fluid/operators/svd_op.cc b/paddle/fluid/operators/svd_op.cc index 7ae85343e04..6c250675b62 100644 --- a/paddle/fluid/operators/svd_op.cc +++ b/paddle/fluid/operators/svd_op.cc @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/svd_op.h" - #include #include #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/ddim.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -167,11 +166,3 @@ REGISTER_OPERATOR(svd, ops::SvdGradMaker); REGISTER_OPERATOR(svd_grad, ops::SvdGradOp); - -REGISTER_OP_CPU_KERNEL(svd, - ops::SvdCPUKernel, - ops::SvdCPUKernel); - -REGISTER_OP_CPU_KERNEL(svd_grad, - ops::SvdGradKernel, - ops::SvdGradKernel); diff --git a/paddle/fluid/operators/svd_op.cu b/paddle/fluid/operators/svd_op.cu deleted file mode 100644 index 02851891619..00000000000 --- a/paddle/fluid/operators/svd_op.cu +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef PADDLE_WITH_HIP -// HIP not support cusolver - -#include - -#include -#include - -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/svd_op.h" -#include "paddle/fluid/platform/dynload/cusolver.h" - -namespace paddle { -namespace operators { - -template -class SvdGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto& dev_ctx = - context.template device_context(); - - const Tensor* x = context.Input("X"); - Tensor* U = context.Output("U"); - Tensor* VH = context.Output("VH"); - Tensor* S = context.Output("S"); - const bool full_matrices = context.Attr("full_matrices"); - - auto& dims = x->dims(); - int batch_count = 1; - for (int i = 0; i < dims.size() - 2; i++) { - batch_count *= dims[i]; - } - int rank = dims.size(); - int m = dims[rank - 2]; - int n = dims[rank - 1]; - - auto* vh_data = VH->mutable_data(context.GetPlace()); - auto* s_data = S->mutable_data(context.GetPlace()); - auto* u_data = U->mutable_data(context.GetPlace()); - // NOTE:(@xiongkun03) - // matrices are assumed to be stored in column-major order in cusolver - // then view A as n x m and do A^T SVD, we can avoid transpose - // Must Copy X once, because the gesvdj will change the origin input matrix - Tensor x_tmp; - paddle::framework::TensorCopy(*x, context.GetPlace(), &x_tmp); - auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_count); - int* info_ptr = reinterpret_cast(info->ptr()); - - GesvdjBatched(dev_ctx, - batch_count, - n, - m, - std::min(m, n), - x_tmp.mutable_data(context.GetPlace()), - vh_data, - u_data, - s_data, - info_ptr, - !full_matrices); - - framework::DDim UT_dim = U->dims(); - std::swap(UT_dim[rank - 1], UT_dim[rank - 2]); // Get the dim of UT_dim - U->Resize(UT_dim); // U is entirely UT - auto dito = - math::DeviceIndependenceTensorOperations(context); - auto tmp_U = dito.Transpose(*U); - U->ShareDataWith(tmp_U); // U becomse UT, aka VT - } - void GesvdjBatched(const platform::CUDADeviceContext& dev_ctx, - int batchSize, - int m, - int n, - int k, - T* A, - T* U, - T* V, - T* S, - int* info, - int thin_UV = 1) const; -}; - -template <> -void SvdGPUKernel::GesvdjBatched( - const platform::CUDADeviceContext& dev_ctx, - int batchSize, - int m, - int n, - int k, - float* A, - float* U, - float* V, - float* S, - int* info, - int thin_UV) const { - /* compute singular vectors */ - const cusolverEigMode_t jobz = - CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ - gesvdjInfo_t gesvdj_params = NULL; - int lda = m; - int ldu = m; - int ldt = n; - int lwork = 0; - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnSgesvdj_bufferSize(handle, - jobz, - thin_UV, - m, - n, - A, - lda, - S, - U, - ldu, - V, - ldt, - &lwork, - gesvdj_params)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); - float* workspace_ptr = reinterpret_cast(workspace->ptr()); - int stride_A = lda * n; - int stride_U = ldu * (thin_UV ? k : m); - int stride_V = ldt * (thin_UV ? k : n); - for (int i = 0; i < batchSize; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnSgesvdj(handle, - jobz, - thin_UV, - m, - n, - A + stride_A * i, - lda, - S + k * i, - U + stride_U * i, - ldu, - V + stride_V * i, - ldt, - workspace_ptr, - lwork, - info, - gesvdj_params)); - // check the error info - int error_info; - memory::Copy(platform::CPUPlace(), - &error_info, - dev_ctx.GetPlace(), - info, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - error_info, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); - } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); -} - -template <> -void SvdGPUKernel::GesvdjBatched( - const platform::CUDADeviceContext& dev_ctx, - int batchSize, - int m, - int n, - int k, - double* A, - double* U, - double* V, - double* S, - int* info, - int thin_UV) const { - /* compute singular vectors */ - const cusolverEigMode_t jobz = - CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ - gesvdjInfo_t gesvdj_params = NULL; - int lda = m; - int ldu = m; - int ldt = n; - int lwork = 0; - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDgesvdj_bufferSize(handle, - jobz, - thin_UV, - m, - n, - A, - lda, - S, - U, - ldu, - V, - ldt, - &lwork, - gesvdj_params)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); - double* workspace_ptr = reinterpret_cast(workspace->ptr()); - int stride_A = lda * n; - int stride_U = ldu * (thin_UV ? k : m); - int stride_V = ldt * (thin_UV ? k : n); - for (int i = 0; i < batchSize; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDgesvdj(handle, - jobz, - thin_UV, - m, - n, - A + stride_A * i, - lda, - S + k * i, - U + stride_U * i, - ldu, - V + stride_V * i, - ldt, - workspace_ptr, - lwork, - info, - gesvdj_params)); - // check the error info - int error_info; - memory::Copy(platform::CPUPlace(), - &error_info, - dev_ctx.GetPlace(), - info, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - error_info, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); - } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); -} - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(svd, - ops::SvdGPUKernel, - ops::SvdGPUKernel); -REGISTER_OP_CUDA_KERNEL( - svd_grad, - ops::SvdGradKernel, - ops::SvdGradKernel); -#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/svd_op.h b/paddle/fluid/operators/svd_op.h deleted file mode 100644 index b7d3b7d3e5a..00000000000 --- a/paddle/fluid/operators/svd_op.h +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/svd_helper.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/transpose_kernel.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; -using DDim = framework::DDim; - -template -class SvdCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* x = context.Input("X"); - Tensor* U = context.Output("U"); - Tensor* VH = context.Output("VH"); - Tensor* S = context.Output("S"); - int full = context.Attr("full_matrices"); - - /*Create Tensors and output, set the dim ...*/ - auto numel = x->numel(); - auto& orig_dev_ctx = context.template device_context(); - auto& dev_ctx = static_cast< - const typename framework::ConvertToPhiContext::TYPE&>( - orig_dev_ctx); - Tensor trans_x = ::phi::TransposeLast2Dim(dev_ctx, *x); - auto* x_data = trans_x.data(); - auto x_dims = x->dims(); - int rows = x_dims[x_dims.size() - 2]; - int cols = x_dims[x_dims.size() - 1]; - int k = std::min(rows, cols); - int col_u = full ? rows : k; - int col_v = full ? cols : k; - int batches = numel / (rows * cols); - auto* U_out = U->mutable_data>( - context.GetPlace(), - size_t(batches * rows * col_u * sizeof(phi::dtype::Real))); - auto* VH_out = VH->mutable_data>( - context.GetPlace(), - size_t(batches * col_v * cols * sizeof(phi::dtype::Real))); - auto* S_out = S->mutable_data>( - context.GetPlace(), size_t(batches * k * sizeof(phi::dtype::Real))); - /*SVD Use the Eigen Library*/ - math::BatchSvd(x_data, U_out, VH_out, S_out, rows, cols, batches, full); - /* let C[m, n] as a col major matrix with m rows and n cols. - * let R[m, n] is row major matrix with m rows and n cols. - * then we have: R[m,n] = C[m, n].resize((n,m)).tranpose_last_two() - * */ - auto col_major_to_row_major = [&dev_ctx](Tensor* out) { - auto origin_dim = out->dims(); - int64_t& x = origin_dim[origin_dim.size() - 1]; - int64_t& y = origin_dim[origin_dim.size() - 2]; - std::swap(x, y); - out->Resize(origin_dim); - return ::phi::TransposeLast2Dim(dev_ctx, *out); - }; - *U = col_major_to_row_major(U); - *VH = col_major_to_row_major(VH); - } -}; - -template -class SvdGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor& U_const = *ctx.Input("U"); - const framework::Tensor& VH_const = *ctx.Input("VH"); - const framework::Tensor& S = *ctx.Input("S"); - framework::Tensor& dX = - *ctx.Output(framework::GradVarName("X")); - const framework::Tensor& dU_const = - *ctx.Input(framework::GradVarName("U")); - const framework::Tensor& dVH_const = - *ctx.Input(framework::GradVarName("VH")); - - const bool full = ctx.Attr("full_matrices"); - int m = dX.dims()[dX.dims().size() - 2]; - int n = dX.dims()[dX.dims().size() - 1]; - int k = S.dims()[S.dims().size() - 1]; - auto dito = math::DeviceIndependenceTensorOperations(ctx); - framework::Tensor U, VH, dU, dV, dVH; - if (full) { - // if full_matrices is set, slice the U and VT to k columns - U = dito.Slice(U_const, {-1}, {0}, {k}); - VH = dito.Slice(VH_const, {-2}, {0}, {k}); - dU = dito.Slice(dU_const, {-1}, {0}, {k}); - dVH = dito.Slice(dVH_const, {-2}, {0}, {k}); - } else { - U = U_const; - VH = VH_const; - dU = dU_const; - dVH = dVH_const; - } - auto s_inverse = dito.Pow(S, -1); - auto s_square = dito.Pow(S, 2); - auto F = - dito.Sub(dito.Unsqueeze(s_square, -2), dito.Unsqueeze(s_square, -1)); - F = dito.Add(F, dito.Diag(dito.Infinits({k}))); - F = dito.Pow(F, -1); - Tensor sigma_term; - Tensor u_term; - Tensor v_term; - - if (ctx.HasInput(framework::GradVarName("S"))) { - const framework::Tensor& gS = - *ctx.Input(framework::GradVarName("S")); - sigma_term = dito.Mul(dito.Unsqueeze(gS, -2), U); - sigma_term = dito.Matmul(sigma_term, VH); - } - - if (ctx.HasInput(framework::GradVarName("U"))) { - auto UTG = dito.Matmul(U, dU, true, false); - auto GTU = dito.Matmul(dU, U, true, false); - u_term = dito.Mul(dito.Mul(dito.Sub(UTG, GTU), F), dito.Unsqueeze(S, -2)); - u_term = dito.Matmul(U, u_term); - if (m > k) { - auto project = dito.Sub(dito.Eye(m), dito.Matmul(U, U, false, true)); - u_term = dito.Add( - u_term, - dito.Mul(dito.Matmul(project, dU), dito.Unsqueeze(s_inverse, -2))); - } - u_term = dito.Matmul(u_term, VH); - } - - if (ctx.HasInput(framework::GradVarName("VH"))) { - auto UTG = dito.Matmul(VH, dVH, false, true); - auto GTU = dito.Matmul(dVH, VH, false, true); - v_term = dito.Mul(dito.Matmul(dito.Mul(dito.Sub(UTG, GTU), F), VH), - dito.Unsqueeze(S, -1)); - if (n > k) { - auto project = dito.Sub(dito.Eye(n), dito.Matmul(VH, VH, true, false)); - v_term = dito.Add( - v_term, - dito.Mul(dito.Matmul(dVH, project), dito.Unsqueeze(s_inverse, -1))); - } - v_term = dito.Matmul(U, v_term); - } - - dX.ShareDataWith(dito.Add(dito.Add(u_term, sigma_term), v_term)); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 8e5913e10fd..6d0a380e28b 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -103,4 +103,15 @@ void PowKernel(const Context& dev_ctx, const Scalar& factor, DenseTensor* out); +template +DenseTensor Pow(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& factor) { + DenseTensor out; + MetaTensor meta_out(out); + UnchangedInferMeta(x, &meta_out); + PowKernel(dev_ctx, x, factor, &out); + return out; +} + } // namespace phi diff --git a/paddle/phi/kernels/cpu/svd_grad_kernel.cc b/paddle/phi/kernels/cpu/svd_grad_kernel.cc new file mode 100644 index 00000000000..546a9be9fde --- /dev/null +++ b/paddle/phi/kernels/cpu/svd_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/kernels/svd_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/svd_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + svd_grad, CPU, ALL_LAYOUT, phi::SvdGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/svd_kernel.cc b/paddle/phi/kernels/cpu/svd_kernel.cc new file mode 100644 index 00000000000..814a9c451e7 --- /dev/null +++ b/paddle/phi/kernels/cpu/svd_kernel.cc @@ -0,0 +1,132 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/kernels/svd_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +void LapackSvd( + const T* X, T* U, T* VH, T* S, int rows, int cols, int full = false) { + char jobz = full ? 'A' : 'S'; + int mx = std::max(rows, cols); + int mn = std::min(rows, cols); + T* a = const_cast(X); + int lda = rows; + int ldu = rows; + int ldvt = full ? cols : mn; + int lwork = full ? (4 * mn * mn + 6 * mn + mx) : (4 * mn * mn + 7 * mn); + std::vector work(lwork); + std::vector iwork(8 * mn); + int info; + phi::funcs::lapackSvd(jobz, + rows, + cols, + a, + lda, + S, + U, + ldu, + VH, + ldvt, + work.data(), + lwork, + iwork.data(), + &info); + if (info < 0) { + PADDLE_THROW(phi::errors::InvalidArgument( + "This %s-th argument has an illegal value", info)); + } + if (info > 0) { + PADDLE_THROW(phi::errors::InvalidArgument( + "DBDSDC/SBDSDC did not converge, updating process failed. May be you " + "passes a invalid matrix.")); + } +} + +template +void BatchSvd(const T* X, + T* U, + T* VH, + T* S, + int rows, + int cols, + int batches, + int full = false) { + // NOTE: this function is row major, because this function called the lapack. + int stride = rows * cols; + int k = std::min(rows, cols); + int stride_u = full ? rows * rows : k * rows; + int stride_v = full ? cols * cols : k * cols; + for (int i = 0; i < batches; ++i) { + LapackSvd(X + i * stride, + U + i * stride_u, + VH + i * stride_v, + S + i * k, + rows, + cols, + full); + } + return; +} + +template +void SvdKernel(const Context& dev_ctx, + const DenseTensor& X, + bool full_matrices, + DenseTensor* U, + DenseTensor* S, + DenseTensor* VH) { + int full = full_matrices; + /*Create Tensors and output, set the dim ...*/ + auto numel = X.numel(); + DenseTensor trans_x = ::phi::TransposeLast2Dim(dev_ctx, X); + auto* x_data = trans_x.data(); + auto x_dims = X.dims(); + int rows = x_dims[x_dims.size() - 2]; + int cols = x_dims[x_dims.size() - 1]; + // int k = std::min(rows, cols); + // int col_u = full ? rows : k; + // int col_v = full ? cols : k; + int batches = numel / (rows * cols); + auto* U_out = dev_ctx.template Alloc>(U); + auto* VH_out = dev_ctx.template Alloc>(VH); + auto* S_out = dev_ctx.template Alloc>(S); + /*SVD Use the Eigen Library*/ + BatchSvd(x_data, U_out, VH_out, S_out, rows, cols, batches, full); + /* let C[m, n] as a col major matrix with m rows and n cols. + * let R[m, n] is row major matrix with m rows and n cols. + * then we have: R[m,n] = C[m, n].resize((n,m)).tranpose_last_two() + * */ + auto col_major_to_row_major = [&dev_ctx](DenseTensor* out) { + auto origin_dim = out->dims(); + int64_t& x = origin_dim[origin_dim.size() - 1]; + int64_t& y = origin_dim[origin_dim.size() - 2]; + std::swap(x, y); + out->Resize(origin_dim); + return ::phi::TransposeLast2Dim(dev_ctx, *out); + }; + *U = col_major_to_row_major(U); + *VH = col_major_to_row_major(VH); +} + +} // namespace phi + +PD_REGISTER_KERNEL(svd, CPU, ALL_LAYOUT, phi::SvdKernel, float, double) {} diff --git a/paddle/phi/kernels/diag_kernel.h b/paddle/phi/kernels/diag_kernel.h index 3168aea54e6..704632a5208 100644 --- a/paddle/phi/kernels/diag_kernel.h +++ b/paddle/phi/kernels/diag_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { @@ -45,4 +46,16 @@ void DiagKernel(const Context& dev_ctx, float padding_value, DenseTensor* out); +template +DenseTensor Diag(const Context& dev_ctx, + const DenseTensor& x, + int offset, + float padding_value) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + DiagInferMeta(x, offset, padding_value, &meta_out); + DiagKernel(dev_ctx, x, offset, padding_value, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu index c3aa64fe185..5661d61c4e8 100644 --- a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu +++ b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu @@ -35,17 +35,17 @@ namespace phi { template -void GesvdjBatched(const phi::GPUContext& dev_ctx, - int batchSize, - int m, - int n, - int k, - T* A, - T* U, - T* V, - T* S, - int* info, - int thin_UV = 1); +static void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + T* A, + T* U, + T* V, + T* S, + int* info, + int thin_UV = 1); template void SyevjBatched(const phi::GPUContext& dev_ctx, diff --git a/paddle/phi/kernels/gpu/svd_grad_kernel.cu b/paddle/phi/kernels/gpu/svd_grad_kernel.cu new file mode 100644 index 00000000000..cc2051c5f8d --- /dev/null +++ b/paddle/phi/kernels/gpu/svd_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/svd_grad_kernel.h" + +#include "paddle/fluid/memory/memory.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/svd_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + svd_grad, GPU, ALL_LAYOUT, phi::SvdGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/svd_kernel.cu b/paddle/phi/kernels/gpu/svd_kernel.cu new file mode 100644 index 00000000000..d7fd3c9dffd --- /dev/null +++ b/paddle/phi/kernels/gpu/svd_kernel.cu @@ -0,0 +1,253 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include "paddle/phi/kernels/svd_kernel.h" + +#include "paddle/fluid/memory/memory.h" +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +static void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + T* A, + T* U, + T* V, + T* S, + int* info, + int thin_UV = 1); + +template <> +void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + float* A, + float* U, + float* V, + float* S, + int* info, + int thin_UV) { + /* compute singular vectors */ + const cusolverEigMode_t jobz = + CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnSgesvdj_bufferSize(handle, + jobz, + thin_UV, + m, + n, + A, + lda, + S, + U, + ldu, + V, + ldt, + &lwork, + gesvdj_params)); + auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgesvdj(handle, + jobz, + thin_UV, + m, + n, + A + stride_A * i, + lda, + S + k * i, + U + stride_U * i, + ldu, + V + stride_V * i, + ldt, + workspace_ptr, + lwork, + info, + gesvdj_params)); + // check the error info + int error_info; + paddle::memory::Copy(phi::CPUPlace(), + &error_info, + dev_ctx.GetPlace(), + info, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +template <> +void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + double* A, + double* U, + double* V, + double* S, + int* info, + int thin_UV) { + /* compute singular vectors */ + const cusolverEigMode_t jobz = + CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnDgesvdj_bufferSize(handle, + jobz, + thin_UV, + m, + n, + A, + lda, + S, + U, + ldu, + V, + ldt, + &lwork, + gesvdj_params)); + auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDgesvdj(handle, + jobz, + thin_UV, + m, + n, + A + stride_A * i, + lda, + S + k * i, + U + stride_U * i, + ldu, + V + stride_V * i, + ldt, + workspace_ptr, + lwork, + info, + gesvdj_params)); + // check the error info + int error_info; + paddle::memory::Copy(phi::CPUPlace(), + &error_info, + dev_ctx.GetPlace(), + info, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +template +void SvdKernel(const Context& dev_ctx, + const DenseTensor& X, + bool full_matrices, + DenseTensor* U, + DenseTensor* S, + DenseTensor* VH) { + auto& dims = X.dims(); + int batch_count = 1; + for (int i = 0; i < dims.size() - 2; i++) { + batch_count *= dims[i]; + } + int rank = dims.size(); + int m = dims[rank - 2]; + int n = dims[rank - 1]; + + auto* u_data = dev_ctx.template Alloc>(U); + auto* vh_data = dev_ctx.template Alloc>(VH); + auto* s_data = dev_ctx.template Alloc>(S); + // NOTE:(@xiongkun03) + // matrices are assumed to be stored in column-major order in cusolver + // then view A as n x m and do A^T SVD, we can avoid transpose + // Must Copy X once, because the gesvdj will change the origin input matrix + DenseTensor x_tmp; + Copy(dev_ctx, X, dev_ctx.GetPlace(), false, &x_tmp); + auto info = Empty(dev_ctx, {batch_count}); + int* info_ptr = reinterpret_cast(info.data()); + + GesvdjBatched(dev_ctx, + batch_count, + n, + m, + std::min(m, n), + dev_ctx.template Alloc(&x_tmp), + vh_data, + u_data, + s_data, + info_ptr, + !full_matrices); + + auto UT_dim = U->dims(); + std::swap(UT_dim[rank - 1], UT_dim[rank - 2]); // Get the dim of UT_dim + U->Resize(UT_dim); // U is entirely UT + auto tmp_U = TransposeLast2Dim(dev_ctx, *U); + U->ShareDataWith(tmp_U); // U becomse UT, aka VT; +} +} // namespace phi + +PD_REGISTER_KERNEL(svd, // cuda_only + GPU, + ALL_LAYOUT, + phi::SvdKernel, + float, + double) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/impl/svd_grad_kernel_impl.h b/paddle/phi/kernels/impl/svd_grad_kernel_impl.h new file mode 100644 index 00000000000..f87a8910ebe --- /dev/null +++ b/paddle/phi/kernels/impl/svd_grad_kernel_impl.h @@ -0,0 +1,177 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/activation_kernel.h" +#include "paddle/phi/kernels/diag_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/elementwise_multiply_kernel.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/slice_kernel.h" + +namespace phi { + +template +static DenseTensor Fill(const Context& ctx, + std::vector shape, + float fill_value) { + DenseTensor ret; + ret.Resize(make_ddim(shape)); + ctx.template Alloc(&ret); + funcs::SetConstant()(ctx, &ret, T(fill_value)); + return ret; +} + +template +static DenseTensor Eye(const Context& dev_ctx, int n) { + auto output = Fill(dev_ctx, {n}, 1); + auto ret = Diag(dev_ctx, output, 0, 0); + return ret; +} + +template +static DenseTensor Infinits(const Context& ctx, std::vector shape) { + auto value = static_cast(std::numeric_limits::infinity()); + return Fill(ctx, shape, value); +} + +static DenseTensor Unsqueeze(const DenseTensor& x, int axis = 0) { + // don't copy data, only change the dims + DenseTensor out; + out.ShareDataWith(x); + std::vector out_shape = phi::vectorize(x.dims()); + if (axis >= 0) { + auto index = (out_shape.begin() + axis); + out_shape.insert(index, 1); + } else if (axis < 0) { + auto index = (out_shape.end() + axis + 1); + out_shape.insert(index, 1); + } + out.Resize(phi::make_ddim(out_shape)); + return out; +} + +template +void SvdGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& u, + const DenseTensor& vh, + const DenseTensor& s, + const DenseTensor& u_grad, + const DenseTensor& vh_grad, + const DenseTensor& s_grad, + bool full_matrices, + DenseTensor* x_grad) { + const auto& dX = *x_grad; + int m = dX.dims()[dX.dims().size() - 2]; + int n = dX.dims()[dX.dims().size() - 1]; + int k = s.dims()[s.dims().size() - 1]; + DenseTensor U, VH, dU, dV, dVH; + if (full_matrices) { + // if full_matrices is set, slice the U and VT to k columns + U = SliceKernel( + dev_ctx, u, {u.dims().size() - 1}, {0}, {k}, {1}, {}); + VH = SliceKernel( + dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k}, {1}, {}); + dU = SliceKernel( + dev_ctx, u_grad, {u_grad.dims().size() - 1}, {0}, {k}, {1}, {}); + dVH = SliceKernel( + dev_ctx, vh_grad, {vh.dims().size() - 2}, {0}, {k}, {1}, {}); + } else { + U = u; + VH = vh; + dU = u_grad; + dVH = vh_grad; + } + auto s_inverse = Pow(dev_ctx, s, -1); + auto s_square = Pow(dev_ctx, s, 2); + auto F = Subtract( + dev_ctx, Unsqueeze(s_square, -2), Unsqueeze(s_square, -1)); + F = Add( + dev_ctx, + F, + Diag(dev_ctx, Infinits(dev_ctx, {k}), 0, 0)); + F = Pow(dev_ctx, F, -1); + DenseTensor sigma_term; + DenseTensor u_term; + DenseTensor v_term; + + // if (ctx.HasInput(framework::GradVarName("S"))) + { + const DenseTensor& gS = s_grad; + sigma_term = Multiply(dev_ctx, Unsqueeze(gS, -2), U); + sigma_term = Matmul(dev_ctx, sigma_term, VH); + } + + // if (ctx.HasInput(framework::GradVarName("U"))) { + { + auto UTG = Matmul(dev_ctx, U, dU, true, false); + auto GTU = Matmul(dev_ctx, dU, U, true, false); + u_term = Multiply( + dev_ctx, + Multiply( + dev_ctx, Subtract(dev_ctx, UTG, GTU), F), + Unsqueeze(s, -2)); + u_term = Matmul(dev_ctx, U, u_term); + if (m > k) { + auto project = + Subtract(dev_ctx, + Eye(dev_ctx, m), + Matmul(dev_ctx, U, U, false, true)); + u_term = Add( + dev_ctx, + u_term, + Multiply(dev_ctx, + Matmul(dev_ctx, project, dU), + Unsqueeze(s_inverse, -2))); + } + u_term = Matmul(dev_ctx, u_term, VH); + } + // } + + // if (ctx.HasInput(framework::GradVarName("VH"))) { + { + auto UTG = Matmul(dev_ctx, VH, dVH, false, true); + auto GTU = Matmul(dev_ctx, dVH, VH, false, true); + v_term = Multiply( + dev_ctx, + Matmul( + dev_ctx, + Multiply( + dev_ctx, Subtract(dev_ctx, UTG, GTU), F), + VH), + Unsqueeze(s, -1)); + if (n > k) { + auto project = Subtract( + dev_ctx, + Eye(dev_ctx, n), + Matmul(dev_ctx, VH, VH, true, false)); + v_term = Add( + dev_ctx, + v_term, + Multiply(dev_ctx, + Matmul(dev_ctx, dVH, project), + Unsqueeze(s_inverse, -1))); + } + v_term = Matmul(dev_ctx, U, v_term); + } + + *x_grad = Add( + dev_ctx, Add(dev_ctx, u_term, sigma_term), v_term); +} + +} // namespace phi diff --git a/paddle/phi/kernels/slice_kernel.h b/paddle/phi/kernels/slice_kernel.h index c2a96312cdd..e01ff3d74fb 100644 --- a/paddle/phi/kernels/slice_kernel.h +++ b/paddle/phi/kernels/slice_kernel.h @@ -16,6 +16,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { @@ -29,4 +30,21 @@ void SliceRawKernel(const Context& ctx, const std::vector& decrease_axis, DenseTensor* out); +template +DenseTensor SliceKernel(const Context& ctx, + const DenseTensor& input, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + SliceRawInferMeta( + input, axes, starts, ends, infer_flags, decrease_axis, &meta_out); + SliceRawKernel( + ctx, input, axes, starts, ends, infer_flags, decrease_axis, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/svd_grad_kernel.h b/paddle/phi/kernels/svd_grad_kernel.h new file mode 100644 index 00000000000..474fd6ff03d --- /dev/null +++ b/paddle/phi/kernels/svd_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SvdGradKernel(const Context& dev_ctx, + const DenseTensor& X, + const DenseTensor& U, + const DenseTensor& VH, + const DenseTensor& S, + const DenseTensor& U_grad, + const DenseTensor& VH_grad, + const DenseTensor& S_grad, + bool full_matrices, + DenseTensor* X_grad); +} // namespace phi diff --git a/paddle/phi/kernels/svd_kernel.h b/paddle/phi/kernels/svd_kernel.h new file mode 100644 index 00000000000..1497f5a604f --- /dev/null +++ b/paddle/phi/kernels/svd_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SvdKernel(const Context& dev_ctx, + const DenseTensor& X, + bool full_matrices, + DenseTensor* U, + DenseTensor* S, + DenseTensor* VH); + +} // namespace phi diff --git a/paddle/phi/ops/compat/svd_sig.cc b/paddle/phi/ops/compat/svd_sig.cc new file mode 100644 index 00000000000..2b97d23f8b8 --- /dev/null +++ b/paddle/phi/ops/compat/svd_sig.cc @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SvdGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("svd_grad", + {"X", "U", "VH", "S", "U@GRAD", "VH@GRAD", "S@GRAD"}, + {"full_matrices"}, + {"X@GRAD"}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(svd_grad, phi::SvdGradOpArgumentMapping); -- GitLab