diff --git a/paddle/function/BufferArg.h b/paddle/function/BufferArg.h index 3489510b25f0989fb25a9687e8e34ab34479a20b..1f86f49911c484455931810dc1f3264e7d8a9b55 100644 --- a/paddle/function/BufferArg.h +++ b/paddle/function/BufferArg.h @@ -167,7 +167,7 @@ public: ValueType valueType() const { return valueType_; } BufferType bufferType() const { return bufferType_; } const TensorShape& shape() const { return shape_; } - bool isSparse() const { return TENSOR_SPARSE == bufferType_; } + bool isSparseArg() const { return TENSOR_SPARSE == bufferType_; } bool isSequenceArg() const { return TENSOR_SEQUENCE_DATA == bufferType_; } const SequenceArg& sequence() const; diff --git a/paddle/function/MulOp.cpp b/paddle/function/MulOp.cpp index 1fa29fae8d4e1ece2fa8a735683fc3f0921209a5..7d341182523cbb4508bf13ddc0f9bbbf46752151 100644 --- a/paddle/function/MulOp.cpp +++ b/paddle/function/MulOp.cpp @@ -13,16 +13,471 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "MulOp.h" +#include "paddle/math/MathFunctions.h" +#include "paddle/math/SIMDFunctions.h" +#include "paddle/utils/ThreadLocal.h" + +#ifndef PADDLE_TYPE_DOUBLE +#define GEMM paddle::gemm +#else +#define GEMM paddle::gemm +#endif + +namespace { +inline void vecAddTo(real* a, const real* b, size_t len) { + for (unsigned int i = 0; i < len; ++i) { + a[i] += b[i]; + } +} + +inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) { + for (unsigned int i = 0; i < len; ++i) { + a[i] += scaleB * b[i]; + } +} + +inline void colVecAddTo( + real* a, const real* b, size_t len, size_t aWidth, size_t bWidth) { + for (unsigned int i = 0; i < len; ++i) { + a[i * aWidth] += b[i * bWidth]; + } +} + +inline void colVecAddTo( + real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) { + for (unsigned int i = 0; i < len; ++i) { + a[i * aWidth] += b[i * bWidth] * c; + } +} +} // namespace namespace paddle { +template <> +void MulOp(CpuSparseMatrix& out, + const CpuMatrix& a, + const CpuMatrix& b, + real scaleAB, + real scaleT) { + /// todo(tianbing), clean the code + CHECK(!out.isTransposed()) << "Not supported"; + CHECK_EQ(out.getValueType(), FLOAT_VALUE); + + const real* A = a.getData(); + const real* B = b.getData(); + real* C = out.getValue(); + int* rows = out.getRows(); + int* cols = out.getCols(); + size_t height = out.getHeight(); + size_t width = out.getWidth(); + if (scaleT == 0) { + out.zeroMem(); + } + + if (!a.isTransposed() && !b.isTransposed()) { + size_t m = a.getWidth(); + CHECK_EQ(b.getHeight(), m); + CHECK_EQ(a.getHeight(), height); + CHECK_EQ(b.getWidth(), width); + if (out.getFormat() == SPARSE_CSC) { + for (size_t i = 0; i < width; i++) { + size_t start = out.getColStartIdx(i); + size_t end = out.getColStartIdx(i + 1); + for (size_t j = start; j < end; j++) { + real sum = 0; + size_t rowIdx = rows[j]; + for (size_t k = 0; k < m; k++) { + sum += A[rowIdx * m + k] * B[k * width + i]; + } + C[j] = scaleAB * sum + scaleT * C[j]; + } + } + } else { + for (size_t i = 0; i < height; i++) { + size_t start = out.getRowStartIdx(i); + size_t end = out.getRowStartIdx(i + 1); + for (size_t j = start; j < end; j++) { + real sum = 0; + size_t colIdx = cols[j]; + for (size_t k = 0; k < m; k++) { + sum += A[i * m + k] * B[k * width + colIdx]; + } + C[j] = scaleAB * sum + scaleT * C[j]; + } + } + } + } else if (a.isTransposed() && !b.isTransposed()) { + size_t m = a.getHeight(); + CHECK_EQ(m, b.getHeight()); + CHECK_EQ(b.getWidth(), width); + CHECK_EQ(a.getWidth(), height); + + if (out.getFormat() == SPARSE_CSC) { + for (size_t i = 0; i < width; i++) { + size_t start = out.getColStartIdx(i); + size_t end = out.getColStartIdx(i + 1); + for (size_t j = start; j < end; j++) { + real sum = 0; + size_t rowIdx = rows[j]; + for (size_t k = 0; k < m; k++) { + sum += A[k * height + rowIdx] * B[k * width + i]; + } + C[j] = scaleAB * sum + scaleT * C[j]; + } + } + } else { + for (size_t i = 0; i < height; i++) { + int start = out.getRowStartIdx(i); + int end = out.getRowStartIdx(i + 1); + for (int j = start; j < end; j++) { + real sum = 0; + size_t colIdx = cols[j]; + for (size_t k = 0; k < m; k++) { + sum += A[k * height + i] * B[k * width + colIdx]; + } + C[j] = scaleAB * sum + scaleT * C[j]; + } + } + } + } else if (!a.isTransposed() && b.isTransposed()) { + size_t m = a.getWidth(); + CHECK_EQ(b.getWidth(), m); + CHECK_EQ(a.getHeight(), height); + CHECK_EQ(b.getHeight(), width); + if (out.getFormat() == SPARSE_CSR) { + for (size_t i = 0; i < height; i++) { + size_t start = out.getRowStartIdx(i); + size_t end = out.getRowStartIdx(i + 1); + for (size_t j = start; j < end; j++) { + real sum = 0; + size_t colIdx = cols[j]; + for (size_t k = 0; k < m; k++) { + sum += A[i * m + k] * B[colIdx * m + k]; + } + C[j] = scaleAB * sum + scaleT * C[j]; + } + } + } else { + LOG(FATAL) << "Not supported csc format " + "when a is not trans and b is trans"; + } + } else { + LOG(FATAL) << "Not supported"; + } +} + +template <> +void MulOp(CpuMatrix& out, + const CpuMatrix& a, + const CpuMatrix& b, + real scaleAB, + real scaleT) { + /// todo(tianbing), clean the code + CHECK(!out.isTransposed()) << "Not supported"; + CBLAS_TRANSPOSE aTrans = CblasNoTrans; + size_t aRow = a.getHeight(); + size_t aCol = a.getWidth(); + CBLAS_TRANSPOSE bTrans = CblasNoTrans; + size_t bRow = b.getHeight(); + size_t bCol = b.getWidth(); + if (a.isTransposed()) { + aTrans = CblasTrans; + aRow = a.getWidth(); + aCol = a.getHeight(); + } + if (b.isTransposed()) { + bTrans = CblasTrans; + bRow = b.getWidth(); + bCol = b.getHeight(); + } + + /// C = A * B, for matrix format + CHECK_EQ(aCol, bRow); + CHECK_EQ(aRow, out.getHeight()); + CHECK_EQ(bCol, out.getWidth()); + + const real* A = a.getData(); + const real* B = b.getData(); + real* C = out.getData(); + + int M = out.getHeight(); + int N = out.getWidth(); + int K = aCol; + int lda = a.getStride(); + int ldb = b.getStride(); + int ldc = out.getStride(); + + GEMM(aTrans, bTrans, M, N, K, scaleAB, A, lda, B, ldb, scaleT, C, ldc); + + VLOG(2) << " A[0]=" << A[0] << " A[1]=" << A[1] << " B[0]=" << B[0] + << " B[1]=" << B[1] << " C[0]=" << C[0] << " C[1]=" << C[1]; +} + +static ThreadLocal> threadLocalColArray; + +template <> +void MulOp(CpuMatrix& out, + const CpuSparseMatrix& a, + const CpuMatrix& b, + real scaleAB, + real scaleT) { + /// todo(tianbing), clean the code + CHECK(!out.isTransposed()) << "Not supported"; + CHECK(!b.isTransposed()) << "Not supported"; + CHECK(scaleT == 0 || scaleT == 1) << "Not support"; + CHECK_EQ(scaleAB, static_cast(1.0)) << "Not supported"; + CHECK_EQ(a.getFormat(), SPARSE_CSR) << "Not supported"; + + const real* B = b.getData(); + real* C = out.getData(); + size_t height = out.getHeight(); + size_t width = out.getWidth(); + int* cols = a.getCols(); + real* values = a.getValue(); + + if (scaleT == 0) { + out.zeroMem(); + } + + if (!a.isTransposed()) { + size_t m = a.getWidth(); + CHECK_EQ(b.getHeight(), m); + CHECK_EQ(a.getHeight(), height); + CHECK_EQ(b.getWidth(), width); + + if (a.getValueType() == NO_VALUE) { + if (width % 32 == 0) { // use libaddto + CHECK_EQ((size_t)B % 32, 0UL); + CHECK_EQ((size_t)C % 32, 0UL); + auto& colArray = *threadLocalColArray; + for (size_t i = 0; i < a.getHeight(); ++i) { + const int start = a.getRowStartIdx(i); + const int end = a.getRowStartIdx(i + 1); + size_t colNum = end - start; + colArray.resize(colNum); + for (int j = 0; j < end - start; ++j) { + colArray[j] = const_cast(b).getRow(cols[j + start]); + } + simd::batchAddTo(out.getRow(i), &colArray[0], colNum, width); + } + + } else { + for (size_t i = 0; i < a.getHeight(); ++i) { + const int start = a.getRowStartIdx(i); + const int end = a.getRowStartIdx(i + 1); + for (int j = start; j < end; ++j) { + vecAddTo(out.getRow(i), + const_cast(b).getRow(cols[j]), + width); + } + } + } + } else if (a.getValueType() == FLOAT_VALUE) { + for (size_t i = 0; i < a.getHeight(); ++i) { + const int start = a.getRowStartIdx(i); + const int end = a.getRowStartIdx(i + 1); + for (int j = start; j < end; ++j) { + vecAddTo(out.getRow(i), + const_cast(b).getRow(cols[j]), + values[j], + width); + } + } + } + } else /*if (a->isTransposed())*/ { + size_t m = a.getHeight(); + CHECK_EQ(b.getHeight(), m); + CHECK_EQ(a.getWidth(), height); + CHECK_EQ(b.getWidth(), width); + if (a.getValueType() == NO_VALUE) { + if (width % 32 == 0) { // use libaddto + CHECK_EQ((size_t)B % 32, 0UL); + CHECK_EQ((size_t)C % 32, 0UL); + for (size_t i = 0; i < a.getHeight(); ++i) { + const int start = a.getRowStartIdx(i); + const int end = a.getRowStartIdx(i + 1); + for (int j = start; j < end; ++j) { + simd::addTo(out.getRow(cols[j]), + const_cast(b).getRow(i), + width); + } + } + + } else { + for (size_t i = 0; i < a.getHeight(); ++i) { + const int start = a.getRowStartIdx(i); + const int end = a.getRowStartIdx(i + 1); + for (int j = start; j < end; ++j) { + vecAddTo(out.getRow(cols[j]), + const_cast(b).getRow(i), + width); + } + } + } + } else if (a.getValueType() == FLOAT_VALUE) { + for (size_t i = 0; i < a.getHeight(); ++i) { + const int start = a.getRowStartIdx(i); + const int end = a.getRowStartIdx(i + 1); + for (int j = start; j < end; ++j) { + vecAddTo(out.getRow(cols[j]), + const_cast(b).getRow(i), + values[j], + width); + } + } + } + } +} + +template <> +void MulOp(CpuMatrix& out, + const CpuMatrix& a, + const CpuSparseMatrix& b, + real scaleAB, + real scaleT) { + /// todo(tianbing), clean the code + CHECK(!out.trans_) << "Not supported"; + CHECK(!a.isTransposed()) << "Not supported"; + CHECK(scaleT == 0 || scaleT == 1); + CHECK_EQ(scaleAB, static_cast(1.0)); + + real* A = const_cast(a.getData()); + real* B = const_cast(b.getValue()); + real* C = out.getData(); + int* rows = b.getRows(); + int* cols = b.getCols(); + + if (scaleT == 0) { + out.zeroMem(); + } + /// todo(tianbing), clean the code + if (b.getFormat() == SPARSE_CSC) { + if (!b.isTransposed()) { + size_t m = a.getWidth(); + CHECK_EQ(b.getHeight(), m); + CHECK_EQ(a.getHeight(), out.height_); + CHECK_EQ(b.getWidth(), out.width_); + + if (b.getValueType() == NO_VALUE) { + for (size_t j = 0; j < b.getWidth(); ++j) { + int start = b.getColStartIdx(j); + int end = b.getColStartIdx(j + 1); + for (int i = start; i < end; ++i) { + colVecAddTo( + C + j, A + rows[i], out.height_, out.width_, a.getWidth()); + } + } + } else if (b.getValueType() == FLOAT_VALUE) { + for (size_t j = 0; j < b.getWidth(); ++j) { + int start = b.getColStartIdx(j); + int end = b.getColStartIdx(j + 1); + for (int i = start; i < end; ++i) { + colVecAddTo(C + j, + A + rows[i], + B[i], + out.height_, + out.width_, + a.getWidth()); + } + } + } + } else /*if (b.isTransposed())*/ { + size_t m = a.getWidth(); + CHECK_EQ(b.getHeight(), out.width_); + CHECK_EQ(a.getHeight(), out.height_); + CHECK_EQ(b.getWidth(), m); + if (b.getValueType() == NO_VALUE) { + for (size_t i = 0; i < b.getWidth(); ++i) { + int start = b.getColStartIdx(i); + int end = b.getColStartIdx(i + 1); + for (int j = start; j < end; ++j) { + colVecAddTo( + C + rows[j], A + i, out.height_, out.width_, a.getWidth()); + } + } + } else if (b.getValueType() == FLOAT_VALUE) { + for (size_t i = 0; i < b.getWidth(); ++i) { + int start = b.getColStartIdx(i); + int end = b.getColStartIdx(i + 1); + for (int j = start; j < end; ++j) { + colVecAddTo(C + rows[j], + A + i, + B[j], + out.height_, + out.width_, + a.getWidth()); + } + } + } + } + } else { + if (!b.isTransposed()) { + size_t m = a.getWidth(); + CHECK_EQ(b.getHeight(), m); + CHECK_EQ(a.getHeight(), out.height_); + CHECK_EQ(b.getWidth(), out.width_); + + if (b.getValueType() == NO_VALUE) { + for (size_t j = 0; j < b.getHeight(); ++j) { + int start = b.getRowStartIdx(j); + int end = b.getRowStartIdx(j + 1); + for (int i = start; i < end; ++i) { + colVecAddTo( + C + cols[i], A + j, out.height_, out.width_, a.getWidth()); + } + } + } else if (b.getValueType() == FLOAT_VALUE) { + for (size_t j = 0; j < b.getHeight(); ++j) { + int start = b.getRowStartIdx(j); + int end = b.getRowStartIdx(j + 1); + for (int i = start; i < end; ++i) { + colVecAddTo(C + cols[i], + A + j, + B[i], + out.height_, + out.width_, + a.getWidth()); + } + } + } + } else /*if (b.isTransposed())*/ { + size_t m = a.getWidth(); + CHECK_EQ(b.getHeight(), out.width_); + CHECK_EQ(a.getHeight(), out.height_); + CHECK_EQ(b.getWidth(), m); + if (b.getValueType() == NO_VALUE) { + for (size_t i = 0; i < b.getHeight(); ++i) { + int start = b.getRowStartIdx(i); + int end = b.getRowStartIdx(i + 1); + for (int j = start; j < end; ++j) { + colVecAddTo( + C + i, A + cols[j], out.height_, out.width_, a.getWidth()); + } + } + } else if (b.getValueType() == FLOAT_VALUE) { + for (size_t i = 0; i < b.getHeight(); ++i) { + int start = b.getRowStartIdx(i); + int end = b.getRowStartIdx(i + 1); + for (int j = start; j < end; ++j) { + colVecAddTo(C + i, + A + cols[j], + B[j], + out.height_, + out.width_, + a.getWidth()); + } + } + } + } + } +} /** * mul operator * out = scaleT * out + scaleAB*(in1 * in2) * - * \param outputs[0] output matrix, N * M - * \param inputs[0] first input (sparse) matrix, N * K - * \param inputs[1] second input matrix, K * M (non-transpose) + * \param outputs[0] output matrix, M * N + * \param inputs[0] first input (sparse) matrix, M * K (if non-trans) + * \param inputs[1] second input matrix, K * N (if non-trans) */ template class MulFunc : public FunctionBase { @@ -33,19 +488,23 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - /// todo(tianbing), add more checks - CHECK_EQ((size_t)1, inputs.size()); - CHECK_EQ((size_t)2, outputs.size()); + CHECK_EQ((size_t)2, inputs.size()); + CHECK_EQ((size_t)1, outputs.size()); CHECK(inputs[0].data() && inputs[1].data() && outputs[0].data()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)2); CHECK_EQ(inputs[1].shape().ndims(), (size_t)2); CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); - CHECK(inputs[0].isSparse()) << "SparseMatrix requried here"; - const auto in1_mat = inputs[0].sparse().SparseMatrix(); + auto in1_mat = inputs[0].matrix(); + if (inputs[0].isSparseArg()) { + in1_mat = inputs[0].sparse().SparseMatrix(); + } + auto in2_mat = inputs[1].matrix(); + if (inputs[1].isSparseArg()) { + in2_mat = inputs[1].sparse().SparseMatrix(); + } auto out_mat = outputs[0].matrix(); - const auto in2_mat = inputs[1].matrix(); MulOp(out_mat, in1_mat, in2_mat, scaleAB_, scaleT_); } @@ -54,6 +513,7 @@ private: real scaleT_; }; +REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc); #ifndef PADDLE_ONLY_CPU REGISTER_TYPED_FUNC(MulOp, GPU, MulFunc); #endif diff --git a/paddle/function/MulOp.h b/paddle/function/MulOp.h index bcea1864026b2768dbcc4b5b822ce121e97796c1..f3699f8c78cda71154a65b34a00dc1986bd4c221 100644 --- a/paddle/function/MulOp.h +++ b/paddle/function/MulOp.h @@ -19,6 +19,40 @@ limitations under the License. */ #include "paddle/math/SparseMatrix.h" namespace paddle { +template +void MulOp(CpuMatrix& out, + const CpuMatrix& a, + const CpuMatrix& b, + real scaleAB, + real scaleT); + +template +void MulOp(CpuMatrix& out, + const CpuSparseMatrix& a, + const CpuMatrix& b, + real scaleAB, + real scaleT); + +template +void MulOp(CpuMatrix& out, + const CpuMatrix& a, + const CpuSparseMatrix& b, + real scaleAB, + real scaleT); + +template +void MulOp(CpuSparseMatrix& out, + const CpuMatrix& a, + const CpuMatrix& b, + real scaleAB, + real scaleT); + +template +void MulOp(GpuMatrix& out, + const GpuMatrix& a, + const GpuMatrix& b, + real scaleAB, + real scaleT); template void MulOp(GpuMatrix& out, @@ -27,4 +61,11 @@ void MulOp(GpuMatrix& out, real scaleAB, real scaleT); +template +void MulOp(GpuMatrix& out, + const GpuMatrix& a, + const GpuSparseMatrix& b, + real scaleAB, + real scaleT); + } // namespace paddle diff --git a/paddle/function/MulOpGpu.cu b/paddle/function/MulOpGpu.cu index db716c1e46b4e1473b99a5d4c645a76e640adc98..73d788a4743326ab06392aa582b85c0c0ce75b2b 100644 --- a/paddle/function/MulOpGpu.cu +++ b/paddle/function/MulOpGpu.cu @@ -20,6 +20,65 @@ limitations under the License. */ namespace paddle { /** * out = scale_t * out + scale_ab * (a * b) + * out : output matrix, M * N + */ +template <> +void MulOp(GpuMatrix& out, + const GpuMatrix& a, + const GpuMatrix& b, + real scale_ab, + real scale_t) { + CHECK(!out.isTransposed()) << "Not supported"; + + if (!a.isTransposed() && !b.isTransposed()) { + /// a : M * K, b: K * N + CHECK_EQ(out.width_, b.width_); + CHECK_EQ(out.height_, a.height_); + CHECK_EQ(a.width_, b.height_); + } else if (a.isTransposed() && !b.isTransposed()) { + /// a : K * M, b : K * N + CHECK_EQ(out.width_, b.width_); + CHECK_EQ(out.height_, a.width_); + CHECK_EQ(a.height_, b.height_); + } else if (!a.isTransposed() && b.isTransposed()) { + /// a: M * K, b : N * K + CHECK_EQ(out.width_, b.height_); + CHECK_EQ(out.height_, a.height_); + CHECK_EQ(a.width_, b.width_); + } else { + LOG(FATAL) << "Is not supported"; + } + + real* a_data = a.data_; + real* b_data = b.data_; + real* out_data = out.data_; + int dim_m = out.getHeight(); + int dim_n = out.getWidth(); + int dim_k = !a.isTransposed() ? a.width_ : a.height_; + int lda = a.getStride(); + int ldb = b.getStride(); + int ldc = out.getStride(); + hl_trans_op_t trans_a = !a.isTransposed() ? HPPL_OP_N : HPPL_OP_T; + hl_trans_op_t trans_b = !b.isTransposed() ? HPPL_OP_N : HPPL_OP_T; + + hl_matrix_mul(a_data, + trans_a, + b_data, + trans_b, + out_data, + dim_m, + dim_n, + dim_k, + scale_ab, + scale_t, + lda, + ldb, + ldc); +} + +/** + * out = scale_t * out + scale_ab * (a * b) + * out : M * N */ template <> void MulOp(GpuMatrix& out, @@ -32,12 +91,15 @@ void MulOp(GpuMatrix& out, CHECK(b.useGpu_ == true) << "Matrix type are not equal"; CHECK(!out.trans_ && !b.trans_) << "not supported"; if (!a.trans_) { + /// a: M * K, b: K * N CHECK(out.width_ == b.width_ && out.height_ == a.height_ - && a.width_ == b.height_) << "Matrix dimensions are not equal"; + && a.width_ == b.height_) << "Matrix dimensions are not equal"; } else { + /// a: K * M, transpose, b: K * N CHECK(out.width_ == b.width_ && out.height_ == a.width_ - && a.height_ == b.height_) << "Matrix dimensions are not equal"; + && a.height_ == b.height_) << "Matrix dimensions are not equal"; } + hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N; hl_sparse_matrix_s a_data = a.sMatrix_.get(); real* b_data = b.data_; @@ -54,4 +116,58 @@ void MulOp(GpuMatrix& out, scale_t); } +/** + * out = scale_t * out + scale_ab * (a * b) + * out : M * N + */ +template <> +void MulOp(GpuMatrix& out, + const GpuMatrix& a, + const GpuSparseMatrix& b, + real scale_ab, + real scale_t) { + CHECK(out.isContiguous()); + CHECK(a.isContiguous()); + CHECK(a.useGpu_ == true) << "Matrix type are not equal"; + + hl_sparse_matrix_s b_data = b.sMatrix_.get(); + real* a_data = a.data_; + real* out_data = out.data_; + hl_trans_op_t trans_b = b.trans_ ? HPPL_OP_T : HPPL_OP_N; + if (!b.trans_) { + /// a : M * K, b : K * N + CHECK(out.width_ == b.width_ && + out.height_ == a.height_ && a.width_ == b.height_) + << "Matrix dimensions are not equal"; + } else { + /// a : M * K, b : N * K, transpose + CHECK(out.width_ == b.height_ && + out.height_ == a.height_ && a.width_ == b.width_) + << "Matrix dimensions are not equal"; + } + if (b.format_ == SPARSE_CSC) { + hl_matrix_dense_mul_csc(a_data, + HPPL_OP_N, + b_data, + trans_b, + out_data, + out.height_, + out.width_, + a.width_, + scale_ab, + scale_t); + } else { + hl_matrix_dense_mul_csr(a_data, + HPPL_OP_N, + b_data, + trans_b, + out_data, + out.height_, + out.width_, + a.width_, + scale_ab, + scale_t); + } +} + } // namespace paddle diff --git a/paddle/function/MulOpTest.cpp b/paddle/function/MulOpTest.cpp index bc1fa9f607a575c8a55aa019a0a2340c2d6ced04..ce9d37d664710c8f434bc7e9268198644f52aa5b 100644 --- a/paddle/function/MulOpTest.cpp +++ b/paddle/function/MulOpTest.cpp @@ -22,31 +22,41 @@ using namespace paddle; // NOLINT void testSpMatrixMul(int M, int N, int K, real rate, real scale1, real scale2) { /// todo(tianbing) check CPU/GPU - const auto gpuFunc = FunctionBase::funcRegistrar_.createByType("MulOP-GPU"); + const auto gpuFunc = FunctionBase::funcRegistrar_.createByType("MulOp-GPU"); gpuFunc->init(FuncConfig().set("scaleAB", scale1).set("scaleT", scale2)); - int nnz = M * K * rate; - auto gpuA = std::make_shared(M, K, nnz); - const auto gpuB = std::make_shared(K, N); - const auto gpuOut = std::make_shared(M, N); + int nnz = M * N * rate; + MatrixPtr cpuA = std::make_shared(M, K); + MatrixPtr cpuB = std::make_shared(N, K); + MatrixPtr cpuC(new CpuSparseMatrix(M, N, nnz)); - gpuA->randomizeUniform(); - gpuB->randomizeUniform(); - gpuOut->randomizeUniform(); + MatrixPtr gpuA = std::make_shared(M, K); + MatrixPtr gpuB = std::make_shared(N, K); + MatrixPtr gpuC(new GpuSparseMatrix(M, N, nnz)); + + cpuA->randomizeUniform(); + cpuB->randomizeUniform(); + cpuC->randomizeUniform(); + + hl_stream_t stream(HPPL_STREAM_3); + gpuA->copyFrom(*cpuA, stream); + gpuB->copyFrom(*cpuB, stream); + gpuC->copyFrom(*cpuC, stream); + hl_stream_synchronize(stream); BufferArgs inputs; BufferArgs outputs; - inputs.addArg(*gpuA); - inputs.addArg(*gpuB); - outputs.addArg(*gpuOut); + inputs.addArg(*gpuA->getTranspose()); + inputs.addArg(*gpuB->getTranspose()); + outputs.addArg(*gpuC, ASSIGN_TO); gpuFunc->calc(inputs, outputs); } TEST(SMatrix, sMatrixMul) { for (auto M : {1, 40, 128, 200}) { - for (auto N : {100, 2000, 20480}) { - for (auto K : {100, 512, 1024}) { + for (auto N : {100}) { + for (auto K : {100}) { /// todo(tianbing), add scaleAB and scaleT VLOG(3) << " M=" << M << " N=" << N << " K=" << K; testSpMatrixMul(M, N, K, 0.05, 1, 1);