diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index bb03def4391da80c6219f7863d300fd3c8d8c7ac..54c326c1d9a39938fadbc9746730e071cd5b00de 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { "(boolean, default false) " "Sparse update") .SetDefault(false); + AddAttr( + "padding_idx", + "(int64_t, default -1) " + " If given, pads the output with zeros whenever it encounters " + "the index.") + .SetDefault(-1); AddComment(R"DOC( Lookup Table Operator. diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index 2fd3335868406455ec01f9ded6bacc7bda5e2a67..2fa45e2437dedc49d5564c96a9c18b944e9240b2 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -32,6 +32,7 @@ class LookupTableKernel : public framework::OpKernel { auto* table_t = context.Input("W"); // float tensor auto* ids_t = context.Input("Ids"); // int tensor auto* output_t = context.Output("Out"); // float tensor + int64_t padding_idx = context.Attr("padding_idx"); int N = table_t->dims()[0]; int D = table_t->dims()[1]; @@ -39,9 +40,13 @@ class LookupTableKernel : public framework::OpKernel { auto* table = table_t->data(); auto* output = output_t->mutable_data(context.GetPlace()); for (int64_t i = 0; i < ids_t->numel(); ++i) { - PADDLE_ENFORCE_LT(ids[i], N); - PADDLE_ENFORCE_GE(ids[i], 0); - memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); + if (ids[i] == padding_idx) { + memset(output + i * D, 0, D * sizeof(T)); + } else { + PADDLE_ENFORCE_LT(ids[i], N); + PADDLE_ENFORCE_GE(ids[i], 0); + memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); + } } } }; @@ -51,6 +56,8 @@ class LookupTableGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { bool is_sparse = context.Attr("is_sparse"); + int64_t padding_idx = context.Attr("padding_idx"); + if (is_sparse) { auto* ids = context.Input("Ids"); auto* table = context.Input("W"); @@ -63,6 +70,9 @@ class LookupTableGradKernel : public framework::OpKernel { framework::Vector new_rows; new_rows.reserve(ids_dim[0]); for (int64_t i = 0; i < ids_dim[0]; i++) { + if (ids_data[i] == padding_idx) + continue; // Paddings are not trainable and the gradient are not + // necessary. new_rows.push_back(ids_data[i]); } d_table->set_rows(new_rows); @@ -96,6 +106,9 @@ class LookupTableGradKernel : public framework::OpKernel { memset(d_table_data, 0, d_table->numel() * sizeof(T)); for (int64_t i = 0; i < ids->numel(); ++i) { + if (ids_data[i] == padding_idx) + continue; // Paddings are not trainable and the gradient are not + // necessary. PADDLE_ENFORCE_LT(ids_data[i], N); PADDLE_ENFORCE_GE(ids_data[i], 0); for (int j = 0; j < D; ++j) {