// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #ifdef PADDLE_WITH_MKLML #include "paddle/fluid/platform/dynload/mklml.h" #endif #ifdef PADDLE_WITH_LIBXSMM #include #endif #ifdef PADDLE_USE_OPENBLAS #include #endif namespace paddle { namespace operators { namespace math { /** * Matrix Descriptor of a memory buffer. * * It is used for Blas::MatMul. MatMul operator can be batched. * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a * `batch_size` times of GEMM. The batched GEMM could be faster base on the * implementation of the blas library. The batch size could be zero. If any * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g., * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be * [BatchSize, H1, W2] * * The boolean flag, `trans`, describe the memory is the transpose of matrix or * not. If the trans is true, the last two dims of matrix are transposed. The * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height]. * * The MatDescriptor is not only the dimension or shape of a matrix, it also * contains the layout, stride of matrix. It is clearer to have a structure than * reuse `DDim`. */ struct MatDescriptor { int64_t height_; int64_t width_; int64_t stride_{0}; int64_t batch_size_{0}; bool trans_; }; /** * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose * flag * * @param tensor_dim: The dimension of the tensor. The rank of this dimension * must larger than 1. * * @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first * dimension(column length) will be the product of tensor's first `num_col_dims` * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the * batch_size of descriptor. * * @param trans: True if the matrix is transposed. */ extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim, int num_flatten_cols, bool trans); template class Blas { public: explicit Blas(const DeviceContext& context) : context_(context) {} template void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C) const; template void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; template void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc) const; #ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, const int K) const; template void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M, int N, int K, const T alpha, const T* src, const int ld, T* dst) const; template void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A, const int lda, const T* B, const int ldb, T beta, T* C, const int ldc) const; template void GEMM_FREE(T* data) const; #endif template void MatMul(const int M, const int N, const int K, const T* A, const T* B, T* C) const; template void MatMul(const framework::Tensor& mat_a, bool trans_a, const framework::Tensor& mat_b, bool trans_b, T alpha, framework::Tensor* mat_out, T beta) const; template void MatMul(const framework::Tensor& mat_a, bool trans_a, const framework::Tensor& mat_b, bool trans_b, framework::Tensor* mat_out) const { MatMul(mat_a, trans_a, mat_b, trans_b, static_cast(1.0), mat_out, static_cast(0.0)); } template void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b, framework::Tensor* mat_out) const { this->template MatMul(mat_a, false, mat_b, false, mat_out); } template void AXPY(int n, T alpha, const T* x, T* y) const; template void VADD(int n, const T* x, const T* y, T* z) const; template void VMUL(int n, const T* x, const T* y, T* z) const; template void VCOPY(int n, const T* x, T* y) const; template void VEXP(int n, const T* x, T* y) const; template void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; template T DOT(int n, const T* x, const T* y) const; template void SCAL(int n, const T a, T* x) const; template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T* A, const T* B, T beta, T* C, int batchCount, int64_t strideA, int64_t strideB) const; template void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a, const framework::Tensor& mat_b, const MatDescriptor& dim_b, T alpha, framework::Tensor* mat_out, T beta) const; private: const DeviceContext& context_; }; template class BlasT : private Blas { public: using Blas::Blas; template void GEMM(ARGS... args) const { Base()->template GEMM(args...); } #ifdef PADDLE_WITH_MKLML template T* GEMM_ALLOC(ARGS... args) const { return Base()->template GEMM_ALLOC(args...); } template void GEMM_PACK(ARGS... args) const { Base()->template GEMM_PACK(args...); } template void GEMM_COMPUTE(ARGS... args) const { Base()->template GEMM_COMPUTE(args...); } template void GEMM_FREE(ARGS... args) const { Base()->template GEMM_FREE(args...); } #endif template void MatMul(ARGS... args) const { Base()->template MatMul(args...); } template void AXPY(ARGS... args) const { Base()->template AXPY(args...); } template void VADD(ARGS... args) const { Base()->template VADD(args...); } template void VMUL(ARGS... args) const { Base()->template VMUL(args...); } template void VCOPY(ARGS... args) const { Base()->template VCOPY(args...); } template void VEXP(ARGS... args) const { Base()->template VEXP(args...); } template void GEMV(ARGS... args) const { Base()->template GEMV(args...); } template T DOT(ARGS... args) const { return Base()->template DOT(args...); } template void SCAL(ARGS... args) const { Base()->template SCAL(args...); } template void BatchedGEMM(ARGS... args) const { Base()->template BatchedGEMM(args...); } private: const Blas* Base() const { return static_cast*>(this); } }; template inline BlasT GetBlas( const framework::ExecutionContext& exe_ctx) { return BlasT( exe_ctx.template device_context()); } template inline BlasT GetBlas(const DeviceContext& dev_ctx) { return BlasT(dev_ctx); } } // namespace math } // namespace operators } // namespace paddle #include "paddle/fluid/operators/math/blas_impl.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/operators/math/blas_impl.cu.h" #endif