eigen_values_vectors.h 12.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/memory/memory.h"
18
#include "paddle/fluid/operators/math/lapack_function.h"
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
#include "paddle/fluid/operators/svd_helper.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cusolver.h"
#endif  // PADDLE_WITH_CUDA

namespace paddle {
namespace operators {
namespace math {

inline int64_t GetBatchSize(framework::DDim dims) {
  int64_t batch_size = 1;
  auto dim_size = dims.size();
  for (int i = 0; i < dim_size - 2; i++) {
    batch_size *= dims[i];
  }
  return batch_size;
}

37 38 39 40 41 42 43 44 45 46 47 48 49 50
static void CheckEighResult(const int batch, const int info) {
  PADDLE_ENFORCE_LE(
      info, 0,
      platform::errors::PreconditionNotMet(
          "For batch [%d]: the [%d] off-diagonal elements of an intermediate"
          "tridiagonal form did not converge to zero",
          batch, info));
  PADDLE_ENFORCE_GE(
      info, 0, platform::errors::PreconditionNotMet(
                   "For batch [%d]: the [%d] argument had an illegal value",
                   batch, info));
}

template <typename DeviceContext, typename T>
51 52 53 54 55 56
struct MatrixEighFunctor {
  void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
                  Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
                  bool has_vectors);
};

57 58 59
// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices, and uses the variable has_vectors to
// control whether to return the eigenvectors.
60 61
template <typename T>
struct MatrixEighFunctor<platform::CPUDeviceContext, T> {
62 63 64 65
 public:
  void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
                  Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
                  bool has_vectors) {
66 67
    using ValueType = math::Real<T>;
    auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());
68

69
    auto dito =
70 71
        math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext, T>(
            ctx);
72

73 74 75 76 77
    Tensor input_trans;
    // lapack is a column-major storge, transpose make the input to
    // have a continuous memory layout
    input_trans = dito.Transpose(input);
    auto *input_vector = input_trans.data<T>();
78

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    auto dims = input.dims();
    int dim_size = dims.size();
    int64_t batch_size = GetBatchSize(dims);

    int vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
    int values_stride = dims[dim_size - 1];
    char uplo = is_lower ? 'L' : 'U';
    char jobz = has_vectors ? 'V' : 'N';
    auto n = dims[dim_size - 1];
    auto lda = std::max<int64_t>(1, n);
    // if work = -1, it means that you need to use the lapack function to query
    // the optimal value
    int lwork = -1;      // The length of the array work
    int lrwork = -1;     // The dimension of the array rwork,rwork is REAL array
    int liwork = -1;     // The dimension of the array iwork
    int iwork_opt = -1;  // The optimal length of the array liwork
    T lwork_opt = static_cast<T>(-1);  // The optimal length of the array work
    ValueType rwork_opt =
        static_cast<ValueType>(-1);  // The optimal length of the array rwork

    int info = 0;
    // Call lapackEigh to get the optimal size of work data
    math::lapackEigh<T, ValueType>(jobz, uplo, n, input_vector, lda, out_value,
                                   &lwork_opt, lwork, &rwork_opt, lrwork,
                                   &iwork_opt, liwork, &info);
    lwork = std::max<int>(1, static_cast<int>(lwork_opt));
    liwork = std::max<int>(1, iwork_opt);

    Tensor rwork_tensor;
    ValueType *rwork_data = nullptr;

    // complex type
111 112
    if (framework::IsComplexType(
            framework::TransToProtoVarType(input.dtype()))) {
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
      lrwork = std::max<int>(1, static_cast<int>(rwork_opt));
      rwork_data = rwork_tensor.mutable_data<ValueType>(
          framework::make_ddim({lrwork}), ctx.GetPlace());
    }
    Tensor iwork_tensor, work_tensor;
    auto *iwork_data = iwork_tensor.mutable_data<int>(
        framework::make_ddim({liwork}), ctx.GetPlace());
    auto *work_data = work_tensor.mutable_data<T>(framework::make_ddim({lwork}),
                                                  ctx.GetPlace());

    for (auto i = 0; i < batch_size; i++) {
      auto *value_data = out_value + i * values_stride;
      auto *input_data = input_vector + i * vector_stride;
      math::lapackEigh<T, Real<T>>(jobz, uplo, n, input_data, lda, value_data,
                                   work_data, lwork, rwork_data, lrwork,
                                   iwork_data, liwork, &info);
      CheckEighResult(i, info);
    }
    if (has_vectors) {
      PADDLE_ENFORCE_NOT_NULL(eigen_vectors,
                              platform::errors::InvalidArgument(
                                  "When has_vectors is true,"
                                  "the eigenvectors needs to be calculated, "
                                  "so the eigenvectors must be provided."));
      input_trans = dito.Transpose(input_trans);
      eigen_vectors->ShareDataWith(input_trans);
139 140 141 142 143 144 145 146 147
    }
  }
};

#ifdef PADDLE_WITH_CUDA

// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices on GPU, and uses the variable has_vectors
// to control whether to return the eigenvectors.
148 149
template <typename T>
struct MatrixEighFunctor<platform::CUDADeviceContext, T> {
150 151 152 153
 public:
  void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
                  Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
                  bool has_vectors) {
154
    using ValueType = math::Real<T>;
155 156
    auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());

