未验证 提交 475de6da 编写于 作者: S songyouwei 提交者: GitHub

API(Embedding) error message enhancement (#23533)

* err msg enhance for Embedding

* add ut
test=develop
上级 dc1901f4
......@@ -1576,6 +1576,7 @@ class Embedding(layers.Layer):
'is_distributed', self._is_distributed, 'remote_prefetch',
self._remote_prefetch, 'padding_idx', self._padding_idx)
check_variable_and_dtype(input, 'input', ['int64'], 'Embedding')
attrs = {
'is_sparse': self._is_sparse,
'is_distributed': self._is_distributed,
......
......@@ -25,6 +25,21 @@ import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestDygraphEmbeddingAPIError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
dict_size = 20
layer = fluid.dygraph.nn.Embedding(
size=[dict_size, 32], param_attr='emb.w', is_sparse=False)
# the input must be Variable.
x0 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
self.assertRaises(TypeError, layer, x0)
# the input dtype must be int64
data_t = fluid.data(name='word', shape=[1], dtype='int32')
self.assertRaises(TypeError, layer, data_t)
class TestLookupTableOp(OpTest):
def setUp(self):
self.op_type = "lookup_table_v2"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册