提交 9ade63e6 编写于 作者: X xutianbing

clean code a little bit.

上级 171eaff2
...@@ -56,7 +56,16 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -56,7 +56,16 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
/// todo(tianbing), clean the code /// todo(tianbing), clean the code
CHECK(!out.isTransposed()) << "Not supported"; CHECK(!out.isTransposed()) << "Not supported";
CHECK_EQ(out.getValueType(), FLOAT_VALUE); CHECK_EQ(out.getValueType(), FLOAT_VALUE);
CHECK(!a.isTransposed() || !b.isTransposed())
<< "Not support both a and b are transpose matrices";
if (!a.isTransposed() && b.isTransposed()) {
CHECK(out.getFormat() != SPARSE_CSC)
<< "Not supported CSC format when a is not trans and b is trans";
}
if (scaleT == 0) {
out.zeroMem();
}
const real* A = a.getData(); const real* A = a.getData();
const real* B = b.getData(); const real* B = b.getData();
real* C = out.getValue(); real* C = out.getValue();
...@@ -64,15 +73,11 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -64,15 +73,11 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
int* cols = out.getCols(); int* cols = out.getCols();
size_t height = out.getHeight(); size_t height = out.getHeight();
size_t width = out.getWidth(); size_t width = out.getWidth();
if (scaleT == 0) {
out.zeroMem();
}
if (!a.isTransposed() && !b.isTransposed()) { if (!a.isTransposed() && !b.isTransposed()) {
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == height &&
b.getWidth() == width);
size_t m = a.getWidth(); size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), m);
CHECK_EQ(a.getHeight(), height);
CHECK_EQ(b.getWidth(), width);
if (out.getFormat() == SPARSE_CSC) { if (out.getFormat() == SPARSE_CSC) {
for (size_t i = 0; i < width; i++) { for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i); size_t start = out.getColStartIdx(i);
...@@ -86,26 +91,27 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -86,26 +91,27 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
C[j] = scaleAB * sum + scaleT * C[j]; C[j] = scaleAB * sum + scaleT * C[j];
} }
} }
} else { } else { /// out.getFormat() == SPARSE_CSR
for (size_t i = 0; i < height; i++) { for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i); size_t start = out.getRowStartIdx(i);
size_t end = out.getRowStartIdx(i + 1); size_t end = out.getRowStartIdx(i + 1);
for (size_t j = start; j < end; j++) { for (size_t j = start; j < end; j++) {
real sum = 0; real sum = 0;
size_t colIdx = cols[j]; size_t colIdx = cols[j];
for (size_t k = 0; k < m; k++) { for (size_t k = 0; k < a.getWidth(); k++) {
sum += A[i * m + k] * B[k * width + colIdx]; sum += A[i * m + k] * B[k * width + colIdx];
} }
C[j] = scaleAB * sum + scaleT * C[j]; C[j] = scaleAB * sum + scaleT * C[j];
} }
} }
} }
} else if (a.isTransposed() && !b.isTransposed()) { return;
size_t m = a.getHeight(); }
CHECK_EQ(m, b.getHeight());
CHECK_EQ(b.getWidth(), width);
CHECK_EQ(a.getWidth(), height);
if (a.isTransposed() && !b.isTransposed()) {
CHECK(a.getHeight() == b.getHeight() && b.getWidth() == width &&
a.getWidth() == height);
size_t m = a.getHeight();
if (out.getFormat() == SPARSE_CSC) { if (out.getFormat() == SPARSE_CSC) {
for (size_t i = 0; i < width; i++) { for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i); size_t start = out.getColStartIdx(i);
...@@ -119,25 +125,27 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -119,25 +125,27 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
C[j] = scaleAB * sum + scaleT * C[j]; C[j] = scaleAB * sum + scaleT * C[j];
} }
} }
} else { } else { /// out.getFormat() == SPARSE_CSR
for (size_t i = 0; i < height; i++) { for (size_t i = 0; i < height; i++) {
int start = out.getRowStartIdx(i); int start = out.getRowStartIdx(i);
int end = out.getRowStartIdx(i + 1); int end = out.getRowStartIdx(i + 1);
for (int j = start; j < end; j++) { for (int j = start; j < end; j++) {
real sum = 0; real sum = 0;
size_t colIdx = cols[j]; size_t colIdx = cols[j];
for (size_t k = 0; k < m; k++) { for (size_t k = 0; k < a.getHeight(); k++) {
sum += A[k * height + i] * B[k * width + colIdx]; sum += A[k * height + i] * B[k * width + colIdx];
} }
C[j] = scaleAB * sum + scaleT * C[j]; C[j] = scaleAB * sum + scaleT * C[j];
} }
} }
} }
} else if (!a.isTransposed() && b.isTransposed()) { return;
}
if (!a.isTransposed() && b.isTransposed()) {
CHECK(b.getWidth() == a.getWidth() && a.getHeight() == height &&
b.getHeight() == width);
size_t m = a.getWidth(); size_t m = a.getWidth();
CHECK_EQ(b.getWidth(), m);
CHECK_EQ(a.getHeight(), height);
CHECK_EQ(b.getHeight(), width);
if (out.getFormat() == SPARSE_CSR) { if (out.getFormat() == SPARSE_CSR) {
for (size_t i = 0; i < height; i++) { for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i); size_t start = out.getRowStartIdx(i);
...@@ -151,12 +159,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -151,12 +159,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
C[j] = scaleAB * sum + scaleT * C[j]; C[j] = scaleAB * sum + scaleT * C[j];
} }
} }
} else {
LOG(FATAL) << "Not supported csc format "
"when a is not trans and b is trans";
} }
} else { return;
LOG(FATAL) << "Not supported";
} }
} }
...@@ -166,159 +170,75 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -166,159 +170,75 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuMatrix& b, const CpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT) {
/// todo(tianbing), clean the code CHECK(!out.isTransposed()) << "out matrix transpose not supported";
CHECK(!out.isTransposed()) << "Not supported"; CBLAS_TRANSPOSE aTrans = a.isTransposed() ? CblasTrans : CblasNoTrans;
CBLAS_TRANSPOSE aTrans = CblasNoTrans; size_t aRow = a.isTransposed() ? a.getWidth() : a.getHeight();
size_t aRow = a.getHeight(); size_t aCol = a.isTransposed() ? a.getHeight() : a.getWidth();
size_t aCol = a.getWidth(); CBLAS_TRANSPOSE bTrans = b.isTransposed() ? CblasTrans : CblasNoTrans;
CBLAS_TRANSPOSE bTrans = CblasNoTrans; size_t bRow = b.isTransposed() ? b.getWidth() : b.getHeight();
size_t bRow = b.getHeight(); size_t bCol = b.isTransposed() ? b.getHeight() : b.getWidth();
size_t bCol = b.getWidth();
if (a.isTransposed()) {
aTrans = CblasTrans;
aRow = a.getWidth();
aCol = a.getHeight();
}
if (b.isTransposed()) {
bTrans = CblasTrans;
bRow = b.getWidth();
bCol = b.getHeight();
}
/// C = A * B, for matrix format /// C = A * B, for matrix format
CHECK_EQ(aCol, bRow); CHECK_EQ(aCol, bRow);
CHECK_EQ(aRow, out.getHeight()); CHECK_EQ(aRow, out.getHeight());
CHECK_EQ(bCol, out.getWidth()); CHECK_EQ(bCol, out.getWidth());
const real* A = a.getData(); GEMM(aTrans,
const real* B = b.getData(); bTrans,
real* C = out.getData(); out.getHeight(),
out.getWidth(),
int M = out.getHeight(); aCol,
int N = out.getWidth(); scaleAB,
int K = aCol; a.getData(),
int lda = a.getStride(); a.getStride(),
int ldb = b.getStride(); b.getData(),
int ldc = out.getStride(); b.getStride(),
scaleT,
GEMM(aTrans, bTrans, M, N, K, scaleAB, A, lda, B, ldb, scaleT, C, ldc); out.getData(),
out.getStride());
VLOG(2) << " A[0]=" << A[0] << " A[1]=" << A[1] << " B[0]=" << B[0]
<< " B[1]=" << B[1] << " C[0]=" << C[0] << " C[1]=" << C[1];
} }
static ThreadLocal<std::vector<const real*>> threadLocalColArray;
template <> template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuSparseMatrix& a, const CpuSparseMatrix& a,
const CpuMatrix& b, const CpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT) {
/// todo(tianbing), clean the code
CHECK(!out.isTransposed()) << "Not supported"; CHECK(!out.isTransposed()) << "Not supported";
CHECK(!b.isTransposed()) << "Not supported"; CHECK(!b.isTransposed()) << "Not supported";
CHECK(scaleT == 0 || scaleT == 1) << "Not support"; CHECK(scaleT == 0 || scaleT == 1) << "Not support";
CHECK_EQ(scaleAB, static_cast<real>(1.0)) << "Not supported"; CHECK_EQ(scaleAB, static_cast<real>(1.0)) << "Not supported";
CHECK_EQ(a.getFormat(), SPARSE_CSR) << "Not supported"; CHECK_EQ(a.getFormat(), SPARSE_CSR) << "Not supported";
const real* B = b.getData(); if (!a.isTransposed()) {
real* C = out.getData(); CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() &&
size_t height = out.getHeight(); b.getWidth() == out.getWidth());
size_t width = out.getWidth(); } else {
int* cols = a.getCols(); CHECK(b.getHeight() == a.getHeight() && a.getWidth() == out.getHeight() &&
real* values = a.getValue(); b.getWidth() == out.getWidth());
}
if (scaleT == 0) { if (scaleT == 0) {
out.zeroMem(); out.zeroMem();
} }
const real* B = b.getData();
real* C = out.getData();
if (out.getWidth() % 32 == 0) {
CHECK_EQ((size_t)B % 32, 0UL);
CHECK_EQ((size_t)C % 32, 0UL);
}
if (!a.isTransposed()) { int* cols = a.getCols();
size_t m = a.getWidth(); real* values = a.getValue();
CHECK_EQ(b.getHeight(), m); for (size_t i = 0; i < a.getHeight(); ++i) {
CHECK_EQ(a.getHeight(), height); const int start = a.getRowStartIdx(i);
CHECK_EQ(b.getWidth(), width); const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
if (a.getValueType() == NO_VALUE) { vecAddTo(!a.isTransposed() ? out.getRow(i) : out.getRow(cols[j]),
if (width % 32 == 0) { // use libaddto !a.isTransposed() ? const_cast<CpuMatrix&>(b).getRow(cols[j])
CHECK_EQ((size_t)B % 32, 0UL); : const_cast<CpuMatrix&>(b).getRow(i),
CHECK_EQ((size_t)C % 32, 0UL); (a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0,
auto& colArray = *threadLocalColArray; out.getWidth());
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
size_t colNum = end - start;
colArray.resize(colNum);
for (int j = 0; j < end - start; ++j) {
colArray[j] = const_cast<CpuMatrix&>(b).getRow(cols[j + start]);
}
simd::batchAddTo(out.getRow(i), &colArray[0], colNum, width);
}
} else {
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
vecAddTo(out.getRow(i),
const_cast<CpuMatrix&>(b).getRow(cols[j]),
width);
}
}
}
} else if (a.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
vecAddTo(out.getRow(i),
const_cast<CpuMatrix&>(b).getRow(cols[j]),
values[j],
width);
}
}
}
} else /*if (a->isTransposed())*/ {
size_t m = a.getHeight();
CHECK_EQ(b.getHeight(), m);
CHECK_EQ(a.getWidth(), height);
CHECK_EQ(b.getWidth(), width);
if (a.getValueType() == NO_VALUE) {
if (width % 32 == 0) { // use libaddto
CHECK_EQ((size_t)B % 32, 0UL);
CHECK_EQ((size_t)C % 32, 0UL);
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
simd::addTo(out.getRow(cols[j]),
const_cast<CpuMatrix&>(b).getRow(i),
width);
}
}
} else {
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
vecAddTo(out.getRow(cols[j]),
const_cast<CpuMatrix&>(b).getRow(i),
width);
}
}
}
} else if (a.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
vecAddTo(out.getRow(cols[j]),
const_cast<CpuMatrix&>(b).getRow(i),
values[j],
width);
}
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册