未验证 提交 806073d6 编写于 作者: L limingshu 提交者: GitHub

Optimize memcpy operation in Eigh (#42853)

* 1st commit

* fix usless change in header transpose_kernel_h file

* add sync
上级 3591a252
...@@ -27,10 +27,10 @@ ...@@ -27,10 +27,10 @@
namespace phi { namespace phi {
namespace funcs { namespace funcs {
inline int64_t GetBatchSize(phi::DDim dims) { inline int64_t GetBatchSize(const phi::DDim &dims) {
int64_t batch_size = 1; int64_t batch_size = 1;
auto dim_size = dims.size(); auto dim_size = dims.size();
for (int i = 0; i < dim_size - 2; i++) { for (int i = 0; i < dim_size - 2; ++i) {
batch_size *= dims[i]; batch_size *= dims[i];
} }
return batch_size; return batch_size;
...@@ -54,6 +54,24 @@ static void CheckEighResult(const int batch, const int info) { ...@@ -54,6 +54,24 @@ static void CheckEighResult(const int batch, const int info) {
info)); info));
} }
#ifdef PADDLE_WITH_CUDA
static void CheckEighResult(const GPUContext &dev_ctx,
const int64_t batch_size,
int *info) {
std::vector<int> error_info(batch_size);
paddle::memory::Copy(phi::CPUPlace(),
error_info.data(),
dev_ctx.GetPlace(),
info,
sizeof(int) * batch_size,
dev_ctx.stream());
dev_ctx.Wait();
for (auto i = 0; i < batch_size; ++i) {
CheckEighResult(i, error_info[i]);
}
}
#endif
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct MatrixEighFunctor { struct MatrixEighFunctor {
void operator()(const DeviceContext &dev_ctx, void operator()(const DeviceContext &dev_ctx,
...@@ -95,7 +113,8 @@ struct MatrixEighFunctor<CPUContext, T> { ...@@ -95,7 +113,8 @@ struct MatrixEighFunctor<CPUContext, T> {
char jobz = has_vectors ? 'V' : 'N'; char jobz = has_vectors ? 'V' : 'N';
int n = dims[dim_size - 1]; int n = dims[dim_size - 1];
int64_t lda = std::max<int64_t>(1, n); int64_t lda = std::max<int64_t>(1, n);
// if work = -1, it means that you need to use the lapack function to query // if work = -1, it means that you need to use the lapack function to
// query
// the optimal value // the optimal value
int lwork = -1; // The length of the array work int lwork = -1; // The length of the array work
int lrwork = -1; // The dimension of the array rwork,rwork is REAL array int lrwork = -1; // The dimension of the array rwork,rwork is REAL array
...@@ -188,97 +207,92 @@ struct MatrixEighFunctor<GPUContext, T> { ...@@ -188,97 +207,92 @@ struct MatrixEighFunctor<GPUContext, T> {
bool is_lower, bool is_lower,
bool has_vectors) { bool has_vectors) {
using ValueType = phi::dtype::Real<T>; using ValueType = phi::dtype::Real<T>;
ValueType *out_value = dev_ctx.template Alloc<ValueType>(eigen_values);
DenseTensor input_trans; int workspace_size = 0;
input_trans = phi::TransposeLast2Dim<T>(dev_ctx, input);
T *input_vector = input_trans.data<T>();
auto &dims = input.dims(); auto &dims = input.dims();
int dim_size = dims.size(); int dim_size = dims.size();
int64_t batch_size = GetBatchSize(dims); int64_t batch_size = GetBatchSize(dims);
int last_dim = dims[dim_size - 1];
int lda = std::max<int>(1, last_dim);
auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
auto values_stride = dims[dim_size - 1];
cublasFillMode_t uplo = cublasFillMode_t uplo =
is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
cusolverEigMode_t jobz = cusolverEigMode_t jobz =
has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
int n = dims[dim_size - 1]; ValueType *out_value = dev_ctx.template Alloc<ValueType>(eigen_values);
int lda = std::max<int>(1, n);
auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
auto values_stride = dims[dim_size - 1];
int lwork = 0;
auto info = paddle::memory::Alloc(dev_ctx, sizeof(int) * batch_size); auto info = paddle::memory::Alloc(dev_ctx, sizeof(int) * batch_size);
auto *info_ptr = reinterpret_cast<int *>(info->ptr()); auto *info_ptr = reinterpret_cast<int *>(info->ptr());
// When the input type is float32, and the feature value input dimension DenseTensor input_trans = phi::TransposeLast2Dim<T>(dev_ctx, input);
// is greater than or equal to [*,32,32] and less than or equal to T *input_vector = input_trans.data<T>();
// [*,512,512], Syevj has better performance.
// Once input data type is float32, and the last dimension of
// input is located in range [32, 512], Syevj works better.
bool use_syevj = (input.dtype() == phi::DataType::FLOAT32 && bool use_syevj = (input.dtype() == phi::DataType::FLOAT32 &&
values_stride >= 32 && values_stride <= 512); values_stride >= 32 && values_stride <= 512);
auto handle = dev_ctx.cusolver_dn_handle();
syevjInfo_t syevj_params; syevjInfo_t syevj_params;
if (use_syevj) { if (use_syevj) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnCreateSyevjInfo(&syevj_params)); dynload::cusolverDnCreateSyevjInfo(&syevj_params));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize( PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize(
dev_ctx.cusolver_dn_handle(), dev_ctx.cusolver_dn_handle(),
jobz, jobz,
uplo, uplo,
n, last_dim,
reinterpret_cast<const float *>(input_vector), reinterpret_cast<const float *>(input_vector),
lda, lda,
reinterpret_cast<const float *>(out_value), reinterpret_cast<const float *>(out_value),
&lwork, &workspace_size,
syevj_params)); syevj_params));
} else { } else {
EvdBuffer(dev_ctx.cusolver_dn_handle(), EvdBuffer(dev_ctx.cusolver_dn_handle(),
jobz, jobz,
uplo, uplo,
n, last_dim,
input_vector, input_vector,
lda, lda,
out_value, out_value,
&lwork); &workspace_size);
} }
auto work = paddle::memory::Alloc(dev_ctx, sizeof(T) * lwork); auto work = paddle::memory::Alloc(dev_ctx, sizeof(T) * workspace_size);
auto *work_ptr = reinterpret_cast<T *>(work->ptr()); auto *work_ptr = reinterpret_cast<T *>(work->ptr());
for (auto i = 0; i < batch_size; i++) {
for (auto i = 0; i < batch_size; ++i) {
auto *input_data = input_vector + i * vector_stride; auto *input_data = input_vector + i * vector_stride;
auto *value_data = out_value + i * values_stride; auto *value_data = out_value + i * values_stride;
auto handle = dev_ctx.cusolver_dn_handle();
if (use_syevj) { if (use_syevj) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnSsyevj(handle, dynload::cusolverDnSsyevj(handle,
jobz, jobz,
uplo, uplo,
n, last_dim,
reinterpret_cast<float *>(input_data), reinterpret_cast<float *>(input_data),
lda, lda,
reinterpret_cast<float *>(value_data), reinterpret_cast<float *>(value_data),
reinterpret_cast<float *>(work_ptr), reinterpret_cast<float *>(work_ptr),
lwork, workspace_size,
info_ptr, &info_ptr[i],
syevj_params)); syevj_params));
} else { } else {
Evd(handle, Evd(handle,
jobz, jobz,
uplo, uplo,
n, last_dim,
input_data, input_data,
lda, lda,
value_data, value_data,
work_ptr, work_ptr,
lwork, workspace_size,
info_ptr); &info_ptr[i]);
} }
int error_info = 0;
paddle::memory::Copy(phi::CPUPlace(),
&error_info,
dev_ctx.GetPlace(),
info_ptr,
sizeof(int),
dev_ctx.stream());
CheckEighResult(i, error_info);
} }
CheckEighResult(dev_ctx, batch_size, info_ptr);
if (use_syevj) { if (use_syevj) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册