未验证 提交 0b8630b9 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #9897 from Yancey1989/auto_grwon_sparse_table

Auto-grown sparse table
...@@ -17,6 +17,52 @@ limitations under the License. */ ...@@ -17,6 +17,52 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct ReAllocateVisitor {
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims)
: tensor_(tensor), dims_(dims) {}
template <typename T>
void operator()() const {
framework::Tensor cpu_tensor;
platform::CPUPlace cpu;
T* ptr = cpu_tensor.mutable_data<T>(dims_, cpu);
const T* old_ptr =
tensor_->memory_size() == 0 ? nullptr : tensor_->data<T>();
if (old_ptr != nullptr) {
std::copy(old_ptr, old_ptr + tensor_->numel(), ptr);
}
tensor_->ShareDataWith(cpu_tensor);
}
framework::Tensor* tensor_;
framework::DDim dims_;
};
struct TensorCopyVisitor {
TensorCopyVisitor(framework::Tensor* dst, int64_t dst_offset,
const framework::Tensor src, int64_t src_offset,
int64_t size)
: dst_(dst),
dst_offset_(dst_offset),
src_(src),
src_offset_(src_offset),
size_(size) {}
template <typename T>
void operator()() const {
// TODO(Yancey1989): support other place
platform::CPUPlace cpu;
memory::Copy(cpu, dst_->mutable_data<T>(cpu) + dst_offset_, cpu,
src_.data<T>() + src_offset_, size_ * sizeof(T));
}
framework::Tensor* dst_;
int64_t dst_offset_;
framework::Tensor src_;
int64_t src_offset_;
int64_t size_;
};
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version { // the 1st field, uint32_t version
...@@ -69,5 +115,66 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, ...@@ -69,5 +115,66 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
TensorFromStream(is, selected_rows->mutable_value(), dev_ctx); TensorFromStream(is, selected_rows->mutable_value(), dev_ctx);
} }
bool SelectedRows::HasKey(int64_t key) const {
return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false
: true;
}
std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
std::vector<int64_t> non_keys;
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
"execpt the dims[0].");
for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]);
if (index == -1) {
non_keys.push_back(keys[i]);
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
}
return non_keys;
}
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
PADDLE_ENFORCE(value.IsInitialized(), "The value should be initialized.");
if (value_->IsInitialized()) {
PADDLE_ENFORCE_EQ(
value.type(), value_->type(),
"The type of the value should be same with the original value");
}
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
"The first dim of value should be 1.");
auto index = Index(key);
bool is_new_key = false;
if (index == -1) {
rows_.push_back(key);
index = rows_.size() - 1;
is_new_key = true;
// whether need to resize the table
if (static_cast<int64_t>(rows_.size()) > value_->dims()[0]) {
auto dims = value_->dims();
dims[0] = (dims[0] + 1) << 1;
framework::VisitDataType(framework::ToDataType(value.type()),
ReAllocateVisitor(value_.get(), dims));
}
}
framework::VisitDataType(
framework::ToDataType(value.type()),
TensorCopyVisitor(value_.get(),
index * value_->numel() / value_->dims()[0], value,
static_cast<int64_t>(0), value.numel()));
return is_new_key;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,15 +14,33 @@ limitations under the License. */ ...@@ -14,15 +14,33 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/memcpy.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class SelectedRows { class SelectedRows {
/*
* @brief We can use the SelectedRows structure to reproduce a sparse table.
* A sparse table is a key-value structure that the key is an `int64_t`
* number,
* and the value is a Tensor which the first dimension is 0.
* You can use the following interface to operate the sparse table, and you
* can find
* some detail information from the comments of each interface:
*
* HasKey(key), whether the sparse table has the specified key.
* Set(key, value), set a key-value pair into the sparse table.
* Get(keys, value*), get value by given key list and apply it to the given
* value pointer
* with the specified offset.
*
*/
public: public:
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height) SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) { : rows_(rows), height_(height) {
...@@ -50,12 +68,45 @@ class SelectedRows { ...@@ -50,12 +68,45 @@ class SelectedRows {
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; } void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
/** /*
* get the index of id in rows * @brief wheter has the specified key in the table.
*
* @return true if the key is exists.
*/
bool HasKey(int64_t key) const;
/*
* @brief Get value by the key list, if the
*
* @return a list of keys which does not exists in table
*/
std::vector<int64_t> Get(std::vector<int64_t> keys,
framework::Tensor* tensor) const;
/*
* @brief Set a key-value pair into the table.
* This function will double the value memory if it's not engouth.
*
* @note:
* 1. The first dim of the value should be 1
* 2. The value should be initialized and the data type
* should be the same with the table.
*
* @return true if the key is a new one, otherwise false
*
*/ */
int64_t index(int64_t id) const { bool Set(int64_t key, const Tensor& value);
auto it = std::find(rows_.begin(), rows_.end(), id);
PADDLE_ENFORCE(it != rows_.end(), "id should be in rows"); /*
* @brief Get the index of key in rows
*
* @return -1 if the key does not exists.
*/
int64_t Index(int64_t key) const {
auto it = std::find(rows_.begin(), rows_.end(), key);
if (it == rows_.end()) {
return static_cast<int64_t>(-1);
}
return static_cast<int64_t>(std::distance(rows_.begin(), it)); return static_cast<int64_t>(std::distance(rows_.begin(), it));
} }
......
...@@ -17,7 +17,7 @@ namespace framework { ...@@ -17,7 +17,7 @@ namespace framework {
class SelectedRowsTester : public ::testing::Test { class SelectedRowsTester : public ::testing::Test {
public: public:
virtual void SetUp() override { void SetUp() override {
std::vector<int64_t> rows{0, 4, 7}; std::vector<int64_t> rows{0, 4, 7};
int64_t height = 10; int64_t height = 10;
int64_t row_numel = 100; int64_t row_numel = 100;
...@@ -59,5 +59,40 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ...@@ -59,5 +59,40 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims()); ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
} }
TEST_F(SelectedRowsTester, Table) {
platform::CPUPlace cpu;
SelectedRows table;
// initialize a sparse table
table.mutable_value()->Resize(framework::make_ddim({1, 100}));
table.mutable_value()->mutable_data<float>(cpu);
table.mutable_rows()->push_back(1);
int64_t key = 10000;
int64_t non_key = 999;
framework::Tensor value;
value.Resize(framework::make_ddim({1, 100}));
auto ptr = value.mutable_data<float>(cpu);
ptr[0] = static_cast<float>(10);
ASSERT_EQ(table.rows().size(), static_cast<size_t>(1));
ASSERT_EQ(table.HasKey(key), false);
table.Set(key, value);
ASSERT_EQ(table.rows().size(), static_cast<size_t>(2));
ASSERT_EQ(table.HasKey(key), true);
// check re-allocate
ASSERT_EQ(table.value().dims()[0], static_cast<int64_t>(4));
framework::Tensor get_value;
get_value.mutable_data<float>(framework::make_ddim({2, 100}), cpu);
std::vector<int64_t> keys({non_key, key});
auto non_keys = table.Get(keys, &get_value);
ASSERT_EQ(get_value.data<float>()[100], static_cast<float>(10));
ASSERT_EQ(non_keys.size(), static_cast<size_t>(1));
ASSERT_EQ(non_keys[0], non_key);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -103,7 +103,8 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -103,7 +103,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
memset(output + i * row_width, 0, row_width * sizeof(T)); memset(output + i * row_width, 0, row_width * sizeof(T));
} else { } else {
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.index(ids[i]); auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
memcpy(output + i * row_width, table + id_index * row_width, memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} }
......
...@@ -107,7 +107,9 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -107,7 +107,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < grad.rows().size(); i++) { for (size_t i = 0; i < grad.rows().size(); i++) {
PADDLE_ENFORCE(grad.rows()[i] < grad.height(), PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
"Input rows index should less than height"); "Input rows index should less than height");
int64_t id_index = param.index(grad.rows()[i]); int64_t id_index = param.Index(grad.rows()[i]);
PADDLE_ENFORCE_GE(id_index, static_cast<int64_t>(0),
"id should be in the table");
for (size_t j = 0; j < grad_row_width; j++) { for (size_t j = 0; j < grad_row_width; j++) {
out_data[id_index * grad_row_width + j] -= out_data[id_index * grad_row_width + j] -=
lr[0] * grad_data[i * grad_row_width + j]; lr[0] * grad_data[i * grad_row_width + j];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册