blas.h 12.3 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"

#ifdef PADDLE_WITH_MKLML
21
#include "paddle/fluid/platform/dynload/mklml.h"
Y
Yu Yang 已提交
22 23
#endif

T
tensor-tang 已提交
24 25 26 27
#ifdef PADDLE_WITH_LIBXSMM
#include <libxsmm.h>
#endif

Y
Yu Yang 已提交
28 29 30 31 32 33 34 35
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#endif

namespace paddle {
namespace operators {
namespace math {

Y
Yu Yang 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
/**
 * Matrix Descriptor of a memory buffer.
 *
 * It is used for Blas::MatMul. MatMul operator can be batched.
 * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
 * `batch_size` times of GEMM. The batched GEMM could be faster base on the
 * implementation of the blas library. The batch size could be zero. If any
 * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
 * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
 * [BatchSize, H1, W2]
 *
 * The boolean flag, `trans`, describe the memory is the transpose of matrix or
 * not. If the trans is true, the last two dims of matrix are transposed. The
 * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
 *
 * The MatDescriptor is not only the dimension or shape of a matrix, it also
 * contains the layout, stride of matrix. It is clearer to have a structure than
 * reuse `DDim`.
 */
Y
Yu Yang 已提交
55
struct MatDescriptor {
Y
Yu Yang 已提交
56 57 58 59 60 61 62
  int64_t height_;
  int64_t width_;
  int64_t stride_{0};
  int64_t batch_size_{0};
  bool trans_;
};

Y
Yu Yang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
/**
 * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
 * flag
 *
 * @param tensor_dim: The dimension of the tensor. The rank of this dimension
 * must larger than 1.
 *
 * @param num_flatten_cols:  Reshape a tensor to a matrix. The matrix's first
 * dimension(column length) will be the product of tensor's first `num_col_dims`
 * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
 * batch_size of descriptor.
 *
 * @param trans: True if the matrix is transposed.
 */
extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim,
                                            int num_flatten_cols, bool trans);
Y
Yu Yang 已提交
79

Y
Yu Yang 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92
template <typename DeviceContext>
class Blas {
 public:
  explicit Blas(const DeviceContext& context) : context_(context) {}

  template <typename T>
  void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
            T alpha, const T* A, const T* B, T beta, T* C) const;

  template <typename T>
  void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
            int lda, const T* B, int ldb, T beta, T* C, int ldc) const;

T
tensor-tang 已提交
93 94 95 96 97
  template <typename T>
  void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
            T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
            int ldc) const;

T
tensor-tang 已提交
98
#ifdef PADDLE_WITH_MKLML
T
tensor-tang 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
  template <typename T>
  T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
                const int K) const;

  template <typename T>
  void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M,
                 int N, int K, const T alpha, const T* src, const int ld,
                 T* dst) const;

  template <typename T>
  void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A,
                    const int lda, const T* B, const int ldb, T beta, T* C,
                    const int ldc) const;

  template <typename T>
  void GEMM_FREE(T* data) const;
115

116 117 118 119 120 121
  template <typename T>
  void CSRMM(const char* transa, const int* m, const int* n, const int* k,
             const T* alpha, const char* matdescra, const T* val,
             const int* indx, const int* pntrb, const int* pntre, const T* b,
             const int* ldb, const T* beta, T* c, const int* ldc) const;

122 123 124 125 126 127
#if !defined(PADDLE_WITH_CUDA)
  template <typename T>
  void MatMulWithHead(const framework::Tensor& mat_a,
                      const MatDescriptor& dim_a,
                      const framework::Tensor& mat_b,
                      const MatDescriptor& dim_b, T alpha, int head_number,
128 129
                      framework::Tensor* mat_out, T beta,
                      bool mat_y_split_vertical) const;
130
#endif
T
tensor-tang 已提交
131
#endif
T
tensor-tang 已提交
132

T
tensor-tang 已提交
133 134 135 136
  template <typename T>
  void MatMul(const int M, const int N, const int K, const T* A, const T* B,
              T* C) const;

Y
Yu Yang 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
  template <typename T>
  void MatMul(const framework::Tensor& mat_a, bool trans_a,
              const framework::Tensor& mat_b, bool trans_b, T alpha,
              framework::Tensor* mat_out, T beta) const;

  template <typename T>
  void MatMul(const framework::Tensor& mat_a, bool trans_a,
              const framework::Tensor& mat_b, bool trans_b,
              framework::Tensor* mat_out) const {
    MatMul(mat_a, trans_a, mat_b, trans_b, static_cast<T>(1.0), mat_out,
           static_cast<T>(0.0));
  }

  template <typename T>
  void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b,
              framework::Tensor* mat_out) const {
    this->template MatMul<T>(mat_a, false, mat_b, false, mat_out);
  }

  template <typename T>
  void AXPY(int n, T alpha, const T* x, T* y) const;

