diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 92c7d7f9cad6ac5bcf6c020e0afb4a22c715085b..deabcdc99f819851b2df9bb0c7b05a5b339568f3 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -18,6 +18,22 @@ limitations under the License. */ namespace paddle { namespace operators { +static inline framework::OpKernelType ExpectedKernelType( + const framework::ExecutionContext& ctx) { + auto* table_var = ctx.InputVar("W"); + if (table_var->IsType()) { + return framework::OpKernelType( + framework::ToDataType(table_var->Get().type()), + ctx.device_context()); + } else if (table_var->IsType()) { + return framework::OpKernelType( + framework::ToDataType(table_var->Get().value().type()), + ctx.device_context()); + } else { + PADDLE_THROW("W should be LoDTensor or SelectedRows"); + } +} + class LookupTableOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -51,9 +67,7 @@ class LookupTableOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("W")->type()), - ctx.device_context()); + return ExpectedKernelType(ctx); } }; @@ -124,9 +138,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("W")->type()), - ctx.device_context()); + return ExpectedKernelType(ctx); } }; diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index ed920ad388ff0e01887404e70fe82565b4cd28fa..3f739afd2516fdc2bdf3711d4780a1196c6f3f13 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -96,5 +96,47 @@ class TestLookupTableIdsIsSelectedRows(OpTest): self.check_with_place(place) +class TestLookupTableWIsSelectedRows(OpTest): + def check_with_place(self, place): + scope = core.Scope() + + # create and initialize Id Variable + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.array([[0], [4], [3], [5]]).astype("int64") + ids_tensor.set(ids_array, place) + + # create and initialize W Variable + rows = [0, 1, 2, 3, 4, 5, 6] + row_numel = 12 + + w_selected_rows = scope.var('W').get_selected_rows() + w_selected_rows.set_height(len(rows)) + w_selected_rows.set_rows(rows) + w_array = np.ones((len(rows), row_numel)).astype("float32") + for i in range(len(rows)): + w_array[i] *= i + ids_tensor = w_selected_rows.get_tensor() + ids_tensor.set(w_array, place) + + # create Out Variable + Out_tensor = scope.var('Out').get_tensor() + + # create and run lookup_table operator + lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out') + lookup_table.run(scope, place) + + # get result from Out + result_array = np.array(Out_tensor) + # all(): return True if all elements of the iterable are true (or if the iterable is empty) + for idx, row in enumerate(ids_array): + assert (row[0] == result_array[idx]).all() + + def test_w_is_selected_rows(self): + places = [core.CPUPlace()] + # currently only support CPU + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main()