From 53b0e427092219b402f0ed6fab4235c3b70fdc7c Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 17 Aug 2017 16:19:59 +0800 Subject: [PATCH] Add EigenGemm. --- paddle/function/EigenGemm.cpp | 92 ++++++++++++++++++++++++++++++ paddle/function/GemmFunctor.cpp | 85 ++++++++++++++++++++++++++++ paddle/function/GemmFunctor.h | 99 +++++++++++---------------------- 3 files changed, 211 insertions(+), 65 deletions(-) create mode 100644 paddle/function/EigenGemm.cpp create mode 100644 paddle/function/GemmFunctor.cpp diff --git a/paddle/function/EigenGemm.cpp b/paddle/function/EigenGemm.cpp new file mode 100644 index 00000000000..0b4220fcbef --- /dev/null +++ b/paddle/function/EigenGemm.cpp @@ -0,0 +1,92 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { + +template +struct EigenBlasGemm { + typedef Eigen::TensorMap, + Eigen::Aligned> + Matrix; + + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + Eigen::array sizeA; + if (transA) { + sizeA[0] = K; + sizeA[1] = M; + CHECK_EQ(M, lda); + } else { + sizeA[0] = M; + sizeA[1] = K; + CHECK_EQ(K, lda); + } + Eigen::array sizeB; + if (transB) { + sizeB[0] = N; + sizeB[1] = K; + CHECK_EQ(K, ldb); + } else { + sizeB[0] = K; + sizeB[1] = N; + CHECK_EQ(N, ldb); + } + Eigen::array sizeC; + sizeC[0] = M; + sizeC[1] = N; + CHECK_EQ(N, ldc); + + const Matrix a(const_cast(A), sizeA); + const Matrix b(const_cast(B), sizeB); + Matrix c(C, sizeC); + + typedef typename Eigen::Tensor::DimensionPair DimPair; + Eigen::array dims; + dims[0] = DimPair(1, 0); + dims[0].first = transA ? 0 : 1; + dims[0].second = transB ? 1 : 0; + + Eigen::DefaultDevice device; + if (alpha == T(1) && beta == T(0)) { + c.device(device) = a.contract(b, dims); + } else if (alpha == T(1) && beta == T(1)) { + c.device(device) += a.contract(b, dims); + } else { + c.device(device) = + c.constant(alpha) * a.contract(b, dims) + c.constant(beta) * c; + } + } +}; + +#ifdef PADDLE_TYPE_DOUBLE +template class EigenBlasGemm; +#else +template class EigenBlasGemm; +#endif + +} // namespace paddle diff --git a/paddle/function/GemmFunctor.cpp b/paddle/function/GemmFunctor.cpp new file mode 100644 index 00000000000..8df9b884fed --- /dev/null +++ b/paddle/function/GemmFunctor.cpp @@ -0,0 +1,85 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "GemmFunctor.h" +#include "paddle/math/MathFunctions.h" + +namespace paddle { + +template +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + gemm(transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); + } +}; + +template +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + hl_matrix_mul((T*)A, + transA == false ? HPPL_OP_N : HPPL_OP_T, + (T*)B, + transB == false ? HPPL_OP_N : HPPL_OP_T, + C, + M, + N, + K, + alpha, + beta, + lda, + ldb, + ldc); + } +}; + +template class BlasGemm; +template class BlasGemm; + +} // namespace paddle diff --git a/paddle/function/GemmFunctor.h b/paddle/function/GemmFunctor.h index d5db5cf5e7a..0809953b4eb 100644 --- a/paddle/function/GemmFunctor.h +++ b/paddle/function/GemmFunctor.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/math/MathFunctions.h" +#include "TensorType.h" namespace paddle { @@ -24,73 +24,42 @@ namespace paddle { // of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul // interface. template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc); +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); }; +// TODO(hedaoyuan): Since the definition of the real type in the Paddle +// conflicts with the Eigen library, so compile the Eigen code can not +// include the Paddle header file. And need an EigenBlasGemm template class +// that does not contain the DeviceType parameter. +// I will fix this problem and merge BlasGemm and EigenBlasGemm into one. template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc) { - gemm(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); - } -}; - -template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc) { - hl_matrix_mul((T*)A, - transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, - (T*)B, - TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, - C, - M, - N, - K, - alpha, - beta, - lda, - ldb, - ldc); - } +struct EigenBlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); }; } // namespace paddle -- GitLab