diff --git a/python/paddle/fluid/tests/unittests/test_nce.py b/python/paddle/fluid/tests/unittests/test_nce.py index 80787e7fd3f389fd36fde3beb9b55855ac961558..8a4c555ad572080ade0013185e347f1195a0743d 100644 --- a/python/paddle/fluid/tests/unittests/test_nce.py +++ b/python/paddle/fluid/tests/unittests/test_nce.py @@ -330,6 +330,13 @@ class TestNCE_OpError(unittest.TestCase): TypeError, paddle.static.nn.nce, input4, label4, 5 ) + input5 = paddle.static.data(name='x', shape=[1], dtype='float32') + label5 = paddle.static.data(name='label', shape=[1], dtype='int64') + + self.assertRaises( + ValueError, paddle.static.nn.nce, input5, label5, 1 + ) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/static/nn/loss.py b/python/paddle/static/nn/loss.py index 3f464928c289d65368002c4d037a232118813bb2..41f32e4a63fab32d51b618efc050a06f669e62bf 100644 --- a/python/paddle/static/nn/loss.py +++ b/python/paddle/static/nn/loss.py @@ -129,6 +129,11 @@ def nce( check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'nce') check_variable_and_dtype(label, 'label', ['int64'], 'nce') + if input.ndim != 2: + raise ValueError( + f'The rank of `input` must be 2, but received {input.ndim}.' + ) + dim = input.shape[1] num_true_class = label.shape[1] w = helper.create_parameter(