diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index 854de23150cad7f108e72b175791bc57ef3854f8..c49c30eb65c42d278c594c34a20f040e0f7e6a2c 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -96,13 +96,22 @@ CUSOLVER_ROUTINE_EACH_R1(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) #endif #if CUDA_VERSION >= 9020 -#define CUSOLVER_ROUTINE_EACH_R2(__macro) \ - __macro(cusolverDnCreateSyevjInfo); \ - __macro(cusolverDnSsyevj_bufferSize); \ - __macro(cusolverDnDsyevj_bufferSize); \ - __macro(cusolverDnSsyevj); \ - __macro(cusolverDnDsyevj); \ - __macro(cusolverDnDestroySyevjInfo); +#define CUSOLVER_ROUTINE_EACH_R2(__macro) \ + __macro(cusolverDnCreateSyevjInfo); \ + __macro(cusolverDnSsyevj_bufferSize); \ + __macro(cusolverDnDsyevj_bufferSize); \ + __macro(cusolverDnSsyevj); \ + __macro(cusolverDnDsyevj); \ + __macro(cusolverDnDestroySyevjInfo); \ + __macro(cusolverDnXsyevjSetSortEig); \ + __macro(cusolverDnSsyevjBatched_bufferSize); \ + __macro(cusolverDnDsyevjBatched_bufferSize); \ + __macro(cusolverDnCheevjBatched_bufferSize); \ + __macro(cusolverDnZheevjBatched_bufferSize); \ + __macro(cusolverDnSsyevjBatched); \ + __macro(cusolverDnDsyevjBatched); \ + __macro(cusolverDnCheevjBatched); \ + __macro(cusolverDnZheevjBatched); CUSOLVER_ROUTINE_EACH_R2(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) #endif diff --git a/paddle/phi/backends/dynload/cusolver.h b/paddle/phi/backends/dynload/cusolver.h index 1354e310554804ea5d7402cb0cd62431365e285e..a86e85144fd7fb04535fecdc493f70e50329a9e1 100644 --- a/paddle/phi/backends/dynload/cusolver.h +++ b/paddle/phi/backends/dynload/cusolver.h @@ -108,13 +108,22 @@ CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) #endif #if CUDA_VERSION >= 9020 -#define CUSOLVER_ROUTINE_EACH_R2(__macro) \ - __macro(cusolverDnCreateSyevjInfo); \ - __macro(cusolverDnSsyevj_bufferSize); \ - __macro(cusolverDnDsyevj_bufferSize); \ - __macro(cusolverDnSsyevj); \ - __macro(cusolverDnDsyevj); \ - __macro(cusolverDnDestroySyevjInfo); +#define CUSOLVER_ROUTINE_EACH_R2(__macro) \ + __macro(cusolverDnCreateSyevjInfo); \ + __macro(cusolverDnSsyevj_bufferSize); \ + __macro(cusolverDnDsyevj_bufferSize); \ + __macro(cusolverDnSsyevj); \ + __macro(cusolverDnDsyevj); \ + __macro(cusolverDnDestroySyevjInfo); \ + __macro(cusolverDnXsyevjSetSortEig); \ + __macro(cusolverDnSsyevjBatched_bufferSize); \ + __macro(cusolverDnDsyevjBatched_bufferSize); \ + __macro(cusolverDnCheevjBatched_bufferSize); \ + __macro(cusolverDnZheevjBatched_bufferSize); \ + __macro(cusolverDnSsyevjBatched); \ + __macro(cusolverDnDsyevjBatched); \ + __macro(cusolverDnCheevjBatched); \ + __macro(cusolverDnZheevjBatched); CUSOLVER_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) #endif diff --git a/paddle/phi/kernels/funcs/values_vectors_functor.h b/paddle/phi/kernels/funcs/values_vectors_functor.h index 88bef61fa921ff1fd5dd115b328452fb28716ac8..63202ca4a484d134cf5ce0f75b3612379be59a12 100644 --- a/paddle/phi/kernels/funcs/values_vectors_functor.h +++ b/paddle/phi/kernels/funcs/values_vectors_functor.h @@ -13,10 +13,10 @@ // limitations under the License. #pragma once - #include "paddle/fluid/memory/memory.h" #ifdef PADDLE_WITH_CUDA #include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/core/errors.h" #endif // PADDLE_WITH_CUDA #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" @@ -54,6 +54,137 @@ static void CheckEighResult(const int batch, const int info) { info)); } +#ifdef PADDLE_WITH_CUDA + +#if CUDA_VERSION >= 11031 +static bool use_cusolver_syevj_batched = true; +#else +static bool use_cusolver_syevj_batched = false; +#endif + +#define CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \ + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \ + int n, const scalar_t *A, int lda, const value_t *W, int *lwork, \ + syevjInfo_t params, int batchsize + +template +void syevjBatched_bufferSize( + CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { + PADDLE_THROW(phi::errors::InvalidArgument( + "syevjBatched_bufferSize: not implemented for %s", + typeid(scalar_t).name())); +} + +template <> +inline void syevjBatched_bufferSize( + CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(float, float)) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevjBatched_bufferSize( + handle, jobz, uplo, n, A, lda, W, lwork, params, batchsize)); +} + +template <> +inline void syevjBatched_bufferSize( + CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(double, double)) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDsyevjBatched_bufferSize( + handle, jobz, uplo, n, A, lda, W, lwork, params, batchsize)); +} + +template <> +inline void syevjBatched_bufferSize, float>( + CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(phi::dtype::complex, + float)) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnCheevjBatched_bufferSize( + handle, + jobz, + uplo, + n, + reinterpret_cast(A), + lda, + W, + lwork, + params, + batchsize)); +} + +template <> +inline void syevjBatched_bufferSize, double>( + CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(phi::dtype::complex, + double)) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnZheevjBatched_bufferSize( + handle, + jobz, + uplo, + n, + reinterpret_cast(A), + lda, + W, + lwork, + params, + batchsize)); +} + +#define CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t) \ + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \ + int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork, \ + int *info, syevjInfo_t params, int batchsize + +template +void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t)) { + PADDLE_THROW(phi::errors::InvalidArgument( + "syevjBatched: not implemented for %s", typeid(scalar_t).name())); +} + +template <> +inline void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(float, + float)) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevjBatched( + handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchsize)); +} + +template <> +inline void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(double, + double)) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDsyevjBatched( + handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchsize)); +} + +template <> +inline void syevjBatched, float>( + CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(phi::dtype::complex, float)) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnCheevjBatched(handle, + jobz, + uplo, + n, + reinterpret_cast(A), + lda, + W, + reinterpret_cast(work), + lwork, + info, + params, + batchsize)); +} + +template <> +inline void syevjBatched, double>( + CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(phi::dtype::complex, double)) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnZheevjBatched( + handle, + jobz, + uplo, + n, + reinterpret_cast(A), + lda, + W, + reinterpret_cast(work), + lwork, + info, + params, + batchsize)); +} +#endif + #ifdef PADDLE_WITH_CUDA static void CheckEighResult(const GPUContext &dev_ctx, const int64_t batch_size, @@ -232,17 +363,33 @@ struct MatrixEighFunctor { DenseTensor input_trans = phi::TransposeLast2Dim(dev_ctx, input); T *input_vector = input_trans.data(); - // Once input data type is float32, and the last dimension of - // input is located in range [32, 512], Syevj works better. - bool use_syevj = (input.dtype() == phi::DataType::FLOAT32 && - values_stride >= 32 && values_stride <= 512); + // Precision loss will occur in some cases while using + // cusolverDnZheevjBatched to calculate in Paddle(cuda11.7) but it works + // well in Paddle(cuda10.2) + use_cusolver_syevj_batched = (use_cusolver_syevj_batched) && + (batch_size > 1) && + (input.dtype() != phi::DataType::COMPLEX128); + bool use_cusolver_syevj = (input.dtype() == phi::DataType::FLOAT32 && + last_dim >= 32 && last_dim <= 512); auto handle = dev_ctx.cusolver_dn_handle(); syevjInfo_t syevj_params; - if (use_syevj) { + if (use_cusolver_syevj_batched) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnCreateSyevjInfo(&syevj_params)); + syevjBatched_bufferSize(handle, + jobz, + uplo, + last_dim, + input_vector, + lda, + out_value, + &workspace_size, + syevj_params, + batch_size); + } else if (use_cusolver_syevj) { PADDLE_ENFORCE_GPU_SUCCESS( dynload::cusolverDnCreateSyevjInfo(&syevj_params)); - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize( dev_ctx.cusolver_dn_handle(), jobz, @@ -272,7 +419,21 @@ struct MatrixEighFunctor { for (auto i = 0; i < batch_size; ++i) { auto *input_data = input_vector + i * vector_stride; auto *value_data = out_value + i * values_stride; - if (use_syevj) { + if (use_cusolver_syevj_batched) { + syevjBatched(handle, + jobz, + uplo, + last_dim, + input_data, + lda, + value_data, + work_ptr, + workspace_size, + &info_ptr[i], + syevj_params, + batch_size); + break; + } else if (use_cusolver_syevj) { PADDLE_ENFORCE_GPU_SUCCESS( dynload::cusolverDnSsyevj(handle, jobz, @@ -300,7 +461,7 @@ struct MatrixEighFunctor { } CheckEighResult(dev_ctx, batch_size, info_ptr); - if (use_syevj) { + if (use_cusolver_syevj_batched || use_cusolver_syevj) { PADDLE_ENFORCE_GPU_SUCCESS( dynload::cusolverDnDestroySyevjInfo(syevj_params)); }