eigen_values_vectors.h 15.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/svd_helper.h"
19
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cusolver.h"
#endif  // PADDLE_WITH_CUDA

namespace paddle {
namespace operators {
namespace math {

inline int64_t GetBatchSize(framework::DDim dims) {
  int64_t batch_size = 1;
  auto dim_size = dims.size();
  for (int i = 0; i < dim_size - 2; i++) {
    batch_size *= dims[i];
  }
  return batch_size;
}

37 38
static void CheckEighResult(const int batch, const int info) {
  PADDLE_ENFORCE_LE(
39 40
      info,
      0,
41 42 43
      platform::errors::PreconditionNotMet(
          "For batch [%d]: the [%d] off-diagonal elements of an intermediate"
          "tridiagonal form did not converge to zero",
44 45
          batch,
          info));
46
  PADDLE_ENFORCE_GE(
47 48
      info,
      0,
49
      platform::errors::PreconditionNotMet(
50 51
          "For batch [%d]: the [%d] argument had an illegal value",
          batch,
52
          info));
53 54 55
}

template <typename DeviceContext, typename T>
56
struct MatrixEighFunctor {
57 58 59 60 61
  void operator()(const framework::ExecutionContext &ctx,
                  const Tensor &input,
                  Tensor *eigen_values,
                  Tensor *eigen_vectors,
                  bool is_lower,
62 63 64
                  bool has_vectors);
};

65 66 67
// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices, and uses the variable has_vectors to
// control whether to return the eigenvectors.
68
template <typename T>
L
Leo Chen 已提交
69
struct MatrixEighFunctor<phi::CPUContext, T> {
70
 public:
71 72 73 74 75
  void operator()(const framework::ExecutionContext &ctx,
                  const Tensor &input,
                  Tensor *eigen_values,
                  Tensor *eigen_vectors,
                  bool is_lower,
76
                  bool has_vectors) {
77
    using ValueType = phi::dtype::Real<T>;
78
    auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());
79

80
    auto dito =
L
Leo Chen 已提交
81
        math::DeviceIndependenceTensorOperations<phi::CPUContext, T>(ctx);
82

83 84 85 86 87
    Tensor input_trans;
    // lapack is a column-major storge, transpose make the input to
    // have a continuous memory layout
    input_trans = dito.Transpose(input);
    auto *input_vector = input_trans.data<T>();
88

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    auto dims = input.dims();
    int dim_size = dims.size();
    int64_t batch_size = GetBatchSize(dims);

    int vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
    int values_stride = dims[dim_size - 1];
    char uplo = is_lower ? 'L' : 'U';
    char jobz = has_vectors ? 'V' : 'N';
    auto n = dims[dim_size - 1];
    auto lda = std::max<int64_t>(1, n);
    // if work = -1, it means that you need to use the lapack function to query
    // the optimal value
    int lwork = -1;      // The length of the array work
    int lrwork = -1;     // The dimension of the array rwork,rwork is REAL array
    int liwork = -1;     // The dimension of the array iwork
    int iwork_opt = -1;  // The optimal length of the array liwork
    T lwork_opt = static_cast<T>(-1);  // The optimal length of the array work
    ValueType rwork_opt =
        static_cast<ValueType>(-1);  // The optimal length of the array rwork

    int info = 0;
    // Call lapackEigh to get the optimal size of work data
111 112 113 114 115 116 117 118 119 120 121 122 123
    phi::funcs::lapackEigh<T, ValueType>(jobz,
                                         uplo,
                                         n,
                                         input_vector,
                                         lda,
                                         out_value,
                                         &lwork_opt,
                                         lwork,
                                         &rwork_opt,
                                         lrwork,
                                         &iwork_opt,
                                         liwork,
                                         &info);
124 125 126 127 128 129 130
    lwork = std::max<int>(1, static_cast<int>(lwork_opt));
    liwork = std::max<int>(1, iwork_opt);

    Tensor rwork_tensor;
    ValueType *rwork_data = nullptr;

    // complex type
131 132
    if (framework::IsComplexType(
            framework::TransToProtoVarType(input.dtype()))) {
133 134
      lrwork = std::max<int>(1, static_cast<int>(rwork_opt));
      rwork_data = rwork_tensor.mutable_data<ValueType>(
135
          phi::make_ddim({lrwork}), ctx.GetPlace());
136 137
    }
    Tensor iwork_tensor, work_tensor;
138
    auto *iwork_data = iwork_tensor.mutable_data<int>(phi::make_ddim({liwork}),
139 140
                                                      ctx.GetPlace());
    auto *work_data =
141
        work_tensor.mutable_data<T>(phi::make_ddim({lwork}), ctx.GetPlace());
142 143 144 145