159 160 161
  template <typename T>
  void VADD(int n, const T* x, const T* y, T* z) const;

162 163 164
  template <typename T>
  void VSUB(int n, const T* x, const T* y, T* z) const;

T
tensor-tang 已提交
165 166 167
  template <typename T>
  void VMUL(int n, const T* x, const T* y, T* z) const;

168 169 170
  template <typename T>
  void VDIV(int n, const T* x, const T* y, T* z) const;

171 172 173
  template <typename T>
  void VCOPY(int n, const T* x, T* y) const;

Y
Yu Yang 已提交
174
  template <typename T>
T
tensor-tang 已提交
175 176
  void VEXP(int n, const T* x, T* y) const;

T
tensor-tang 已提交
177
  template <typename T>
T
tensor-tang 已提交
178
  void VSQUARE(int n, const T* x, T* y) const;
T
tensor-tang 已提交
179 180 181 182

  template <typename T>
  void VPOW(int n, const T* x, T alpha, T* y) const;

T
tensor-tang 已提交
183
  template <typename T>
Y
Yu Yang 已提交
184 185 186
  void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
            T* C) const;

T
tensor-tang 已提交
187 188 189
  template <typename T>
  T DOT(int n, const T* x, const T* y) const;

T
tensor-tang 已提交
190
  template <typename T>
T
tensor-tang 已提交
191
  void SCAL(int n, const T a, T* x) const;
T
tensor-tang 已提交
192

J
Jacek Czaja 已提交
193 194 195
  template <typename T>
  T ASUM(int n, T* x, int inc) const;

Y
Yu Yang 已提交
196 197 198 199 200
  template <typename T>
  void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
                   int K, T alpha, const T* A, const T* B, T beta, T* C,
                   int batchCount, int64_t strideA, int64_t strideB) const;

201 202 203
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
  template <typename T>
  void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
204 205 206 207
                           int W1, int H1, int W2, int H2, T alpha, const T* A,
                           const T* B, T beta, T* C, int batchCount,
                           int64_t strideA, int64_t strideB,
                           int64_t head_number, bool split_b_vertical) const;
208 209
#endif

Y
Yu Yang 已提交
210
  template <typename T>
Y
Yu Yang 已提交
211 212 213
  void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
              const framework::Tensor& mat_b, const MatDescriptor& dim_b,
              T alpha, framework::Tensor* mat_out, T beta) const;
Y
Yu Yang 已提交
214

Y
Use mkl  
Yu Yang 已提交
215 216 217
  template <typename T>
  void VINV(int n, const T* a, T* y) const;

Y
Yihua Xu 已提交
218 219 220
  template <typename T>
  void VMERF(int n, const T* a, T* y, int64_t mode) const;

G
Guo Sheng 已提交
221 222 223 224 225
  template <typename T>
  void TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA,
            CBLAS_DIAG diag, int M, int N, T alpha, const T* A, int lda, T* B,
            int ldb) const;

226 227 228 229 230 231 232 233 234 235 236 237 238
#ifdef PADDLE_WITH_CUDA
  template <typename T>
  void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const;

  template <typename T>
  void BatchedGETRI(int n, const T** a, const int* ipiv, T** a_inv, int* info,
                    int batch_size) const;

  template <typename T>
  void BatchedMatInv(int n, const T** a, T** a_inv, int* info,
                     int batch_size) const;
#endif

Y
Yu Yang 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252
 private:
  const DeviceContext& context_;
};

template <typename DeviceContext, typename T>
class BlasT : private Blas<DeviceContext> {
 public:
  using Blas<DeviceContext>::Blas;

  template <typename... ARGS>
  void GEMM(ARGS... args) const {
    Base()->template GEMM<T>(args...);
  }

T
tensor-tang 已提交
253
#ifdef PADDLE_WITH_MKLML
T
tensor-tang 已提交
254 255
  template <typename... ARGS>
  T* GEMM_ALLOC(ARGS... args) const {
T
tensor-tang 已提交
256
    return Base()->template GEMM_ALLOC<T>(args...);
T
tensor-tang 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
  }

