From 42708ded549cf4c731abd75df8e7b3ef797a4052 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 1 Dec 2017 13:04:08 +0800 Subject: [PATCH] Enable the case N != ldc in EigenBlasGemm. (#5976) * Enable the case N != ldc in EigenBlasGemm. * Use MemoryHandle instead of direct calling of posix_memalign to alloc temporary memory. * Use Eigen's slice() instead of a temporary memory. * Add if-else for different cases in EigenBlasGemm (for N ?= ldc). --- paddle/function/EigenGemm.cpp | 36 ++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/paddle/function/EigenGemm.cpp b/paddle/function/EigenGemm.cpp index b3e666e860..644098a9e7 100644 --- a/paddle/function/EigenGemm.cpp +++ b/paddle/function/EigenGemm.cpp @@ -21,7 +21,7 @@ template struct EigenBlasGemm { typedef Eigen::TensorMap, Eigen::Aligned> - Matrix; + EigenMatrix; static void compute(const bool transA, const bool transB, @@ -56,14 +56,13 @@ struct EigenBlasGemm { sizeB[1] = N; CHECK_EQ(N, ldb); } - Eigen::array sizeC; - sizeC[0] = M; - sizeC[1] = N; - CHECK_EQ(N, ldc); + Eigen::array sizeC = {{M, ldc}}; + Eigen::array offsetC = {{0, 0}}; + Eigen::array extentC = {{M, N}}; - const Matrix a(const_cast(A), sizeA); - const Matrix b(const_cast(B), sizeB); - Matrix c(C, sizeC); + const EigenMatrix a(const_cast(A), sizeA); + const EigenMatrix b(const_cast(B), sizeB); + EigenMatrix c(C, sizeC); typedef typename Eigen::Tensor::DimensionPair DimPair; Eigen::array dims; @@ -72,12 +71,23 @@ struct EigenBlasGemm { dims[0].second = transB ? 1 : 0; Eigen::DefaultDevice device; - if (alpha == T(1) && beta == T(0)) { - c.device(device) = a.contract(b, dims); - } else if (alpha == T(1) && beta == T(1)) { - c.device(device) += a.contract(b, dims); + if (N == ldc) { + if (alpha == T(1) && beta == T(0)) { + c.device(device) = a.contract(b, dims); + } else if (alpha == T(1) && beta == T(1)) { + c.device(device) += a.contract(b, dims); + } else { + c.device(device) = alpha * a.contract(b, dims) + beta * c; + } } else { - c.device(device) = alpha * a.contract(b, dims) + beta * c; + if (alpha == T(1) && beta == T(0)) { + c.slice(offsetC, extentC).device(device) = a.contract(b, dims); + } else if (alpha == T(1) && beta == T(1)) { + c.slice(offsetC, extentC).device(device) += a.contract(b, dims); + } else { + c.slice(offsetC, extentC).device(device) = + alpha * a.contract(b, dims) + beta * c.slice(offsetC, extentC); + } } } }; -- GitLab