From 8f0adcb57507d3b7c812946c0adbb11573a77b66 Mon Sep 17 00:00:00 2001 From: Linjie Chen <40840292+linjieccc@users.noreply.github.com> Date: Wed, 11 Jan 2023 10:44:05 +0800 Subject: [PATCH] fix hsigmoid_loss (#49549) --- python/paddle/fluid/tests/unittests/test_hsigmoid_op.py | 5 +++++ python/paddle/nn/functional/loss.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 6881430979..6c5bc338a5 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -735,6 +735,11 @@ class TestHSigmoidLossAPI(unittest.TestCase): x = paddle.to_tensor(np.reshape(x_arr, (10, 0)), dtype='float32') label = paddle.to_tensor([], dtype='int64') weight = paddle.to_tensor([], dtype='float32') + self.assertRaises(ValueError, F.hsigmoid_loss, x, label, 2, weight) + + x = paddle.to_tensor(np.reshape(x_arr, [1, 0, 0, 1]), dtype='float32') + label = paddle.to_tensor(np.reshape(x_arr, [1, 1, 0]), dtype='int64') + weight = paddle.to_tensor([], dtype='float32') self.assertRaises(ValueError, F.hsigmoid_loss, x, label, 0, weight) paddle.enable_static() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e43fb54614..d783251624 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -953,6 +953,11 @@ def hsigmoid_loss( # [2.11009121] # [1.92374969]] """ + if num_classes < 2: + raise ValueError( + 'Expected num_classes >= 2 (got {})'.format(num_classes) + ) + if in_dygraph_mode(): out, _, _ = _C_ops.hsigmoid_loss( input, -- GitLab