blas.h 15.9 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/phi/core/dense_tensor.h"
W
wanghuancoder 已提交
18

Y
Yu Yang 已提交
19
#ifdef PADDLE_WITH_MKLML
20
#include "paddle/phi/backends/dynload/mklml.h"
Y
Yu Yang 已提交
21 22
#endif

T
tensor-tang 已提交
23 24 25 26
#ifdef PADDLE_WITH_LIBXSMM
#include <libxsmm.h>
#endif

W
Wilber 已提交
27
#if defined(PADDLE_USE_OPENBLAS) || defined(PADDLE_USE_REFERENCE_CBLAS)
Y
Yu Yang 已提交
28 29 30
#include <cblas.h>
#endif

31
namespace phi {
32
namespace funcs {
Y
Yu Yang 已提交
33

Y
Yu Yang 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
/**
 * Matrix Descriptor of a memory buffer.
 *
 * It is used for Blas::MatMul. MatMul operator can be batched.
 * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
 * `batch_size` times of GEMM. The batched GEMM could be faster base on the
 * implementation of the blas library. The batch size could be zero. If any
 * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
 * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
 * [BatchSize, H1, W2]
 *
 * The boolean flag, `trans`, describe the memory is the transpose of matrix or
 * not. If the trans is true, the last two dims of matrix are transposed. The
 * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
 *
 * The MatDescriptor is not only the dimension or shape of a matrix, it also
 * contains the layout, stride of matrix. It is clearer to have a structure than
 * reuse `DDim`.
 */
Y
Yu Yang 已提交
53
struct MatDescriptor {
Y
Yu Yang 已提交
54 55 56 57 58 59 60
  int64_t height_;
  int64_t width_;
  int64_t stride_{0};
  int64_t batch_size_{0};
  bool trans_;
};

Y
Yu Yang 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74
/**
 * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
 * flag
 *
 * @param tensor_dim: The dimension of the tensor. The rank of this dimension
 * must larger than 1.
 *
 * @param num_flatten_cols:  Reshape a tensor to a matrix. The matrix's first
 * dimension(column length) will be the product of tensor's first `num_col_dims`
 * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
 * batch_size of descriptor.
 *
 * @param trans: True if the matrix is transposed.
 */
75
extern MatDescriptor CreateMatrixDescriptor(const DDim& tensor_dim,
76 77
                                            int num_flatten_cols,
                                            bool trans);
Y
Yu Yang 已提交
78

Y
Yu Yang 已提交
79 80 81 82 83 84
template <typename DeviceContext>
class Blas {
 public:
  explicit Blas(const DeviceContext& context) : context_(context) {}

  template <typename T>
85 86 87 88 89 90 91 92 93 94
  void GEMM(CBLAS_TRANSPOSE transA,
            CBLAS_TRANSPOSE transB,
            int M,
            int N,
            int K,
            T alpha,
            const T* A,
            const T* B,
            T beta,
            T* C) const;
Y
Yu Yang 已提交
95 96

  template <typename T>
97 98 99 100 101 102 103 104 105 106 107 108 109
  void GEMM(bool transA,
            bool transB,
            int M,
            int N,
            int K,
            T alpha,
            const T* A,
            int lda,
            const T* B,
            int ldb,
            T beta,
            T* C,
            int ldc) const;
Y
Yu Yang 已提交
110

T
tensor-tang 已提交
111
  template <typename T>
112 113 114 115 116 117 118 119 120 121 122 123
  void GEMM(CBLAS_TRANSPOSE transA,
            CBLAS_TRANSPOSE transB,
            int M,
            int N,
            int K,
            T alpha,
            const T* A,
            int lda,
            const T* B,
            int ldb,
            T beta,
            T* C,
T
tensor-tang 已提交
124 125
            int ldc) const;

126
#ifdef PADDLE_WITH_MKLML  // @{ Group MKLML: class Blas
T
tensor-tang 已提交
127
  template <typename T>
128 129 130
  T* GEMM_ALLOC(const CBLAS_IDENTIFIER id,
                const int M,
                const int N,
T
tensor-tang 已提交
131 132 133
                const int K) const;

