blas.h 8.9 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"

#ifdef PADDLE_WITH_MKLML
21
#include "paddle/fluid/platform/dynload/mklml.h"
Y
Yu Yang 已提交
22 23
#endif

T
tensor-tang 已提交
24 25 26 27
#ifdef PADDLE_WITH_LIBXSMM
#include <libxsmm.h>
#endif

Y
Yu Yang 已提交
28 29 30 31 32 33 34 35
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#endif

namespace paddle {
namespace operators {
namespace math {

Y
Yu Yang 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
/**
 * Matrix Descriptor of a memory buffer.
 *
 * It is used for Blas::MatMul. MatMul operator can be batched.
 * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
 * `batch_size` times of GEMM. The batched GEMM could be faster base on the
 * implementation of the blas library. The batch size could be zero. If any
 * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
 * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
 * [BatchSize, H1, W2]
 *
 * The boolean flag, `trans`, describe the memory is the transpose of matrix or
 * not. If the trans is true, the last two dims of matrix are transposed. The
 * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
 *
 * The MatDescriptor is not only the dimension or shape of a matrix, it also
 * contains the layout, stride of matrix. It is clearer to have a structure than
 * reuse `DDim`.
 */
Y
Yu Yang 已提交
55
struct MatDescriptor {
Y
Yu Yang 已提交
56 57 58 59 60 61 62
  int64_t height_;
  int64_t width_;
  int64_t stride_{0};
  int64_t batch_size_{0};
  bool trans_;
};

Y
Yu Yang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
/**
 * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
 * flag
 *
 * @param tensor_dim: The dimension of the tensor. The rank of this dimension
 * must larger than 1.
 *
 * @param num_flatten_cols:  Reshape a tensor to a matrix. The matrix's first
 * dimension(column length) will be the product of tensor's first `num_col_dims`
 * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
 * batch_size of descriptor.
 *
 * @param trans: True if the matrix is transposed.
 */
extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim,
                                            int num_flatten_cols, bool trans);
Y
Yu Yang 已提交
79

Y
Yu Yang 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92
template <typename DeviceContext>
class Blas {
 public:
  explicit Blas(const DeviceContext& context) : context_(context) {}

  template <typename T>
  void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
            T alpha, const T* A, const T* B, T beta, T* C) const;

  template <typename T>
  void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
            int lda, const T* B, int ldb, T beta, T* C, int ldc) const;

T
tensor-tang 已提交
93 94 95 96 97
  template <typename T>
  void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
            T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
            int ldc) const;

T
tensor-tang 已提交
98
#ifdef PADDLE_WITH_MKLML
T
tensor-tang 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
  template <typename T>
  T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
                const int K) const;

  template <typename T>
  void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M,
                 int N, int K, const T alpha, const T* src, const int ld,
                 T* dst) const;

  template <typename T>
  void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A,
                    const int lda, const T* B, const int ldb, T beta, T* C,
                    const int ldc) const;

  template <typename T>
  void GEMM_FREE(T* data) const;
T
tensor-tang 已提交
115
#endif
T
tensor-tang 已提交
116

T
tensor-tang 已提交
117 118 119 120
  template <typename T>
  void MatMul(const int M, const int N, const int K, const T* A, const T* B,
              T* C) const;

Y
Yu Yang 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
  template <typename T>
  void MatMul(const framework::Tensor& mat_a, bool trans_a,
              const framework::Tensor& mat_b, bool trans_b, T alpha,
              framework::Tensor* mat_out, T beta) const;

  template <typename T>
  void MatMul(const framework::Tensor& mat_a, bool trans_a,
              const framework::Tensor& mat_b, bool trans_b,
              framework::Tensor* mat_out) const {
    MatMul(mat_a, trans_a, mat_b, trans_b, static_cast<T>(1.0), mat_out,
           static_cast<T>(0.0));
  }

  template <typename T>
  void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b,
              framework::Tensor* mat_out) const {
    this->template MatMul<T>(mat_a, false, mat_b, false, mat_out);
  }

  template <typename T>
  void AXPY(int n, T alpha, const T* x, T* y) const;

143 144 145
  template <typename T>
  void VADD(int n, const T* x, const T* y, T* z) const;

T
tensor-tang 已提交
146 147 148
  template <typename T>
  void VMUL(int n, const T* x, const T* y, T* z) const;

149 150 151
  template <typename T>
  void VCOPY(int n, const T* x, T* y) const;

Y
Yu Yang 已提交
152
  template <typename T>
T
tensor-tang 已提交
153 154
  void VEXP(int n, const T* x, T* y) const;

T
tensor-tang 已提交
155
  template <typename T>
T
tensor-tang 已提交
156
  void VSQUARE(int n, const T* x, T* y) const;
