未验证 提交 f57f9b6b 编写于 作者: S supplyout 提交者: GitHub

Fix npairloss bug (#51092)

上级 731b407e
......@@ -196,5 +196,42 @@ class TestNpairLossOpError(unittest.TestCase):
self.assertRaises(TypeError, test_labels_type)
class TestNpairLossZeroError(unittest.TestCase):
def test_errors(self):
with paddle.fluid.dygraph.guard():
def test_anchor_0_size():
array = np.array([], dtype=np.float32)
anchor = paddle.to_tensor(
np.reshape(array, [0, 0, 0]), dtype='float32'
)
positive = paddle.to_tensor(
np.reshape(array, [0]), dtype='float32'
)
array = np.array([1, 2, 3, 4], dtype=np.float32)
labels = paddle.to_tensor(
np.reshape(array, [4]), dtype='float32'
)
paddle.nn.functional.npair_loss(anchor, positive, labels)
def test_positive_0_size():
array = np.array([1], dtype=np.float32)
array1 = np.array([], dtype=np.float32)
anchor = paddle.to_tensor(
np.reshape(array, [1, 1, 1]), dtype='float32'
)
positive = paddle.to_tensor(
np.reshape(array1, [0]), dtype='float32'
)
array = np.array([1, 2, 3, 4], dtype=np.float32)
labels = paddle.to_tensor(
np.reshape(array, [4]), dtype='float32'
)
paddle.nn.functional.npair_loss(anchor, positive, labels)
self.assertRaises(ValueError, test_anchor_0_size)
self.assertRaises(ValueError, test_positive_0_size)
if __name__ == '__main__':
unittest.main()
......@@ -373,6 +373,10 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
print(npair_loss)
"""
if anchor.size == 0:
raise ValueError("The dims of anchor should be greater than 0.")
if positive.size == 0:
raise ValueError("The dims of positive should be greater than 0.")
check_variable_and_dtype(
anchor, 'anchor', ['float32', 'float64'], 'npair_loss'
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册