提交 3b3d210c 编写于 作者: Q qiaolongfei

lookuptable support SelectedRows as table parameter

上级 997e9a1f
......@@ -10,6 +10,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
......@@ -52,7 +55,7 @@ class SelectedRows {
private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simplely concated when adding together. Until a
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
Vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr};
......
......@@ -84,7 +84,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(-1);
.SetDefault(kNoPadding);
AddComment(R"DOC(
Lookup Table Operator.
......
......@@ -14,6 +14,9 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -25,16 +28,37 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
static const int64_t kNoPadding = -1;
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
auto it = std::find(rows.begin(), rows.end(), value);
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
return std::distance(rows.begin(), it);
}
template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W");
auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
void Compute(const framework::ExecutionContext &context) const override {
auto *table_var = context.InputVar("W");
auto *ids_var = context.InputVar("Ids");
Tensor *output_t = context.Output<Tensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t* ids;
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW("table only support LoDTensor and SelectedRows");
}
int64_t *ids;
int64_t ids_numel;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
......@@ -42,39 +66,50 @@ class LookupTableKernel : public framework::OpKernel<T> {
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
auto *ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t *>(ids_t->data<int64_t>());
ids_numel = ids_t->numel();
} else if (ids_var->IsType<SelectedRows>()) {
auto* ids_t = context.Input<SelectedRows>("Ids");
ids = const_cast<int64_t*>(ids_t->rows().data());
auto *ids_t = context.Input<SelectedRows>("Ids");
ids = const_cast<int64_t *>(ids_t->rows().data());
ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_t->dims()[1]});
output_t->Resize({ids_numel, table_dim[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
int N = table_t->dims()[0];
int D = table_t->dims()[1];
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
if (padding_idx == -1) {
for (int64_t i = 0; i < ids_numel; ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
} else {
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T));
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
auto id_index = getIndex(table_t.rows(), ids[i]);
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
}
}
}
......@@ -84,17 +119,17 @@ class LookupTableKernel : public framework::OpKernel<T> {
template <typename T>
class LookupTableGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext &context) const override {
bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W");
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto *ids = context.Input<LoDTensor>("Ids");
auto *table = context.Input<LoDTensor>("W");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>();
auto *ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims();
framework::Vector<int64_t> new_rows;
......@@ -104,31 +139,31 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
}
d_table->set_rows(new_rows);
auto* d_table_value = d_table->mutable_value();
auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_dim[0], table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace());
d_table->set_height(table->dims()[0]);
auto* d_output_data = d_output->data<T>();
auto* d_table_data = d_table_value->data<T>();
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table_value->data<T>();
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else {
auto* ids = context.Input<LoDTensor>("Ids");
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
auto* table = context.Input<LoDTensor>("W");
auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
auto *table = context.Input<LoDTensor>("W");
auto* ids_data = ids->data<int64_t>();
auto *ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims();
int N = table->dims()[0];
int D = d_output->dims()[1];
auto* d_output_data = d_output->data<T>();
auto* d_table_data = d_table->mutable_data<T>(context.GetPlace());
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
memset(d_table_data, 0, d_table->numel() * sizeof(T));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册