lapack_function.cc 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
16

17
#include "paddle/phi/backends/dynload/lapack.h"
18
#include "paddle/phi/common/complex.h"
19

20
namespace phi {
21 22 23 24 25
namespace funcs {

// LU (for example)
template <>
void lapackLu<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
26
  dynload::dgetrf_(&m, &n, a, &lda, ipiv, info);
27 28 29 30
}

template <>
void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
31
  dynload::sgetrf_(&m, &n, a, &lda, ipiv, info);
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
}

// eigh
template <>
void lapackEigh<float>(char jobz,
                       char uplo,
                       int n,
                       float *a,
                       int lda,
                       float *w,
                       float *work,
                       int lwork,
                       float *rwork,
                       int lrwork,
                       int *iwork,
                       int liwork,
                       int *info) {
  (void)rwork;   // unused
  (void)lrwork;  // unused
51
  dynload::ssyevd_(
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
      &jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}

template <>
void lapackEigh<double>(char jobz,
                        char uplo,
                        int n,
                        double *a,
                        int lda,
                        double *w,
                        double *work,
                        int lwork,
                        double *rwork,
                        int lrwork,
                        int *iwork,
                        int liwork,
                        int *info) {
  (void)rwork;   // unused
  (void)lrwork;  // unused
71
  dynload::dsyevd_(
72 73 74 75
      &jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}

template <>
76
void lapackEigh<phi::dtype::complex<float>, float>(
77 78 79
    char jobz,
    char uplo,
    int n,
80
    phi::dtype::complex<float> *a,
81 82
    int lda,
    float *w,
83
    phi::dtype::complex<float> *work,
84 85 86 87 88 89
    int lwork,
    float *rwork,
    int lrwork,
    int *iwork,
    int liwork,
    int *info) {
90 91 92 93 94 95 96 97 98 99 100 101 102
  dynload::cheevd_(&jobz,
                   &uplo,
                   &n,
                   reinterpret_cast<std::complex<float> *>(a),
                   &lda,
                   w,
                   reinterpret_cast<std::complex<float> *>(work),
                   &lwork,
                   rwork,
                   &lrwork,
                   iwork,
                   &liwork,
                   info);
103 104 105
}

template <>
106
void lapackEigh<phi::dtype::complex<double>, double>(
107 108 109
    char jobz,
    char uplo,
    int n,
110
    phi::dtype::complex<double> *a,
111 112
    int lda,
    double *w,
113
    phi::dtype::complex<double> *work,
114 115 116 117 118 119
    int lwork,
    double *rwork,
    int lrwork,
    int *iwork,
    int liwork,
    int *info) {
120 121 122 123 124 125 126 127 128 129 130 131 132
  dynload::zheevd_(&jobz,
                   &uplo,
                   &n,
                   reinterpret_cast<std::complex<double> *>(a),
                   &lda,
                   w,
                   reinterpret_cast<std::complex<double> *>(work),
                   &lwork,
                   rwork,
                   &lrwork,
                   iwork,
                   &liwork,
                   info);
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
}

// Eig
template <>
void lapackEig<double>(char jobvl,
                       char jobvr,
                       int n,
                       double *a,
                       int lda,
                       double *w,
                       double *vl,
                       int ldvl,
                       double *vr,
                       int ldvr,
                       double *work,
                       int lwork,
                       double *rwork,
                       int *info) {
  double *wr = w;
  double *wi = w + n;
  (void)rwork;  // unused
154 155 156 157 158 159 160 161 162 163 164 165 166 167
  dynload::dgeev_(&jobvl,
                  &jobvr,
                  &n,
                  a,
                  &lda,
                  wr,
                  wi,
                  vl,
                  &ldvl,
                  vr,
                  &ldvr,
                  work,
                  &lwork,
                  info);
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
}

template <>
void lapackEig<float>(char jobvl,
                      char jobvr,
                      int n,
                      float *a,
                      int lda,
                      float *w,
                      float *vl,
                      int ldvl,
                      float *vr,
                      int ldvr,
                      float *work,
                      int lwork,
                      float *rwork,
                      int *info) {
  float *wr = w;
  float *wi = w + n;
  (void)rwork;  // unused
188 189 190 191 192 193 194 195 196 197 198 199 200 201
  dynload::sgeev_(&jobvl,
                  &jobvr,
                  &n,
                  a,
                  &lda,
                  wr,
                  wi,
                  vl,
                  &ldvl,
                  vr,
                  &ldvr,
                  work,
                  &lwork,
                  info);
202 203 204
}

template <>
205
void lapackEig<phi::dtype::complex<double>, double>(
206 207 208
    char jobvl,
    char jobvr,
    int n,
209
    phi::dtype::complex<double> *a,
210
    int lda,
211 212
    phi::dtype::complex<double> *w,
    phi::dtype::complex<double> *vl,
213
    int ldvl,
214
    phi::dtype::complex<double> *vr,
215
    int ldvr,
216
    phi::dtype::complex<double> *work,
217 218 219
    int lwork,
    double *rwork,
    int *info) {
220 221 222 223 224 225 226 227 228 229 230 231 232 233
  dynload::zgeev_(&jobvl,
                  &jobvr,
                  &n,
                  reinterpret_cast<std::complex<double> *>(a),
                  &lda,
                  reinterpret_cast<std::complex<double> *>(w),
                  reinterpret_cast<std::complex<double> *>(vl),
                  &ldvl,
                  reinterpret_cast<std::complex<double> *>(vr),
                  &ldvr,
                  reinterpret_cast<std::complex<double> *>(work),
                  &lwork,
                  rwork,
                  info);
234 235 236
}

template <>
237
void lapackEig<phi::dtype::complex<float>, float>(
238 239 240
    char jobvl,
    char jobvr,
    int n,
241
    phi::dtype::complex<float> *a,
242
    int lda,
243 244
    phi::dtype::complex<float> *w,
    phi::dtype::complex<float> *vl,
245
    int ldvl,
246
    phi::dtype::complex<float> *vr,
247
    int ldvr,
248
    phi::dtype::complex<float> *work,
249 250 251
    int lwork,
    float *rwork,
    int *info) {
252 253 254 255 256 257 258 259 260 261 262 263 264 265
  dynload::cgeev_(&jobvl,
                  &jobvr,
                  &n,
                  reinterpret_cast<std::complex<float> *>(a),
                  &lda,
                  reinterpret_cast<std::complex<float> *>(w),
                  reinterpret_cast<std::complex<float> *>(vl),
                  &ldvl,
                  reinterpret_cast<std::complex<float> *>(vr),
                  &ldvr,
                  reinterpret_cast<std::complex<float> *>(work),
                  &lwork,
                  rwork,
                  info);
266 267 268 269 270 271 272 273 274 275 276 277 278 279
}

template <>
void lapackGels<double>(char trans,
                        int m,
                        int n,
                        int nrhs,
                        double *a,
                        int lda,
                        double *b,
                        int ldb,
                        double *work,
                        int lwork,
                        int *info) {
280
  dynload::dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
281 282 283 284 285 286 287 288 289 290 291 292 293 294
}

template <>
void lapackGels<float>(char trans,
                       int m,
                       int n,
                       int nrhs,
                       float *a,
                       int lda,
                       float *b,
                       int ldb,
                       float *work,
                       int lwork,
                       int *info) {
295
  dynload::sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
}

template <>
void lapackGelsd<double>(int m,
                         int n,
                         int nrhs,
                         double *a,
                         int lda,
                         double *b,
                         int ldb,
                         double *s,
                         double rcond,
                         int *rank,
                         double *work,
                         int lwork,
                         double *rwork,
                         int *iwork,
                         int *info) {
314 315 316 317 318 319 320 321 322 323 324 325 326 327
  dynload::dgelsd_(&m,
                   &n,
                   &nrhs,
                   a,
                   &lda,
                   b,
                   &ldb,
                   s,
                   &rcond,
                   rank,
                   work,
                   &lwork,
                   iwork,
                   info);
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
}

template <>
void lapackGelsd<float>(int m,
                        int n,
                        int nrhs,
                        float *a,
                        int lda,
                        float *b,
                        int ldb,
                        float *s,
                        float rcond,
                        int *rank,
                        float *work,
                        int lwork,
                        float *rwork,
                        int *iwork,
                        int *info) {
346 347 348 349 350 351 352 353 354 355 356 357 358 359
  dynload::sgelsd_(&m,
                   &n,
                   &nrhs,
                   a,
                   &lda,
                   b,
                   &ldb,
                   s,
                   &rcond,
                   rank,
                   work,
                   &lwork,
                   iwork,
                   info);
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
}

template <>
void lapackGelsy<double>(int m,
                         int n,
                         int nrhs,
                         double *a,
                         int lda,
                         double *b,
                         int ldb,
                         int *jpvt,
                         double rcond,
                         int *rank,
                         double *work,
                         int lwork,
                         double *rwork,
                         int *info) {
377
  dynload::dgelsy_(
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
      &m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, rank, work, &lwork, info);
}

template <>
void lapackGelsy<float>(int m,
                        int n,
                        int nrhs,
                        float *a,
                        int lda,
                        float *b,
                        int ldb,
                        int *jpvt,
                        float rcond,
                        int *rank,
                        float *work,
                        int lwork,
                        float *rwork,
                        int *info) {
396
  dynload::sgelsy_(
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
      &m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, rank, work, &lwork, info);
}

template <>
void lapackGelss<double>(int m,
                         int n,
                         int nrhs,
                         double *a,
                         int lda,
                         double *b,
                         int ldb,
                         double *s,
                         double rcond,
                         int *rank,
                         double *work,
                         int lwork,
                         double *rwork,
                         int *info) {
415
  dynload::dgelss_(
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
      &m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, work, &lwork, info);
}

template <>
void lapackGelss<float>(int m,
                        int n,
                        int nrhs,
                        float *a,
                        int lda,
                        float *b,
                        int ldb,
                        float *s,
                        float rcond,
                        int *rank,
                        float *work,
                        int lwork,
                        float *rwork,
                        int *info) {
434
  dynload::sgelss_(
435 436 437 438
      &m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, work, &lwork, info);
}

template <>
439
void lapackCholeskySolve<phi::dtype::complex<double>>(
440 441 442
    char uplo,
    int n,
    int nrhs,
443
    phi::dtype::complex<double> *a,
444
    int lda,
445
    phi::dtype::complex<double> *b,
446 447
    int ldb,
    int *info) {
448 449 450 451 452 453 454 455
  dynload::zpotrs_(&uplo,
                   &n,
                   &nrhs,
                   reinterpret_cast<std::complex<double> *>(a),
                   &lda,
                   reinterpret_cast<std::complex<double> *>(b),
                   &ldb,
                   info);
456 457 458
}

template <>
459
void lapackCholeskySolve<phi::dtype::complex<float>>(
460 461 462
    char uplo,
    int n,
    int nrhs,
463
    phi::dtype::complex<float> *a,
464
    int lda,
465
    phi::dtype::complex<float> *b,
466 467
    int ldb,
    int *info) {
468 469 470 471 472 473 474 475
  dynload::cpotrs_(&uplo,
                   &n,
                   &nrhs,
                   reinterpret_cast<std::complex<float> *>(a),
                   &lda,
                   reinterpret_cast<std::complex<float> *>(b),
                   &ldb,
                   info);
476 477 478 479 480 481 482 483 484 485 486
}

template <>
void lapackCholeskySolve<double>(char uplo,
                                 int n,
                                 int nrhs,
                                 double *a,
                                 int lda,
                                 double *b,
                                 int ldb,
                                 int *info) {
487
  dynload::dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
488 489 490 491 492 493 494 495 496 497 498
}

template <>
void lapackCholeskySolve<float>(char uplo,
                                int n,
                                int nrhs,
                                float *a,
                                int lda,
                                float *b,
                                int ldb,
                                int *info) {
499
  dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
500 501 502
}

}  // namespace funcs
503
}  // namespace phi