未验证 提交 42708ded 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enable the case N != ldc in EigenBlasGemm. (#5976)

* Enable the case N != ldc in EigenBlasGemm.

* Use MemoryHandle instead of direct calling of posix_memalign to alloc temporary memory.

* Use Eigen's slice() instead of a temporary memory.

* Add if-else for different cases in EigenBlasGemm (for N ?= ldc).
上级 5f0d0818
...@@ -21,7 +21,7 @@ template <class T> ...@@ -21,7 +21,7 @@ template <class T>
struct EigenBlasGemm { struct EigenBlasGemm {
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, int>, typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, int>,
Eigen::Aligned> Eigen::Aligned>
Matrix; EigenMatrix;
static void compute(const bool transA, static void compute(const bool transA,
const bool transB, const bool transB,
...@@ -56,14 +56,13 @@ struct EigenBlasGemm { ...@@ -56,14 +56,13 @@ struct EigenBlasGemm {
sizeB[1] = N; sizeB[1] = N;
CHECK_EQ(N, ldb); CHECK_EQ(N, ldb);
} }
Eigen::array<int, 2> sizeC; Eigen::array<int, 2> sizeC = {{M, ldc}};
sizeC[0] = M; Eigen::array<int, 2> offsetC = {{0, 0}};
sizeC[1] = N; Eigen::array<int, 2> extentC = {{M, N}};
CHECK_EQ(N, ldc);
const Matrix a(const_cast<T*>(A), sizeA); const EigenMatrix a(const_cast<T*>(A), sizeA);
const Matrix b(const_cast<T*>(B), sizeB); const EigenMatrix b(const_cast<T*>(B), sizeB);
Matrix c(C, sizeC); EigenMatrix c(C, sizeC);
typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair; typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
Eigen::array<DimPair, 1> dims; Eigen::array<DimPair, 1> dims;
...@@ -72,6 +71,7 @@ struct EigenBlasGemm { ...@@ -72,6 +71,7 @@ struct EigenBlasGemm {
dims[0].second = transB ? 1 : 0; dims[0].second = transB ? 1 : 0;
Eigen::DefaultDevice device; Eigen::DefaultDevice device;
if (N == ldc) {
if (alpha == T(1) && beta == T(0)) { if (alpha == T(1) && beta == T(0)) {
c.device(device) = a.contract(b, dims); c.device(device) = a.contract(b, dims);
} else if (alpha == T(1) && beta == T(1)) { } else if (alpha == T(1) && beta == T(1)) {
...@@ -79,6 +79,16 @@ struct EigenBlasGemm { ...@@ -79,6 +79,16 @@ struct EigenBlasGemm {
} else { } else {
c.device(device) = alpha * a.contract(b, dims) + beta * c; c.device(device) = alpha * a.contract(b, dims) + beta * c;
} }
} else {
if (alpha == T(1) && beta == T(0)) {
c.slice(offsetC, extentC).device(device) = a.contract(b, dims);
} else if (alpha == T(1) && beta == T(1)) {
c.slice(offsetC, extentC).device(device) += a.contract(b, dims);
} else {
c.slice(offsetC, extentC).device(device) =
alpha * a.contract(b, dims) + beta * c.slice(offsetC, extentC);
}
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册