blas.h 13.4 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"

W
wanghuancoder 已提交
20 21 22 23 24 25 26
namespace paddle {
namespace framework {
class ExecutionContext;
class Tensor;
}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
27
#ifdef PADDLE_WITH_MKLML
28
#include "paddle/fluid/platform/dynload/mklml.h"
Y
Yu Yang 已提交
29 30
#endif

T
tensor-tang 已提交
31 32 33 34
#ifdef PADDLE_WITH_LIBXSMM
#include <libxsmm.h>
#endif

W
Wilber 已提交
35
#if defined(PADDLE_USE_OPENBLAS) || defined(PADDLE_USE_REFERENCE_CBLAS)
Y
Yu Yang 已提交
36 37 38 39 40 41 42
#include <cblas.h>
#endif

namespace paddle {
namespace operators {
namespace math {

Y
Yu Yang 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
/**
 * Matrix Descriptor of a memory buffer.
 *
 * It is used for Blas::MatMul. MatMul operator can be batched.
 * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
 * `batch_size` times of GEMM. The batched GEMM could be faster base on the
 * implementation of the blas library. The batch size could be zero. If any
 * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
 * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
 * [BatchSize, H1, W2]
 *
 * The boolean flag, `trans`, describe the memory is the transpose of matrix or
 * not. If the trans is true, the last two dims of matrix are transposed. The
 * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
 *
 * The MatDescriptor is not only the dimension or shape of a matrix, it also
 * contains the layout, stride of matrix. It is clearer to have a structure than
 * reuse `DDim`.
 */
Y
Yu Yang 已提交
62
struct MatDescriptor {
Y
Yu Yang 已提交
63 64 65 66 67 68 69
  int64_t height_;
  int64_t width_;
  int64_t stride_{0};
  int64_t batch_size_{0};
  bool trans_;
};

Y
Yu Yang 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
/**
 * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
 * flag
 *
 * @param tensor_dim: The dimension of the tensor. The rank of this dimension
 * must larger than 1.
 *
 * @param num_flatten_cols:  Reshape a tensor to a matrix. The matrix's first
 * dimension(column length) will be the product of tensor's first `num_col_dims`
 * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
 * batch_size of descriptor.
 *
 * @param trans: True if the matrix is transposed.
 */
extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim,
                                            int num_flatten_cols, bool trans);
Y
Yu Yang 已提交
86

Y
Yu Yang 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99
template <typename DeviceContext>
class Blas {
 public:
  explicit Blas(const DeviceContext& context) : context_(context) {}

  template <typename T>
  void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
            T alpha, const T* A, const T* B, T beta, T* C) const;

  template <typename T>
  void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
            int lda, const T* B, int ldb, T beta, T* C, int ldc) const;

T
tensor-tang 已提交
100 101 102 103 104
  template <typename T>
  void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
            T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
            int ldc) const;

105
#ifdef PADDLE_WITH_MKLML  // @{ Group MKLML: class Blas
T
tensor-tang 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
  template <typename T>
  T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
                const int K) const;

  template <typename T>
  void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M,
                 int N, int K, const T alpha, const T* src, const int ld,
                 T* dst) const;

  template <typename T>
  void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A,
                    const int lda, const T* B, const int ldb, T beta, T* C,
                    const int ldc) const;

  template <typename T>
  void GEMM_FREE(T* data) const;
122

123 124 125 126 127 128
  template <typename T>
  void CSRMM(const char* transa, const int* m, const int* n, const int* k,
             const T* alpha, const char* matdescra, const T* val,
             const int* indx, const int* pntrb, const int* pntre, const T* b,
             const int* ldb, const T* beta, T* c, const int* ldc) const;

129
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
130 131 132 133 134
  template <typename T>
  void MatMulWithHead(const framework::Tensor& mat_a,
                      const MatDescriptor& dim_a,
                      const framework::Tensor& mat_b,
                      const MatDescriptor& dim_b, T alpha, int head_number,
135 136
                      framework::Tensor* mat_out, T beta,
                      bool mat_y_split_vertical) const;
137
#endif
138
#endif  // @} End Group MKLML: class Blas
T
tensor-tang 已提交
139

T
tensor-tang 已提交
140 141 142 143
  template <typename T>
  void MatMul(const int M, const int N, const int K, const T* A, const T* B,
              T* C) const;

Y
Yu Yang 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
  template <typename T>
  void MatMul(const framework::Tensor& mat_a, bool trans_a,
              const framework::Tensor& mat_b, bool trans_b, T alpha,
              framework::Tensor* mat_out, T beta) const;

  template <typename T>
  void MatMul(const framework::Tensor& mat_a, bool trans_a,
              const framework::Tensor& mat_b, bool trans_b,
              framework::Tensor* mat_out) const {
    MatMul(mat_a, trans_a, mat_b, trans_b, static_cast<T>(1.0), mat_out,
           static_cast<T>(0.0));
  }

  template <typename T>
  void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b,
              framework::Tensor* mat_out) const {
    this->template MatMul<T>(mat_a, false, mat_b, false, mat_out);
  }

  template <typename T>
  void AXPY(int n, T alpha, const T* x, T* y) const;

166 167 168
  template <typename T>
  void VADD(int n, const T* x, const T* y, T* z) const;

169 170 171
  template <typename T>
  void VSUB(int n, const T* x, const T* y, T* z) const;

T
tensor-tang 已提交
172 173 174
  template <typename T>
  void VMUL(int n, const T* x, const T* y, T* z) const;

175 176 177
  template <typename T>
  void VDIV(int n, const T* x, const T* y, T* z) const;

178 179 180
  template <typename T>
  void VCOPY(int n, const T* x, T* y) const;

Y
Yu Yang 已提交
181
  template <typename T>
T
tensor-tang 已提交
182 183
  void VEXP(int n, const T* x, T* y) const;

T
tensor-tang 已提交
184
  template <typename T>
T
tensor-tang 已提交
185
  void VSQUARE(int n, const T* x, T* y) const;
T
tensor-tang 已提交
186 187 188 189

  template <typename T>
  void VPOW(int n, const T* x, T alpha, T* y) const;

T
tensor-tang 已提交
190
  template <typename T>
Y
Yu Yang 已提交
191 192 193
  void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
            T* C) const;

T
tensor-tang 已提交
194 195 196
  template <typename T>
  T DOT(int n, const T* x, const T* y) const;

T
tensor-tang 已提交
197
  template <typename T>
T
tensor-tang 已提交
198
  void SCAL(int n, const T a, T* x) const;
T
tensor-tang 已提交
199

J
Jacek Czaja 已提交
200 201 202
  template <typename T>
  T ASUM(int n, T* x, int inc) const;

Y
Yu Yang 已提交
203 204 205 206 207
  template <typename T>
  void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
                   int K, T alpha, const T* A, const T* B, T beta, T* C,
                   int batchCount, int64_t strideA, int64_t strideB) const;

S
ShenLiang 已提交
208 209 210 211 212
  template <typename T>
  void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
                   int K, T alpha, const T** A, const T** B, T beta, T** C,
                   int batchCount) const;

213 214
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
    !defined(PADDLE_WITH_HIP)
215 216
  template <typename T>
  void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
217 218 219 220
                           int W1, int H1, int W2, int H2, T alpha, const T* A,
                           const T* B, T beta, T* C, int batchCount,
                           int64_t strideA, int64_t strideB,
                           int64_t head_number, bool split_b_vertical) const;
221 222
#endif

Y
Yu Yang 已提交
223
  template <typename T>
Y
Yu Yang 已提交
224 225 226
  void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
              const framework::Tensor& mat_b, const MatDescriptor& dim_b,
              T alpha, framework::Tensor* mat_out, T beta) const;
Y
Yu Yang 已提交
227

Y
Use mkl  
Yu Yang 已提交
228 229 230
  template <typename T>
  void VINV(int n, const T* a, T* y) const;

Y
Yihua Xu 已提交
231 232 233
  template <typename T>
  void VMERF(int n, const T* a, T* y, int64_t mode) const;

G
Guo Sheng 已提交
234 235 236 237 238
  template <typename T>
  void TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA,
            CBLAS_DIAG diag, int M, int N, T alpha, const T* A, int lda, T* B,
            int ldb) const;

239
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
240 241 242 243 244 245 246 247 248 249
  template <typename T>
  void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const;

  template <typename T>
  void BatchedGETRI(int n, const T** a, const int* ipiv, T** a_inv, int* info,
                    int batch_size) const;

  template <typename T>
  void BatchedMatInv(int n, const T** a, T** a_inv, int* info,
                     int batch_size) const;
250 251 252 253 254 255

  // cuBlas solve
  template <typename T>
  void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a,
                    int lda, int* ipiv, T** b, int ldb, int* info,
                    int batch_size) const;
256 257
#endif

Y
Yu Yang 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270 271
 private:
  const DeviceContext& context_;
};

template <typename DeviceContext, typename T>
class BlasT : private Blas<DeviceContext> {
 public:
  using Blas<DeviceContext>::Blas;

  template <typename... ARGS>
  void GEMM(ARGS... args) const {
    Base()->template GEMM<T>(args...);
  }

272
#ifdef PADDLE_WITH_MKLML  // @{ Group MKLML: class BlasT
T
tensor-tang 已提交
273 274
  template <typename... ARGS>
  T* GEMM_ALLOC(ARGS... args) const {
T
tensor-tang 已提交
275
    return Base()->template GEMM_ALLOC<T>(args...);
T
tensor-tang 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
  }

