From 509ae79a5de846dfd38bd85618b2467066413a97 Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Thu, 26 Oct 2017 00:47:06 +0800 Subject: [PATCH] Add rowScale for CpuSparseMatrix --- paddle/math/CpuSparseMatrix.cpp | 17 +++++++++++++++++ paddle/math/CpuSparseMatrix.h | 9 +++++++++ 2 files changed, 26 insertions(+) diff --git a/paddle/math/CpuSparseMatrix.cpp b/paddle/math/CpuSparseMatrix.cpp index bf62229c03b..e211c23a7e6 100644 --- a/paddle/math/CpuSparseMatrix.cpp +++ b/paddle/math/CpuSparseMatrix.cpp @@ -260,6 +260,23 @@ void CpuSparseMatrix::printOneRow(std::ostream& os, size_t idx) const { os << ";"; } +void CpuSparseMatrix::rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c) { + CHECK(getFormat() != SPARSE_CSC) << "Not supported"; + CHECK(height_ == b.getHeight()); + CHECK(width_ == b.getWidth()); + real* A = getValue(); + real* B = b.getValue(); + for (size_t i = 0; i < height_; i++) { + size_t start = getRowStartIdx(i); + size_t end = getRowStartIdx(i + 1); + CHECK(start == b.getRowStartIdx(i)); + CHECK(end == b.getRowStartIdx(i + 1)); + for (size_t j = start; j < end; j++) { + A[j] = B[j] * c.getElement(i, cCol); + } + } +} + void CpuSparseMatrix::randomizeUniform() { CHECK_LE(elementCnt_, height_ * width_); if (valueType_ == FLOAT_VALUE) { diff --git a/paddle/math/CpuSparseMatrix.h b/paddle/math/CpuSparseMatrix.h index 36d57bbb652..8f9ad67215f 100644 --- a/paddle/math/CpuSparseMatrix.h +++ b/paddle/math/CpuSparseMatrix.h @@ -236,6 +236,15 @@ public: const unsigned int* cols, const real* values); + /** + * @brief this_row = b_row * c_row[cCol] + * + * @param[in] cCol the column of matrix c used to scale each row of b + * @param[in] b CpuSparseMatrix + * @param[in] c Matrix + */ + void rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c); + void randomizeUniform(); void copyFrom(const GpuSparseMatrix& src, hl_stream_t stream); -- GitLab