blas.h 9.2 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"

#ifdef PADDLE_WITH_MKLML
21
#include "paddle/fluid/platform/dynload/mklml.h"
Y
Yu Yang 已提交
22 23
#endif

T
tensor-tang 已提交
24 25 26 27
#ifdef PADDLE_WITH_LIBXSMM
#include <libxsmm.h>
#endif

Y
Yu Yang 已提交
28 29 30 31 32 33 34 35
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#endif

namespace paddle {
namespace operators {
namespace math {

Y
Yu Yang 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
/**
 * Matrix Descriptor of a memory buffer.
 *
 * It is used for Blas::MatMul. MatMul operator can be batched.
 * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
 * `batch_size` times of GEMM. The batched GEMM could be faster base on the
 * implementation of the blas library. The batch size could be zero. If any
 * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
 * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
 * [BatchSize, H1, W2]
 *
 * The boolean flag, `trans`, describe the memory is the transpose of matrix or
 * not. If the trans is true, the last two dims of matrix are transposed. The
 * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
 *
 * The MatDescriptor is not only the dimension or shape of a matrix, it also
 * contains the layout, stride of matrix. It is clearer to have a structure than
 * reuse `DDim`.
 */
Y
Yu Yang 已提交
55
struct MatDescriptor {
Y
Yu Yang 已提交
56 57 58 59 60 61 62
  int64_t height_;
  int64_t width_;
  int64_t stride_{0};
  int64_t batch_size_{0};
  bool trans_;
};

Y
Yu Yang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
/**
 * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
 * flag
 *
 * @param tensor_dim: The dimension of the tensor. The rank of this dimension
 * must larger than 1.
 *
 * @param num_flatten_cols:  Reshape a tensor to a matrix. The matrix's first
 * dimension(column length) will be the product of tensor's first `num_col_dims`
 * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
 * batch_size of descriptor.
 *
 * @param trans: True if the matrix is transposed.
 */
extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim,
                                            int num_flatten_cols, bool trans);
Y
Yu Yang 已提交
79

Y
Yu Yang 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92
template <typename DeviceContext>
class Blas {
 public:
  explicit Blas(const DeviceContext& context) : context_(context) {}

  template <typename T>
  void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
            T alpha, const T* A, const T* B, T beta, T* C) const;

  template <typename T>
  void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
            int lda, const T* B, int ldb, T beta, T* C, int ldc) const;

T
tensor-tang 已提交
93 94 95 96 97
  template <typename T>
  void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
            T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
            int ldc) const;

T
tensor-tang 已提交
98
#ifdef PADDLE_WITH_MKLML
T
tensor-tang 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
  template <typename T>
  T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
                const int K) const;

  template <typename T>
  void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M,
                 int N, int K, const T alpha, const T* src, const int ld,
                 T* dst) const;

  template <typename T>
  void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A,
                    const int lda, const T* B, const int ldb, T beta, T* C,
                    const int ldc) const;

  template <typename T>
  void GEMM_FREE(T* data) const;
T
tensor-tang 已提交
115
#endif
T
tensor-tang 已提交
116

T
tensor-tang 已提交
117 118 119 120
  template <typename T>
  void MatMul(const int M, const int N, const int K, const T* A, const T* B,
              T* C) const;

Y
Yu Yang 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
  template <typename T>
  void MatMul(const framework::Tensor& mat_a, bool trans_a,
              const framework::Tensor& mat_b, bool trans_b, T alpha,
              framework::Tensor* mat_out, T beta) const;

  template <typename T>
  void MatMul(const framework::Tensor& mat_a, bool trans_a,
              const framework::Tensor& mat_b, bool trans_b,
              framework::Tensor* mat_out) const {
    MatMul(mat_a, trans_a, mat_b, trans_b, static_cast<T>(1.0), mat_out,
           static_cast<T>(0.0));
  }

  template <typename T>
  void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b,
              framework::Tensor* mat_out) const {
    this->template MatMul<T>(mat_a, false, mat_b, false, mat_out);
  }

  template <typename T>
  void AXPY(int n, T alpha, const T* x, T* y) const;

143 144 145
  template <typename T>
  void VADD(int n, const T* x, const T* y, T* z) const;

T
tensor-tang 已提交
146 147 148
  template <typename T>
  void VMUL(int n, const T* x, const T* y, T* z) const;

149 150 151
  template <typename T>
  void VCOPY(int n, const T* x, T* y) const;

Y
Yu Yang 已提交
152
  template <typename T>
T
tensor-tang 已提交
153 154
  void VEXP(int n, const T* x, T* y) const;

T
tensor-tang 已提交
155
  template <typename T>
T
tensor-tang 已提交
156
  void VSQUARE(int n, const T* x, T* y) const;
T
tensor-tang 已提交
157 158 159 160

  template <typename T>
  void VPOW(int n, const T* x, T alpha, T* y) const;

T
tensor-tang 已提交
161
  template <typename T>
Y
Yu Yang 已提交
162 163 164
  void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
            T* C) const;

T
tensor-tang 已提交
165 166 167
  template <typename T>
  T DOT(int n, const T* x, const T* y) const;

