eigvals_op.h 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <string>
18 19 20 21
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
22 23
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/for_range.h"
24
#include "paddle/pten/kernels/funcs/complex_functors.h"
25
#include "paddle/pten/kernels/funcs/lapack/lapack_function.h"
26 27 28 29 30 31

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;

32 33
template <typename T, typename enable = void>
struct PaddleComplex;
34 35

template <typename T>
36 37 38
struct PaddleComplex<
    T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
  using type = paddle::platform::complex<T>;
39
};
40 41 42 43 44 45
template <typename T>
struct PaddleComplex<
    T, typename std::enable_if<
           std::is_same<T, platform::complex<float>>::value ||
           std::is_same<T, platform::complex<double>>::value>::type> {
  using type = T;
46 47 48
};

template <typename T>
49
using PaddleCType = typename PaddleComplex<T>::type;
50
template <typename T>
51
using Real = typename pten::funcs::Real<T>;
52

53 54
static void SpiltBatchSquareMatrix(const Tensor& input,
                                   std::vector<Tensor>* output) {
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
  DDim input_dims = input.dims();
  int last_dim = input_dims.size() - 1;
  int n_dim = input_dims[last_dim];

  DDim flattened_input_dims, flattened_output_dims;
  if (input_dims.size() > 2) {
    flattened_input_dims = flatten_to_3d(input_dims, last_dim - 1, last_dim);
  } else {
    flattened_input_dims = framework::make_ddim({1, n_dim, n_dim});
  }

  Tensor flattened_input;
  flattened_input.ShareDataWith(input);
  flattened_input.Resize(flattened_input_dims);
  (*output) = flattened_input.Split(1, 0);
}

72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
static void CheckLapackEigResult(const int info, const std::string& name) {
  PADDLE_ENFORCE_LE(info, 0, platform::errors::PreconditionNotMet(
                                 "The QR algorithm failed to compute all the "
                                 "eigenvalues in function %s.",
                                 name.c_str()));
  PADDLE_ENFORCE_GE(
      info, 0, platform::errors::InvalidArgument(
                   "The %d-th argument has an illegal value in function %s.",
                   -info, name.c_str()));
}

template <typename DeviceContext, typename T>
static typename std::enable_if<std::is_floating_point<T>::value>::type
LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
              Tensor* output, Tensor* work, Tensor* rwork /*unused*/) {
  Tensor a;  // will be overwritten when lapackEig exit
  framework::TensorCopy(input, input.place(), &a);

  Tensor w;
  int64_t n_dim = input.dims()[1];
  auto* w_data =
      w.mutable_data<T>(framework::make_ddim({n_dim << 1}), ctx.GetPlace());

  int64_t work_mem = work->memory_size();
  int64_t required_work_mem = 3 * n_dim * sizeof(T);
  PADDLE_ENFORCE_GE(
      work_mem, 3 * n_dim * sizeof(T),
      platform::errors::InvalidArgument(
          "The memory size of the work tensor in LapackEigvals function "
          "should be at least %" PRId64 " bytes, "
          "but received work\'s memory size = %" PRId64 " bytes.",
          required_work_mem, work_mem));

  int info = 0;
106 107 108 109 110
  pten::funcs::lapackEig<T>('N', 'N', static_cast<int>(n_dim),
                            a.template data<T>(), static_cast<int>(n_dim),
                            w_data, NULL, 1, NULL, 1, work->template data<T>(),
                            static_cast<int>(work_mem / sizeof(T)),
                            static_cast<T*>(NULL), &info);
111 112

  std::string name = "framework::platform::dynload::dgeev_";
113 114
  if (framework::TransToProtoVarType(input.dtype()) ==
      framework::proto::VarType::FP64) {
115 116 117 118 119 120
    name = "framework::platform::dynload::sgeev_";
  }
  CheckLapackEigResult(info, name);

  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), n_dim);
121
  pten::funcs::RealImagToComplexFunctor<PaddleCType<T>> functor(
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
      w_data, w_data + n_dim, output->template data<PaddleCType<T>>(), n_dim);
  for_range(functor);
}

template <typename DeviceContext, typename T>
typename std::enable_if<std::is_same<T, platform::complex<float>>::value ||
                        std::is_same<T, platform::complex<double>>::value>::type
LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
              Tensor* output, Tensor* work, Tensor* rwork) {
  Tensor a;  // will be overwritten when lapackEig exit
  framework::TensorCopy(input, input.place(), &a);

  int64_t work_mem = work->memory_size();
  int64_t n_dim = input.dims()[1];
  int64_t required_work_mem = 3 * n_dim * sizeof(T);
  PADDLE_ENFORCE_GE(
      work_mem, 3 * n_dim * sizeof(T),
      platform::errors::InvalidArgument(
          "The memory size of the work tensor in LapackEigvals function "
          "should be at least %" PRId64 " bytes, "
          "but received work\'s memory size = %" PRId64 " bytes.",
          required_work_mem, work_mem));

  int64_t rwork_mem = rwork->memory_size();
146
  int64_t required_rwork_mem = (n_dim << 1) * sizeof(pten::funcs::Real<T>);
147 148 149 150 151 152 153 154 155
  PADDLE_ENFORCE_GE(
      rwork_mem, required_rwork_mem,
      platform::errors::InvalidArgument(
          "The memory size of the rwork tensor in LapackEigvals function "
          "should be at least %" PRId64 " bytes, "
          "but received rwork\'s memory size = %" PRId64 " bytes.",
          required_rwork_mem, rwork_mem));

  int info = 0;
156
  pten::funcs::lapackEig<T, pten::funcs::Real<T>>(
157 158 159
      'N', 'N', static_cast<int>(n_dim), a.template data<T>(),
      static_cast<int>(n_dim), output->template data<T>(), NULL, 1, NULL, 1,
      work->template data<T>(), static_cast<int>(work_mem / sizeof(T)),
160
      rwork->template data<pten::funcs::Real<T>>(), &info);
161 162

  std::string name = "framework::platform::dynload::cgeev_";
163 164
  if (framework::TransToProtoVarType(input.dtype()) ==
      framework::proto::VarType::COMPLEX64) {
165 166 167 168 169
    name = "framework::platform::dynload::zgeev_";
  }
  CheckLapackEigResult(info, name);
}

170 171 172
template <typename DeviceContext, typename T>
class EigvalsKernel : public framework::OpKernel<T> {
 public:
173 174 175 176
  void Compute(const framework::ExecutionContext& ctx) const override {
    const Tensor* input = ctx.Input<Tensor>("X");
    Tensor* output = ctx.Output<Tensor>("Out");
    output->mutable_data<PaddleCType<T>>(ctx.GetPlace());
177 178 179 180

    std::vector<Tensor> input_matrices;
    SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices);

181 182
    int64_t n_dim = input_matrices[0].dims()[1];
    int64_t n_batch = input_matrices.size();
183 184 185 186
    DDim output_dims = output->dims();
    output->Resize(framework::make_ddim({n_batch, n_dim}));
    std::vector<Tensor> output_vectors = output->Split(1, 0);

187 188 189
    // query workspace size
    T qwork;
    int info;
190
    pten::funcs::lapackEig<T, pten::funcs::Real<T>>(
191 192
        'N', 'N', static_cast<int>(n_dim), input_matrices[0].template data<T>(),
        static_cast<int>(n_dim), NULL, NULL, 1, NULL, 1, &qwork, -1,
193
        static_cast<Real<T>*>(NULL), &info);
194 195 196 197 198 199 200 201 202 203 204 205 206 207
    int64_t lwork = static_cast<int64_t>(qwork);

    Tensor work, rwork;
    try {
      work.mutable_data<T>(framework::make_ddim({lwork}), ctx.GetPlace());
    } catch (memory::allocation::BadAlloc&) {
      LOG(WARNING) << "Failed to allocate Lapack workspace with the optimal "
                   << "memory size = " << lwork * sizeof(T) << " bytes, "
                   << "try reallocating a smaller workspace with the minimum "
                   << "required size = " << 3 * n_dim * sizeof(T) << " bytes, "
                   << "this may lead to bad performance.";
      lwork = 3 * n_dim;
      work.mutable_data<T>(framework::make_ddim({lwork}), ctx.GetPlace());
    }
208 209
    if (framework::IsComplexType(
            framework::TransToProtoVarType(input->dtype()))) {
210 211
      rwork.mutable_data<pten::funcs::Real<T>>(
          framework::make_ddim({n_dim << 1}), ctx.GetPlace());
212 213 214 215 216
    }

    for (int64_t i = 0; i < n_batch; ++i) {
      LapackEigvals<DeviceContext, T>(ctx, input_matrices[i],
                                      &output_vectors[i], &work, &rwork);
217 218 219 220 221 222
    }
    output->Resize(output_dims);
  }
};
}  // namespace operators
}  // namespace paddle