提交 a3043989 编写于 作者: Y Yu Yang

Extract RowBuffer class for SparseRowMatrix.

* The original SparseRowMatrix use two fields to store each rows,
  which let code very confusing. Try to extract a RowBuffer class,
  for SparseRowMatrix data storage, and manage auto-growth logic.
上级 f8e8d1ad
......@@ -293,7 +293,7 @@ TEST(PyDataProvider2, can_over_batch_size) {
while (true) {
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (realBatchSize) {
CHECK_LE(realBatchSize, batchSize);
CHECK_LE((size_t)realBatchSize, batchSize);
} else {
break;
}
......
......@@ -24,6 +24,73 @@ P_DECLARE_bool(allow_inefficient_sparse_update);
namespace paddle {
/**
* @brief The RowBuffer class
* Represent the SparseRow Matrix Data.
*
* If not set memory handler, then the data could be auto growth.
*/
class RowBuffer {
public:
explicit RowBuffer(size_t width) : width_(width) {}
RowBuffer(const CpuMemHandlePtr& mem, size_t width)
: preallocatedBuf_(mem), width_(width) {}
inline void reserve(int rowCnt) {
if (preallocatedBuf_) {
CHECK(preallocatedBuf_->getSize() < rowCnt * width_ * sizeof(real));
} else {
rowStore_.reserve(rowCnt * width_);
}
}
inline const real* get(int row) const {
if (preallocatedBuf_) {
CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize());
return reinterpret_cast<real*>(preallocatedBuf_->getBuf()) + row * width_;
} else {
CHECK_LE((row + 1) * width_, rowStore_.size());
return rowStore_.data() + row * width_;
}
}
inline const real* getWithAutoGrowth(int row) {
if (preallocatedBuf_) {
return get(row);
} else {
if ((rowStore_.size() <= row * width_)) {
rowStore_.resize((row + 1) * width_);
}
return rowStore_.data() + row * width_;
}
}
inline real* data() {
if (preallocatedBuf_) {
return reinterpret_cast<real*>(preallocatedBuf_->getBuf());
} else {
return rowStore_.data();
}
}
inline void clear() { rowStore_.clear(); }
inline size_t getRowCount() const {
if (preallocatedBuf_) {
return preallocatedBuf_->getSize() / sizeof(float) / width_;
} else {
return rowStore_.size() / width_;
}
}
inline bool canAutoGrowth() const { return preallocatedBuf_ == nullptr; }
private:
CpuMemHandlePtr preallocatedBuf_;
std::vector<real, AlignedAllocator<real, 32>> rowStore_;
size_t width_;
};
/**
* Sparse Row
*/
......@@ -45,12 +112,9 @@ public:
IndexDictPtr indexDictHandle = nullptr,
bool trans = false)
: CpuMatrix(nullptr, height, width, trans),
storeMat_(dataHandle,
dataHandle ? dataHandle->getSize() / sizeof(real) / width : 0,
width,
trans),
indexDictHandle_(indexDictHandle) {
init(height, width);
buf_.reset(new RowBuffer(dataHandle, width));
}
virtual ~SparseRowCpuMatrix() {}
......@@ -72,24 +136,17 @@ public:
* @param row row id in local storage
*/
real* getLocalRow(size_t row) {
if (storeMat_.getData()) return storeMat_.rowBuf(row);
if (rowStore_.size() <= row * width_) {
rowStore_.resize((row + 1) * width_);
}
return rowStore_.data() + row * width_;
return const_cast<real*>(buf_->getWithAutoGrowth(row));
}
/**
* reserve the storage for rows according to current size of indexDictHandle.
* reserve the storage for rows according to current size of
* indexDictHandle.
*
* This is only used when SparseRowCpuMatrix is constructed with
* indexDictHandle.
*/
void reserveStore() {
if (!storeMat_.getData() && !localIndices_->empty()) {
rowStore_.resize(localIndices_->size() * width_);
}
}
void reserveStore() { buf_->reserve(localIndices_->size()); }
// row is the row id in the original matrix
virtual real* getRowBuf(size_t row) { return getRow(row); }
......@@ -117,7 +174,8 @@ public:
*
* If L1 decay set use L1, else if L2 set use L2, otherwise no decay atall.
*
* t0 is a int vector used by L1/L2 decay, size = height of parameter matrix,
* t0 is a int vector used by L1/L2 decay, size = height of parameter
* matrix,
* store the time that each weight row last updated.
*
* Time is batchId, currentTime is current batchId.
......@@ -176,8 +234,7 @@ public:
protected:
template <typename Func>
void apply(Func f) {
real* data = storeMat_.getData() ? storeMat_.getData() : rowStore_.data();
f(data, localIndices_->size() * width_);
f(buf_->data(), localIndices_->size() * width_);
}
void init(size_t height, size_t width);
......@@ -188,25 +245,25 @@ protected:
globalIndices_[id] = kUnusedId_;
}
localIndices_->clear();
rowStore_.clear();
buf_->clear();
}
inline void checkStoreSize() {
if (storeMat_.getData()) {
CHECK_LE(localIndices_->size(), storeMat_.getHeight());
} else if (!FLAGS_allow_inefficient_sparse_update) {
if (localIndices_->size() > 0.5 * height_) {
if (buf_->canAutoGrowth()) {
if (buf_->getRowCount() > 0.5 * height_) {
LOG(WARNING)
<< "There are more than 0.5*height (" << localIndices_->size()
<< ") rows are used for sparse "
<< "update, which is not efficient. Considering not use "
<< "sparse_update or set --allow_inefficient_sparse_update=true";
} else {
CHECK_LE(localIndices_->size(), buf_->getRowCount());
}
}
}
CpuMatrix storeMat_;
std::vector<real, AlignedAllocator<real, 32>> rowStore_;
std::unique_ptr<RowBuffer> buf_;
IndexDictPtr indexDictHandle_;
std::vector<unsigned int>* localIndices_; // =&indexDictHandle_->localIndices
unsigned int* globalIndices_; // =indexDictHandle_->globalIndices.data();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册