From fa2c06fb053769f1c82e4d5c98b2bcdee376d6d0 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 13 Dec 2016 17:25:38 +0800 Subject: [PATCH] Add comments --- paddle/math/RowBuffer.h | 127 ++++++++++++++++++++++++++++++++++ paddle/math/SparseRowMatrix.h | 72 +------------------ 2 files changed, 130 insertions(+), 69 deletions(-) create mode 100644 paddle/math/RowBuffer.h diff --git a/paddle/math/RowBuffer.h b/paddle/math/RowBuffer.h new file mode 100644 index 00000000000..e3582046121 --- /dev/null +++ b/paddle/math/RowBuffer.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "MemoryHandle.h" +#include "paddle/utils/Util.h" + +namespace paddle { + +/** + * @brief The RowBuffer class + * Represent the SparseRow Matrix Data. + * + * If not set memory handler, then the data could be auto growth. + */ +class RowBuffer { +public: + /** + * @brief RowBuffer create a auto-growth row buffer. The row length is width. + * @param width the length of each row, a.k.a matrix width. + */ + explicit RowBuffer(size_t width) : width_(width) {} + + /** + * @brief RowBuffer create a row buffer, which cannot be auto-growth. + * @param mem the pre-allocated memory. + * @param width the length of each row, a.k.a matrix width. + */ + RowBuffer(const CpuMemHandlePtr& mem, size_t width) + : preallocatedBuf_(mem), width_(width) {} + + /** + * @brief resize resize the buffer with rowCount + * @param rowCnt number of row. matrix height. + */ + inline void resize(int rowCnt) { + if (preallocatedBuf_) { + CHECK(preallocatedBuf_->getSize() < rowCnt * width_ * sizeof(real)); + } else { + rowStore_.resize(rowCnt * width_); + } + } + + /** + * @brief get a row buffer with row index. + * @param row the index of row. + * @return row buffer. + */ + inline const real* get(int row) const { + if (preallocatedBuf_) { + CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize()); + return reinterpret_cast(preallocatedBuf_->getBuf()) + row * width_; + } else { + CHECK_LE((row + 1) * width_, rowStore_.size()); + return rowStore_.data() + row * width_; + } + } + + /** + * @brief get a row buffer with row index. If row index is larger than local + * buffer, the size of local buffer will grow. + * @param row the index of row. + * @return row buffer. + */ + inline const real* getWithAutoGrowth(int row) { + if (preallocatedBuf_) { + return get(row); + } else { + if ((rowStore_.size() <= row * width_)) { + rowStore_.resize((row + 1) * width_); + } + return rowStore_.data() + row * width_; + } + } + + /** + * @return raw data buffer. + */ + inline real* data() { + if (preallocatedBuf_) { + return reinterpret_cast(preallocatedBuf_->getBuf()); + } else { + return rowStore_.data(); + } + } + + /** + * @brief clear local buffer. It only affect auto-growth buffer. + */ + inline void clear() { rowStore_.clear(); } + + /** + * @brief get current number of rows. + * @return number of rows. + */ + inline size_t getRowCount() const { + if (preallocatedBuf_) { + return preallocatedBuf_->getSize() / sizeof(float) / width_; + } else { + return rowStore_.size() / width_; + } + } + + /** + * @brief get is this buffer can automatically grow or not. + * @return ture if can automacitally grow. + */ + inline bool isAutoGrowth() const { return preallocatedBuf_ == nullptr; } + +private: + CpuMemHandlePtr preallocatedBuf_; + std::vector> rowStore_; + size_t width_; +}; +} // namespace paddle diff --git a/paddle/math/SparseRowMatrix.h b/paddle/math/SparseRowMatrix.h index db1530f7cfd..d77d8c3ed13 100644 --- a/paddle/math/SparseRowMatrix.h +++ b/paddle/math/SparseRowMatrix.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "Matrix.h" +#include "RowBuffer.h" #include "paddle/utils/CommandLineParser.h" #include "paddle/utils/Util.h" @@ -24,73 +25,6 @@ P_DECLARE_bool(allow_inefficient_sparse_update); namespace paddle { -/** - * @brief The RowBuffer class - * Represent the SparseRow Matrix Data. - * - * If not set memory handler, then the data could be auto growth. - */ -class RowBuffer { -public: - explicit RowBuffer(size_t width) : width_(width) {} - RowBuffer(const CpuMemHandlePtr& mem, size_t width) - : preallocatedBuf_(mem), width_(width) {} - - inline void reserve(int rowCnt) { - if (preallocatedBuf_) { - CHECK(preallocatedBuf_->getSize() < rowCnt * width_ * sizeof(real)); - } else { - rowStore_.reserve(rowCnt * width_); - } - } - - inline const real* get(int row) const { - if (preallocatedBuf_) { - CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize()); - return reinterpret_cast(preallocatedBuf_->getBuf()) + row * width_; - } else { - CHECK_LE((row + 1) * width_, rowStore_.size()); - return rowStore_.data() + row * width_; - } - } - - inline const real* getWithAutoGrowth(int row) { - if (preallocatedBuf_) { - return get(row); - } else { - if ((rowStore_.size() <= row * width_)) { - rowStore_.resize((row + 1) * width_); - } - return rowStore_.data() + row * width_; - } - } - - inline real* data() { - if (preallocatedBuf_) { - return reinterpret_cast(preallocatedBuf_->getBuf()); - } else { - return rowStore_.data(); - } - } - - inline void clear() { rowStore_.clear(); } - - inline size_t getRowCount() const { - if (preallocatedBuf_) { - return preallocatedBuf_->getSize() / sizeof(float) / width_; - } else { - return rowStore_.size() / width_; - } - } - - inline bool canAutoGrowth() const { return preallocatedBuf_ == nullptr; } - -private: - CpuMemHandlePtr preallocatedBuf_; - std::vector> rowStore_; - size_t width_; -}; - /** * Sparse Row */ @@ -146,7 +80,7 @@ public: * This is only used when SparseRowCpuMatrix is constructed with * indexDictHandle. */ - void reserveStore() { buf_->reserve(localIndices_->size()); } + void reserveStore() { buf_->resize(localIndices_->size()); } // row is the row id in the original matrix virtual real* getRowBuf(size_t row) { return getRow(row); } @@ -249,7 +183,7 @@ protected: } inline void checkStoreSize() { - if (buf_->canAutoGrowth()) { + if (buf_->isAutoGrowth()) { if (buf_->getRowCount() > 0.5 * height_) { LOG(WARNING) << "There are more than 0.5*height (" << localIndices_->size() -- GitLab