提交 9ade63e6 编写于 作者: X xutianbing

clean code a little bit.

上级 171eaff2
......@@ -56,7 +56,16 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
/// todo(tianbing), clean the code
CHECK(!out.isTransposed()) << "Not supported";
CHECK_EQ(out.getValueType(), FLOAT_VALUE);
CHECK(!a.isTransposed() || !b.isTransposed())
<< "Not support both a and b are transpose matrices";
if (!a.isTransposed() && b.isTransposed()) {
CHECK(out.getFormat() != SPARSE_CSC)
<< "Not supported CSC format when a is not trans and b is trans";
}
if (scaleT == 0) {
out.zeroMem();
}
const real* A = a.getData();
const real* B = b.getData();
real* C = out.getValue();
......@@ -64,15 +73,11 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
int* cols = out.getCols();
size_t height = out.getHeight();
size_t width = out.getWidth();
if (scaleT == 0) {
out.zeroMem();
}
if (!a.isTransposed() && !b.isTransposed()) {
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == height &&
b.getWidth() == width);
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), m);
CHECK_EQ(a.getHeight(), height);
CHECK_EQ(b.getWidth(), width);
if (out.getFormat() == SPARSE_CSC) {
for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i);
......@@ -86,26 +91,27 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
C[j] = scaleAB * sum + scaleT * C[j];
}
}
} else {
} else { /// out.getFormat() == SPARSE_CSR
for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i);
size_t end = out.getRowStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t colIdx = cols[j];
for (size_t k = 0; k < m; k++) {
for (size_t k = 0; k < a.getWidth(); k++) {
sum += A[i * m + k] * B[k * width + colIdx];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
}
} else if (a.isTransposed() && !b.isTransposed()) {
size_t m = a.getHeight();
CHECK_EQ(m, b.getHeight());
CHECK_EQ(b.getWidth(), width);
CHECK_EQ(a.getWidth(), height);
return;
}
if (a.isTransposed() && !b.isTransposed()) {
CHECK(a.getHeight() == b.getHeight() && b.getWidth() == width &&
a.getWidth() == height);
size_t m = a.getHeight();
if (out.getFormat() == SPARSE_CSC) {
for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i);
......@@ -119,25 +125,27 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
C[j] = scaleAB * sum + scaleT * C[j];
}
}
} else {
} else { /// out.getFormat() == SPARSE_CSR
for (size_t i = 0; i < height; i++) {
int start = out.getRowStartIdx(i);
int end = out.getRowStartIdx(i + 1);
for (int j = start; j < end; j++) {
real sum = 0;
size_t colIdx = cols[j];
for (size_t k = 0; k < m; k++) {
for (size_t k = 0; k < a.getHeight(); k++) {
sum += A[k * height + i] * B[k * width + colIdx];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
}
} else if (!a.isTransposed() && b.isTransposed()) {
return;
}
if (!a.isTransposed() && b.isTransposed()) {
CHECK(b.getWidth() == a.getWidth() && a.getHeight() == height &&
b.getHeight() == width);
size_t m = a.getWidth();
CHECK_EQ(b.getWidth(), m);
CHECK_EQ(a.getHeight(), height);
CHECK_EQ(b.getHeight(), width);
if (out.getFormat() == SPARSE_CSR) {
for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i);
......@@ -151,12 +159,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
C[j] = scaleAB * sum + scaleT * C[j];
}
}
} else {
LOG(FATAL) << "Not supported csc format "
"when a is not trans and b is trans";
}
} else {
LOG(FATAL) << "Not supported";
return;
}
}
......@@ -166,159 +170,75 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuMatrix& b,
real scaleAB,
real scaleT) {
/// todo(tianbing), clean the code
CHECK(!out.isTransposed()) << "Not supported";
CBLAS_TRANSPOSE aTrans = CblasNoTrans;
size_t aRow = a.getHeight();
size_t aCol = a.getWidth();
CBLAS_TRANSPOSE bTrans = CblasNoTrans;
size_t bRow = b.getHeight();
size_t bCol = b.getWidth();
if (a.isTransposed()) {
aTrans = CblasTrans;
aRow = a.getWidth();
aCol = a.getHeight();
}
if (b.isTransposed()) {
bTrans = CblasTrans;
bRow = b.getWidth();
bCol = b.getHeight();
}
CHECK(!out.isTransposed()) << "out matrix transpose not supported";
CBLAS_TRANSPOSE aTrans = a.isTransposed() ? CblasTrans : CblasNoTrans;
size_t aRow = a.isTransposed() ? a.getWidth() : a.getHeight();
size_t aCol = a.isTransposed() ? a.getHeight() : a.getWidth();
CBLAS_TRANSPOSE bTrans = b.isTransposed() ? CblasTrans : CblasNoTrans;
size_t bRow = b.isTransposed() ? b.getWidth() : b.getHeight();
size_t bCol = b.isTransposed() ? b.getHeight() : b.getWidth();
/// C = A * B, for matrix format
CHECK_EQ(aCol, bRow);
CHECK_EQ(aRow, out.getHeight());
CHECK_EQ(bCol, out.getWidth());
const real* A = a.getData();
const real* B = b.getData();
real* C = out.getData();
int M = out.getHeight();
int N = out.getWidth();
int K = aCol;
int lda = a.getStride();
int ldb = b.getStride();
int ldc = out.getStride();
GEMM(aTrans, bTrans, M, N, K, scaleAB, A, lda, B, ldb, scaleT, C, ldc);
VLOG(2) << " A[0]=" << A[0] << " A[1]=" << A[1] << " B[0]=" << B[0]
<< " B[1]=" << B[1] << " C[0]=" << C[0] << " C[1]=" << C[1];
GEMM(aTrans,
bTrans,
out.getHeight(),
out.getWidth(),
aCol,
scaleAB,
a.getData(),
a.getStride(),
b.getData(),
b.getStride(),
scaleT,
out.getData(),
out.getStride());
}
static ThreadLocal<std::vector<const real*>> threadLocalColArray;
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuSparseMatrix& a,
const CpuMatrix& b,
real scaleAB,
real scaleT) {
/// todo(tianbing), clean the code
CHECK(!out.isTransposed()) << "Not supported";
CHECK(!b.isTransposed()) << "Not supported";
CHECK(scaleT == 0 || scaleT == 1) << "Not support";
CHECK_EQ(scaleAB, static_cast<real>(1.0)) << "Not supported";
CHECK_EQ(a.getFormat(), SPARSE_CSR) << "Not supported";
const real* B = b.getData();
real* C = out.getData();
size_t height = out.getHeight();
size_t width = out.getWidth();
int* cols = a.getCols();
real* values = a.getValue();
if (scaleT == 0) {
out.zeroMem();
}
if (!a.isTransposed()) {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), m);
CHECK_EQ(a.getHeight(), height);
CHECK_EQ(b.getWidth(), width);
if (a.getValueType() == NO_VALUE) {
if (width % 32 == 0) { // use libaddto
CHECK_EQ((size_t)B % 32, 0UL);
CHECK_EQ((size_t)C % 32, 0UL);
auto& colArray = *threadLocalColArray;
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
size_t colNum = end - start;
colArray.resize(colNum);
for (int j = 0; j < end - start; ++j) {
colArray[j] = const_cast<CpuMatrix&>(b).getRow(cols[j + start]);
}
simd::batchAddTo(out.getRow(i), &colArray[0], colNum, width);
}
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == out.getWidth());
} else {
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
vecAddTo(out.getRow(i),
const_cast<CpuMatrix&>(b).getRow(cols[j]),
width);
}
}
}
} else if (a.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
vecAddTo(out.getRow(i),
const_cast<CpuMatrix&>(b).getRow(cols[j]),
values[j],
width);
}
CHECK(b.getHeight() == a.getHeight() && a.getWidth() == out.getHeight() &&
b.getWidth() == out.getWidth());
}
if (scaleT == 0) {
out.zeroMem();
}
} else /*if (a->isTransposed())*/ {
size_t m = a.getHeight();
CHECK_EQ(b.getHeight(), m);
CHECK_EQ(a.getWidth(), height);
CHECK_EQ(b.getWidth(), width);
if (a.getValueType() == NO_VALUE) {
if (width % 32 == 0) { // use libaddto
const real* B = b.getData();
real* C = out.getData();
if (out.getWidth() % 32 == 0) {
CHECK_EQ((size_t)B % 32, 0UL);
CHECK_EQ((size_t)C % 32, 0UL);
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
simd::addTo(out.getRow(cols[j]),
const_cast<CpuMatrix&>(b).getRow(i),
width);
}
}
} else {
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
vecAddTo(out.getRow(cols[j]),
const_cast<CpuMatrix&>(b).getRow(i),
width);
}
}
}
} else if (a.getValueType() == FLOAT_VALUE) {
int* cols = a.getCols();
real* values = a.getValue();
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
vecAddTo(out.getRow(cols[j]),
const_cast<CpuMatrix&>(b).getRow(i),
values[j],
width);
}
}
vecAddTo(!a.isTransposed() ? out.getRow(i) : out.getRow(cols[j]),
!a.isTransposed() ? const_cast<CpuMatrix&>(b).getRow(cols[j])
: const_cast<CpuMatrix&>(b).getRow(i),
(a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0,
out.getWidth());
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册