diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 02ffbd136113b9d25bc429a886f56e20ded0a81f..8760cc2ee9643bf7a9569c926d8a85588c77832e 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -120,12 +120,22 @@ template class LookupTableGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { + auto *table_var = context.InputVar("W"); + DDim table_dim; + if (table_var->IsType()) { + table_dim = context.Input("W")->dims(); + } else if (table_var->IsType()) { + auto *table_t = context.Input("W"); + table_dim = table_t->value().dims(); + } else { + PADDLE_THROW("table only support LoDTensor and SelectedRows"); + } + bool is_sparse = context.Attr("is_sparse"); // Since paddings are not trainable and fixed in forward, the gradient of // paddings makes no sense and we don't deal with it in backward. if (is_sparse) { auto *ids = context.Input("Ids"); - auto *table = context.Input("W"); auto *d_output = context.Input(framework::GradVarName("Out")); auto *d_table = context.Output(framework::GradVarName("W")); @@ -140,10 +150,10 @@ class LookupTableGradKernel : public framework::OpKernel { d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); - d_table_value->Resize({ids_dim[0], table->dims()[1]}); + d_table_value->Resize({ids_dim[0], table_dim[1]}); d_table_value->mutable_data(context.GetPlace()); - d_table->set_height(table->dims()[0]); + d_table->set_height(table_dim[0]); auto *d_output_data = d_output->data(); auto *d_table_data = d_table_value->data(); @@ -154,12 +164,11 @@ class LookupTableGradKernel : public framework::OpKernel { auto *ids = context.Input("Ids"); auto *d_output = context.Input(framework::GradVarName("Out")); auto *d_table = context.Output(framework::GradVarName("W")); - auto *table = context.Input("W"); auto *ids_data = ids->data(); auto ids_dim = ids->dims(); - int N = table->dims()[0]; + int N = table_dim[0]; int D = d_output->dims()[1]; auto *d_output_data = d_output->data();