T
tensor-tang 已提交
168
  template <typename T>
T
tensor-tang 已提交
169
  void SCAL(int n, const T a, T* x) const;
T
tensor-tang 已提交
170

J
Jacek Czaja 已提交
171 172 173
  template <typename T>
  T ASUM(int n, T* x, int inc) const;

Y
Yu Yang 已提交
174 175 176 177 178
  template <typename T>
  void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
                   int K, T alpha, const T* A, const T* B, T beta, T* C,
                   int batchCount, int64_t strideA, int64_t strideB) const;

Y
Yu Yang 已提交
179
  template <typename T>
Y
Yu Yang 已提交
180 181 182
  void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
              const framework::Tensor& mat_b, const MatDescriptor& dim_b,
              T alpha, framework::Tensor* mat_out, T beta) const;
Y
Yu Yang 已提交
183

Y
Use mkl  
Yu Yang 已提交
184 185 186
  template <typename T>
  void VINV(int n, const T* a, T* y) const;

Y
Yu Yang 已提交
187 188 189 190 191 192 193 194 195 196 197 198 199 200
 private:
  const DeviceContext& context_;
};

template <typename DeviceContext, typename T>
class BlasT : private Blas<DeviceContext> {
 public:
  using Blas<DeviceContext>::Blas;

  template <typename... ARGS>
  void GEMM(ARGS... args) const {
    Base()->template GEMM<T>(args...);
  }

T
tensor-tang 已提交
201
#ifdef PADDLE_WITH_MKLML
T
tensor-tang 已提交
202 203
  template <typename... ARGS>
  T* GEMM_ALLOC(ARGS... args) const {
T
tensor-tang 已提交
204
    return Base()->template GEMM_ALLOC<T>(args...);
T
tensor-tang 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
  }

  template <typename... ARGS>
  void GEMM_PACK(ARGS... args) const {
    Base()->template GEMM_PACK<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_COMPUTE(ARGS... args) const {
    Base()->template GEMM_COMPUTE<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_FREE(ARGS... args) const {
    Base()->template GEMM_FREE<T>(args...);
  }
T
tensor-tang 已提交
221
#endif
T
tensor-tang 已提交
222

Y
Yu Yang 已提交
223 224 225 226 227 228 229 230 231 232
  template <typename... ARGS>
  void MatMul(ARGS... args) const {
    Base()->template MatMul<T>(args...);
  }

  template <typename... ARGS>
  void AXPY(ARGS... args) const {
    Base()->template AXPY<T>(args...);
  }

233 234 235 236 237
  template <typename... ARGS>
  void VADD(ARGS... args) const {
    Base()->template VADD<T>(args...);
  }

T
tensor-tang 已提交
238 239 240 241 242
  template <typename... ARGS>
  void VMUL(ARGS... args) const {
    Base()->template VMUL<T>(args...);
  }

243 244 245 246 247
  template <typename... ARGS>
  void VCOPY(ARGS... args) const {
    Base()->template VCOPY<T>(args...);
  }

T
tensor-tang 已提交
248 249 250 251 252
  template <typename... ARGS>
  void VEXP(ARGS... args) const {
    Base()->template VEXP<T>(args...);
  }

T
tensor-tang 已提交
253
  template <typename... ARGS>
T
tensor-tang 已提交
254 255
  void VSQUARE(ARGS... args) const {
    Base()->template VSQUARE<T>(args...);
T
tensor-tang 已提交
256 257 258 259 260 261 262
  }

  template <typename... ARGS>
  void VPOW(ARGS... args) const {
    Base()->template VPOW<T>(args...);
  }

Y
Yu Yang 已提交
263 264 265 266 267
  template <typename... ARGS>
  void GEMV(ARGS... args) const {
    Base()->template GEMV<T>(args...);
  }

T
tensor-tang 已提交
268 269 270 271 272 273 274 275 276 277
  template <typename... ARGS>
  T DOT(ARGS... args) const {
    return Base()->template DOT<T>(args...);
  }

  template <typename... ARGS>
  void SCAL(ARGS... args) const {
    Base()->template SCAL<T>(args...);
  }

J
Jacek Czaja 已提交
278 279 280 281 282
  template <typename... ARGS>
  T ASUM(ARGS... args) const {
    return Base()->template ASUM<T>(args...);
  }

Y
Yu Yang 已提交
283 284 285 286 287
  template <typename... ARGS>
  void BatchedGEMM(ARGS... args) const {
    Base()->template BatchedGEMM<T>(args...);
  }

Y
Use mkl  
Yu Yang 已提交
288 289 290 291 292
  template <typename... ARGS>
  void VINV(ARGS... args) const {
    Base()->template VINV<T>(args...);
  }

Y
Yu Yang 已提交
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
 private:
  const Blas<DeviceContext>* Base() const {
    return static_cast<const Blas<DeviceContext>*>(this);
  }
};

template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
    const framework::ExecutionContext& exe_ctx) {
  return BlasT<DeviceContext, T>(
      exe_ctx.template device_context<DeviceContext>());
}

template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
  return BlasT<DeviceContext, T>(dev_ctx);
}

}  // namespace math
}  // namespace operators
}  // namespace paddle

#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif