diff --git a/paddle/function/MulOp.cpp b/paddle/function/MulOp.cpp index 1117944a4ed99440a8f80d65c33f7a3935d7df0f..b911ccd13b607a6ea97d9844bfbc58166adc3fd6 100644 --- a/paddle/function/MulOp.cpp +++ b/paddle/function/MulOp.cpp @@ -56,7 +56,16 @@ void MulOp(CpuSparseMatrix& out, /// todo(tianbing), clean the code CHECK(!out.isTransposed()) << "Not supported"; CHECK_EQ(out.getValueType(), FLOAT_VALUE); + CHECK(!a.isTransposed() || !b.isTransposed()) + << "Not support both a and b are transpose matrices"; + if (!a.isTransposed() && b.isTransposed()) { + CHECK(out.getFormat() != SPARSE_CSC) + << "Not supported CSC format when a is not trans and b is trans"; + } + if (scaleT == 0) { + out.zeroMem(); + } const real* A = a.getData(); const real* B = b.getData(); real* C = out.getValue(); @@ -64,15 +73,11 @@ void MulOp(CpuSparseMatrix& out, int* cols = out.getCols(); size_t height = out.getHeight(); size_t width = out.getWidth(); - if (scaleT == 0) { - out.zeroMem(); - } if (!a.isTransposed() && !b.isTransposed()) { + CHECK(b.getHeight() == a.getWidth() && a.getHeight() == height && + b.getWidth() == width); size_t m = a.getWidth(); - CHECK_EQ(b.getHeight(), m); - CHECK_EQ(a.getHeight(), height); - CHECK_EQ(b.getWidth(), width); if (out.getFormat() == SPARSE_CSC) { for (size_t i = 0; i < width; i++) { size_t start = out.getColStartIdx(i); @@ -86,26 +91,27 @@ void MulOp(CpuSparseMatrix& out, C[j] = scaleAB * sum + scaleT * C[j]; } } - } else { + } else { /// out.getFormat() == SPARSE_CSR for (size_t i = 0; i < height; i++) { size_t start = out.getRowStartIdx(i); size_t end = out.getRowStartIdx(i + 1); for (size_t j = start; j < end; j++) { real sum = 0; size_t colIdx = cols[j]; - for (size_t k = 0; k < m; k++) { + for (size_t k = 0; k < a.getWidth(); k++) { sum += A[i * m + k] * B[k * width + colIdx]; } C[j] = scaleAB * sum + scaleT * C[j]; } } } - } else if (a.isTransposed() && !b.isTransposed()) { - size_t m = a.getHeight(); - CHECK_EQ(m, b.getHeight()); - CHECK_EQ(b.getWidth(), width); - CHECK_EQ(a.getWidth(), height); + return; + } + if (a.isTransposed() && !b.isTransposed()) { + CHECK(a.getHeight() == b.getHeight() && b.getWidth() == width && + a.getWidth() == height); + size_t m = a.getHeight(); if (out.getFormat() == SPARSE_CSC) { for (size_t i = 0; i < width; i++) { size_t start = out.getColStartIdx(i); @@ -119,25 +125,27 @@ void MulOp(CpuSparseMatrix& out, C[j] = scaleAB * sum + scaleT * C[j]; } } - } else { + } else { /// out.getFormat() == SPARSE_CSR for (size_t i = 0; i < height; i++) { int start = out.getRowStartIdx(i); int end = out.getRowStartIdx(i + 1); for (int j = start; j < end; j++) { real sum = 0; size_t colIdx = cols[j]; - for (size_t k = 0; k < m; k++) { + for (size_t k = 0; k < a.getHeight(); k++) { sum += A[k * height + i] * B[k * width + colIdx]; } C[j] = scaleAB * sum + scaleT * C[j]; } } } - } else if (!a.isTransposed() && b.isTransposed()) { + return; + } + + if (!a.isTransposed() && b.isTransposed()) { + CHECK(b.getWidth() == a.getWidth() && a.getHeight() == height && + b.getHeight() == width); size_t m = a.getWidth(); - CHECK_EQ(b.getWidth(), m); - CHECK_EQ(a.getHeight(), height); - CHECK_EQ(b.getHeight(), width); if (out.getFormat() == SPARSE_CSR) { for (size_t i = 0; i < height; i++) { size_t start = out.getRowStartIdx(i); @@ -151,12 +159,8 @@ void MulOp(CpuSparseMatrix& out, C[j] = scaleAB * sum + scaleT * C[j]; } } - } else { - LOG(FATAL) << "Not supported csc format " - "when a is not trans and b is trans"; } - } else { - LOG(FATAL) << "Not supported"; + return; } } @@ -166,159 +170,75 @@ void MulOp(CpuMatrix& out, const CpuMatrix& b, real scaleAB, real scaleT) { - /// todo(tianbing), clean the code - CHECK(!out.isTransposed()) << "Not supported"; - CBLAS_TRANSPOSE aTrans = CblasNoTrans; - size_t aRow = a.getHeight(); - size_t aCol = a.getWidth(); - CBLAS_TRANSPOSE bTrans = CblasNoTrans; - size_t bRow = b.getHeight(); - size_t bCol = b.getWidth(); - if (a.isTransposed()) { - aTrans = CblasTrans; - aRow = a.getWidth(); - aCol = a.getHeight(); - } - if (b.isTransposed()) { - bTrans = CblasTrans; - bRow = b.getWidth(); - bCol = b.getHeight(); - } + CHECK(!out.isTransposed()) << "out matrix transpose not supported"; + CBLAS_TRANSPOSE aTrans = a.isTransposed() ? CblasTrans : CblasNoTrans; + size_t aRow = a.isTransposed() ? a.getWidth() : a.getHeight(); + size_t aCol = a.isTransposed() ? a.getHeight() : a.getWidth(); + CBLAS_TRANSPOSE bTrans = b.isTransposed() ? CblasTrans : CblasNoTrans; + size_t bRow = b.isTransposed() ? b.getWidth() : b.getHeight(); + size_t bCol = b.isTransposed() ? b.getHeight() : b.getWidth(); /// C = A * B, for matrix format CHECK_EQ(aCol, bRow); CHECK_EQ(aRow, out.getHeight()); CHECK_EQ(bCol, out.getWidth()); - const real* A = a.getData(); - const real* B = b.getData(); - real* C = out.getData(); - - int M = out.getHeight(); - int N = out.getWidth(); - int K = aCol; - int lda = a.getStride(); - int ldb = b.getStride(); - int ldc = out.getStride(); - - GEMM(aTrans, bTrans, M, N, K, scaleAB, A, lda, B, ldb, scaleT, C, ldc); - - VLOG(2) << " A[0]=" << A[0] << " A[1]=" << A[1] << " B[0]=" << B[0] - << " B[1]=" << B[1] << " C[0]=" << C[0] << " C[1]=" << C[1]; + GEMM(aTrans, + bTrans, + out.getHeight(), + out.getWidth(), + aCol, + scaleAB, + a.getData(), + a.getStride(), + b.getData(), + b.getStride(), + scaleT, + out.getData(), + out.getStride()); } -static ThreadLocal> threadLocalColArray; - template <> void MulOp(CpuMatrix& out, const CpuSparseMatrix& a, const CpuMatrix& b, real scaleAB, real scaleT) { - /// todo(tianbing), clean the code CHECK(!out.isTransposed()) << "Not supported"; CHECK(!b.isTransposed()) << "Not supported"; CHECK(scaleT == 0 || scaleT == 1) << "Not support"; CHECK_EQ(scaleAB, static_cast(1.0)) << "Not supported"; CHECK_EQ(a.getFormat(), SPARSE_CSR) << "Not supported"; - const real* B = b.getData(); - real* C = out.getData(); - size_t height = out.getHeight(); - size_t width = out.getWidth(); - int* cols = a.getCols(); - real* values = a.getValue(); + if (!a.isTransposed()) { + CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() && + b.getWidth() == out.getWidth()); + } else { + CHECK(b.getHeight() == a.getHeight() && a.getWidth() == out.getHeight() && + b.getWidth() == out.getWidth()); + } if (scaleT == 0) { out.zeroMem(); } + const real* B = b.getData(); + real* C = out.getData(); + if (out.getWidth() % 32 == 0) { + CHECK_EQ((size_t)B % 32, 0UL); + CHECK_EQ((size_t)C % 32, 0UL); + } - if (!a.isTransposed()) { - size_t m = a.getWidth(); - CHECK_EQ(b.getHeight(), m); - CHECK_EQ(a.getHeight(), height); - CHECK_EQ(b.getWidth(), width); - - if (a.getValueType() == NO_VALUE) { - if (width % 32 == 0) { // use libaddto - CHECK_EQ((size_t)B % 32, 0UL); - CHECK_EQ((size_t)C % 32, 0UL); - auto& colArray = *threadLocalColArray; - for (size_t i = 0; i < a.getHeight(); ++i) { - const int start = a.getRowStartIdx(i); - const int end = a.getRowStartIdx(i + 1); - size_t colNum = end - start; - colArray.resize(colNum); - for (int j = 0; j < end - start; ++j) { - colArray[j] = const_cast(b).getRow(cols[j + start]); - } - simd::batchAddTo(out.getRow(i), &colArray[0], colNum, width); - } - - } else { - for (size_t i = 0; i < a.getHeight(); ++i) { - const int start = a.getRowStartIdx(i); - const int end = a.getRowStartIdx(i + 1); - for (int j = start; j < end; ++j) { - vecAddTo(out.getRow(i), - const_cast(b).getRow(cols[j]), - width); - } - } - } - } else if (a.getValueType() == FLOAT_VALUE) { - for (size_t i = 0; i < a.getHeight(); ++i) { - const int start = a.getRowStartIdx(i); - const int end = a.getRowStartIdx(i + 1); - for (int j = start; j < end; ++j) { - vecAddTo(out.getRow(i), - const_cast(b).getRow(cols[j]), - values[j], - width); - } - } - } - } else /*if (a->isTransposed())*/ { - size_t m = a.getHeight(); - CHECK_EQ(b.getHeight(), m); - CHECK_EQ(a.getWidth(), height); - CHECK_EQ(b.getWidth(), width); - if (a.getValueType() == NO_VALUE) { - if (width % 32 == 0) { // use libaddto - CHECK_EQ((size_t)B % 32, 0UL); - CHECK_EQ((size_t)C % 32, 0UL); - for (size_t i = 0; i < a.getHeight(); ++i) { - const int start = a.getRowStartIdx(i); - const int end = a.getRowStartIdx(i + 1); - for (int j = start; j < end; ++j) { - simd::addTo(out.getRow(cols[j]), - const_cast(b).getRow(i), - width); - } - } - - } else { - for (size_t i = 0; i < a.getHeight(); ++i) { - const int start = a.getRowStartIdx(i); - const int end = a.getRowStartIdx(i + 1); - for (int j = start; j < end; ++j) { - vecAddTo(out.getRow(cols[j]), - const_cast(b).getRow(i), - width); - } - } - } - } else if (a.getValueType() == FLOAT_VALUE) { - for (size_t i = 0; i < a.getHeight(); ++i) { - const int start = a.getRowStartIdx(i); - const int end = a.getRowStartIdx(i + 1); - for (int j = start; j < end; ++j) { - vecAddTo(out.getRow(cols[j]), - const_cast(b).getRow(i), - values[j], - width); - } - } + int* cols = a.getCols(); + real* values = a.getValue(); + for (size_t i = 0; i < a.getHeight(); ++i) { + const int start = a.getRowStartIdx(i); + const int end = a.getRowStartIdx(i + 1); + for (int j = start; j < end; ++j) { + vecAddTo(!a.isTransposed() ? out.getRow(i) : out.getRow(cols[j]), + !a.isTransposed() ? const_cast(b).getRow(cols[j]) + : const_cast(b).getRow(i), + (a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0, + out.getWidth()); } } }