diff --git a/paddle/function/MulOp.cpp b/paddle/function/MulOp.cpp index 37f8808605e10c7c0e6f88f6fec7b5f20697fbaf..1117944a4ed99440a8f80d65c33f7a3935d7df0f 100644 --- a/paddle/function/MulOp.cpp +++ b/paddle/function/MulOp.cpp @@ -38,13 +38,6 @@ inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) { } } -inline void colVecAddTo( - real* a, const real* b, size_t len, size_t aWidth, size_t bWidth) { - for (unsigned int i = 0; i < len; ++i) { - a[i * aWidth] += b[i * bWidth]; - } -} - inline void colVecAddTo( real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) { for (unsigned int i = 0; i < len; ++i) { @@ -336,140 +329,59 @@ void MulOp(CpuMatrix& out, const CpuSparseMatrix& b, real scaleAB, real scaleT) { - /// todo(tianbing), clean the code CHECK(!out.trans_) << "Not supported"; CHECK(!a.isTransposed()) << "Not supported"; CHECK(scaleT == 0 || scaleT == 1); CHECK_EQ(scaleAB, static_cast(1.0)); + if (!b.isTransposed()) { /// b is not Transpose + CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() && + b.getWidth() == out.getWidth()); + } else { + CHECK(b.getHeight() == out.getWidth() && a.getHeight() == out.getHeight() && + b.getWidth() == a.getWidth()); + } + if (scaleT == 0) { + out.zeroMem(); + } real* A = const_cast(a.getData()); real* B = const_cast(b.getValue()); real* C = out.getData(); int* rows = b.getRows(); int* cols = b.getCols(); - if (scaleT == 0) { - out.zeroMem(); - } - /// todo(tianbing), clean the code + /// b.getFormat() == SPARSE_CSC if (b.getFormat() == SPARSE_CSC) { - if (!b.isTransposed()) { - size_t m = a.getWidth(); - CHECK_EQ(b.getHeight(), m); - CHECK_EQ(a.getHeight(), out.height_); - CHECK_EQ(b.getWidth(), out.width_); - - if (b.getValueType() == NO_VALUE) { - for (size_t j = 0; j < b.getWidth(); ++j) { - int start = b.getColStartIdx(j); - int end = b.getColStartIdx(j + 1); - for (int i = start; i < end; ++i) { - colVecAddTo( - C + j, A + rows[i], out.height_, out.width_, a.getWidth()); - } - } - } else if (b.getValueType() == FLOAT_VALUE) { - for (size_t j = 0; j < b.getWidth(); ++j) { - int start = b.getColStartIdx(j); - int end = b.getColStartIdx(j + 1); - for (int i = start; i < end; ++i) { - colVecAddTo(C + j, - A + rows[i], - B[i], - out.height_, - out.width_, - a.getWidth()); - } - } - } - } else /*if (b.isTransposed())*/ { - size_t m = a.getWidth(); - CHECK_EQ(b.getHeight(), out.width_); - CHECK_EQ(a.getHeight(), out.height_); - CHECK_EQ(b.getWidth(), m); - if (b.getValueType() == NO_VALUE) { - for (size_t i = 0; i < b.getWidth(); ++i) { - int start = b.getColStartIdx(i); - int end = b.getColStartIdx(i + 1); - for (int j = start; j < end; ++j) { - colVecAddTo( - C + rows[j], A + i, out.height_, out.width_, a.getWidth()); - } - } - } else if (b.getValueType() == FLOAT_VALUE) { - for (size_t i = 0; i < b.getWidth(); ++i) { - int start = b.getColStartIdx(i); - int end = b.getColStartIdx(i + 1); - for (int j = start; j < end; ++j) { - colVecAddTo(C + rows[j], - A + i, - B[j], - out.height_, - out.width_, - a.getWidth()); - } - } + for (size_t j = 0; j < b.getWidth(); ++j) { + int start = b.getColStartIdx(j); + int end = b.getColStartIdx(j + 1); + for (int i = start; i < end; ++i) { + colVecAddTo(!b.isTransposed() ? C + j : C + rows[i], + !b.isTransposed() ? A + rows[i] : A + j, + (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i], + out.getHeight(), + out.getWidth(), + a.getWidth()); } } - } else { - if (!b.isTransposed()) { - size_t m = a.getWidth(); - CHECK_EQ(b.getHeight(), m); - CHECK_EQ(a.getHeight(), out.height_); - CHECK_EQ(b.getWidth(), out.width_); - - if (b.getValueType() == NO_VALUE) { - for (size_t j = 0; j < b.getHeight(); ++j) { - int start = b.getRowStartIdx(j); - int end = b.getRowStartIdx(j + 1); - for (int i = start; i < end; ++i) { - colVecAddTo( - C + cols[i], A + j, out.height_, out.width_, a.getWidth()); - } - } - } else if (b.getValueType() == FLOAT_VALUE) { - for (size_t j = 0; j < b.getHeight(); ++j) { - int start = b.getRowStartIdx(j); - int end = b.getRowStartIdx(j + 1); - for (int i = start; i < end; ++i) { - colVecAddTo(C + cols[i], - A + j, - B[i], - out.height_, - out.width_, - a.getWidth()); - } - } - } - } else /*if (b.isTransposed())*/ { - size_t m = a.getWidth(); - CHECK_EQ(b.getHeight(), out.width_); - CHECK_EQ(a.getHeight(), out.height_); - CHECK_EQ(b.getWidth(), m); - if (b.getValueType() == NO_VALUE) { - for (size_t i = 0; i < b.getHeight(); ++i) { - int start = b.getRowStartIdx(i); - int end = b.getRowStartIdx(i + 1); - for (int j = start; j < end; ++j) { - colVecAddTo( - C + i, A + cols[j], out.height_, out.width_, a.getWidth()); - } - } - } else if (b.getValueType() == FLOAT_VALUE) { - for (size_t i = 0; i < b.getHeight(); ++i) { - int start = b.getRowStartIdx(i); - int end = b.getRowStartIdx(i + 1); - for (int j = start; j < end; ++j) { - colVecAddTo(C + i, - A + cols[j], - B[j], - out.height_, - out.width_, - a.getWidth()); - } - } + return; + } + + /// b.getFormat() == SPARSE_CSR + if (b.getFormat() == SPARSE_CSR) { + for (size_t j = 0; j < b.getHeight(); ++j) { + int start = b.getRowStartIdx(j); + int end = b.getRowStartIdx(j + 1); + for (int i = start; i < end; ++i) { + colVecAddTo(!b.isTransposed() ? C + cols[i] : C + j, + !b.isTransposed() ? A + j : A + cols[i], + (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i], + out.getHeight(), + out.getWidth(), + a.getWidth()); } } + return; } } diff --git a/paddle/function/MulOpGpu.cu b/paddle/function/MulOpGpu.cu index 3c4654b9b27574fd146d27e66164254f5e40da7d..09d2a764911f183cbd156e064ee50fc3ef300a38 100644 --- a/paddle/function/MulOpGpu.cu +++ b/paddle/function/MulOpGpu.cu @@ -19,154 +19,147 @@ limitations under the License. */ namespace paddle { /** - * out = scale_t * out + scale_ab * (a * b) + * out = scaleT * out + scaleAB * (a * b) * out : output matrix, M * N */ template <> void MulOp(GpuMatrix& out, const GpuMatrix& a, const GpuMatrix& b, - real scale_ab, - real scale_t) { - CHECK(!out.isTransposed()) << "Not supported"; - + real scaleAB, + real scaleT) { + CHECK(!out.isTransposed()) << "Transpose not supported for out matrix"; if (!a.isTransposed() && !b.isTransposed()) { - /// a : M * K, b: K * N - CHECK_EQ(out.width_, b.width_); - CHECK_EQ(out.height_, a.height_); - CHECK_EQ(a.width_, b.height_); + /// a : M * K, b: K * N + CHECK(out.getWidth() == b.getWidth() && + out.getHeight() == a.getHeight() && + a.getWidth() == b.getHeight()); } else if (a.isTransposed() && !b.isTransposed()) { - /// a : K * M, b : K * N - CHECK_EQ(out.width_, b.width_); - CHECK_EQ(out.height_, a.width_); - CHECK_EQ(a.height_, b.height_); + /// a : K * M, b : K * N + CHECK(out.getWidth() == b.getWidth() && + out.getHeight() == a.getWidth() && + a.getHeight() == b.getHeight()); } else if (!a.isTransposed() && b.isTransposed()) { - /// a: M * K, b : N * K - CHECK_EQ(out.width_, b.height_); - CHECK_EQ(out.height_, a.height_); - CHECK_EQ(a.width_, b.width_); + /// a: M * K, b : N * K + CHECK(out.getWidth() == b.getHeight() && + out.getHeight() == a.getHeight() && + a.getWidth() == b.getWidth()); } else { - LOG(FATAL) << "Is not supported"; + LOG(FATAL) << "Not support for both a and b are Transposed Matrices"; } - real* a_data = a.data_; - real* b_data = b.data_; - real* out_data = out.data_; - int dim_m = out.getHeight(); - int dim_n = out.getWidth(); - int dim_k = !a.isTransposed() ? a.width_ : a.height_; - int lda = a.getStride(); - int ldb = b.getStride(); - int ldc = out.getStride(); - hl_trans_op_t trans_a = !a.isTransposed() ? HPPL_OP_N : HPPL_OP_T; - hl_trans_op_t trans_b = !b.isTransposed() ? HPPL_OP_N : HPPL_OP_T; - - hl_matrix_mul(a_data, - trans_a, - b_data, - trans_b, - out_data, - dim_m, - dim_n, - dim_k, - scale_ab, - scale_t, - lda, - ldb, - ldc); + real* aData = const_cast(a.getData()); + real* bData = const_cast(b.getData()); + real* outData = const_cast(out.getData()); + hl_matrix_mul(aData, + !a.isTransposed() ? HPPL_OP_N : HPPL_OP_T, + bData, + !b.isTransposed() ? HPPL_OP_N : HPPL_OP_T, + outData, + out.getHeight(), + out.getWidth(), + !a.isTransposed() ? a.getWidth() : a.getHeight(), + scaleAB, + scaleT, + a.getStride(), + b.getStride(), + out.getStride()); } /** - * out = scale_t * out + scale_ab * (a * b) + * out = scaleT * out + scaleAB * (a * b) * out : M * N */ template <> void MulOp(GpuMatrix& out, const GpuSparseMatrix& a, const GpuMatrix& b, - real scale_ab, - real scale_t) { + real scaleAB, + real scaleT) { CHECK(out.isContiguous()); CHECK(b.isContiguous()); - CHECK(b.useGpu_ == true) << "Matrix type are not equal"; - CHECK(!out.trans_ && !b.trans_) << "not supported"; - if (!a.trans_) { + CHECK(b.useGpu_) << "Matrix type are not equal"; + CHECK(!out.isTransposed() && !b.isTransposed()) << "not supported"; + if (!a.isTransposed()) { /// a: M * K, b: K * N - CHECK(out.width_ == b.width_ && out.height_ == a.height_ - && a.width_ == b.height_) << "Matrix dimensions are not equal"; + CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getHeight() + && a.getWidth() == b.getHeight()) << "Matrix dimensions are not equal"; } else { /// a: K * M, transpose, b: K * N - CHECK(out.width_ == b.width_ && out.height_ == a.width_ - && a.height_ == b.height_) << "Matrix dimensions are not equal"; + CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getWidth() + && a.getHeight() == b.getHeight()) << "Matrix dimensions are not equal"; } - hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N; - hl_sparse_matrix_s a_data = a.sMatrix_.get(); - real* b_data = b.data_; - real* out_data = out.data_; - hl_matrix_csr_mul_dense(a_data, - a_trans, - b_data, + hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N; + hl_sparse_matrix_s aData = a.sMatrix_.get(); + real* bData = const_cast(b.getData()); + real* outData = const_cast(out.getData()); + hl_matrix_csr_mul_dense(aData, + aTrans, + bData, HPPL_OP_N, - out_data, - out.height_, - out.width_, - b.height_, - scale_ab, - scale_t); + outData, + out.getHeight(), + out.getWidth(), + b.getHeight(), + scaleAB, + scaleT); } /** - * out = scale_t * out + scale_ab * (a * b) + * out = scaleT * out + scaleAB * (a * b) * out : M * N */ template <> void MulOp(GpuMatrix& out, const GpuMatrix& a, const GpuSparseMatrix& b, - real scale_ab, - real scale_t) { + real scaleAB, + real scaleT) { CHECK(out.isContiguous()); CHECK(a.isContiguous()); - CHECK(a.useGpu_ == true) << "Matrix type are not equal"; - - hl_sparse_matrix_s b_data = b.sMatrix_.get(); - real* a_data = a.data_; - real* out_data = out.data_; - hl_trans_op_t trans_b = b.trans_ ? HPPL_OP_T : HPPL_OP_N; - if (!b.trans_) { - /// a : M * K, b : K * N - CHECK(out.width_ == b.width_ && - out.height_ == a.height_ && a.width_ == b.height_) - << "Matrix dimensions are not equal"; + CHECK(a.useGpu_) << "Matrix type are not equal"; + if (!b.isTransposed()) { + /// a : M * K, b : K * N + CHECK(out.getWidth() == b.getWidth() && + out.getHeight() == a.getHeight() && + a.getWidth() == b.getHeight()) + << "Matrix dimensions are not equal"; } else { - /// a : M * K, b : N * K, transpose - CHECK(out.width_ == b.height_ && - out.height_ == a.height_ && a.width_ == b.width_) - << "Matrix dimensions are not equal"; + /// a : M * K, b : N * K, transpose + CHECK(out.getWidth() == b.getHeight() && + out.getHeight() == a.getHeight() && + a.getWidth() == b.getWidth()) + << "Matrix dimensions are not equal"; } + + hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N; + hl_sparse_matrix_s bData = b.sMatrix_.get(); + real* aData = const_cast(a.getData()); + real* outData = const_cast(out.getData()); + if (b.format_ == SPARSE_CSC) { - hl_matrix_dense_mul_csc(a_data, + hl_matrix_dense_mul_csc(aData, HPPL_OP_N, - b_data, - trans_b, - out_data, - out.height_, - out.width_, - a.width_, - scale_ab, - scale_t); + bData, + bTrans, + outData, + out.getHeight(), + out.getWidth(), + a.getWidth(), + scaleAB, + scaleT); } else { - hl_matrix_dense_mul_csr(a_data, + hl_matrix_dense_mul_csr(aData, HPPL_OP_N, - b_data, - trans_b, - out_data, - out.height_, - out.width_, - a.width_, - scale_ab, - scale_t); + bData, + bTrans, + outData, + out.getHeight(), + out.getWidth(), + a.getWidth(), + scaleAB, + scaleT); } } @@ -174,38 +167,36 @@ template <> void MulOp(GpuSparseMatrix& out, const GpuMatrix& a, const GpuMatrix& b, - real scale_ab, - real scale_t) { - /// todo(tianbing), clean the code - CHECK(a.useGpu_ && b.useGpu_) << "type not match"; - CHECK(!out.trans_) << "trans not supported"; - real* a_data = const_cast(a.getData()); - real* b_data = const_cast(b.getData()); - hl_sparse_matrix_s out_data = out.sMatrix_.get(); - hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N; - hl_trans_op_t b_trans = b.trans_ ? HPPL_OP_T : HPPL_OP_N; - - if (!a.trans_ && !b.trans_) { - CHECK(out.height_ == a.getHeight()); - CHECK(out.width_ == b.getWidth()); - CHECK(a.getWidth() == b.getHeight()); - } else if (a.trans_ && !b.trans_) { - CHECK(out.height_ == a.getWidth()); - CHECK(out.width_ == b.getWidth()); - CHECK(a.getHeight() == b.getHeight()); - } else if (!a.trans_ && b.trans_) { - CHECK(out.height_ == a.getHeight()); - CHECK(out.width_ == b.getHeight()); - CHECK(a.getWidth() == b.getWidth()); + real scaleAB, + real scaleT) { + CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; + CHECK(!out.isTransposed()) << "Transpose is not supported for out matrix"; + + if (!a.isTransposed() && !b.isTransposed()) { + CHECK(out.getHeight() == a.getHeight() && + out.getWidth() == b.getWidth() && + a.getWidth() == b.getHeight()); + } else if (a.isTransposed() && !b.isTransposed()) { + CHECK(out.getHeight() == a.getWidth() && + out.getWidth() == b.getWidth() && + a.getHeight() == b.getHeight()); + } else if (!a.isTransposed() && b.isTransposed()) { + CHECK(out.getHeight() == a.getHeight() && + out.getWidth() == b.getHeight() && + a.getWidth() == b.getWidth()); } else { - LOG(INFO) << "Not support"; + LOG(FATAL) << "Not support for both a and b are Transposed Matrices"; } - int dim_m = out.height_; - int dim_n = out.width_; - int dim_k = !b.trans_ ? b.getHeight() : b.getWidth(); - hl_sparse_matrix_mul( - a_data, a_trans, b_data, b_trans, out_data, - dim_m, dim_n, dim_k, scale_ab, scale_t); + + hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N; + hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N; + int dimK = !b.isTransposed() ? b.getHeight() : b.getWidth(); + real* aData = const_cast(a.getData()); + real* bData = const_cast(b.getData()); + hl_sparse_matrix_s outData = out.sMatrix_.get(); + + hl_sparse_matrix_mul(aData, aTrans, bData, bTrans, outData, + out.getHeight(), out.getWidth(), dimK, scaleAB, scaleT); } } // namespace paddle diff --git a/paddle/function/MulOpTest.cpp b/paddle/function/MulOpTest.cpp index 630070b845a9af7aff734ea3e8ff9a7cf62fd7d3..965ffea20ce7fb03c743145d84918ff81f4858e5 100644 --- a/paddle/function/MulOpTest.cpp +++ b/paddle/function/MulOpTest.cpp @@ -76,12 +76,12 @@ void testDDDMatrix(bool transa, bool transb, int dimM, int dimN, int dimK) { TEST(Matrix, DDDMul) { LOG(INFO) << "test for dense = dense * dense matrix"; - for (auto transa : {false, true}) { - for (auto transb : {false, true}) { - for (auto dimM : {1, 10, 100}) { - for (auto dimN : {1, 10}) { - for (auto dimK : {8}) { - if (true == transa && true == transb) { + for (const auto transa : {false, true}) { + for (const auto transb : {false, true}) { + for (const auto dimM : {1, 10, 100}) { + for (const auto dimN : {1, 10}) { + for (const auto dimK : {8}) { + if (transa && transb) { continue; } VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') @@ -89,7 +89,6 @@ TEST(Matrix, DDDMul) { << " dimM=" << std::setw(5) << dimM << " dimN=" << std::setw(5) << dimN << " dimK=" << std::setw(5) << dimK; - testDDDMatrix(transa, transb, dimM, dimN, dimK); } }