diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 68814309791926eca9a60e0f0cfddabfcee1d906..6c5bc338a565b87a47bac3184ca97f4fed4483d6 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -735,6 +735,11 @@ class TestHSigmoidLossAPI(unittest.TestCase): x = paddle.to_tensor(np.reshape(x_arr, (10, 0)), dtype='float32') label = paddle.to_tensor([], dtype='int64') weight = paddle.to_tensor([], dtype='float32') + self.assertRaises(ValueError, F.hsigmoid_loss, x, label, 2, weight) + + x = paddle.to_tensor(np.reshape(x_arr, [1, 0, 0, 1]), dtype='float32') + label = paddle.to_tensor(np.reshape(x_arr, [1, 1, 0]), dtype='int64') + weight = paddle.to_tensor([], dtype='float32') self.assertRaises(ValueError, F.hsigmoid_loss, x, label, 0, weight) paddle.enable_static() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e43fb54614307530f568c42623ff4de102cb3271..d783251624c25b8a3e8a85449caaa54128a29124 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -953,6 +953,11 @@ def hsigmoid_loss( # [2.11009121] # [1.92374969]] """ + if num_classes < 2: + raise ValueError( + 'Expected num_classes >= 2 (got {})'.format(num_classes) + ) + if in_dygraph_mode(): out, _, _ = _C_ops.hsigmoid_loss( input,