From f57f9b6bc3a35ca6396cb5c0bd4183fd4aa595a9 Mon Sep 17 00:00:00 2001 From: supplyout <32414691+supplyout@users.noreply.github.com> Date: Mon, 6 Mar 2023 16:44:49 +0800 Subject: [PATCH] Fix npairloss bug (#51092) --- .../tests/unittests/test_npair_loss_op.py | 37 +++++++++++++++++++ python/paddle/nn/functional/loss.py | 4 ++ 2 files changed, 41 insertions(+) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_npair_loss_op.py diff --git a/python/paddle/fluid/tests/unittests/test_npair_loss_op.py b/python/paddle/fluid/tests/unittests/test_npair_loss_op.py old mode 100644 new mode 100755 index 841d7acc2c0..ad4aaf1f0e2 --- a/python/paddle/fluid/tests/unittests/test_npair_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_npair_loss_op.py @@ -196,5 +196,42 @@ class TestNpairLossOpError(unittest.TestCase): self.assertRaises(TypeError, test_labels_type) +class TestNpairLossZeroError(unittest.TestCase): + def test_errors(self): + with paddle.fluid.dygraph.guard(): + + def test_anchor_0_size(): + array = np.array([], dtype=np.float32) + anchor = paddle.to_tensor( + np.reshape(array, [0, 0, 0]), dtype='float32' + ) + positive = paddle.to_tensor( + np.reshape(array, [0]), dtype='float32' + ) + array = np.array([1, 2, 3, 4], dtype=np.float32) + labels = paddle.to_tensor( + np.reshape(array, [4]), dtype='float32' + ) + paddle.nn.functional.npair_loss(anchor, positive, labels) + + def test_positive_0_size(): + array = np.array([1], dtype=np.float32) + array1 = np.array([], dtype=np.float32) + anchor = paddle.to_tensor( + np.reshape(array, [1, 1, 1]), dtype='float32' + ) + positive = paddle.to_tensor( + np.reshape(array1, [0]), dtype='float32' + ) + array = np.array([1, 2, 3, 4], dtype=np.float32) + labels = paddle.to_tensor( + np.reshape(array, [4]), dtype='float32' + ) + paddle.nn.functional.npair_loss(anchor, positive, labels) + + self.assertRaises(ValueError, test_anchor_0_size) + self.assertRaises(ValueError, test_positive_0_size) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 2892d7b667d..328c25ba563 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -373,6 +373,10 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002): print(npair_loss) """ + if anchor.size == 0: + raise ValueError("The dims of anchor should be greater than 0.") + if positive.size == 0: + raise ValueError("The dims of positive should be greater than 0.") check_variable_and_dtype( anchor, 'anchor', ['float32', 'float64'], 'npair_loss' ) -- GitLab