// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { namespace math { template struct CBlas; #ifdef PADDLE_WITH_MKLML template <> struct CBlas { template static void GEMM(ARGS... args) { platform::dynload::cblas_sgemm(args...); } template static float *GEMM_ALLOC(ARGS... args) { return platform::dynload::cblas_sgemm_alloc(args...); } template static void GEMM_PACK(ARGS... args) { platform::dynload::cblas_sgemm_pack(args...); } template static void GEMM_COMPUTE(ARGS... args) { platform::dynload::cblas_sgemm_compute(args...); } template static void GEMM_FREE(ARGS... args) { platform::dynload::cblas_sgemm_free(args...); } #ifdef PADDLE_WITH_LIBXSMM template static void SMM_GEMM(ARGS... args) { libxsmm_sgemm(args...); } #endif template static void AXPY(ARGS... args) { platform::dynload::cblas_saxpy(args...); } template static void VCOPY(ARGS... args) { platform::dynload::cblas_scopy(args...); } template static void GEMV(ARGS... args) { platform::dynload::cblas_sgemv(args...); } template static float DOT(ARGS... args) { return platform::dynload::cblas_sdot(args...); } template static void SCAL(ARGS... args) { platform::dynload::cblas_sscal(args...); } template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_sgemm_batch(args...); } template static void VADD(ARGS... args) { platform::dynload::vsAdd(args...); } template static void VMUL(ARGS... args) { platform::dynload::vsMul(args...); } template static void VEXP(ARGS... args) { platform::dynload::vsExp(args...); } template static void VSQR(ARGS... args) { platform::dynload::vsSqr(args...); } template static void VPOW(ARGS... args) { platform::dynload::vsPowx(args...); } }; template <> struct CBlas { template static void GEMM(ARGS... args) { platform::dynload::cblas_dgemm(args...); } template static double *GEMM_ALLOC(ARGS... args) { return platform::dynload::cblas_dgemm_alloc(args...); } template static void GEMM_PACK(ARGS... args) { platform::dynload::cblas_dgemm_pack(args...); } template static void GEMM_COMPUTE(ARGS... args) { platform::dynload::cblas_dgemm_compute(args...); } template static void GEMM_FREE(ARGS... args) { platform::dynload::cblas_dgemm_free(args...); } #ifdef PADDLE_WITH_LIBXSMM template static void SMM_GEMM(ARGS... args) { libxsmm_dgemm(args...); } #endif template static void AXPY(ARGS... args) { platform::dynload::cblas_daxpy(args...); } template static void VCOPY(ARGS... args) { platform::dynload::cblas_dcopy(args...); } template static void GEMV(ARGS... args) { platform::dynload::cblas_dgemv(args...); } template static double DOT(ARGS... args) { return platform::dynload::cblas_ddot(args...); } template static void SCAL(ARGS... args) { platform::dynload::cblas_dscal(args...); } template static void GEMM_BATCH(ARGS... args) { platform::dynload::cblas_dgemm_batch(args...); } template static void VADD(ARGS... args) { platform::dynload::vdAdd(args...); } template static void VMUL(ARGS... args) { platform::dynload::vdMul(args...); } template static void VEXP(ARGS... args) { platform::dynload::vdExp(args...); } template static void VSQR(ARGS... args) { platform::dynload::vdSqr(args...); } template static void VPOW(ARGS... args) { platform::dynload::vdPowx(args...); } }; #else template <> struct CBlas { template static void GEMM(ARGS... args) { cblas_sgemm(args...); } template static void AXPY(ARGS... args) { cblas_saxpy(args...); } template static void VCOPY(ARGS... args) { cblas_scopy(args...); } template static void GEMV(ARGS... args) { cblas_sgemv(args...); } }; template <> struct CBlas { template static void GEMM(ARGS... args) { cblas_dgemm(args...); } template static void AXPY(ARGS... args) { cblas_daxpy(args...); } template static void VCOPY(ARGS... args) { cblas_dcopy(args...); } template static void GEMV(ARGS... args) { cblas_dgemv(args...); } }; #endif template <> struct CBlas { static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } static void SMM_GEMM(...) { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); } static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); } #endif }; #ifdef PADDLE_WITH_MKLML template <> template T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, const int K) const { return CBlas::GEMM_ALLOC(id, M, N, K); } template <> template void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M, int N, int K, const T alpha, const T *src, const int ld, T *dst) const { CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); } template <> template void Blas::GEMM_COMPUTE( int transA, int transB, int M, int N, int K, const T *A, const int lda, const T *B, const int ldb, T beta, T *C, const int ldc) const { CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, beta, C, ldc); } template <> template void Blas::GEMM_FREE(T *data) const { CBlas::GEMM_FREE(data); } #endif template <> template void Blas::GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T *A, const T *B, T beta, T *C) const { int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } template <> template void Blas::GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T *A, int lda, const T *B, int ldb, T beta, T *C, int ldc) const { CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } template <> template void Blas::GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T *A, int lda, const T *B, int ldb, T beta, T *C, int ldc) const { CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } template template void Blas::MatMul(const framework::Tensor &mat_a, bool trans_a, const framework::Tensor &mat_b, bool trans_b, T alpha, framework::Tensor *mat_out, T beta) const { auto dim_a = mat_a.dims(); auto dim_b = mat_b.dims(); auto dim_out = mat_out->dims(); PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, "The input and output of matmul be matrix"); PADDLE_ENFORCE( mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), "The places of matrices must be same"); int M = dim_out[0]; int N = dim_out[1]; int K = !trans_a ? dim_a[1] : dim_a[0]; CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = !trans_b ? CblasNoTrans : CblasTrans; this->GEMM(transA, transB, M, N, K, alpha, mat_a.data(), mat_b.data(), beta, mat_out->data()); } template <> template void Blas::AXPY(int n, T alpha, const T *x, T *y) const { CBlas::AXPY(n, alpha, x, 1, y, 1); } template <> template void Blas::VCOPY(int n, const T *x, T *y) const { CBlas::VCOPY(n, x, 1, y, 1); } template <> template void Blas::VADD(int n, const T *x, const T *y, T *z) const { #ifdef PADDLE_WITH_MKLML CBlas::VADD(n, x, y, z); #else this->template VCOPY(n, y, z); this->template AXPY(n, 1., x, z); #endif } template <> template void Blas::VMUL(int n, const T *x, const T *y, T *z) const { #ifdef PADDLE_WITH_MKLML CBlas::VMUL(n, x, y, z); #else // try to find if openblas support vmul for (int i = 0; i < n; ++i) { z[i] = x[i] * y[i]; } #endif } template <> template void Blas::VEXP(int n, const T *x, T *y) const { #ifdef PADDLE_WITH_MKLML CBlas::VEXP(n, x, y); #else // try to find if openblas support vexp for (int i = 0; i < n; ++i) { y[i] = std::exp(x[i]); } #endif } template <> template void Blas::VSQR(int n, const T *x, T *y) const { #ifdef PADDLE_WITH_MKLML CBlas::VSQR(n, x, y); #else for (int i = 0; i < n; ++i) { y[i] = std::sqrt(x[i]); } #endif } template <> template void Blas::VPOW(int n, const T *x, T a, T *y) const { #ifdef PADDLE_WITH_MKLML CBlas::VPOW(n, x, a, y); #else for (int i = 0; i < n; ++i) { y[i] = std::pow(x[i], a); } #endif } template <> template T Blas::DOT(int n, const T *x, const T *y) const { #ifdef PADDLE_WITH_MKLML return CBlas::DOT(n, x, 1, y, 1); #else // try to find if openblas support cblas_dot T sum = 0; for (int i = 0; i < n; ++i) { sum += x[i] * y[i]; } return sum; #endif } template <> template void Blas::SCAL(int n, const T a, T *x) const { #ifdef PADDLE_WITH_MKLML CBlas::SCAL(n, a, x, 1); #else // try to find if openblas support cblas_scal for (int i = 0; i < n; ++i) { x[i] = a * x[i]; } #endif } template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, const T *A, const T *B, T beta, T *C) const { CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); } template <> template void Blas::BatchedGEMM( CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T *A, const T *B, T beta, T *C, int batchCount, int64_t strideA, int64_t strideB) const { #ifdef PADDLE_WITH_MKLML int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; auto a_array = std::vector(batchCount); auto b_array = std::vector(batchCount); auto c_array = std::vector(batchCount); for (int k = 0; k < batchCount; ++k) { a_array[k] = &A[k * strideA]; b_array[k] = &B[k * strideB]; c_array[k] = &C[k * M * N]; } CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, a_array.data(), &lda, b_array.data(), &ldb, &beta, c_array.data(), &ldc, 1 /* group_count */, &batchCount); #else for (int k = 0; k < batchCount; ++k) { auto *Ak = &A[k * strideA]; auto *Bk = &B[k * strideB]; auto *Ck = &C[k * M * N]; this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); } #endif } template template void Blas::MatMul(const int M, const int N, const int K, const T *A, const T *B, T *C) const { this->template GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, static_cast(1), A, K, B, N, static_cast(0), C, N); } template <> template void Blas::MatMul(const int M, const int N, const int K, const T *A, const T *B, T *C) const { #ifdef PADDLE_WITH_LIBXSMM // Refer to https://github.com/hfp/libxsmm/blob/master/README.md // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; // Since the matrix is very small, // so the unit of calculation is already very fast, // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, // use xsmm directly. // Note: SMM use ColMajor const char transa = 'N'; const char transb = 'N'; const T alpha = static_cast(1); const T beta = static_cast(0); CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, C, &N); return; #endif CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, static_cast(1), A, K, B, N, static_cast(0), C, N); } template template void Blas::MatMul(const framework::Tensor &mat_a, const MatDescriptor &dim_a, const framework::Tensor &mat_b, const MatDescriptor &dim_b, T alpha, framework::Tensor *mat_out, T beta) const { PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_); CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a.data(), mat_b.data(), beta, mat_out->data()); } else { PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); this->template BatchedGEMM( transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a.data(), mat_b.data(), beta, mat_out->data(), dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, dim_a.stride_, dim_b.stride_); } } } // namespace math } // namespace operators } // namespace paddle