提交 9247aee7 编写于 作者: G guosheng

Enhance lookup_table_op to support padding_idx

上级 f086ebb8
......@@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"Sparse update")
.SetDefault(false);
AddAttr<int64_t>(
"padding_idx",
"(int64_t, default -1) "
" If given, pads the output with zeros whenever it encounters "
"the index.")
.SetDefault(-1);
AddComment(R"DOC(
Lookup Table Operator.
......
......@@ -32,6 +32,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int N = table_t->dims()[0];
int D = table_t->dims()[1];
......@@ -39,11 +40,15 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_t->numel(); ++i) {
if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
}
}
};
template <typename T>
......@@ -51,6 +56,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W");
......@@ -63,6 +70,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
framework::Vector<int64_t> new_rows;
new_rows.reserve(ids_dim[0]);
for (int64_t i = 0; i < ids_dim[0]; i++) {
if (ids_data[i] == padding_idx)
continue; // Paddings are not trainable and the gradient are not
// necessary.
new_rows.push_back(ids_data[i]);
}
d_table->set_rows(new_rows);
......@@ -96,6 +106,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
memset(d_table_data, 0, d_table->numel() * sizeof(T));
for (int64_t i = 0; i < ids->numel(); ++i) {
if (ids_data[i] == padding_idx)
continue; // Paddings are not trainable and the gradient are not
// necessary.
PADDLE_ENFORCE_LT(ids_data[i], N);
PADDLE_ENFORCE_GE(ids_data[i], 0);
for (int j = 0; j < D; ++j) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册