From c8548af34f665d886efdd62f431f3a252b62ad40 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Thu, 2 Feb 2023 16:40:42 +0800 Subject: [PATCH] Fix Python IndexError of case8: paddle.static.nn.nce (#49989) * add dimension check for nce * add unittest * fix incorrect type in test_nce --- python/paddle/fluid/tests/unittests/test_nce.py | 7 +++++++ python/paddle/static/nn/loss.py | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_nce.py b/python/paddle/fluid/tests/unittests/test_nce.py index 80787e7fd3..8a4c555ad5 100644 --- a/python/paddle/fluid/tests/unittests/test_nce.py +++ b/python/paddle/fluid/tests/unittests/test_nce.py @@ -330,6 +330,13 @@ class TestNCE_OpError(unittest.TestCase): TypeError, paddle.static.nn.nce, input4, label4, 5 ) + input5 = paddle.static.data(name='x', shape=[1], dtype='float32') + label5 = paddle.static.data(name='label', shape=[1], dtype='int64') + + self.assertRaises( + ValueError, paddle.static.nn.nce, input5, label5, 1 + ) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/static/nn/loss.py b/python/paddle/static/nn/loss.py index 3f464928c2..41f32e4a63 100644 --- a/python/paddle/static/nn/loss.py +++ b/python/paddle/static/nn/loss.py @@ -129,6 +129,11 @@ def nce( check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'nce') check_variable_and_dtype(label, 'label', ['int64'], 'nce') + if input.ndim != 2: + raise ValueError( + f'The rank of `input` must be 2, but received {input.ndim}.' + ) + dim = input.shape[1] num_true_class = label.shape[1] w = helper.create_parameter( -- GitLab