diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 3acdca17afc2fea05fb81871e6e03d72691fe91e..461e5bd2d325d38a8bdb9dc29a577b8e58bd877a 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -33,8 +33,13 @@ class LookupTableOp : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); auto ids_dims = ctx->GetInputDim("Ids"); - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); + auto ids_var_type = ctx->GetInputsVarType("Ids").front(); + // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. + // Maybe near future we will add concat_rows op. + if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ(ids_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[1], 1); + } ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->ShareLoD("Ids", /*->*/ "Out"); diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 923340f46102d888f549c79684ae0ae2f78ed038..125e0f94415827127b0264241f2b4870ff409e06 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -74,14 +74,34 @@ class LookupTableCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* table_t = context.Input("W"); - auto* ids_t = context.Input("Ids"); - auto* output_t = context.Output("Out"); int64_t padding_idx = context.Attr("padding_idx"); + auto* ids_var = context.InputVar("Ids"); // int tensor + + int64_t* ids; + int64_t K; + framework::Tensor* output_t; + + // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. + // Maybe near future we will add concat_rows op. + if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + output_t = context.Output("Out"); // float tensor + ids = const_cast(ids_t->data()); + K = ids_t->numel(); + } else if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + output_t = const_cast( + &(context.Output("Out") + ->value())); // float tensor + ids = const_cast(ids_t->rows().CUDAData(context.GetPlace())); + K = ids_t->rows().size(); + output_t->Resize({K, table_t->dims()[1]}); + } else { + PADDLE_THROW("Unsupported Variable Type of Ids"); + } size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; - size_t K = ids_t->numel(); - auto* ids = ids_t->data(); auto* table = table_t->data(); auto* output = output_t->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index d88b034e919f1127ac3c424e87e4a5f81a598dc8..b2439c68371a8949922855604fe92c11a74c6ec9 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -22,6 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; @@ -29,25 +30,46 @@ template class LookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* table_t = context.Input("W"); // float tensor - auto* ids_t = context.Input("Ids"); // int tensor - auto* output_t = context.Output("Out"); // float tensor + auto* table_t = context.Input("W"); // float tensor + auto* ids_var = context.InputVar("Ids"); // int tensor + + int64_t* ids; + int64_t ids_numel; + Tensor* output_t; + + // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. + // Maybe near future we will add concat_rows op. + if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + output_t = context.Output("Out"); + ids = const_cast(ids_t->data()); + ids_numel = ids_t->numel(); + } else if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + output_t = + const_cast(&(context.Output("Out")->value())); + ids = const_cast(ids_t->rows().data()); + ids_numel = ids_t->rows().size(); + output_t->Resize({ids_numel, table_t->dims()[1]}); + } else { + PADDLE_THROW("Unsupported Variable Type of Ids"); + } + int64_t padding_idx = context.Attr("padding_idx"); int N = table_t->dims()[0]; int D = table_t->dims()[1]; - auto* ids = ids_t->data(); auto* table = table_t->data(); auto* output = output_t->mutable_data(context.GetPlace()); if (padding_idx == -1) { - for (int64_t i = 0; i < ids_t->numel(); ++i) { + for (int64_t i = 0; i < ids_numel; ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); } } else { - for (int64_t i = 0; i < ids_t->numel(); ++i) { + for (int64_t i = 0; i < ids_numel; ++i) { if (ids[i] == padding_idx) { memset(output + i * D, 0, D * sizeof(T)); } else {