diff --git a/paddle/math/RowBuffer.h b/paddle/math/RowBuffer.h index e35820461214aab0dcba7cb51d9558231fd1ad9f..bb55ca5f9f0d241fb00e124dedaf880c124aeee8 100644 --- a/paddle/math/RowBuffer.h +++ b/paddle/math/RowBuffer.h @@ -58,13 +58,13 @@ public: * @param row the index of row. * @return row buffer. */ - inline const real* get(int row) const { + inline real* get(int row) const { if (preallocatedBuf_) { CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize()); return reinterpret_cast(preallocatedBuf_->getBuf()) + row * width_; } else { CHECK_LE((row + 1) * width_, rowStore_.size()); - return rowStore_.data() + row * width_; + return const_cast(rowStore_.data() + row * width_); } } @@ -74,7 +74,7 @@ public: * @param row the index of row. * @return row buffer. */ - inline const real* getWithAutoGrowth(int row) { + inline real* getWithAutoGrowth(int row) { if (preallocatedBuf_) { return get(row); } else { @@ -119,6 +119,12 @@ public: */ inline bool isAutoGrowth() const { return preallocatedBuf_ == nullptr; } + /** + * @brief return the width of matrix. a.k.a length of row. + * @return width of matrix + */ + inline size_t getWidth() const { return width_; } + private: CpuMemHandlePtr preallocatedBuf_; std::vector> rowStore_; diff --git a/paddle/math/SparseRowMatrix.h b/paddle/math/SparseRowMatrix.h index d77d8c3ed13b14270b16280acb3f62db44e4ff49..8532bca879dfda4f6a02bce11727a5cf1e22d955 100644 --- a/paddle/math/SparseRowMatrix.h +++ b/paddle/math/SparseRowMatrix.h @@ -69,9 +69,7 @@ public: * * @param row row id in local storage */ - real* getLocalRow(size_t row) { - return const_cast(buf_->getWithAutoGrowth(row)); - } + real* getLocalRow(size_t row) { return buf_->getWithAutoGrowth(row); } /** * reserve the storage for rows according to current size of diff --git a/paddle/math/tests/CMakeLists.txt b/paddle/math/tests/CMakeLists.txt index fe5177291c21c3505c3694201b36b54397150ccf..9403bb073a273a369f8b7520eb93ae7eca32d0e3 100644 --- a/paddle/math/tests/CMakeLists.txt +++ b/paddle/math/tests/CMakeLists.txt @@ -4,6 +4,7 @@ add_simple_unittest(test_ExecViaCpu) add_simple_unittest(test_SIMDFunctions) add_simple_unittest(test_TrainingAlgorithm) add_simple_unittest(test_SparseMatrix) +add_simple_unittest(test_RowBuffer) # TODO(yuyang18): Refactor TestUtil.cpp. Remove this cross module reference. add_unittest(test_matrixCompare diff --git a/paddle/math/tests/test_RowBuffer.cpp b/paddle/math/tests/test_RowBuffer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f66f22ef73dcff1868c1a3e03139a680b1ce2b5 --- /dev/null +++ b/paddle/math/tests/test_RowBuffer.cpp @@ -0,0 +1,65 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/math/RowBuffer.h" + +TEST(RowBuffer, testAutoGrow) { + paddle::RowBuffer buf(128); + ASSERT_EQ(128, buf.getWidth()); + ASSERT_TRUE(buf.isAutoGrowth()); + buf.resize(2); + ASSERT_EQ(2, buf.getRowCount()); + for (size_t i = 0; i < buf.getWidth() * 2; ++i) { + buf.data()[i] = i; + } + for (size_t i = 0; i < buf.getRowCount(); ++i) { + for (size_t j = 0; j < buf.getWidth(); ++j) { + ASSERT_NEAR(i * buf.getWidth() + j, buf.get(i)[j], 1e-5); + } + } + + auto data = buf.getWithAutoGrowth(2); + for (size_t i = 0; i < buf.getWidth(); ++i) { + data[i] = i; + } + + ASSERT_EQ(3, buf.getRowCount()); + for (size_t i = 0; i < buf.getRowCount() - 1; ++i) { + for (size_t j = 0; j < buf.getWidth(); ++j) { + ASSERT_NEAR(i * buf.getWidth() + j, buf.get(i)[j], 1e-5); + } + } + for (size_t i = 0; i < buf.getWidth(); ++i) { + ASSERT_NEAR(i, buf.get(2)[i], 1e-5); + } +} + +TEST(RowBuffer, testWithMemBuf) { + paddle::CpuMemHandlePtr mem = + std::make_shared(128 * 2 * sizeof(real)); + paddle::RowBuffer buf(mem, 128); + ASSERT_TRUE(!buf.isAutoGrowth()); + ASSERT_EQ(2, buf.getRowCount()); + for (size_t i = 0; i < buf.getWidth() * 2; ++i) { + buf.data()[i] = i; + } + for (size_t i = 0; i < buf.getRowCount(); ++i) { + for (size_t j = 0; j < buf.getWidth(); ++j) { + ASSERT_NEAR(i * buf.getWidth() + j, buf.getWithAutoGrowth(i)[j], 1e-5); + } + } + + ASSERT_DEATH_IF_SUPPORTED(buf.getWithAutoGrowth(3), ".*"); +}