157 158 159 160 161 162 163
    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto dito =
        math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
                                                 T>(ctx);
    Tensor input_trans;
    input_trans = dito.Transpose(input);
    auto *input_vector = input_trans.data<T>();
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
    auto &dims = input.dims();
    int dim_size = dims.size();
    int64_t batch_size = GetBatchSize(dims);

    cublasFillMode_t uplo =
        is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
    cusolverEigMode_t jobz =
        has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;

    int n = dims[dim_size - 1];
    int lda = std::max<int>(1, n);
    auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
    auto values_stride = dims[dim_size - 1];
    int lwork = 0;
    auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_size);
    auto *info_ptr = reinterpret_cast<int *>(info->ptr());

    // When the input type is float32, and the feature value input dimension is
    // greater than or equal to [*,32,32]  and less than or equal to
    // [*,512,512], Syevj has better performance.
184 185
    bool use_syevj = (framework::TransToProtoVarType(input.dtype()) ==
                          framework::proto::VarType::FP32 &&
186
                      values_stride >= 32 && values_stride <= 512);
187 188
    syevjInfo_t syevj_params;
    if (use_syevj) {
189
      PADDLE_ENFORCE_GPU_SUCCESS(
190
          platform::dynload::cusolverDnCreateSyevjInfo(&syevj_params));
191 192 193 194
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSsyevj_bufferSize(
          dev_ctx.cusolver_dn_handle(), jobz, uplo, n,
          reinterpret_cast<const float *>(input_vector), lda,
          reinterpret_cast<const float *>(out_value), &lwork, syevj_params));
195
    } else {
196
      EvdBuffer(dev_ctx.cusolver_dn_handle(), jobz, uplo, n, input_vector, lda,
197 198 199 200 201
                out_value, &lwork);
    }
    auto work = memory::Alloc(dev_ctx, sizeof(T) * lwork);
    auto *work_ptr = reinterpret_cast<T *>(work->ptr());
    for (auto i = 0; i < batch_size; i++) {
202 203
      auto *input_data = input_vector + i * vector_stride;
      auto *value_data = out_value + i * values_stride;
204 205
      auto handle = dev_ctx.cusolver_dn_handle();
      if (use_syevj) {
206
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSsyevj(
207
            handle, jobz, uplo, n, reinterpret_cast<float *>(input_data), lda,
208 209 210 211
            reinterpret_cast<float *>(value_data),
            reinterpret_cast<float *>(work_ptr), lwork, info_ptr,
            syevj_params));
      } else {
212 213
        Evd(handle, jobz, uplo, n, input_data, lda, value_data, work_ptr, lwork,
            info_ptr);
214
      }
215
      int error_info = 0;
216
      memory::Copy(platform::CPUPlace(), &error_info, dev_ctx.GetPlace(),
217
                   info_ptr, sizeof(int), dev_ctx.stream());
218
      CheckEighResult(i, error_info);
219 220 221
    }

    if (use_syevj) {
222
      PADDLE_ENFORCE_GPU_SUCCESS(
223 224 225
          platform::dynload::cusolverDnDestroySyevjInfo(syevj_params));
    }
    if (has_vectors) {
226 227 228 229 230 231 232
      PADDLE_ENFORCE_NOT_NULL(eigen_vectors,
                              platform::errors::InvalidArgument(
                                  "When has_vectors is true,"
                                  "the eigenvectors needs to be calculated,"
                                  "so the eigenvectors must be provided."));
      input_trans = dito.Transpose(input_trans);
      eigen_vectors->ShareDataWith(input_trans);
233 234 235
    }
  }

236
  using ValueType = math::Real<T>;
237 238 239 240 241 242 243 244 245
  inline void EvdBuffer(cusolverDnHandle_t handle, cusolverEigMode_t jobz,
                        cublasFillMode_t uplo, int n, const T *A, int lda,
                        const ValueType *W, int *lwork) const;

  inline void Evd(cusolverDnHandle_t handle, cusolverEigMode_t jobz,
                  cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W,
                  T *work, int lwork, int *devInfo) const;
};

246 247 248 249
#define FUNC_WITH_TYPES(m)                                \
  m(float, Ssy, float) m(double, Dsy, double)             \
      m(paddle::platform::complex<float>, Che, cuComplex) \
          m(paddle::platform::complex<double>, Zhe, cuDoubleComplex)
250

251
#define EVDBUFFER_INSTANCE(T, C, CastType)                                     \
252
  template <>                                                                  \
253
  inline void MatrixEighFunctor<platform::CUDADeviceContext, T>::EvdBuffer(    \
254 255 256
      cusolverDnHandle_t handle, cusolverEigMode_t jobz,                       \
      cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W,   \
      int *lwork) const {                                                      \
257
    PADDLE_ENFORCE_GPU_SUCCESS(                                                \
258 259 260 261 262 263 264
        platform::dynload::cusolverDn##C##evd_bufferSize(                      \
            handle, jobz, uplo, n, reinterpret_cast<const CastType *>(A), lda, \
            W, lwork));                                                        \
  }

FUNC_WITH_TYPES(EVDBUFFER_INSTANCE);

265
#define EVD_INSTANCE(T, C, CastType)                                      \
266
  template <>                                                             \
267
  inline void MatrixEighFunctor<platform::CUDADeviceContext, T>::Evd(     \
268 269 270
      cusolverDnHandle_t handle, cusolverEigMode_t jobz,                  \
      cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \
      int lwork, int *devInfo) const {                                    \
271
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDn##C##evd(     \
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
        handle, jobz, uplo, n, reinterpret_cast<CastType *>(A), lda, W,   \
        reinterpret_cast<CastType *>(work), lwork, devInfo));             \
  }

FUNC_WITH_TYPES(EVD_INSTANCE);

#undef FUNC_WITH_TYPES
#undef EVDBUFFER_INSTANCE
#undef EVD_INSTANCE

#endif  // PADDLE_WITH_CUDA

}  // namespace math
}  // namespace operators
}  // namespace paddle