未验证 提交 16e364d3 编写于 作者: 傅剑寒 提交者: GitHub

Optimization of Eigh op with ssyevj_batched runtime api (#48560)

* fix codestyle

* add double complex<float> complex<double> dtype support for syevj_batched

* fix use_syevj flag for precision loss when input dtype of syevj_batch is complex128 in some case

* optimize eigh in different case

* fix missing ; bug

* fix use_syevj bug

* fix use_cusolver_syevj_batched flag
上级 8498ea4f
...@@ -102,7 +102,16 @@ CUSOLVER_ROUTINE_EACH_R1(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) ...@@ -102,7 +102,16 @@ CUSOLVER_ROUTINE_EACH_R1(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
__macro(cusolverDnDsyevj_bufferSize); \ __macro(cusolverDnDsyevj_bufferSize); \
__macro(cusolverDnSsyevj); \ __macro(cusolverDnSsyevj); \
__macro(cusolverDnDsyevj); \ __macro(cusolverDnDsyevj); \
__macro(cusolverDnDestroySyevjInfo); __macro(cusolverDnDestroySyevjInfo); \
__macro(cusolverDnXsyevjSetSortEig); \
__macro(cusolverDnSsyevjBatched_bufferSize); \
__macro(cusolverDnDsyevjBatched_bufferSize); \
__macro(cusolverDnCheevjBatched_bufferSize); \
__macro(cusolverDnZheevjBatched_bufferSize); \
__macro(cusolverDnSsyevjBatched); \
__macro(cusolverDnDsyevjBatched); \
__macro(cusolverDnCheevjBatched); \
__macro(cusolverDnZheevjBatched);
CUSOLVER_ROUTINE_EACH_R2(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) CUSOLVER_ROUTINE_EACH_R2(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
#endif #endif
......
...@@ -114,7 +114,16 @@ CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) ...@@ -114,7 +114,16 @@ CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
__macro(cusolverDnDsyevj_bufferSize); \ __macro(cusolverDnDsyevj_bufferSize); \
__macro(cusolverDnSsyevj); \ __macro(cusolverDnSsyevj); \
__macro(cusolverDnDsyevj); \ __macro(cusolverDnDsyevj); \
__macro(cusolverDnDestroySyevjInfo); __macro(cusolverDnDestroySyevjInfo); \
__macro(cusolverDnXsyevjSetSortEig); \
__macro(cusolverDnSsyevjBatched_bufferSize); \
__macro(cusolverDnDsyevjBatched_bufferSize); \
__macro(cusolverDnCheevjBatched_bufferSize); \
__macro(cusolverDnZheevjBatched_bufferSize); \
__macro(cusolverDnSsyevjBatched); \
__macro(cusolverDnDsyevjBatched); \
__macro(cusolverDnCheevjBatched); \
__macro(cusolverDnZheevjBatched);
CUSOLVER_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) CUSOLVER_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
#endif #endif
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/phi/backends/dynload/cusolver.h" #include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/core/errors.h"
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
...@@ -54,6 +54,137 @@ static void CheckEighResult(const int batch, const int info) { ...@@ -54,6 +54,137 @@ static void CheckEighResult(const int batch, const int info) {
info)); info));
} }
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 11031
static bool use_cusolver_syevj_batched = true;
#else
static bool use_cusolver_syevj_batched = false;
#endif
#define CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
int n, const scalar_t *A, int lda, const value_t *W, int *lwork, \
syevjInfo_t params, int batchsize
template <class scalar_t, class value_t = scalar_t>
void syevjBatched_bufferSize(
CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) {
PADDLE_THROW(phi::errors::InvalidArgument(
"syevjBatched_bufferSize: not implemented for %s",
typeid(scalar_t).name()));
}
template <>
inline void syevjBatched_bufferSize<float>(
CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(float, float)) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevjBatched_bufferSize(
handle, jobz, uplo, n, A, lda, W, lwork, params, batchsize));
}
template <>
inline void syevjBatched_bufferSize<double>(
CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(double, double)) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDsyevjBatched_bufferSize(
handle, jobz, uplo, n, A, lda, W, lwork, params, batchsize));
}
template <>
inline void syevjBatched_bufferSize<phi::dtype::complex<float>, float>(
CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(phi::dtype::complex<float>,
float)) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnCheevjBatched_bufferSize(
handle,
jobz,
uplo,
n,
reinterpret_cast<const cuComplex *>(A),
lda,
W,
lwork,
params,
batchsize));
}
template <>
inline void syevjBatched_bufferSize<phi::dtype::complex<double>, double>(
CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(phi::dtype::complex<double>,
double)) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnZheevjBatched_bufferSize(
handle,
jobz,
uplo,
n,
reinterpret_cast<const cuDoubleComplex *>(A),
lda,
W,
lwork,
params,
batchsize));
}
#define CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t) \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork, \
int *info, syevjInfo_t params, int batchsize
template <class scalar_t, class value_t = scalar_t>
void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t)) {
PADDLE_THROW(phi::errors::InvalidArgument(
"syevjBatched: not implemented for %s", typeid(scalar_t).name()));
}
template <>
inline void syevjBatched<float>(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(float,
float)) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevjBatched(
handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchsize));
}
template <>
inline void syevjBatched<double>(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(double,
double)) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDsyevjBatched(
handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchsize));
}
template <>
inline void syevjBatched<phi::dtype::complex<float>, float>(
CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(phi::dtype::complex<float>, float)) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnCheevjBatched(handle,
jobz,
uplo,
n,
reinterpret_cast<cuComplex *>(A),
lda,
W,
reinterpret_cast<cuComplex *>(work),
lwork,
info,
params,
batchsize));
}
template <>
inline void syevjBatched<phi::dtype::complex<double>, double>(
CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(phi::dtype::complex<double>, double)) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnZheevjBatched(
handle,
jobz,
uplo,
n,
reinterpret_cast<cuDoubleComplex *>(A),
lda,
W,
reinterpret_cast<cuDoubleComplex *>(work),
lwork,
info,
params,
batchsize));
}
#endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
static void CheckEighResult(const GPUContext &dev_ctx, static void CheckEighResult(const GPUContext &dev_ctx,
const int64_t batch_size, const int64_t batch_size,
...@@ -232,17 +363,33 @@ struct MatrixEighFunctor<GPUContext, T> { ...@@ -232,17 +363,33 @@ struct MatrixEighFunctor<GPUContext, T> {
DenseTensor input_trans = phi::TransposeLast2Dim<T>(dev_ctx, input); DenseTensor input_trans = phi::TransposeLast2Dim<T>(dev_ctx, input);
T *input_vector = input_trans.data<T>(); T *input_vector = input_trans.data<T>();
// Once input data type is float32, and the last dimension of // Precision loss will occur in some cases while using
// input is located in range [32, 512], Syevj works better. // cusolverDnZheevjBatched to calculate in Paddle(cuda11.7) but it works
bool use_syevj = (input.dtype() == phi::DataType::FLOAT32 && // well in Paddle(cuda10.2)
values_stride >= 32 && values_stride <= 512); use_cusolver_syevj_batched = (use_cusolver_syevj_batched) &&
(batch_size > 1) &&
(input.dtype() != phi::DataType::COMPLEX128);
bool use_cusolver_syevj = (input.dtype() == phi::DataType::FLOAT32 &&
last_dim >= 32 && last_dim <= 512);
auto handle = dev_ctx.cusolver_dn_handle(); auto handle = dev_ctx.cusolver_dn_handle();
syevjInfo_t syevj_params; syevjInfo_t syevj_params;
if (use_syevj) { if (use_cusolver_syevj_batched) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnCreateSyevjInfo(&syevj_params));
syevjBatched_bufferSize<T>(handle,
jobz,
uplo,
last_dim,
input_vector,
lda,
out_value,
&workspace_size,
syevj_params,
batch_size);
} else if (use_cusolver_syevj) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnCreateSyevjInfo(&syevj_params)); dynload::cusolverDnCreateSyevjInfo(&syevj_params));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize( PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize(
dev_ctx.cusolver_dn_handle(), dev_ctx.cusolver_dn_handle(),
jobz, jobz,
...@@ -272,7 +419,21 @@ struct MatrixEighFunctor<GPUContext, T> { ...@@ -272,7 +419,21 @@ struct MatrixEighFunctor<GPUContext, T> {
for (auto i = 0; i < batch_size; ++i) { for (auto i = 0; i < batch_size; ++i) {
auto *input_data = input_vector + i * vector_stride; auto *input_data = input_vector + i * vector_stride;
auto *value_data = out_value + i * values_stride; auto *value_data = out_value + i * values_stride;
if (use_syevj) { if (use_cusolver_syevj_batched) {
syevjBatched<T>(handle,
jobz,
uplo,
last_dim,
input_data,
lda,
value_data,
work_ptr,
workspace_size,
&info_ptr[i],
syevj_params,
batch_size);
break;
} else if (use_cusolver_syevj) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnSsyevj(handle, dynload::cusolverDnSsyevj(handle,
jobz, jobz,
...@@ -300,7 +461,7 @@ struct MatrixEighFunctor<GPUContext, T> { ...@@ -300,7 +461,7 @@ struct MatrixEighFunctor<GPUContext, T> {
} }
CheckEighResult(dev_ctx, batch_size, info_ptr); CheckEighResult(dev_ctx, batch_size, info_ptr);
if (use_syevj) { if (use_cusolver_syevj_batched || use_cusolver_syevj) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnDestroySyevjInfo(syevj_params)); dynload::cusolverDnDestroySyevjInfo(syevj_params));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册