未验证 提交 42708ded 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enable the case N != ldc in EigenBlasGemm. (#5976)

* Enable the case N != ldc in EigenBlasGemm.

* Use MemoryHandle instead of direct calling of posix_memalign to alloc temporary memory.

* Use Eigen's slice() instead of a temporary memory.

* Add if-else for different cases in EigenBlasGemm (for N ?= ldc).
上级 5f0d0818
......@@ -21,7 +21,7 @@ template <class T>
struct EigenBlasGemm {
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, int>,
Eigen::Aligned>
Matrix;
EigenMatrix;
static void compute(const bool transA,
const bool transB,
......@@ -56,14 +56,13 @@ struct EigenBlasGemm {
sizeB[1] = N;
CHECK_EQ(N, ldb);
}
Eigen::array<int, 2> sizeC;
sizeC[0] = M;
sizeC[1] = N;
CHECK_EQ(N, ldc);
Eigen::array<int, 2> sizeC = {{M, ldc}};
Eigen::array<int, 2> offsetC = {{0, 0}};
Eigen::array<int, 2> extentC = {{M, N}};
const Matrix a(const_cast<T*>(A), sizeA);
const Matrix b(const_cast<T*>(B), sizeB);
Matrix c(C, sizeC);
const EigenMatrix a(const_cast<T*>(A), sizeA);
const EigenMatrix b(const_cast<T*>(B), sizeB);
EigenMatrix c(C, sizeC);
typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
Eigen::array<DimPair, 1> dims;
......@@ -72,12 +71,23 @@ struct EigenBlasGemm {
dims[0].second = transB ? 1 : 0;
Eigen::DefaultDevice device;
if (alpha == T(1) && beta == T(0)) {
c.device(device) = a.contract(b, dims);
} else if (alpha == T(1) && beta == T(1)) {
c.device(device) += a.contract(b, dims);
if (N == ldc) {
if (alpha == T(1) && beta == T(0)) {
c.device(device) = a.contract(b, dims);
} else if (alpha == T(1) && beta == T(1)) {
c.device(device) += a.contract(b, dims);
} else {
c.device(device) = alpha * a.contract(b, dims) + beta * c;
}
} else {
c.device(device) = alpha * a.contract(b, dims) + beta * c;
if (alpha == T(1) && beta == T(0)) {
c.slice(offsetC, extentC).device(device) = a.contract(b, dims);
} else if (alpha == T(1) && beta == T(1)) {
c.slice(offsetC, extentC).device(device) += a.contract(b, dims);
} else {
c.slice(offsetC, extentC).device(device) =
alpha * a.contract(b, dims) + beta * c.slice(offsetC, extentC);
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册