T
tensor-tang 已提交
157 158 159 160

  template <typename T>
  void VPOW(int n, const T* x, T alpha, T* y) const;

T
tensor-tang 已提交
161
  template <typename T>
Y
Yu Yang 已提交
162 163 164
  void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
            T* C) const;

T
tensor-tang 已提交
165 166 167
  template <typename T>
  T DOT(int n, const T* x, const T* y) const;

T
tensor-tang 已提交
168
  template <typename T>
T
tensor-tang 已提交
169
  void SCAL(int n, const T a, T* x) const;
T
tensor-tang 已提交
170

Y
Yu Yang 已提交
171 172 173 174 175
  template <typename T>
  void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
                   int K, T alpha, const T* A, const T* B, T beta, T* C,
                   int batchCount, int64_t strideA, int64_t strideB) const;

Y
Yu Yang 已提交
176
  template <typename T>
Y
Yu Yang 已提交
177 178 179
  void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
              const framework::Tensor& mat_b, const MatDescriptor& dim_b,
              T alpha, framework::Tensor* mat_out, T beta) const;
Y
Yu Yang 已提交
180

Y
Yu Yang 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194
 private:
  const DeviceContext& context_;
};

template <typename DeviceContext, typename T>
class BlasT : private Blas<DeviceContext> {
 public:
  using Blas<DeviceContext>::Blas;

  template <typename... ARGS>
  void GEMM(ARGS... args) const {
    Base()->template GEMM<T>(args...);
  }

T
tensor-tang 已提交
195
#ifdef PADDLE_WITH_MKLML
T
tensor-tang 已提交
196 197
  template <typename... ARGS>
  T* GEMM_ALLOC(ARGS... args) const {
T
tensor-tang 已提交
198
    return Base()->template GEMM_ALLOC<T>(args...);
T
tensor-tang 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
  }

  template <typename... ARGS>
  void GEMM_PACK(ARGS... args) const {
    Base()->template GEMM_PACK<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_COMPUTE(ARGS... args) const {
    Base()->template GEMM_COMPUTE<T>(args...);
  }

  template <typename... ARGS>
  void GEMM_FREE(ARGS... args) const {
    Base()->template GEMM_FREE<T>(args...);
  }
T
tensor-tang 已提交
215
#endif
T
tensor-tang 已提交
216

Y
Yu Yang 已提交
217 218 219 220 221 222 223 224 225 226
  template <typename... ARGS>
  void MatMul(ARGS... args) const {
    Base()->template MatMul<T>(args...);
  }

  template <typename... ARGS>
  void AXPY(ARGS... args) const {
    Base()->template AXPY<T>(args...);
  }

227 228 229 230 231
  template <typename... ARGS>
  void VADD(ARGS... args) const {
    Base()->template VADD<T>(args...);
  }

T
tensor-tang 已提交
232 233 234 235 236
  template <typename... ARGS>
  void VMUL(ARGS... args) const {
    Base()->template VMUL<T>(args...);
  }

237 238 239 240 241
  template <typename... ARGS>
  void VCOPY(ARGS... args) const {
    Base()->template VCOPY<T>(args...);
  }

T
tensor-tang 已提交
242 243 244 245 246
  template <typename... ARGS>
  void VEXP(ARGS... args) const {
    Base()->template VEXP<T>(args...);
  }

T
tensor-tang 已提交
247
  template <typename... ARGS>
T
tensor-tang 已提交
248 249
  void VSQUARE(ARGS... args) const {
    Base()->template VSQUARE<T>(args...);
T
tensor-tang 已提交
250 251 252 253 254 255 256
  }

  template <typename... ARGS>
  void VPOW(ARGS... args) const {
    Base()->template VPOW<T>(args...);
  }

Y
Yu Yang 已提交
257 258 259 260 261
  template <typename... ARGS>
  void GEMV(ARGS... args) const {
    Base()->template GEMV<T>(args...);
  }

T
tensor-tang 已提交
262 263 264 265 266 267 268 269 270 271
  template <typename... ARGS>
  T DOT(ARGS... args) const {
    return Base()->template DOT<T>(args...);
  }

  template <typename... ARGS>
  void SCAL(ARGS... args) const {
    Base()->template SCAL<T>(args...);
  }

Y
Yu Yang 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
  template <typename... ARGS>
  void BatchedGEMM(ARGS... args) const {
    Base()->template BatchedGEMM<T>(args...);
  }

 private:
  const Blas<DeviceContext>* Base() const {
    return static_cast<const Blas<DeviceContext>*>(this);
  }
};

template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
    const framework::ExecutionContext& exe_ctx) {
  return BlasT<DeviceContext, T>(
      exe_ctx.template device_context<DeviceContext>());
}

template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
  return BlasT<DeviceContext, T>(dev_ctx);
}

}  // namespace math
}  // namespace operators
}  // namespace paddle

#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif