未验证 提交 c8548af3 编写于 作者: R RedContritio 提交者: GitHub

Fix Python IndexError of case8: paddle.static.nn.nce (#49989)

* add dimension check for nce

* add unittest

* fix incorrect type in test_nce
上级 24e395f6
......@@ -330,6 +330,13 @@ class TestNCE_OpError(unittest.TestCase):
TypeError, paddle.static.nn.nce, input4, label4, 5
)
input5 = paddle.static.data(name='x', shape=[1], dtype='float32')
label5 = paddle.static.data(name='label', shape=[1], dtype='int64')
self.assertRaises(
ValueError, paddle.static.nn.nce, input5, label5, 1
)
if __name__ == '__main__':
unittest.main()
......@@ -129,6 +129,11 @@ def nce(
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'nce')
check_variable_and_dtype(label, 'label', ['int64'], 'nce')
if input.ndim != 2:
raise ValueError(
f'The rank of `input` must be 2, but received {input.ndim}.'
)
dim = input.shape[1]
num_true_class = label.shape[1]
w = helper.create_parameter(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册