    for (auto i = 0; i < batch_size; i++) {
      auto *value_data = out_value + i * values_stride;
      auto *input_data = input_vector + i * vector_stride;
146 147 148 149 150 151 152 153 154 155 156 157 158
      phi::funcs::lapackEigh<T, phi::dtype::Real<T>>(jobz,
                                                     uplo,
                                                     n,
                                                     input_data,
                                                     lda,
                                                     value_data,
                                                     work_data,
                                                     lwork,
                                                     rwork_data,
                                                     lrwork,
                                                     iwork_data,
                                                     liwork,
                                                     &info);
159 160 161 162 163 164 165 166 167 168
      CheckEighResult(i, info);
    }
    if (has_vectors) {
      PADDLE_ENFORCE_NOT_NULL(eigen_vectors,
                              platform::errors::InvalidArgument(
                                  "When has_vectors is true,"
                                  "the eigenvectors needs to be calculated, "
                                  "so the eigenvectors must be provided."));
      input_trans = dito.Transpose(input_trans);
      eigen_vectors->ShareDataWith(input_trans);
169 170 171 172 173 174 175 176 177
    }
  }
};

#ifdef PADDLE_WITH_CUDA

// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices on GPU, and uses the variable has_vectors
// to control whether to return the eigenvectors.
178
template <typename T>
L
Leo Chen 已提交
179
struct MatrixEighFunctor<phi::GPUContext, T> {
180
 public:
181 182 183 184 185
  void operator()(const framework::ExecutionContext &ctx,
                  const Tensor &input,
                  Tensor *eigen_values,
                  Tensor *eigen_vectors,
                  bool is_lower,
186
                  bool has_vectors) {
187
    using ValueType = phi::dtype::Real<T>;
188 189
    auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());

L
Leo Chen 已提交
190
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
191
    auto dito =
L
Leo Chen 已提交
192
        math::DeviceIndependenceTensorOperations<phi::GPUContext, T>(ctx);
193 194 195
    Tensor input_trans;
    input_trans = dito.Transpose(input);
    auto *input_vector = input_trans.data<T>();
196 197 198 199 200 201 202 203 204 205 206 207 208 209
    auto &dims = input.dims();
    int dim_size = dims.size();
    int64_t batch_size = GetBatchSize(dims);

    cublasFillMode_t uplo =
        is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
    cusolverEigMode_t jobz =
        has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;

    int n = dims[dim_size - 1];
    int lda = std::max<int>(1, n);
    auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
    auto values_stride = dims[dim_size - 1];
    int lwork = 0;
L
Leo Chen 已提交
210 211 212 213
    auto info = memory::Alloc(
        dev_ctx.GetPlace(),
        sizeof(int) * batch_size,
        phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
214 215 216 217 218
    auto *info_ptr = reinterpret_cast<int *>(info->ptr());

    // When the input type is float32, and the feature value input dimension is
    // greater than or equal to [*,32,32]  and less than or equal to
    // [*,512,512], Syevj has better performance.
219 220
    bool use_syevj = (framework::TransToProtoVarType(input.dtype()) ==
                          framework::proto::VarType::FP32 &&
221
                      values_stride >= 32 && values_stride <= 512);
222 223
    syevjInfo_t syevj_params;
    if (use_syevj) {
224
      PADDLE_ENFORCE_GPU_SUCCESS(
225
          platform::dynload::cusolverDnCreateSyevjInfo(&syevj_params));
226
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSsyevj_bufferSize(
227 228 229 230 231 232 233 234 235
          dev_ctx.cusolver_dn_handle(),
          jobz,
          uplo,
          n,
          reinterpret_cast<const float *>(input_vector),
          lda,
          reinterpret_cast<const float *>(out_value),
          &lwork,
          syevj_params));
236
    } else {
237 238 239 240 241 242 243 244
      EvdBuffer(dev_ctx.cusolver_dn_handle(),
                jobz,
                uplo,
                n,
                input_vector,
                lda,
                out_value,
                &lwork);
245
    }
L
Leo Chen 已提交
246 247 248 249
    auto work = memory::Alloc(
        dev_ctx.GetPlace(),
        sizeof(T) * lwork,
        phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
250 251
    auto *work_ptr = reinterpret_cast<T *>(work->ptr());
    for (auto i = 0; i < batch_size; i++) {
252 253
      auto *input_data = input_vector + i * vector_stride;
      auto *value_data = out_value + i * values_stride;
254 255
      auto handle = dev_ctx.cusolver_dn_handle();
      if (use_syevj) {
256
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSsyevj(
257 258 259 260 261 262
            handle,
            jobz,
            uplo,
            n,
            reinterpret_cast<float *>(input_data),
            lda,
263
            reinterpret_cast<float *>(value_data),
264 265 266
            reinterpret_cast<float *>(work_ptr),
            lwork,
            info_ptr,
267 268
            syevj_params));
      } else {
269 270 271 272 273 274 275 276 277
        Evd(handle,
            jobz,
            uplo,
            n,
            input_data,
            lda,
            value_data,
            work_ptr,
            lwork,
278
            info_ptr);
279
      }
280
      int error_info = 0;
281 282 283 284 285 286
      memory::Copy(platform::CPUPlace(),
                   &error_info,
                   dev_ctx.GetPlace(),
                   info_ptr,
                   sizeof(int),
                   dev_ctx.stream());
287
      CheckEighResult(i, error_info);
288 289 290
    }

    if (use_syevj) {
291
      PADDLE_ENFORCE_GPU_SUCCESS(
292 293 294
          platform::dynload::cusolverDnDestroySyevjInfo(syevj_params));
    }
    if (has_vectors) {
295 296 297 298 299 300 301
      PADDLE_ENFORCE_NOT_NULL(eigen_vectors,
                              platform::errors::InvalidArgument(
                                  "When has_vectors is true,"
                                  "the eigenvectors needs to be calculated,"
                                  "so the eigenvectors must be provided."));
      input_trans = dito.Transpose(input_trans);
      eigen_vectors->ShareDataWith(input_trans);
302 303 304
    }
  }

305
  using ValueType = phi::dtype::Real<T>;
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
  inline void EvdBuffer(cusolverDnHandle_t handle,
                        cusolverEigMode_t jobz,
                        cublasFillMode_t uplo,
                        int n,
                        const T *A,
                        int lda,
                        const ValueType *W,
                        int *lwork) const;

  inline void Evd(cusolverDnHandle_t handle,
                  cusolverEigMode_t jobz,
                  cublasFillMode_t uplo,
                  int n,
                  T *A,
                  int lda,
                  ValueType *W,
                  T *work,
                  int lwork,
                  int *devInfo) const;
325 326
};

327 328 329 330
#define FUNC_WITH_TYPES(m)                                \
  m(float, Ssy, float) m(double, Dsy, double)             \
      m(paddle::platform::complex<float>, Che, cuComplex) \
          m(paddle::platform::complex<double>, Zhe, cuDoubleComplex)
331

L
Leo Chen 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
#define EVDBUFFER_INSTANCE(T, C, CastType)                      \
  template <>                                                   \
  inline void MatrixEighFunctor<phi::GPUContext, T>::EvdBuffer( \
      cusolverDnHandle_t handle,                                \
      cusolverEigMode_t jobz,                                   \
      cublasFillMode_t uplo,                                    \
      int n,                                                    \
      const T *A,                                               \
      int lda,                                                  \
      const ValueType *W,                                       \
      int *lwork) const {                                       \
    PADDLE_ENFORCE_GPU_SUCCESS(                                 \
        platform::dynload::cusolverDn##C##evd_bufferSize(       \
            handle,                                             \
            jobz,                                               \
            uplo,                                               \
            n,                                                  \
            reinterpret_cast<const CastType *>(A),              \
            lda,                                                \
            W,                                                  \
            lwork));                                            \
353 354 355 356
  }

FUNC_WITH_TYPES(EVDBUFFER_INSTANCE);

357 358
#define EVD_INSTANCE(T, C, CastType)                                  \
  template <>                                                         \
L
Leo Chen 已提交
359
  inline void MatrixEighFunctor<phi::GPUContext, T>::Evd(             \
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
      cusolverDnHandle_t handle,                                      \
      cusolverEigMode_t jobz,                                         \
      cublasFillMode_t uplo,                                          \
      int n,                                                          \
      T *A,                                                           \
      int lda,                                                        \
      ValueType *W,                                                   \
      T *work,                                                        \
      int lwork,                                                      \
      int *devInfo) const {                                           \
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDn##C##evd( \
        handle,                                                       \
        jobz,                                                         \
        uplo,                                                         \
        n,                                                            \
        reinterpret_cast<CastType *>(A),                              \
        lda,                                                          \
        W,                                                            \
        reinterpret_cast<CastType *>(work),                           \
        lwork,                                                        \
        devInfo));                                                    \
381 382 383 384 385 386 387 388 389 390 391 392 393
  }

FUNC_WITH_TYPES(EVD_INSTANCE);

#undef FUNC_WITH_TYPES
#undef EVDBUFFER_INSTANCE
#undef EVD_INSTANCE

#endif  // PADDLE_WITH_CUDA

}  // namespace math
}  // namespace operators
}  // namespace paddle