diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 3acdca17afc2fea05fb81871e6e03d72691fe91e..50eeadab72e71f39325c5eda69e9a3c3e6517d7d 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -33,8 +33,16 @@ class LookupTableOp : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); auto ids_dims = ctx->GetInputDim("Ids"); - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); + auto ids_var_type = ctx->GetInputsVarType("Ids").front(); + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W + // and it must be a column vector with rank = 2 while the 2nd dimension + // size must be 1, when Ids's type is SelectedRows, the rows of Ids + // contains the ids to be looked up in W; + if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ(ids_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[1], 1); + } ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->ShareLoD("Ids", /*->*/ "Out"); @@ -54,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("W", - "An input represents embedding tensors, " + "(Tensor) The input represents embedding tensors, " "which is a learnable parameter."); - AddInput("Ids", - "An input with type int32 or int64 " - "contains the ids to be looked up in W. " - "Ids must be a column vector with rank = 2. " - "The 2nd dimension size must be 1."); - AddOutput("Out", "The lookup results, which have the same type as W."); + AddInput( + "Ids", + "(Tensor or SelectedRows) Ids's type can be Tensor or " + "SelectedRows, when Ids's type is Tensor, this tensor contains " + "the ids to be looked up in W and it must be a column vector with " + "rank = 2 while the 2nd dimension size must be 1; when Ids's type is " + "SelectedRows, the rows of Ids contains the ids to be looked up " + "in W."); + AddOutput("Out", + "(Tensor or SelectedRows) The lookup results, which have the " + "same type as W."); AddAttr("is_sparse", "(boolean, default false) " - "Sparse update") + "Sparse update.") .SetDefault(false); AddAttr("padding_idx", "(int64, default -1) " @@ -76,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { Lookup Table Operator. This operator is used to perform lookups on the parameter W, -then concatenated into a dense tensor. +then concatenated into a dense or sparse tensor. + +The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's +type is SelectedRows, the rows of Ids contains the ids to be looked up in W; +when Ids's type is Tensor, this tensor contains the ids to be looked up in W +and it must be a column vector with rank = 2 while the 2nd dimension size must be 1, +at this time, Ids can carry the LoD (Level of Details) information, or not, and +the output only shares the LoD information with input Ids. -The input Ids can carry the LoD (Level of Details) information, -or not. And the output only shares the LoD information with input Ids. )DOC"); } diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 923340f46102d888f549c79684ae0ae2f78ed038..6d81fccd2059c511f71d403229e04587e553e93d 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -74,14 +74,32 @@ class LookupTableCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* table_t = context.Input("W"); - auto* ids_t = context.Input("Ids"); - auto* output_t = context.Output("Out"); int64_t padding_idx = context.Attr("padding_idx"); + auto* ids_var = context.InputVar("Ids"); + Tensor* output_t = context.Output("Out"); + + int64_t* ids; + int64_t K; + + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W; + // when Ids's type is SelectedRows, the rows of Ids contains the + // ids to be looked up in W. + if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + ids = const_cast(ids_t->data()); + K = ids_t->numel(); + } else if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + ids = const_cast(ids_t->rows().CUDAData(context.GetPlace())); + K = ids_t->rows().size(); + output_t->Resize({K, table_t->dims()[1]}); + } else { + PADDLE_THROW("Unsupported Variable Type of Ids"); + } size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; - size_t K = ids_t->numel(); - auto* ids = ids_t->data(); auto* table = table_t->data(); auto* output = output_t->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index d88b034e919f1127ac3c424e87e4a5f81a598dc8..c92ce78eeffb8f1517e61c6d6624d406e04d974d 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -22,6 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; @@ -29,25 +30,45 @@ template class LookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* table_t = context.Input("W"); // float tensor - auto* ids_t = context.Input("Ids"); // int tensor - auto* output_t = context.Output("Out"); // float tensor + auto* table_t = context.Input("W"); + auto* ids_var = context.InputVar("Ids"); + Tensor* output_t = context.Output("Out"); + + int64_t* ids; + int64_t ids_numel; + + // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type + // is LoDTensor, this tensor contains the ids to be looked up in W; + // when Ids's type is SelectedRows, the rows of Ids contains the + // ids to be looked up in W. + if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + ids = const_cast(ids_t->data()); + ids_numel = ids_t->numel(); + } else if (ids_var->IsType()) { + auto* ids_t = context.Input("Ids"); + ids = const_cast(ids_t->rows().data()); + ids_numel = ids_t->rows().size(); + output_t->Resize({ids_numel, table_t->dims()[1]}); + } else { + PADDLE_THROW("Unsupported Variable Type of Ids"); + } + int64_t padding_idx = context.Attr("padding_idx"); int N = table_t->dims()[0]; int D = table_t->dims()[1]; - auto* ids = ids_t->data(); auto* table = table_t->data(); auto* output = output_t->mutable_data(context.GetPlace()); if (padding_idx == -1) { - for (int64_t i = 0; i < ids_t->numel(); ++i) { + for (int64_t i = 0; i < ids_numel; ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); } } else { - for (int64_t i = 0; i < ids_t->numel(); ++i) { + for (int64_t i = 0; i < ids_numel; ++i) { if (ids[i] == padding_idx) { memset(output + i * D, 0, D * sizeof(T)); } else { diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index 03a5bd24a133e703855400532517c293196b64f0..ed920ad388ff0e01887404e70fe82565b4cd28fa 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -15,6 +15,8 @@ import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator class TestLookupTableOp(OpTest): @@ -47,5 +49,52 @@ class TestLookupTableOpWithPadding(TestLookupTableOp): pass +class TestLookupTableIdsIsSelectedRows(OpTest): + def check_with_place(self, place): + scope = core.Scope() + + # create and initialize Variable + height = 10 + rows = [0, 4, 4, 7] + row_numel = 12 + + # create and initialize W Variable + W = scope.var('W').get_tensor() + W_array = np.full((height, row_numel), 1.0).astype("float32") + for i in range(height): + W_array[i] *= i + W.set(W_array, place) + + # create and initialize Ids Variable + ids_selected_rows = scope.var('Ids').get_selected_rows() + ids_selected_rows.set_height(len(rows)) + ids_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + ids_tensor = ids_selected_rows.get_tensor() + ids_tensor.set(np_array, place) + + # create Out Variable + Out = scope.var('Out').get_selected_rows() + + # create and run lookup_table operator + concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out') + concat_rows_op.run(scope, place) + + # get result from Out + Out_tensor = Out.get_tensor() + result_array = np.array(Out_tensor) + + # all(): return True if all elements of the iterable are true (or if the iterable is empty) + for idx, row in enumerate(rows): + assert (row == result_array[idx]).all() + + def test_concat_rows(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main()