diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index b1837ca3c425a5bd580f4403c24fc3f2b0b91972..2eefe79588f32658b1f0eab317528a31142d7ce5 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -121,27 +121,32 @@ bool SelectedRows::HasKey(int64_t key) const { : true; } -bool SelectedRows::Get(int64_t key, framework::Tensor* value, - int64_t offset) const { - int64_t index = Index(key); - PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table."); +std::vector SelectedRows::Get(std::vector keys, + framework::Tensor* value) const { PADDLE_ENFORCE(value->IsInitialized(), "The value tensor should be initialized."); - int64_t value_width = value->numel() / value->dims()[0]; - PADDLE_ENFORCE_EQ(value_width, value_->numel() / value_->dims()[0], + std::vector non_keys; + int64_t value_width = value_->numel() / value_->dims()[0]; + PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], "output tensor should have the same shape with table " "execpt the dims[0]."); // TODO(Yancey1989): support other place platform::CPUPlace cpu; - framework::VisitDataType( - framework::ToDataType(value_->type()), - TensorCopyVisitor(cpu, value, offset * value_width, *value_.get(), - index * value_width, value_width)); - - return true; + for (size_t i = 0; i < keys.size(); ++i) { + int64_t index = Index(keys[i]); + if (index == -1) { + non_keys.push_back(keys[i]); + } else { + framework::VisitDataType( + framework::ToDataType(value_->type()), + TensorCopyVisitor(cpu, value, i * value_width, *value_.get(), + index * value_width, value_width)); + } + } + return non_keys; } bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index f329ae8939d996363d469c593903f8bfb0b2abe3..cef3ddab475eee45cb1eec9064ae75cf9ba4ab53 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -35,7 +35,7 @@ class SelectedRows { * * HasKey(key), whether the sparse table has the specified key. * Set(key, value), set a key-value pair into the sparse table. - * Get(key, value*, offset), get a value by key and apply it to the given + * Get(keys, value*), get value by given key list and apply it to the given * value pointer * with the specified offset. * @@ -75,13 +75,12 @@ class SelectedRows { bool HasKey(int64_t key) const; /* - * @brief Get a value by the specified key, if the - * key does not exists, this function would throw an exception. + * @brief Get value by the key list, if the * - * @return true if the Get operation successed. + * @return a list of keys which does not exists in table */ - - bool Get(int64_t key, framework::Tensor* tensor, int64_t offset = 0) const; + std::vector Get(std::vector keys, + framework::Tensor* tensor) const; /* * @brief Set a key-value pair into the table. diff --git a/paddle/fluid/framework/selected_rows_test.cc b/paddle/fluid/framework/selected_rows_test.cc index 21dade1ab94fd94d1471fc34e4d8e43de4007122..39fe6d92940606084c28eec1a4d6486cb58844ce 100644 --- a/paddle/fluid/framework/selected_rows_test.cc +++ b/paddle/fluid/framework/selected_rows_test.cc @@ -68,6 +68,7 @@ TEST_F(SelectedRowsTester, Table) { table.mutable_rows()->push_back(1); int64_t key = 10000; + int64_t non_key = 999; framework::Tensor value; value.Resize(framework::make_ddim({1, 100})); auto ptr = value.mutable_data(cpu); @@ -84,10 +85,13 @@ TEST_F(SelectedRowsTester, Table) { ASSERT_EQ(table.value().dims()[0], static_cast(4)); framework::Tensor get_value; - get_value.mutable_data(framework::make_ddim({20, 100}), cpu); - table.Get(key, &get_value, 10); + get_value.mutable_data(framework::make_ddim({2, 100}), cpu); + std::vector keys({non_key, key}); + auto non_keys = table.Get(keys, &get_value); - ASSERT_EQ(get_value.data()[10 * 100], static_cast(10)); + ASSERT_EQ(get_value.data()[100], static_cast(10)); + ASSERT_EQ(non_keys.size(), static_cast(1)); + ASSERT_EQ(non_keys[0], non_key); } } // namespace framework