提交 ca327508 编写于 作者: Y Yancey1989

update

上级 ed6241cd
...@@ -38,8 +38,8 @@ struct ReAllocateVisitor { ...@@ -38,8 +38,8 @@ struct ReAllocateVisitor {
framework::DDim dims_; framework::DDim dims_;
}; };
struct TensorSlicedCopyVisitor { struct TensorCopyVisitor {
TensorSlicedCopyVisitor(const platform::Place& place, framework::Tensor* dst, TensorCopyVisitor(const platform::Place& place, framework::Tensor* dst,
int64_t dst_offset, const framework::Tensor src, int64_t dst_offset, const framework::Tensor src,
int64_t src_offset, int64_t size) int64_t src_offset, int64_t size)
: place_(place), : place_(place),
...@@ -121,10 +121,27 @@ bool SelectedRows::HasKey(int64_t key) const { ...@@ -121,10 +121,27 @@ bool SelectedRows::HasKey(int64_t key) const {
: true; : true;
} }
Tensor SelectedRows::Get(int64_t key) const { bool SelectedRows::Get(int64_t key, framework::Tensor* value,
int64_t row) const {
int64_t index = Index(key); int64_t index = Index(key);
PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table."); PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table.");
return value_->Slice(index, index + 1); PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
int64_t value_width = value->numel() / value->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value_->numel() / value_->dims()[0],
"output tensor should have the same shape with table "
"execpt the dims[0].");
// TODO(Yancey1989): support other place
platform::CPUPlace cpu;
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(cpu, value, row * value_width, *value_.get(),
index * value_width, value_width));
return true;
} }
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
...@@ -143,7 +160,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { ...@@ -143,7 +160,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
rows_.push_back(key); rows_.push_back(key);
index = rows_.size() - 1; index = rows_.size() - 1;
is_new_key = true; is_new_key = true;
// whether need to resize the value // whether need to resize the table
if (static_cast<int64_t>(rows_.size()) > value_->dims()[0]) { if (static_cast<int64_t>(rows_.size()) > value_->dims()[0]) {
auto dims = value_->dims(); auto dims = value_->dims();
dims[0] = (dims[0] + 1) << 1; dims[0] = (dims[0] + 1) << 1;
...@@ -154,9 +171,9 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { ...@@ -154,9 +171,9 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value.type()), framework::ToDataType(value.type()),
TensorSlicedCopyVisitor(cpu, value_.get(), TensorCopyVisitor(cpu, value_.get(),
index * value_->numel() / value_->dims()[0], index * value_->numel() / value_->dims()[0], value,
value, static_cast<int64_t>(0), value.numel())); static_cast<int64_t>(0), value.numel()));
return is_new_key; return is_new_key;
} }
......
...@@ -62,9 +62,10 @@ class SelectedRows { ...@@ -62,9 +62,10 @@ class SelectedRows {
* @brief Get a value by the specified key, if the * @brief Get a value by the specified key, if the
* key does not exists, this function would throw an exception. * key does not exists, this function would throw an exception.
* *
* @return a sliced tensor * @return true if the Get operation successed.
*/ */
Tensor Get(int64_t key) const;
bool Get(int64_t key, framework::Tensor* tensor, int64_t row = 0) const;
/* /*
* @brief Set a key-value pair into the table. * @brief Set a key-value pair into the table.
......
...@@ -62,6 +62,10 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ...@@ -62,6 +62,10 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
TEST_F(SelectedRowsTester, Table) { TEST_F(SelectedRowsTester, Table) {
platform::CPUPlace cpu; platform::CPUPlace cpu;
SelectedRows table; SelectedRows table;
// initialize a sparse table
table.mutable_value()->Resize(framework::make_ddim({1, 100}));
table.mutable_value()->mutable_data<float>(cpu);
table.mutable_rows()->push_back(1);
int64_t key = 10000; int64_t key = 10000;
framework::Tensor value; framework::Tensor value;
...@@ -69,15 +73,21 @@ TEST_F(SelectedRowsTester, Table) { ...@@ -69,15 +73,21 @@ TEST_F(SelectedRowsTester, Table) {
auto ptr = value.mutable_data<float>(cpu); auto ptr = value.mutable_data<float>(cpu);
ptr[0] = static_cast<float>(10); ptr[0] = static_cast<float>(10);
ASSERT_EQ(table.rows().size(), static_cast<size_t>(0)); ASSERT_EQ(table.rows().size(), static_cast<size_t>(1));
ASSERT_EQ(table.HasKey(key), false); ASSERT_EQ(table.HasKey(key), false);
table.Set(key, value); table.Set(key, value);
ASSERT_EQ(table.rows().size(), static_cast<size_t>(1)); ASSERT_EQ(table.rows().size(), static_cast<size_t>(2));
ASSERT_EQ(table.HasKey(key), true); ASSERT_EQ(table.HasKey(key), true);
ASSERT_EQ(table.value().dims()[0], static_cast<int64_t>(2)); // check re-allocate
ASSERT_EQ(table.Get(key).data<float>()[0], static_cast<float>(10)); ASSERT_EQ(table.value().dims()[0], static_cast<int64_t>(4));
framework::Tensor get_value;
get_value.mutable_data<float>(framework::make_ddim({20, 100}), cpu);
table.Get(key, &get_value, 10);
ASSERT_EQ(get_value.data<float>()[10 * 100], static_cast<float>(10));
} }
} // namespace framework } // namespace framework
......
...@@ -107,7 +107,9 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -107,7 +107,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < grad.rows().size(); i++) { for (size_t i = 0; i < grad.rows().size(); i++) {
PADDLE_ENFORCE(grad.rows()[i] < grad.height(), PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
"Input rows index should less than height"); "Input rows index should less than height");
int64_t id_index = param.index(grad.rows()[i]); int64_t id_index = param.Index(grad.rows()[i]);
PADDLE_ENFORCE_GE(id_index, static_cast<int64_t>(0),
"id should be in the table");
for (size_t j = 0; j < grad_row_width; j++) { for (size_t j = 0; j < grad_row_width; j++) {
out_data[id_index * grad_row_width + j] -= out_data[id_index * grad_row_width + j] -=
lr[0] * grad_data[i * grad_row_width + j]; lr[0] * grad_data[i * grad_row_width + j];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册