diff --git a/paddle/function/EigenGemm.cpp b/paddle/function/EigenGemm.cpp index b3e666e860d29d89650d48a23cf44917035a02d7..644098a9e7873fb59b6343e805163e4892f060a8 100644 --- a/paddle/function/EigenGemm.cpp +++ b/paddle/function/EigenGemm.cpp @@ -21,7 +21,7 @@ template struct EigenBlasGemm { typedef Eigen::TensorMap, Eigen::Aligned> - Matrix; + EigenMatrix; static void compute(const bool transA, const bool transB, @@ -56,14 +56,13 @@ struct EigenBlasGemm { sizeB[1] = N; CHECK_EQ(N, ldb); } - Eigen::array sizeC; - sizeC[0] = M; - sizeC[1] = N; - CHECK_EQ(N, ldc); + Eigen::array sizeC = {{M, ldc}}; + Eigen::array offsetC = {{0, 0}}; + Eigen::array extentC = {{M, N}}; - const Matrix a(const_cast(A), sizeA); - const Matrix b(const_cast(B), sizeB); - Matrix c(C, sizeC); + const EigenMatrix a(const_cast(A), sizeA); + const EigenMatrix b(const_cast(B), sizeB); + EigenMatrix c(C, sizeC); typedef typename Eigen::Tensor::DimensionPair DimPair; Eigen::array dims; @@ -72,12 +71,23 @@ struct EigenBlasGemm { dims[0].second = transB ? 1 : 0; Eigen::DefaultDevice device; - if (alpha == T(1) && beta == T(0)) { - c.device(device) = a.contract(b, dims); - } else if (alpha == T(1) && beta == T(1)) { - c.device(device) += a.contract(b, dims); + if (N == ldc) { + if (alpha == T(1) && beta == T(0)) { + c.device(device) = a.contract(b, dims); + } else if (alpha == T(1) && beta == T(1)) { + c.device(device) += a.contract(b, dims); + } else { + c.device(device) = alpha * a.contract(b, dims) + beta * c; + } } else { - c.device(device) = alpha * a.contract(b, dims) + beta * c; + if (alpha == T(1) && beta == T(0)) { + c.slice(offsetC, extentC).device(device) = a.contract(b, dims); + } else if (alpha == T(1) && beta == T(1)) { + c.slice(offsetC, extentC).device(device) += a.contract(b, dims); + } else { + c.slice(offsetC, extentC).device(device) = + alpha * a.contract(b, dims) + beta * c.slice(offsetC, extentC); + } } } };