diff --git a/paddle/phi/kernels/funcs/values_vectors_functor.h b/paddle/phi/kernels/funcs/values_vectors_functor.h index 336e9c809427c68be79bc8eaddd98193462f5405..a6a6d4097030b49ad394d491a0b4e9c3051a8b2d 100644 --- a/paddle/phi/kernels/funcs/values_vectors_functor.h +++ b/paddle/phi/kernels/funcs/values_vectors_functor.h @@ -27,10 +27,10 @@ namespace phi { namespace funcs { -inline int64_t GetBatchSize(phi::DDim dims) { +inline int64_t GetBatchSize(const phi::DDim &dims) { int64_t batch_size = 1; auto dim_size = dims.size(); - for (int i = 0; i < dim_size - 2; i++) { + for (int i = 0; i < dim_size - 2; ++i) { batch_size *= dims[i]; } return batch_size; @@ -54,6 +54,24 @@ static void CheckEighResult(const int batch, const int info) { info)); } +#ifdef PADDLE_WITH_CUDA +static void CheckEighResult(const GPUContext &dev_ctx, + const int64_t batch_size, + int *info) { + std::vector error_info(batch_size); + paddle::memory::Copy(phi::CPUPlace(), + error_info.data(), + dev_ctx.GetPlace(), + info, + sizeof(int) * batch_size, + dev_ctx.stream()); + dev_ctx.Wait(); + for (auto i = 0; i < batch_size; ++i) { + CheckEighResult(i, error_info[i]); + } +} +#endif + template struct MatrixEighFunctor { void operator()(const DeviceContext &dev_ctx, @@ -95,7 +113,8 @@ struct MatrixEighFunctor { char jobz = has_vectors ? 'V' : 'N'; int n = dims[dim_size - 1]; int64_t lda = std::max(1, n); - // if work = -1, it means that you need to use the lapack function to query + // if work = -1, it means that you need to use the lapack function to + // query // the optimal value int lwork = -1; // The length of the array work int lrwork = -1; // The dimension of the array rwork,rwork is REAL array @@ -188,97 +207,92 @@ struct MatrixEighFunctor { bool is_lower, bool has_vectors) { using ValueType = phi::dtype::Real; - ValueType *out_value = dev_ctx.template Alloc(eigen_values); - DenseTensor input_trans; - input_trans = phi::TransposeLast2Dim(dev_ctx, input); - T *input_vector = input_trans.data(); + int workspace_size = 0; auto &dims = input.dims(); int dim_size = dims.size(); int64_t batch_size = GetBatchSize(dims); + int last_dim = dims[dim_size - 1]; + int lda = std::max(1, last_dim); + auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; + auto values_stride = dims[dim_size - 1]; cublasFillMode_t uplo = is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; cusolverEigMode_t jobz = has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; - int n = dims[dim_size - 1]; - int lda = std::max(1, n); - auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; - auto values_stride = dims[dim_size - 1]; - int lwork = 0; + ValueType *out_value = dev_ctx.template Alloc(eigen_values); auto info = paddle::memory::Alloc(dev_ctx, sizeof(int) * batch_size); auto *info_ptr = reinterpret_cast(info->ptr()); - // When the input type is float32, and the feature value input dimension - // is greater than or equal to [*,32,32] and less than or equal to - // [*,512,512], Syevj has better performance. + DenseTensor input_trans = phi::TransposeLast2Dim(dev_ctx, input); + T *input_vector = input_trans.data(); + + // Once input data type is float32, and the last dimension of + // input is located in range [32, 512], Syevj works better. bool use_syevj = (input.dtype() == phi::DataType::FLOAT32 && values_stride >= 32 && values_stride <= 512); + auto handle = dev_ctx.cusolver_dn_handle(); + syevjInfo_t syevj_params; if (use_syevj) { PADDLE_ENFORCE_GPU_SUCCESS( dynload::cusolverDnCreateSyevjInfo(&syevj_params)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize( dev_ctx.cusolver_dn_handle(), jobz, uplo, - n, + last_dim, reinterpret_cast(input_vector), lda, reinterpret_cast(out_value), - &lwork, + &workspace_size, syevj_params)); } else { EvdBuffer(dev_ctx.cusolver_dn_handle(), jobz, uplo, - n, + last_dim, input_vector, lda, out_value, - &lwork); + &workspace_size); } - auto work = paddle::memory::Alloc(dev_ctx, sizeof(T) * lwork); + auto work = paddle::memory::Alloc(dev_ctx, sizeof(T) * workspace_size); auto *work_ptr = reinterpret_cast(work->ptr()); - for (auto i = 0; i < batch_size; i++) { + + for (auto i = 0; i < batch_size; ++i) { auto *input_data = input_vector + i * vector_stride; auto *value_data = out_value + i * values_stride; - auto handle = dev_ctx.cusolver_dn_handle(); if (use_syevj) { PADDLE_ENFORCE_GPU_SUCCESS( dynload::cusolverDnSsyevj(handle, jobz, uplo, - n, + last_dim, reinterpret_cast(input_data), lda, reinterpret_cast(value_data), reinterpret_cast(work_ptr), - lwork, - info_ptr, + workspace_size, + &info_ptr[i], syevj_params)); } else { Evd(handle, jobz, uplo, - n, + last_dim, input_data, lda, value_data, work_ptr, - lwork, - info_ptr); + workspace_size, + &info_ptr[i]); } - int error_info = 0; - paddle::memory::Copy(phi::CPUPlace(), - &error_info, - dev_ctx.GetPlace(), - info_ptr, - sizeof(int), - dev_ctx.stream()); - CheckEighResult(i, error_info); } + CheckEighResult(dev_ctx, batch_size, info_ptr); if (use_syevj) { PADDLE_ENFORCE_GPU_SUCCESS(