diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 468c658e5e640a493f0a7d4cb74c0753f0d565d9..4466085873814674d1dc63217a63140ac272c6cc 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -30,6 +30,7 @@ #include "paddle/phi/core/ddim.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -44,41 +45,43 @@ template ; template -void EigenSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, - int full = false) { - auto flag = Eigen::DecompositionOptions::ComputeThinU | - Eigen::DecompositionOptions::ComputeThinV; - if (full) { - flag = Eigen::DecompositionOptions::ComputeFullU | - Eigen::DecompositionOptions::ComputeFullV; - } - Eigen::BDCSVD< - Eigen::Matrix> - svd(2, 2, flag); - /*NOTE(xiongkun03) Eigen::Matrix API need non-const pointer.*/ - T* input = const_cast(X); - auto m = Eigen::Map< - Eigen::Matrix>( - input, rows, cols); - svd.compute(m); - Eigen::Matrix V_trans = - svd.matrixV().transpose(); - memcpy(U, svd.matrixU().data(), svd.matrixU().size() * sizeof(T)); - memcpy(VH, V_trans.data(), V_trans.size() * sizeof(T)); - memcpy(S, svd.singularValues().data(), - svd.singularValues().size() * sizeof(T)); +void LapackSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, + int full = false) { + char jobz = full ? 'A' : 'S'; + int mx = std::max(rows, cols); + int mn = std::min(rows, cols); + T* a = const_cast(X); + int lda = rows; + int ldu = rows; + int ldvt = full ? cols : mn; + int lwork = full ? (4 * mn * mn + 6 * mn + mx) : (4 * mn * mn + 7 * mn); + std::vector work(lwork); + std::vector iwork(8 * mn); + int info; + phi::funcs::lapackSvd(jobz, rows, cols, a, lda, S, U, ldu, VH, ldvt, + work.data(), lwork, iwork.data(), &info); + if (info < 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "This %s-th argument has an illegal value", info)); + } + if (info > 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "DBDSDC/SBDSDC did not converge, updating process failed. May be you " + "passes a invalid matrix.")); + } } template void BatchSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, int batches, int full = false) { + // NOTE: this function is row major, because this function called the lapack. int stride = rows * cols; int k = std::min(rows, cols); int stride_u = full ? rows * rows : k * rows; int stride_v = full ? cols * cols : k * cols; for (int i = 0; i < batches; ++i) { - EigenSvd(X + i * stride, U + i * stride_u, VH + i * stride_v, S + i * k, - rows, cols, full); + LapackSvd(X + i * stride, U + i * stride_u, VH + i * stride_v, S + i * k, + rows, cols, full); } return; } diff --git a/paddle/fluid/operators/svd_op.h b/paddle/fluid/operators/svd_op.h index 1008a69e6de0fa7d0de4d9edf8307b5705762aac..2cc30ea0bc4639946a2e9338fd3edf4c55ce874b 100644 --- a/paddle/fluid/operators/svd_op.h +++ b/paddle/fluid/operators/svd_op.h @@ -21,6 +21,7 @@ #include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/transpose_kernel.h" namespace paddle { namespace operators { @@ -39,7 +40,12 @@ class SvdCPUKernel : public framework::OpKernel { /*Create Tensors and output, set the dim ...*/ auto numel = x->numel(); - auto* x_data = x->data(); + auto& orig_dev_ctx = + context.template device_context(); + auto& dev_ctx = static_cast::TYPE&>(orig_dev_ctx); + Tensor trans_x = ::phi::TransposeLast2Dim(dev_ctx, *x); + auto* x_data = trans_x.data(); auto x_dims = x->dims(); int rows = x_dims[x_dims.size() - 2]; int cols = x_dims[x_dims.size() - 1]; @@ -57,6 +63,20 @@ class SvdCPUKernel : public framework::OpKernel { context.GetPlace(), size_t(batches * k * sizeof(phi::dtype::Real))); /*SVD Use the Eigen Library*/ math::BatchSvd(x_data, U_out, VH_out, S_out, rows, cols, batches, full); + /* let C[m, n] as a col major matrix with m rows and n cols. + * let R[m, n] is row major matrix with m rows and n cols. + * then we have: R[m,n] = C[m, n].resize((n,m)).tranpose_last_two() + * */ + auto col_major_to_row_major = [&dev_ctx](Tensor* out) { + auto origin_dim = out->dims(); + int64_t& x = origin_dim[origin_dim.size() - 1]; + int64_t& y = origin_dim[origin_dim.size() - 2]; + std::swap(x, y); + out->Resize(origin_dim); + return ::phi::TransposeLast2Dim(dev_ctx, *out); + }; + *U = col_major_to_row_major(U); + *VH = col_major_to_row_major(VH); } }; diff --git a/paddle/phi/backends/dynload/lapack.h b/paddle/phi/backends/dynload/lapack.h index f0e1e9ad7a4c0009203d170a45ee792b51bbac11..1a680e32d1c32ee0d4e4cea01c02fbab58911fd4 100644 --- a/paddle/phi/backends/dynload/lapack.h +++ b/paddle/phi/backends/dynload/lapack.h @@ -280,6 +280,34 @@ extern "C" void spotrs_(char *uplo, float *b, int *ldb, int *info); +extern "C" void dgesdd_(char *, + int *, + int *, + double *, + int *, + double *, + double *, + int *, + double *, + int *, + double *, + int *, + int *, + int *); +extern "C" void sgesdd_(char *, + int *, + int *, + float *, + int *, + float *, + float *, + int *, + float *, + int *, + float *, + int *, + int *, + int *); namespace phi { namespace dynload { @@ -328,6 +356,8 @@ extern void *lapack_dso_handle; __macro(sgelsy_); \ __macro(dgelss_); \ __macro(sgelss_); \ + __macro(sgesdd_); \ + __macro(dgesdd_); \ __macro(zpotrs_); \ __macro(cpotrs_); \ __macro(dpotrs_); \ diff --git a/paddle/phi/kernels/funcs/lapack/lapack_function.cc b/paddle/phi/kernels/funcs/lapack/lapack_function.cc index 247bb52153c3e5619aa730ee601ab3efe54c0b9d..09d45fcf24be991a1b0842225cadff970ddad8ee 100644 --- a/paddle/phi/kernels/funcs/lapack/lapack_function.cc +++ b/paddle/phi/kernels/funcs/lapack/lapack_function.cc @@ -499,5 +499,43 @@ void lapackCholeskySolve(char uplo, dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); } +template <> +void lapackSvd(char jobz, + int m, + int n, + double *a, + int lda, + double *s, + double *u, + int ldu, + double *vt, + int ldvt, + double *work, + int lwork, + int *iwork, + int *info) { + dynload::dgesdd_( + &jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info); +} + +template <> +void lapackSvd(char jobz, + int m, + int n, + float *a, + int lda, + float *s, + float *u, + int ldu, + float *vt, + int ldvt, + float *work, + int lwork, + int *iwork, + int *info) { + dynload::sgesdd_( + &jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info); +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/lapack/lapack_function.h b/paddle/phi/kernels/funcs/lapack/lapack_function.h index a57b52f04a11e9b5cf5c92f7e6918afd0b65444b..d251095bb79f066331997400a196ec7489e63c3c 100644 --- a/paddle/phi/kernels/funcs/lapack/lapack_function.h +++ b/paddle/phi/kernels/funcs/lapack/lapack_function.h @@ -120,6 +120,22 @@ void lapackGelss(int m, T2 *rwork, int *info); +template +void lapackSvd(char jobz, + int m, + int n, + T *a, + int lda, + T *s, + T *u, + int ldu, + T *vt, + int ldvt, + T *work, + int lwork, + int *iwork, + int *info); + template void lapackCholeskySolve( char uplo, int n, int nrhs, T *a, int lda, T *b, int ldb, int *info);