  template <typename... ARGS>
  void GEMM_PACK(ARGS... args) const {
    Base()->template GEMM_PACK<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_COMPUTE(ARGS... args) const {
    Base()->template GEMM_COMPUTE<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_FREE(ARGS... args) const {
    Base()->template GEMM_FREE<T>(args...);
  }
273

274 275 276 277 278
  template <typename... ARGS>
  void CSRMM(ARGS... args) const {
    Base()->template CSRMM<T>(args...);
  }

279 280 281 282 283 284
#if !defined(PADDLE_WITH_CUDA)
  template <typename... ARGS>
  void MatMulWithHead(ARGS... args) const {
    Base()->template MatMulWithHead<T>(args...);
  }
#endif
T
tensor-tang 已提交
285
#endif
T
tensor-tang 已提交
286

Y
Yu Yang 已提交
287 288 289 290 291 292 293 294 295 296
  template <typename... ARGS>
  void MatMul(ARGS... args) const {
    Base()->template MatMul<T>(args...);
  }

  template <typename... ARGS>
  void AXPY(ARGS... args) const {
    Base()->template AXPY<T>(args...);
  }

297 298 299 300 301
  template <typename... ARGS>
  void VADD(ARGS... args) const {
    Base()->template VADD<T>(args...);
  }

302 303 304 305 306
  template <typename... ARGS>
  void VSUB(ARGS... args) const {
    Base()->template VSUB<T>(args...);
  }

T
tensor-tang 已提交
307 308 309 310 311
  template <typename... ARGS>
  void VMUL(ARGS... args) const {
    Base()->template VMUL<T>(args...);
  }

312 313 314 315 316
  template <typename... ARGS>
  void VDIV(ARGS... args) const {
    Base()->template VDIV<T>(args...);
  }

317 318 319 320 321
  template <typename... ARGS>
  void VCOPY(ARGS... args) const {
    Base()->template VCOPY<T>(args...);
  }

T
tensor-tang 已提交
322 323 324 325 326
  template <typename... ARGS>
  void VEXP(ARGS... args) const {
    Base()->template VEXP<T>(args...);
  }

T
tensor-tang 已提交
327
  template <typename... ARGS>
T
tensor-tang 已提交
328 329
  void VSQUARE(ARGS... args) const {
    Base()->template VSQUARE<T>(args...);
T
tensor-tang 已提交
330 331 332 333 334 335 336
  }

  template <typename... ARGS>
  void VPOW(ARGS... args) const {
    Base()->template VPOW<T>(args...);
  }

Y
Yu Yang 已提交
337 338 339 340 341
  template <typename... ARGS>
  void GEMV(ARGS... args) const {
    Base()->template GEMV<T>(args...);
  }

T
tensor-tang 已提交
342 343 344 345 346 347 348 349 350 351
  template <typename... ARGS>
  T DOT(ARGS... args) const {
    return Base()->template DOT<T>(args...);
  }

  template <typename... ARGS>
  void SCAL(ARGS... args) const {
    Base()->template SCAL<T>(args...);
  }

J
Jacek Czaja 已提交
352 353 354 355 356
  template <typename... ARGS>
  T ASUM(ARGS... args) const {
    return Base()->template ASUM<T>(args...);
  }

Y
Yu Yang 已提交
357 358 359 360 361
  template <typename... ARGS>
  void BatchedGEMM(ARGS... args) const {
    Base()->template BatchedGEMM<T>(args...);
  }

Y
Use mkl  
Yu Yang 已提交
362 363 364 365 366
  template <typename... ARGS>
  void VINV(ARGS... args) const {
    Base()->template VINV<T>(args...);
  }

Y
Yihua Xu 已提交
367 368 369 370 371
  template <typename... ARGS>
  void VMERF(ARGS... args) const {
    Base()->template VMERF<T>(args...);
  }

G
Guo Sheng 已提交
372 373 374 375 376
  template <typename... ARGS>
  void TRSM(ARGS... args) const {
    Base()->template TRSM<T>(args...);
  }

377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
#ifdef PADDLE_WITH_CUDA
  template <typename... ARGS>
  void BatchedGETRF(ARGS... args) const {
    Base()->template BatchedGETRF<T>(args...);
  }

  template <typename... ARGS>
  void BatchedGETRI(ARGS... args) const {
    Base()->template BatchedGETRI<T>(args...);
  }

  template <typename... ARGS>
  void BatchedMatInv(ARGS... args) const {
    Base()->template BatchedMatInv<T>(args...);
  }
#endif

Y
Yu Yang 已提交
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
 private:
  const Blas<DeviceContext>* Base() const {
    return static_cast<const Blas<DeviceContext>*>(this);
  }
};

template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
    const framework::ExecutionContext& exe_ctx) {
  return BlasT<DeviceContext, T>(
      exe_ctx.template device_context<DeviceContext>());
}

template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
  return BlasT<DeviceContext, T>(dev_ctx);
}

}  // namespace math
}  // namespace operators
}  // namespace paddle

#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif