From 475de6da22d962949a97ce42806c9d5a149ae4ca Mon Sep 17 00:00:00 2001 From: songyouwei Date: Fri, 10 Apr 2020 21:25:32 +0800 Subject: [PATCH] API(Embedding) error message enhancement (#23533) * err msg enhance for Embedding * add ut test=develop --- python/paddle/fluid/dygraph/nn.py | 1 + .../tests/unittests/test_lookup_table_v2_op.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index baa49e5f3e4..d4277daebdc 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1576,6 +1576,7 @@ class Embedding(layers.Layer): 'is_distributed', self._is_distributed, 'remote_prefetch', self._remote_prefetch, 'padding_idx', self._padding_idx) + check_variable_and_dtype(input, 'input', ['int64'], 'Embedding') attrs = { 'is_sparse': self._is_sparse, 'is_distributed': self._is_distributed, diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 9c026f0482b..98d8b7f9f88 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -25,6 +25,21 @@ import paddle.fluid as fluid from paddle.fluid import Program, program_guard +class TestDygraphEmbeddingAPIError(unittest.TestCase): + def test_errors(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + dict_size = 20 + layer = fluid.dygraph.nn.Embedding( + size=[dict_size, 32], param_attr='emb.w', is_sparse=False) + # the input must be Variable. + x0 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + self.assertRaises(TypeError, layer, x0) + # the input dtype must be int64 + data_t = fluid.data(name='word', shape=[1], dtype='int32') + self.assertRaises(TypeError, layer, data_t) + + class TestLookupTableOp(OpTest): def setUp(self): self.op_type = "lookup_table_v2" -- GitLab