提交 9247aee7 编写于 作者: G guosheng

Enhance lookup_table_op to support padding_idx

上级 f086ebb8
...@@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) " "(boolean, default false) "
"Sparse update") "Sparse update")
.SetDefault(false); .SetDefault(false);
AddAttr<int64_t>(
"padding_idx",
"(int64_t, default -1) "
" If given, pads the output with zeros whenever it encounters "
"the index.")
.SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
Lookup Table Operator. Lookup Table Operator.
......
...@@ -32,6 +32,7 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -32,6 +32,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int N = table_t->dims()[0]; int N = table_t->dims()[0];
int D = table_t->dims()[1]; int D = table_t->dims()[1];
...@@ -39,9 +40,13 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -39,9 +40,13 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto* table = table_t->data<T>(); auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace()); auto* output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_t->numel(); ++i) { for (int64_t i = 0; i < ids_t->numel(); ++i) {
PADDLE_ENFORCE_LT(ids[i], N); if (ids[i] == padding_idx) {
PADDLE_ENFORCE_GE(ids[i], 0); memset(output + i * D, 0, D * sizeof(T));
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); } else {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
} }
} }
}; };
...@@ -51,6 +56,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -51,6 +56,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
if (is_sparse) { if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids"); auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W"); auto* table = context.Input<LoDTensor>("W");
...@@ -63,6 +70,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -63,6 +70,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
framework::Vector<int64_t> new_rows; framework::Vector<int64_t> new_rows;
new_rows.reserve(ids_dim[0]); new_rows.reserve(ids_dim[0]);
for (int64_t i = 0; i < ids_dim[0]; i++) { for (int64_t i = 0; i < ids_dim[0]; i++) {
if (ids_data[i] == padding_idx)
continue; // Paddings are not trainable and the gradient are not
// necessary.
new_rows.push_back(ids_data[i]); new_rows.push_back(ids_data[i]);
} }
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
...@@ -96,6 +106,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -96,6 +106,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
memset(d_table_data, 0, d_table->numel() * sizeof(T)); memset(d_table_data, 0, d_table->numel() * sizeof(T));
for (int64_t i = 0; i < ids->numel(); ++i) { for (int64_t i = 0; i < ids->numel(); ++i) {
if (ids_data[i] == padding_idx)
continue; // Paddings are not trainable and the gradient are not
// necessary.
PADDLE_ENFORCE_LT(ids_data[i], N); PADDLE_ENFORCE_LT(ids_data[i], N);
PADDLE_ENFORCE_GE(ids_data[i], 0); PADDLE_ENFORCE_GE(ids_data[i], 0);
for (int j = 0; j < D; ++j) { for (int j = 0; j < D; ++j) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册