  template <typename... ARGS>
  void GEMM_PACK(ARGS... args) const {
    Base()->template GEMM_PACK<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_COMPUTE(ARGS... args) const {
    Base()->template GEMM_COMPUTE<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_FREE(ARGS... args) const {
    Base()->template GEMM_FREE<T>(args...);
  }
292

293 294 295 296 297
  template <typename... ARGS>
  void CSRMM(ARGS... args) const {
    Base()->template CSRMM<T>(args...);
  }

298
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
299 300 301 302 303
  template <typename... ARGS>
  void MatMulWithHead(ARGS... args) const {
    Base()->template MatMulWithHead<T>(args...);
  }
#endif
304
#endif  // @} End Group MKLML: class BlasT
T
tensor-tang 已提交
305

Y
Yu Yang 已提交
306 307 308 309 310 311 312 313 314 315
  template <typename... ARGS>
  void MatMul(ARGS... args) const {
    Base()->template MatMul<T>(args...);
  }

  template <typename... ARGS>
  void AXPY(ARGS... args) const {
    Base()->template AXPY<T>(args...);
  }

316 317 318 319 320
  template <typename... ARGS>
  void VADD(ARGS... args) const {
    Base()->template VADD<T>(args...);
  }

321 322 323 324 325
  template <typename... ARGS>
  void VSUB(ARGS... args) const {
    Base()->template VSUB<T>(args...);
  }

T
tensor-tang 已提交
326 327 328 329 330
  template <typename... ARGS>
  void VMUL(ARGS... args) const {
    Base()->template VMUL<T>(args...);
  }

331 332 333 334 335
  template <typename... ARGS>
  void VDIV(ARGS... args) const {
    Base()->template VDIV<T>(args...);
  }

336 337 338 339 340
  template <typename... ARGS>
  void VCOPY(ARGS... args) const {
    Base()->template VCOPY<T>(args...);
  }

T
tensor-tang 已提交
341 342 343 344 345
  template <typename... ARGS>
  void VEXP(ARGS... args) const {
    Base()->template VEXP<T>(args...);
  }

T
tensor-tang 已提交
346
  template <typename... ARGS>
T
tensor-tang 已提交
347 348
  void VSQUARE(ARGS... args) const {
    Base()->template VSQUARE<T>(args...);
T
tensor-tang 已提交
349 350 351 352 353 354 355
  }

  template <typename... ARGS>
  void VPOW(ARGS... args) const {
    Base()->template VPOW<T>(args...);
  }

Y
Yu Yang 已提交
356 357 358 359 360
  template <typename... ARGS>
  void GEMV(ARGS... args) const {
    Base()->template GEMV<T>(args...);
  }

T
tensor-tang 已提交
361 362 363 364 365 366 367 368 369 370
  template <typename... ARGS>
  T DOT(ARGS... args) const {
    return Base()->template DOT<T>(args...);
  }

  template <typename... ARGS>
  void SCAL(ARGS... args) const {
    Base()->template SCAL<T>(args...);
  }

J
Jacek Czaja 已提交
371 372 373 374 375
  template <typename... ARGS>
  T ASUM(ARGS... args) const {
    return Base()->template ASUM<T>(args...);
  }

Y
Yu Yang 已提交
376 377 378 379 380
  template <typename... ARGS>
  void BatchedGEMM(ARGS... args) const {
    Base()->template BatchedGEMM<T>(args...);
  }

Y
Use mkl  
Yu Yang 已提交
381 382 383 384 385
  template <typename... ARGS>
  void VINV(ARGS... args) const {
    Base()->template VINV<T>(args...);
  }

Y
Yihua Xu 已提交
386 387 388 389 390
  template <typename... ARGS>
  void VMERF(ARGS... args) const {
    Base()->template VMERF<T>(args...);
  }

G
Guo Sheng 已提交
391 392 393 394 395
  template <typename... ARGS>
  void TRSM(ARGS... args) const {
    Base()->template TRSM<T>(args...);
  }

396
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
397 398 399 400 401 402 403 404 405 406 407 408 409 410
  template <typename... ARGS>
  void BatchedGETRF(ARGS... args) const {
    Base()->template BatchedGETRF<T>(args...);
  }

  template <typename... ARGS>
  void BatchedGETRI(ARGS... args) const {
    Base()->template BatchedGETRI<T>(args...);
  }

  template <typename... ARGS>
  void BatchedMatInv(ARGS... args) const {
    Base()->template BatchedMatInv<T>(args...);
  }
411 412 413 414 415 416

  // solve
  template <typename... ARGS>
  void BatchedGETRS(ARGS... args) const {
    Base()->template BatchedGETRS<T>(args...);
  }
417 418
#endif

Y
Yu Yang 已提交
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
 private:
  const Blas<DeviceContext>* Base() const {
    return static_cast<const Blas<DeviceContext>*>(this);
  }
};

template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
    const framework::ExecutionContext& exe_ctx) {
  return BlasT<DeviceContext, T>(
      exe_ctx.template device_context<DeviceContext>());
}

template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
  return BlasT<DeviceContext, T>(dev_ctx);
}

}  // namespace math
}  // namespace operators
}  // namespace paddle

#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
445 446 447
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/math/blas_impl.hip.h"
#endif