未验证 提交 8f0adcb5 编写于 作者: L Linjie Chen 提交者: GitHub

fix hsigmoid_loss (#49549)

上级 08bf1b49
...@@ -735,6 +735,11 @@ class TestHSigmoidLossAPI(unittest.TestCase): ...@@ -735,6 +735,11 @@ class TestHSigmoidLossAPI(unittest.TestCase):
x = paddle.to_tensor(np.reshape(x_arr, (10, 0)), dtype='float32') x = paddle.to_tensor(np.reshape(x_arr, (10, 0)), dtype='float32')
label = paddle.to_tensor([], dtype='int64') label = paddle.to_tensor([], dtype='int64')
weight = paddle.to_tensor([], dtype='float32') weight = paddle.to_tensor([], dtype='float32')
self.assertRaises(ValueError, F.hsigmoid_loss, x, label, 2, weight)
x = paddle.to_tensor(np.reshape(x_arr, [1, 0, 0, 1]), dtype='float32')
label = paddle.to_tensor(np.reshape(x_arr, [1, 1, 0]), dtype='int64')
weight = paddle.to_tensor([], dtype='float32')
self.assertRaises(ValueError, F.hsigmoid_loss, x, label, 0, weight) self.assertRaises(ValueError, F.hsigmoid_loss, x, label, 0, weight)
paddle.enable_static() paddle.enable_static()
......
...@@ -953,6 +953,11 @@ def hsigmoid_loss( ...@@ -953,6 +953,11 @@ def hsigmoid_loss(
# [2.11009121] # [2.11009121]
# [1.92374969]] # [1.92374969]]
""" """
if num_classes < 2:
raise ValueError(
'Expected num_classes >= 2 (got {})'.format(num_classes)
)
if in_dygraph_mode(): if in_dygraph_mode():
out, _, _ = _C_ops.hsigmoid_loss( out, _, _ = _C_ops.hsigmoid_loss(
input, input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册