From a3043989a4d2fb082e19173e430809c8025afba2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 13 Dec 2016 16:57:09 +0800 Subject: [PATCH] Extract RowBuffer class for SparseRowMatrix. * The original SparseRowMatrix use two fields to store each rows, which let code very confusing. Try to extract a RowBuffer class, for SparseRowMatrix data storage, and manage auto-growth logic. --- paddle/gserver/tests/test_PyDataProvider2.cpp | 2 +- paddle/math/SparseRowMatrix.h | 107 ++++++++++++++---- 2 files changed, 83 insertions(+), 26 deletions(-) diff --git a/paddle/gserver/tests/test_PyDataProvider2.cpp b/paddle/gserver/tests/test_PyDataProvider2.cpp index 436318d3563..7a3b51da8b2 100644 --- a/paddle/gserver/tests/test_PyDataProvider2.cpp +++ b/paddle/gserver/tests/test_PyDataProvider2.cpp @@ -293,7 +293,7 @@ TEST(PyDataProvider2, can_over_batch_size) { while (true) { int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); if (realBatchSize) { - CHECK_LE(realBatchSize, batchSize); + CHECK_LE((size_t)realBatchSize, batchSize); } else { break; } diff --git a/paddle/math/SparseRowMatrix.h b/paddle/math/SparseRowMatrix.h index badb4b9c1cc..db1530f7cfd 100644 --- a/paddle/math/SparseRowMatrix.h +++ b/paddle/math/SparseRowMatrix.h @@ -24,6 +24,73 @@ P_DECLARE_bool(allow_inefficient_sparse_update); namespace paddle { +/** + * @brief The RowBuffer class + * Represent the SparseRow Matrix Data. + * + * If not set memory handler, then the data could be auto growth. + */ +class RowBuffer { +public: + explicit RowBuffer(size_t width) : width_(width) {} + RowBuffer(const CpuMemHandlePtr& mem, size_t width) + : preallocatedBuf_(mem), width_(width) {} + + inline void reserve(int rowCnt) { + if (preallocatedBuf_) { + CHECK(preallocatedBuf_->getSize() < rowCnt * width_ * sizeof(real)); + } else { + rowStore_.reserve(rowCnt * width_); + } + } + + inline const real* get(int row) const { + if (preallocatedBuf_) { + CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize()); + return reinterpret_cast(preallocatedBuf_->getBuf()) + row * width_; + } else { + CHECK_LE((row + 1) * width_, rowStore_.size()); + return rowStore_.data() + row * width_; + } + } + + inline const real* getWithAutoGrowth(int row) { + if (preallocatedBuf_) { + return get(row); + } else { + if ((rowStore_.size() <= row * width_)) { + rowStore_.resize((row + 1) * width_); + } + return rowStore_.data() + row * width_; + } + } + + inline real* data() { + if (preallocatedBuf_) { + return reinterpret_cast(preallocatedBuf_->getBuf()); + } else { + return rowStore_.data(); + } + } + + inline void clear() { rowStore_.clear(); } + + inline size_t getRowCount() const { + if (preallocatedBuf_) { + return preallocatedBuf_->getSize() / sizeof(float) / width_; + } else { + return rowStore_.size() / width_; + } + } + + inline bool canAutoGrowth() const { return preallocatedBuf_ == nullptr; } + +private: + CpuMemHandlePtr preallocatedBuf_; + std::vector> rowStore_; + size_t width_; +}; + /** * Sparse Row */ @@ -45,12 +112,9 @@ public: IndexDictPtr indexDictHandle = nullptr, bool trans = false) : CpuMatrix(nullptr, height, width, trans), - storeMat_(dataHandle, - dataHandle ? dataHandle->getSize() / sizeof(real) / width : 0, - width, - trans), indexDictHandle_(indexDictHandle) { init(height, width); + buf_.reset(new RowBuffer(dataHandle, width)); } virtual ~SparseRowCpuMatrix() {} @@ -72,24 +136,17 @@ public: * @param row row id in local storage */ real* getLocalRow(size_t row) { - if (storeMat_.getData()) return storeMat_.rowBuf(row); - if (rowStore_.size() <= row * width_) { - rowStore_.resize((row + 1) * width_); - } - return rowStore_.data() + row * width_; + return const_cast(buf_->getWithAutoGrowth(row)); } /** - * reserve the storage for rows according to current size of indexDictHandle. + * reserve the storage for rows according to current size of + * indexDictHandle. * * This is only used when SparseRowCpuMatrix is constructed with * indexDictHandle. */ - void reserveStore() { - if (!storeMat_.getData() && !localIndices_->empty()) { - rowStore_.resize(localIndices_->size() * width_); - } - } + void reserveStore() { buf_->reserve(localIndices_->size()); } // row is the row id in the original matrix virtual real* getRowBuf(size_t row) { return getRow(row); } @@ -117,7 +174,8 @@ public: * * If L1 decay set use L1, else if L2 set use L2, otherwise no decay atall. * - * t0 is a int vector used by L1/L2 decay, size = height of parameter matrix, + * t0 is a int vector used by L1/L2 decay, size = height of parameter + * matrix, * store the time that each weight row last updated. * * Time is batchId, currentTime is current batchId. @@ -176,8 +234,7 @@ public: protected: template void apply(Func f) { - real* data = storeMat_.getData() ? storeMat_.getData() : rowStore_.data(); - f(data, localIndices_->size() * width_); + f(buf_->data(), localIndices_->size() * width_); } void init(size_t height, size_t width); @@ -188,25 +245,25 @@ protected: globalIndices_[id] = kUnusedId_; } localIndices_->clear(); - rowStore_.clear(); + buf_->clear(); } inline void checkStoreSize() { - if (storeMat_.getData()) { - CHECK_LE(localIndices_->size(), storeMat_.getHeight()); - } else if (!FLAGS_allow_inefficient_sparse_update) { - if (localIndices_->size() > 0.5 * height_) { + if (buf_->canAutoGrowth()) { + if (buf_->getRowCount() > 0.5 * height_) { LOG(WARNING) << "There are more than 0.5*height (" << localIndices_->size() << ") rows are used for sparse " << "update, which is not efficient. Considering not use " << "sparse_update or set --allow_inefficient_sparse_update=true"; + + } else { + CHECK_LE(localIndices_->size(), buf_->getRowCount()); } } } - CpuMatrix storeMat_; - std::vector> rowStore_; + std::unique_ptr buf_; IndexDictPtr indexDictHandle_; std::vector* localIndices_; // =&indexDictHandle_->localIndices unsigned int* globalIndices_; // =indexDictHandle_->globalIndices.data(); -- GitLab