lstsq_op.h 12.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <math.h>
18

19 20
#include <algorithm>
#include <complex>
21

22 23 24 25 26
#include "paddle/fluid/operators/eig_op.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
27 28 29
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
30
#include "paddle/phi/kernels/funcs/matrix_solve.h"
31 32 33 34 35 36 37 38 39 40 41 42

#define EPSILON 1e-6

namespace paddle {
namespace operators {

using paddle::framework::Tensor;
enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss };

using DDim = framework::DDim;
static DDim UDDim(const DDim& x_dim) {
  auto x_vec = vectorize(x_dim);
43
  return phi::make_ddim(x_vec);
44 45 46 47 48 49
}

template <typename DeviceContext, typename T>
class LstsqCPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
50
    using ValueType = phi::dtype::Real<T>;
51 52

    const Tensor& x = *context.Input<Tensor>("X");
53
    auto y = context.Input<Tensor>("Y");
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    auto rcond = context.Attr<float>("rcond");
    auto driver_string = context.Attr<std::string>("driver");

    static auto driver_type = std::unordered_map<std::string, LapackDriverType>(
        {{"gels", LapackDriverType::Gels},
         {"gelsy", LapackDriverType::Gelsy},
         {"gelsd", LapackDriverType::Gelsd},
         {"gelss", LapackDriverType::Gelss}});
    auto driver = driver_type[driver_string];

    auto solution = context.Output<Tensor>("Solution");
    auto* rank = context.Output<Tensor>("Rank");
    auto* singular_values = context.Output<Tensor>("SingularValues");

    auto dito =
        math::DeviceIndependenceTensorOperations<DeviceContext, T>(context);

    auto x_dims = x.dims();
72
    auto y_dims = y->dims();
73 74
    int dim_size = x_dims.size();
    int x_stride = MatrixStride(x);
75
    int y_stride = MatrixStride(*y);
76
    int batch_count = BatchCount(x);
77
    auto solution_dim = solution->dims();
78
    int ori_solu_stride = MatrixStride(*solution);
79 80
    int max_solu_stride = std::max(y_stride, ori_solu_stride);
    int min_solu_stride = std::min(y_stride, ori_solu_stride);
81 82 83 84 85 86 87 88 89 90 91 92 93

    // lapack is a column-major storge, transpose make the input to
    // have a continuous memory layout
    int info = 0;
    int m = x_dims[dim_size - 2];
    int n = x_dims[dim_size - 1];
    int nrhs = y_dims[dim_size - 1];
    int lda = std::max<int>(m, 1);
    int ldb = std::max<int>(1, std::max(m, n));

    Tensor new_x;
    new_x.mutable_data<T>(context.GetPlace(),
                          size_t(batch_count * m * n * sizeof(T)));
94 95
    framework::TensorCopy(x, context.GetPlace(), &new_x);

96 97 98 99
    solution->mutable_data<T>(
        context.GetPlace(),
        size_t(batch_count * std::max(m, n) * nrhs * sizeof(T)));

100 101 102 103 104 105 106 107 108 109 110 111
    if (m >= n) {
      const Tensor& new_y = *context.Input<Tensor>("Y");
      framework::TensorCopy(new_y, context.GetPlace(), solution);
    } else {
      auto* solu_data = solution->data<T>();
      auto* y_data = y->data<T>();
      for (auto i = 0; i < batch_count; i++) {
        for (auto j = 0; j < min_solu_stride; j++) {
          solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j];
        }
      }
    }
112 113 114

    Tensor input_x_trans = dito.Transpose(new_x);
    Tensor input_y_trans = dito.Transpose(*solution);
H
Haohongxiang 已提交
115 116
    framework::TensorCopy(input_x_trans, context.GetPlace(), &new_x);
    framework::TensorCopy(input_y_trans, context.GetPlace(), solution);
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145

    auto* x_vector = new_x.data<T>();
    auto* y_vector = solution->data<T>();

    // "gels" divers does not need to compute rank
    int rank_32 = 0;
    int* rank_data = nullptr;
    int* rank_working_ptr = nullptr;
    if (driver != LapackDriverType::Gels) {
      rank_data = rank->mutable_data<int>(context.GetPlace());
      rank_working_ptr = rank_data;
    }

    // "gelsd" and "gelss" divers need to compute singular values
    ValueType* s_data = nullptr;
    ValueType* s_working_ptr = nullptr;
    int s_stride = 0;
    if (driver == LapackDriverType::Gelsd ||
        driver == LapackDriverType::Gelss) {
      s_data = singular_values->mutable_data<ValueType>(context.GetPlace());
      s_working_ptr = s_data;
      auto s_dims = singular_values->dims();
      s_stride = s_dims[s_dims.size() - 1];
    }

    // "jpvt" is only used for "gelsy" driver
    Tensor jpvt;
    int* jpvt_data = nullptr;
    if (driver == LapackDriverType::Gelsy) {
146
      jpvt.Resize(phi::make_ddim({std::max<int>(1, n)}));
147 148 149 150 151 152 153 154 155 156
      jpvt_data = jpvt.mutable_data<int>(context.GetPlace());
    }

    // run once the driver, first to get the optimal workspace size
    int lwork = -1;
    T wkopt;
    ValueType rwkopt;
    int iwkopt = 0;

    if (driver == LapackDriverType::Gels) {
157 158
      phi::funcs::lapackGels(
          'N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt, lwork, &info);
159
    } else if (driver == LapackDriverType::Gelsd) {
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
      phi::funcs::lapackGelsd(m,
                              n,
                              nrhs,
                              x_vector,
                              lda,
                              y_vector,
                              ldb,
                              s_working_ptr,
                              static_cast<ValueType>(rcond),
                              &rank_32,
                              &wkopt,
                              lwork,
                              &rwkopt,
                              &iwkopt,
                              &info);
175
    } else if (driver == LapackDriverType::Gelsy) {
176 177 178 179 180 181 182 183 184 185 186 187 188 189
      phi::funcs::lapackGelsy(m,
                              n,
                              nrhs,
                              x_vector,
                              lda,
                              y_vector,
                              ldb,
                              jpvt_data,
                              static_cast<ValueType>(rcond),
                              &rank_32,
                              &wkopt,
                              lwork,
                              &rwkopt,
                              &info);
190
    } else if (driver == LapackDriverType::Gelss) {
191 192 193 194 195 196 197 198 199 200 201 202 203 204
      phi::funcs::lapackGelss(m,
                              n,
                              nrhs,
                              x_vector,
                              lda,
                              y_vector,
                              ldb,
                              s_working_ptr,
                              static_cast<ValueType>(rcond),
                              &rank_32,
                              &wkopt,
                              lwork,
                              &rwkopt,
                              &info);
205 206
    }

207
    lwork = std::max<int>(1, static_cast<int>(phi::dtype::Real<T>(wkopt)));
208
    Tensor work;
209
    work.Resize(phi::make_ddim({lwork}));
210 211 212 213 214
    T* work_data = work.mutable_data<T>(context.GetPlace());

    // "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers
    Tensor rwork;
    ValueType* rwork_data = nullptr;
215
    if (framework::IsComplexType(framework::TransToProtoVarType(x.dtype())) &&
216 217 218 219 220 221 222 223 224
        driver != LapackDriverType::Gels) {
      int rwork_len = 0;
      if (driver == LapackDriverType::Gelsy) {
        rwork_len = std::max<int>(1, 2 * n);
      } else if (driver == LapackDriverType::Gelss) {
        rwork_len = std::max<int>(1, 5 * std::min(m, n));
      } else if (driver == LapackDriverType::Gelsd) {
        rwork_len = std::max<int>(1, rwkopt);
      }
225
      rwork.Resize(phi::make_ddim({rwork_len}));
226 227 228 229 230 231 232
      rwork_data = rwork.mutable_data<ValueType>(context.GetPlace());
    }

    // "iwork" workspace array is relavant only for "gelsd" driver
    Tensor iwork;
    int* iwork_data = nullptr;
    if (driver == LapackDriverType::Gelsd) {
233
      iwork.Resize(phi::make_ddim({std::max<int>(1, iwkopt)}));
234 235 236 237 238
      iwork_data = iwork.mutable_data<int>(context.GetPlace());
    }

    for (auto i = 0; i < batch_count; ++i) {
      auto* x_input = &x_vector[i * x_stride];
239
      auto* y_input = &y_vector[i * max_solu_stride];
240 241 242 243
      rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr;
      s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr;

      if (driver == LapackDriverType::Gels) {
244 245 246 247 248 249 250 251 252 253 254
        phi::funcs::lapackGels('N',
                               m,
                               n,
                               nrhs,
                               x_input,
                               lda,
                               y_input,
                               ldb,
                               work_data,
                               lwork,
                               &info);
255
      } else if (driver == LapackDriverType::Gelsd) {
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
        phi::funcs::lapackGelsd(m,
                                n,
                                nrhs,
                                x_input,
                                lda,
                                y_input,
                                ldb,
                                s_working_ptr,
                                static_cast<ValueType>(rcond),
                                &rank_32,
                                work_data,
                                lwork,
                                rwork_data,
                                iwork_data,
                                &info);
271
      } else if (driver == LapackDriverType::Gelsy) {
272 273 274 275 276 277 278 279 280 281 282 283 284 285
        phi::funcs::lapackGelsy(m,
                                n,
                                nrhs,
                                x_input,
                                lda,
                                y_input,
                                ldb,
                                jpvt_data,
                                static_cast<ValueType>(rcond),
                                &rank_32,
                                work_data,
                                lwork,
                                rwork_data,
                                &info);
286
      } else if (driver == LapackDriverType::Gelss) {
287 288 289 290 291 292 293 294 295 296 297 298 299 300
        phi::funcs::lapackGelss(m,
                                n,
                                nrhs,
                                x_input,
                                lda,
                                y_input,
                                ldb,
                                s_working_ptr,
                                static_cast<ValueType>(rcond),
                                &rank_32,
                                work_data,
                                lwork,
                                rwork_data,
                                &info);
301 302 303
      }

      PADDLE_ENFORCE_EQ(
304 305
          info,
          0,
306 307 308 309 310 311 312
          platform::errors::PreconditionNotMet(
              "For batch [%d]: Lapack info is not zero but [%d]", i, info));

      if (rank_working_ptr) *rank_working_ptr = static_cast<int>(rank_32);
    }

    Tensor tmp_s = dito.Transpose(*solution);
H
Haohongxiang 已提交
313
    framework::TensorCopy(tmp_s, context.GetPlace(), solution);
314

315 316 317 318 319 320 321 322 323 324 325
    if (m > n) {
      auto* solu_data = solution->data<T>();
      for (auto i = 1; i < batch_count; i++) {
        for (auto j = 0; j < min_solu_stride; j++) {
          solu_data[i * min_solu_stride + j] =
              solu_data[i * max_solu_stride + j];
        }
      }
    }

    solution->Resize(UDDim(solution_dim));
326 327 328
  }
};

329
template <typename DeviceContext, typename T>
330 331 332 333 334 335 336 337 338 339 340 341 342
void BatchedOrmqr(const DeviceContext& dev_ctx,
                  bool left,
                  bool transpose,
                  int batch_size,
                  int m,
                  int n,
                  int k,
                  T* a,
                  int a_stride,
                  T* tau,
                  int tau_stride,
                  T* other,
                  int other_stride);
343

344 345
}  // namespace operators
}  // namespace paddle