diff --git a/python/paddle/fluid/tests/unittests/test_npair_loss_op.py b/python/paddle/fluid/tests/unittests/test_npair_loss_op.py old mode 100644 new mode 100755 index 841d7acc2c0864cdf703cd34b8e43896e168584a..ad4aaf1f0e29c475d4d3e32f69b9b17fc916d5ac --- a/python/paddle/fluid/tests/unittests/test_npair_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_npair_loss_op.py @@ -196,5 +196,42 @@ class TestNpairLossOpError(unittest.TestCase): self.assertRaises(TypeError, test_labels_type) +class TestNpairLossZeroError(unittest.TestCase): + def test_errors(self): + with paddle.fluid.dygraph.guard(): + + def test_anchor_0_size(): + array = np.array([], dtype=np.float32) + anchor = paddle.to_tensor( + np.reshape(array, [0, 0, 0]), dtype='float32' + ) + positive = paddle.to_tensor( + np.reshape(array, [0]), dtype='float32' + ) + array = np.array([1, 2, 3, 4], dtype=np.float32) + labels = paddle.to_tensor( + np.reshape(array, [4]), dtype='float32' + ) + paddle.nn.functional.npair_loss(anchor, positive, labels) + + def test_positive_0_size(): + array = np.array([1], dtype=np.float32) + array1 = np.array([], dtype=np.float32) + anchor = paddle.to_tensor( + np.reshape(array, [1, 1, 1]), dtype='float32' + ) + positive = paddle.to_tensor( + np.reshape(array1, [0]), dtype='float32' + ) + array = np.array([1, 2, 3, 4], dtype=np.float32) + labels = paddle.to_tensor( + np.reshape(array, [4]), dtype='float32' + ) + paddle.nn.functional.npair_loss(anchor, positive, labels) + + self.assertRaises(ValueError, test_anchor_0_size) + self.assertRaises(ValueError, test_positive_0_size) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 2892d7b667d76e534a7236ae371ce589709029ac..328c25ba56315023fdd232cbd4d54087c32ca58d 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -373,6 +373,10 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002): print(npair_loss) """ + if anchor.size == 0: + raise ValueError("The dims of anchor should be greater than 0.") + if positive.size == 0: + raise ValueError("The dims of positive should be greater than 0.") check_variable_and_dtype( anchor, 'anchor', ['float32', 'float64'], 'npair_loss' )