提交 70bf732f 编写于 作者: Y Yancey1989

refine get interface

上级 b920b516
......@@ -121,27 +121,32 @@ bool SelectedRows::HasKey(int64_t key) const {
: true;
}
bool SelectedRows::Get(int64_t key, framework::Tensor* value,
int64_t offset) const {
int64_t index = Index(key);
PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table.");
std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
int64_t value_width = value->numel() / value->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value_->numel() / value_->dims()[0],
std::vector<int64_t> non_keys;
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
"execpt the dims[0].");
// TODO(Yancey1989): support other place
platform::CPUPlace cpu;
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(cpu, value, offset * value_width, *value_.get(),
index * value_width, value_width));
return true;
for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]);
if (index == -1) {
non_keys.push_back(keys[i]);
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(cpu, value, i * value_width, *value_.get(),
index * value_width, value_width));
}
}
return non_keys;
}
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
......
......@@ -35,7 +35,7 @@ class SelectedRows {
*
* HasKey(key), whether the sparse table has the specified key.
* Set(key, value), set a key-value pair into the sparse table.
* Get(key, value*, offset), get a value by key and apply it to the given
* Get(keys, value*), get value by given key list and apply it to the given
* value pointer
* with the specified offset.
*
......@@ -75,13 +75,12 @@ class SelectedRows {
bool HasKey(int64_t key) const;
/*
* @brief Get a value by the specified key, if the
* key does not exists, this function would throw an exception.
* @brief Get value by the key list, if the
*
* @return true if the Get operation successed.
* @return a list of keys which does not exists in table
*/
bool Get(int64_t key, framework::Tensor* tensor, int64_t offset = 0) const;
std::vector<int64_t> Get(std::vector<int64_t> keys,
framework::Tensor* tensor) const;
/*
* @brief Set a key-value pair into the table.
......
......@@ -68,6 +68,7 @@ TEST_F(SelectedRowsTester, Table) {
table.mutable_rows()->push_back(1);
int64_t key = 10000;
int64_t non_key = 999;
framework::Tensor value;
value.Resize(framework::make_ddim({1, 100}));
auto ptr = value.mutable_data<float>(cpu);
......@@ -84,10 +85,13 @@ TEST_F(SelectedRowsTester, Table) {
ASSERT_EQ(table.value().dims()[0], static_cast<int64_t>(4));
framework::Tensor get_value;
get_value.mutable_data<float>(framework::make_ddim({20, 100}), cpu);
table.Get(key, &get_value, 10);
get_value.mutable_data<float>(framework::make_ddim({2, 100}), cpu);
std::vector<int64_t> keys({non_key, key});
auto non_keys = table.Get(keys, &get_value);
ASSERT_EQ(get_value.data<float>()[10 * 100], static_cast<float>(10));
ASSERT_EQ(get_value.data<float>()[100], static_cast<float>(10));
ASSERT_EQ(non_keys.size(), static_cast<size_t>(1));
ASSERT_EQ(non_keys[0], non_key);
}
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册