diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index d9d6b7dd67f1c6e4bbd6a4e1a8f0843d4cb93c05..f1dbd75e403c48a1d23bb35f507b0f1f7095c6ae 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -17,6 +17,53 @@ limitations under the License. */ namespace paddle { namespace framework { +struct ReAllocateVisitor { + ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims) + : tensor_(tensor), dims_(dims) {} + + template + void operator()() const { + framework::Tensor cpu_tensor; + platform::CPUPlace cpu; + T* ptr = cpu_tensor.mutable_data(dims_, cpu); + const T* old_ptr = + tensor_->memory_size() == 0 ? nullptr : tensor_->data(); + if (old_ptr != nullptr) { + std::copy(old_ptr, old_ptr + tensor_->numel(), ptr); + } + tensor_->ShareDataWith(cpu_tensor); + } + + framework::Tensor* tensor_; + framework::DDim dims_; +}; + +struct TensorSlicedCopyVisitor { + TensorSlicedCopyVisitor(const platform::Place& place, framework::Tensor* dst, + int64_t dst_offset, const framework::Tensor src, + int64_t src_offset, int64_t size) + : place_(place), + dst_(dst), + dst_offset_(dst_offset), + src_(src), + src_offset_(src_offset), + size_(size) {} + + template + void operator()() const { + std::copy(src_.data() + src_offset_, + src_.data() + src_offset_ + size_, + dst_->mutable_data(place_) + dst_offset_); + } + + platform::Place place_; + framework::Tensor* dst_; + int64_t dst_offset_; + framework::Tensor src_; + int64_t src_offset_; + int64_t size_; +}; + void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, const platform::DeviceContext& dev_ctx) { { // the 1st field, uint32_t version @@ -69,5 +116,49 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, TensorFromStream(is, selected_rows->mutable_value(), dev_ctx); } +bool SelectedRows::HasKey(int64_t key) const { + return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false + : true; +} + +Tensor SelectedRows::Get(int64_t key) const { + int64_t index = Index(key); + PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table."); + return value_->Slice(index, index + 1); +} + +bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { + PADDLE_ENFORCE(value.IsInitialized(), "The value should be initialized."); + if (value_->IsInitialized()) { + PADDLE_ENFORCE_EQ( + value.type(), value_->type(), + "The type of the value should be same with the original value"); + } + PADDLE_ENFORCE_EQ(value.dims()[0], static_cast(1), + "The first dim of value should be 1."); + auto index = Index(key); + platform::Place cpu = platform::CPUPlace(); + bool is_new_key = false; + if (index == -1) { + rows_.push_back(key); + index = rows_.size() - 1; + is_new_key = true; + // whether need to resize the value + if (static_cast(rows_.size()) > value_->dims()[0]) { + auto dims = value_->dims(); + dims[0] = (dims[0] + 1) << 1; + framework::VisitDataType(framework::ToDataType(value.type()), + ReAllocateVisitor(value_.get(), dims)); + } + } + + framework::VisitDataType( + framework::ToDataType(value.type()), + TensorSlicedCopyVisitor(cpu, value_.get(), + index * value_->numel() / value_->dims()[0], + value, static_cast(0), value.numel())); + return is_new_key; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 8e2d9470d3954e0f66c74828a8d8292c2875a8f4..6a125d59ec77a871a99ab539b3ccf9ea509d2cdd 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/lod_tensor.h" @@ -50,12 +51,45 @@ class SelectedRows { void set_rows(const Vector& rows) { rows_ = rows; } - /** - * get the index of id in rows + /* + * @brief wheter has the specified key in the table. + * + * @return true if the key is exists. */ - int64_t index(int64_t id) const { - auto it = std::find(rows_.begin(), rows_.end(), id); - PADDLE_ENFORCE(it != rows_.end(), "id should be in rows"); + bool HasKey(int64_t key) const; + + /* + * @brief Get a value by the specified key, if the + * key does not exists, this function would throw an exception. + * + * @return a sliced tensor + */ + Tensor Get(int64_t key) const; + + /* + * @brief Set a key-value pair into the table. + * This function will double the value memory if it's not engouth. + * + * @note: + * 1. The first dim of the value should be 1 + * 2. The value should be initialized and the data type + * should be the same with the table. + * + * @return true if the key is a new one, otherwise false + * + */ + bool Set(int64_t key, const Tensor& value); + + /* + * @brief Get the index of key in rows + * + * @return -1 if the key does not exists. + */ + int64_t Index(int64_t key) const { + auto it = std::find(rows_.begin(), rows_.end(), key); + if (it == rows_.end()) { + return static_cast(-1); + } return static_cast(std::distance(rows_.begin(), it)); } diff --git a/paddle/fluid/framework/selected_rows_test.cc b/paddle/fluid/framework/selected_rows_test.cc index 960d8d64f04a819217413ff881977ce5fb5a30f2..2cbf2bfea2d5d00e54e637bab1a42eb68264f88c 100644 --- a/paddle/fluid/framework/selected_rows_test.cc +++ b/paddle/fluid/framework/selected_rows_test.cc @@ -17,7 +17,7 @@ namespace framework { class SelectedRowsTester : public ::testing::Test { public: - virtual void SetUp() override { + void SetUp() override { std::vector rows{0, 4, 7}; int64_t height = 10; int64_t row_numel = 100; @@ -59,5 +59,26 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims()); } +TEST_F(SelectedRowsTester, Table) { + platform::CPUPlace cpu; + SelectedRows table; + + int64_t key = 10000; + framework::Tensor value; + value.Resize(framework::make_ddim({1, 100})); + auto ptr = value.mutable_data(cpu); + ptr[0] = static_cast(10); + + ASSERT_EQ(table.rows().size(), static_cast(0)); + ASSERT_EQ(table.HasKey(key), false); + + table.Set(key, value); + + ASSERT_EQ(table.rows().size(), static_cast(1)); + ASSERT_EQ(table.HasKey(key), true); + ASSERT_EQ(table.value().dims()[0], static_cast(2)); + ASSERT_EQ(table.Get(key).data()[0], static_cast(10)); +} + } // namespace framework } // namespace paddle