  template <typename T>
134 135 136 137 138 139 140 141
  void GEMM_PACK(const CBLAS_IDENTIFIER id,
                 const CBLAS_TRANSPOSE trans,
                 int M,
                 int N,
                 int K,
                 const T alpha,
                 const T* src,
                 const int ld,
T
tensor-tang 已提交
142 143 144
                 T* dst) const;

  template <typename T>
145 146 147 148 149 150 151 152 153 154 155
  void GEMM_COMPUTE(int transA,
                    int transB,
                    int M,
                    int N,
                    int K,
                    const T* A,
                    const int lda,
                    const T* B,
                    const int ldb,
                    T beta,
                    T* C,
T
tensor-tang 已提交
156 157 158 159
                    const int ldc) const;

  template <typename T>
  void GEMM_FREE(T* data) const;
160

161
  template <typename T>
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
  void CSRMM(const char* transa,
             const int* m,
             const int* n,
             const int* k,
             const T* alpha,
             const char* matdescra,
             const T* val,
             const int* indx,
             const int* pntrb,
             const int* pntre,
             const T* b,
             const int* ldb,
             const T* beta,
             T* c,
             const int* ldc) const;
177

178
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
179
  template <typename T>
180
  void MatMulWithHead(const phi::DenseTensor& mat_a,
181
                      const MatDescriptor& dim_a,
182
                      const phi::DenseTensor& mat_b,
183 184 185
                      const MatDescriptor& dim_b,
                      T alpha,
                      int head_number,
186
                      phi::DenseTensor* mat_out,
187
                      T beta,
188
                      bool mat_y_split_vertical) const;
189
#endif
190
#endif  // @} End Group MKLML: class Blas
T
tensor-tang 已提交
191

T
tensor-tang 已提交
192
  template <typename T>
193 194 195 196 197
  void MatMul(const int M,
              const int N,
              const int K,
              const T* A,
              const T* B,
T
tensor-tang 已提交
198 199
              T* C) const;

Y
Yu Yang 已提交
200
  template <typename T>
201
  void MatMul(const phi::DenseTensor& mat_a,
202
              bool trans_a,
203
              const phi::DenseTensor& mat_b,
204 205
              bool trans_b,
              T alpha,
206
              phi::DenseTensor* mat_out,
207 208 209
              T beta) const;

  template <typename T>
210
  void MatMul(const phi::DenseTensor& mat_a,
211
              bool trans_a,
212
              const phi::DenseTensor& mat_b,
213
              bool trans_b,
214
              phi::DenseTensor* mat_out) const {
215 216 217 218 219 220
    MatMul(mat_a,
           trans_a,
           mat_b,
           trans_b,
           static_cast<T>(1.0),
           mat_out,
Y
Yu Yang 已提交
221 222 223 224
           static_cast<T>(0.0));
  }

  template <typename T>
225 226 227
  void MatMul(const phi::DenseTensor& mat_a,
              const phi::DenseTensor& mat_b,
              phi::DenseTensor* mat_out) const {
Y
Yu Yang 已提交
228 229 230 231 232 233
    this->template MatMul<T>(mat_a, false, mat_b, false, mat_out);
  }

  template <typename T>
  void AXPY(int n, T alpha, const T* x, T* y) const;

234 235 236
  template <typename T>
  void VADD(int n, const T* x, const T* y, T* z) const;

237 238 239
  template <typename T>
  void VSUB(int n, const T* x, const T* y, T* z) const;

T
tensor-tang 已提交
240 241 242
  template <typename T>
  void VMUL(int n, const T* x, const T* y, T* z) const;

243 244 245
  template <typename T>
  void VDIV(int n, const T* x, const T* y, T* z) const;

246 247 248
  template <typename T>
  void VCOPY(int n, const T* x, T* y) const;

Y
Yu Yang 已提交
249
  template <typename T>
T
tensor-tang 已提交
250 251
  void VEXP(int n, const T* x, T* y) const;

T
tensor-tang 已提交
252
  template <typename T>
T
tensor-tang 已提交
253
  void VSQUARE(int n, const T* x, T* y) const;
T
tensor-tang 已提交
254 255 256 257

  template <typename T>
  void VPOW(int n, const T* x, T alpha, T* y) const;

T
tensor-tang 已提交
258
  template <typename T>
259 260 261 262 263 264 265
  void GEMV(bool trans_a,
            int M,
            int N,
            T alpha,
            const T* A,
            const T* B,
            T beta,
Y
Yu Yang 已提交
266 267
            T* C) const;

T
tensor-tang 已提交
268 269 270
  template <typename T>
  T DOT(int n, const T* x, const T* y) const;

T
tensor-tang 已提交
271
  template <typename T>
T
tensor-tang 已提交
272
  void SCAL(int n, const T a, T* x) const;
T
tensor-tang 已提交
273

J
Jacek Czaja 已提交
274 275 276
  template <typename T>
  T ASUM(int n, T* x, int inc) const;

Y
Yu Yang 已提交
277
  template <typename T>
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
  void BatchedGEMM(CBLAS_TRANSPOSE transA,
                   CBLAS_TRANSPOSE transB,
                   int M,
                   int N,
                   int K,
                   T alpha,
                   const T* A,
                   const T* B,
                   T beta,
                   T* C,
                   int batchCount,
                   int64_t strideA,
                   int64_t strideB) const;

  template <typename T>
  void BatchedGEMM(CBLAS_TRANSPOSE transA,
                   CBLAS_TRANSPOSE transB,
                   int M,
                   int N,
                   int K,
                   T alpha,
                   const T** A,
                   const T** B,
                   T beta,
                   T** C,
S
ShenLiang 已提交
303 304
                   int batchCount) const;

305 306
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
    !defined(PADDLE_WITH_HIP)
307
  template <typename T>
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
  void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA,
                           CBLAS_TRANSPOSE transB,
                           int W1,
                           int H1,
                           int W2,
                           int H2,
                           T alpha,
                           const T* A,
                           const T* B,
                           T beta,
                           T* C,
                           int batchCount,
                           int64_t strideA,
                           int64_t strideB,
                           int64_t head_number,
                           bool split_b_vertical) const;
324 325
#endif

Y
Yu Yang 已提交
326
  template <typename T>
327
  void MatMul(const phi::DenseTensor& mat_a,
328
              const MatDescriptor& dim_a,
329
              const phi::DenseTensor& mat_b,
330 331
              const MatDescriptor& dim_b,
              T alpha,
332
              phi::DenseTensor* mat_out,
333
              T beta) const;
Y
Yu Yang 已提交
334

335
  template <typename T>
336 337 338 339 340 341 342
  void MatMul(const T* mat_a,
              const MatDescriptor& dim_a,
              const T* mat_b,
              const MatDescriptor& dim_b,
              T alpha,
              T* mat_out,
              T beta) const;
343

Y
Use mkl  
Yu Yang 已提交
344 345 346
  template <typename T>
  void VINV(int n, const T* a, T* y) const;

Y
Yihua Xu 已提交
347 348 349
  template <typename T>
  void VMERF(int n, const T* a, T* y, int64_t mode) const;

G
Guo Sheng 已提交
350
  template <typename T>
351 352 353 354 355 356 357 358 359 360
  void TRSM(CBLAS_SIDE side,
            CBLAS_UPLO uplo,
            CBLAS_TRANSPOSE transA,
            CBLAS_DIAG diag,
            int M,
            int N,
            T alpha,
            const T* A,
            int lda,
            T* B,
G
Guo Sheng 已提交
361 362
            int ldb) const;

363
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
364 365 366 367
  template <typename T>
  void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const;

  template <typename T>
368 369 370 371 372
  void BatchedGETRI(int n,
                    const T** a,
                    const int* ipiv,
                    T** a_inv,
                    int* info,
373 374 375
                    int batch_size) const;

  template <typename T>
376 377
  void BatchedMatInv(
      int n, const T** a, T** a_inv, int* info, int batch_size) const;
W
Weilong Wu 已提交
378 379 380

  // cuBlas solve
  template <typename T>
381 382 383 384 385 386 387 388 389
  void BatchedGETRS(CBLAS_TRANSPOSE trans,
                    int n,
                    int nrhs,
                    const T** a,
                    int lda,
                    int* ipiv,
                    T** b,
                    int ldb,
                    int* info,
W
Weilong Wu 已提交
390
                    int batch_size) const;
391 392 393

  // cuBlas triangular_solve
  template <typename T>
394 395 396 397 398 399 400 401 402 403 404 405
  void BatchedTRSM(CBLAS_SIDE side,
                   CBLAS_UPLO uplo,
                   CBLAS_TRANSPOSE transA,
                   CBLAS_DIAG diag,
                   int M,
                   int N,
                   T alpha,
                   const T** a,
                   int lda,
                   T** b,
                   int ldb,
                   int batch_size) const;
406 407
#endif

Y
Yu Yang 已提交
408 409 410 411 412 413 414 415 416 417 418 419 420 421
 private:
  const DeviceContext& context_;
};

template <typename DeviceContext, typename T>
class BlasT : private Blas<DeviceContext> {
 public:
  using Blas<DeviceContext>::Blas;

  template <typename... ARGS>
  void GEMM(ARGS... args) const {
    Base()->template GEMM<T>(args...);
  }

422
#ifdef PADDLE_WITH_MKLML  // @{ Group MKLML: class BlasT
T
tensor-tang 已提交
423 424
  template <typename... ARGS>
  T* GEMM_ALLOC(ARGS... args) const {
T
tensor-tang 已提交
425
    return Base()->template GEMM_ALLOC<T>(args...);
T
tensor-tang 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
  }

  template <typename... ARGS>
  void GEMM_PACK(ARGS... args) const {
    Base()->template GEMM_PACK<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_COMPUTE(ARGS... args) const {
    Base()->template GEMM_COMPUTE<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_FREE(ARGS... args) const {
    Base()->template GEMM_FREE<T>(args...);
  }
442

443 444 445 446 447
  template <typename... ARGS>
  void CSRMM(ARGS... args) const {
    Base()->template CSRMM<T>(args...);
  }

448
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
449 450 451 452 453
  template <typename... ARGS>
  void MatMulWithHead(ARGS... args) const {
    Base()->template MatMulWithHead<T>(args...);
  }
#endif
454
#endif  // @} End Group MKLML: class BlasT
T
tensor-tang 已提交
455

Y
Yu Yang 已提交
456 457 458 459 460 461 462 463 464 465
  template <typename... ARGS>
  void MatMul(ARGS... args) const {
    Base()->template MatMul<T>(args...);
  }

  template <typename... ARGS>
  void AXPY(ARGS... args) const {
    Base()->template AXPY<T>(args...);
  }

466 467 468 469 470
  template <typename... ARGS>
  void VADD(ARGS... args) const {
    Base()->template VADD<T>(args...);
  }

471 472 473 474 475
  template <typename... ARGS>
  void VSUB(ARGS... args) const {
    Base()->template VSUB<T>(args...);
  }

T
tensor-tang 已提交
476 477 478 479 480
  template <typename... ARGS>
  void VMUL(ARGS... args) const {
    Base()->template VMUL<T>(args...);
  }

481 482 483 484 485
  template <typename... ARGS>
  void VDIV(ARGS... args) const {
    Base()->template VDIV<T>(args...);
  }

486 487 488 489 490
  template <typename... ARGS>
  void VCOPY(ARGS... args) const {
    Base()->template VCOPY<T>(args...);
  }

T
tensor-tang 已提交
491 492 493 494 495
  template <typename... ARGS>
  void VEXP(ARGS... args) const {
    Base()->template VEXP<T>(args...);
  }

T
tensor-tang 已提交
496
  template <typename... ARGS>
T
tensor-tang 已提交
497 498
  void VSQUARE(ARGS... args) const {
    Base()->template VSQUARE<T>(args...);
T
tensor-tang 已提交
499 500 501 502 503 504 505
  }

  template <typename... ARGS>
  void VPOW(ARGS... args) const {
    Base()->template VPOW<T>(args...);
  }

Y
Yu Yang 已提交
506 507 508 509 510
  template <typename... ARGS>
  void GEMV(ARGS... args) const {
    Base()->template GEMV<T>(args...);
  }

T
tensor-tang 已提交
511 512 513 514 515 516 517 518 519 520
  template <typename... ARGS>
  T DOT(ARGS... args) const {
    return Base()->template DOT<T>(args...);
  }

  template <typename... ARGS>
  void SCAL(ARGS... args) const {
    Base()->template SCAL<T>(args...);
  }

J
Jacek Czaja 已提交
521 522 523 524 525
  template <typename... ARGS>
  T ASUM(ARGS... args) const {
    return Base()->template ASUM<T>(args...);
  }

Y
Yu Yang 已提交
526 527 528 529 530
  template <typename... ARGS>
  void BatchedGEMM(ARGS... args) const {
    Base()->template BatchedGEMM<T>(args...);
  }

Y
Use mkl  
Yu Yang 已提交
531 532 533 534 535
  template <typename... ARGS>
  void VINV(ARGS... args) const {
    Base()->template VINV<T>(args...);
  }

Y
Yihua Xu 已提交
536 537 538 539 540
  template <typename... ARGS>
  void VMERF(ARGS... args) const {
    Base()->template VMERF<T>(args...);
  }

G
Guo Sheng 已提交
541 542 543 544 545
  template <typename... ARGS>
  void TRSM(ARGS... args) const {
    Base()->template TRSM<T>(args...);
  }

546
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
547 548 549 550 551 552 553 554 555 556 557 558 559 560
  template <typename... ARGS>
  void BatchedGETRF(ARGS... args) const {
    Base()->template BatchedGETRF<T>(args...);
  }

  template <typename... ARGS>
  void BatchedGETRI(ARGS... args) const {
    Base()->template BatchedGETRI<T>(args...);
  }

  template <typename... ARGS>
  void BatchedMatInv(ARGS... args) const {
    Base()->template BatchedMatInv<T>(args...);
  }
W
Weilong Wu 已提交
561 562 563 564 565 566

  // solve
  template <typename... ARGS>
  void BatchedGETRS(ARGS... args) const {
    Base()->template BatchedGETRS<T>(args...);
  }
567 568 569 570 571 572

  // triangular_solve
  template <typename... ARGS>
  void BatchedTRSM(ARGS... args) const {
    Base()->template BatchedTRSM<T>(args...);
  }
573 574
#endif

Y
Yu Yang 已提交
575 576 577 578 579 580 581 582 583 584 585
 private:
  const Blas<DeviceContext>* Base() const {
    return static_cast<const Blas<DeviceContext>*>(this);
  }
};

template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
  return BlasT<DeviceContext, T>(dev_ctx);
}

586
}  // namespace funcs
587
}  // namespace phi
Y
Yu Yang 已提交
588

589
#include "paddle/phi/kernels/funcs/blas/blas_impl.h"
Y
Yu Yang 已提交
590
#ifdef PADDLE_WITH_CUDA
591
#include "paddle/phi/kernels/funcs/blas/blas_impl.cu.h"
Y
Yu Yang 已提交
592
#endif
593
#ifdef PADDLE_WITH_HIP
594
#include "paddle/phi/kernels/funcs/blas/blas_impl.hip.h"
595
#endif