From ca327508ccbac09cf92841d2cb026a9478e40faf Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 17 Apr 2018 18:36:18 +0800 Subject: [PATCH] update --- paddle/fluid/framework/selected_rows.cc | 37 ++++++++++++++------ paddle/fluid/framework/selected_rows.h | 5 +-- paddle/fluid/framework/selected_rows_test.cc | 18 +++++++--- paddle/fluid/operators/sgd_op.h | 4 ++- 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index f1dbd75e403..ec82611d2e0 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -38,10 +38,10 @@ struct ReAllocateVisitor { framework::DDim dims_; }; -struct TensorSlicedCopyVisitor { - TensorSlicedCopyVisitor(const platform::Place& place, framework::Tensor* dst, - int64_t dst_offset, const framework::Tensor src, - int64_t src_offset, int64_t size) +struct TensorCopyVisitor { + TensorCopyVisitor(const platform::Place& place, framework::Tensor* dst, + int64_t dst_offset, const framework::Tensor src, + int64_t src_offset, int64_t size) : place_(place), dst_(dst), dst_offset_(dst_offset), @@ -121,10 +121,27 @@ bool SelectedRows::HasKey(int64_t key) const { : true; } -Tensor SelectedRows::Get(int64_t key) const { +bool SelectedRows::Get(int64_t key, framework::Tensor* value, + int64_t row) const { int64_t index = Index(key); PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table."); - return value_->Slice(index, index + 1); + PADDLE_ENFORCE(value->IsInitialized(), + "The value tensor should be initialized."); + + int64_t value_width = value->numel() / value->dims()[0]; + PADDLE_ENFORCE_EQ(value_width, value_->numel() / value_->dims()[0], + "output tensor should have the same shape with table " + "execpt the dims[0]."); + + // TODO(Yancey1989): support other place + platform::CPUPlace cpu; + + framework::VisitDataType( + framework::ToDataType(value_->type()), + TensorCopyVisitor(cpu, value, row * value_width, *value_.get(), + index * value_width, value_width)); + + return true; } bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { @@ -143,7 +160,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { rows_.push_back(key); index = rows_.size() - 1; is_new_key = true; - // whether need to resize the value + // whether need to resize the table if (static_cast(rows_.size()) > value_->dims()[0]) { auto dims = value_->dims(); dims[0] = (dims[0] + 1) << 1; @@ -154,9 +171,9 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { framework::VisitDataType( framework::ToDataType(value.type()), - TensorSlicedCopyVisitor(cpu, value_.get(), - index * value_->numel() / value_->dims()[0], - value, static_cast(0), value.numel())); + TensorCopyVisitor(cpu, value_.get(), + index * value_->numel() / value_->dims()[0], value, + static_cast(0), value.numel())); return is_new_key; } diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 6a125d59ec7..df580c3a145 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -62,9 +62,10 @@ class SelectedRows { * @brief Get a value by the specified key, if the * key does not exists, this function would throw an exception. * - * @return a sliced tensor + * @return true if the Get operation successed. */ - Tensor Get(int64_t key) const; + + bool Get(int64_t key, framework::Tensor* tensor, int64_t row = 0) const; /* * @brief Set a key-value pair into the table. diff --git a/paddle/fluid/framework/selected_rows_test.cc b/paddle/fluid/framework/selected_rows_test.cc index 2cbf2bfea2d..21dade1ab94 100644 --- a/paddle/fluid/framework/selected_rows_test.cc +++ b/paddle/fluid/framework/selected_rows_test.cc @@ -62,6 +62,10 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { TEST_F(SelectedRowsTester, Table) { platform::CPUPlace cpu; SelectedRows table; + // initialize a sparse table + table.mutable_value()->Resize(framework::make_ddim({1, 100})); + table.mutable_value()->mutable_data(cpu); + table.mutable_rows()->push_back(1); int64_t key = 10000; framework::Tensor value; @@ -69,15 +73,21 @@ TEST_F(SelectedRowsTester, Table) { auto ptr = value.mutable_data(cpu); ptr[0] = static_cast(10); - ASSERT_EQ(table.rows().size(), static_cast(0)); + ASSERT_EQ(table.rows().size(), static_cast(1)); ASSERT_EQ(table.HasKey(key), false); table.Set(key, value); - ASSERT_EQ(table.rows().size(), static_cast(1)); + ASSERT_EQ(table.rows().size(), static_cast(2)); ASSERT_EQ(table.HasKey(key), true); - ASSERT_EQ(table.value().dims()[0], static_cast(2)); - ASSERT_EQ(table.Get(key).data()[0], static_cast(10)); + // check re-allocate + ASSERT_EQ(table.value().dims()[0], static_cast(4)); + + framework::Tensor get_value; + get_value.mutable_data(framework::make_ddim({20, 100}), cpu); + table.Get(key, &get_value, 10); + + ASSERT_EQ(get_value.data()[10 * 100], static_cast(10)); } } // namespace framework diff --git a/paddle/fluid/operators/sgd_op.h b/paddle/fluid/operators/sgd_op.h index cfc8793e1e0..f3e88b0a0b0 100644 --- a/paddle/fluid/operators/sgd_op.h +++ b/paddle/fluid/operators/sgd_op.h @@ -107,7 +107,9 @@ class SGDOpKernel : public framework::OpKernel { for (size_t i = 0; i < grad.rows().size(); i++) { PADDLE_ENFORCE(grad.rows()[i] < grad.height(), "Input rows index should less than height"); - int64_t id_index = param.index(grad.rows()[i]); + int64_t id_index = param.Index(grad.rows()[i]); + PADDLE_ENFORCE_GE(id_index, static_cast(0), + "id should be in the table"); for (size_t j = 0; j < grad_row_width; j++) { out_data[id_index * grad_row_width + j] -= lr[0] * grad_data[i * grad_row_width + j]; -- GitLab