未验证 提交 ecfb89e1 编写于 作者: W wawltor 提交者: GitHub

Update the error message for the margin_ranking_loss

Update the error message for the margin_ranking_loss
上级 94b05850
...@@ -177,6 +177,16 @@ class MarginRakingLossError(unittest.TestCase): ...@@ -177,6 +177,16 @@ class MarginRakingLossError(unittest.TestCase):
self.assertRaises(ValueError, test_margin_value_error) self.assertRaises(ValueError, test_margin_value_error)
def test_functional_margin_value_error():
x = paddle.static.data(name="x", shape=[10, 10], dtype="float64")
y = paddle.static.data(name="y", shape=[10, 10], dtype="float64")
label = paddle.static.data(
name="label", shape=[10, 10], dtype="float64")
result = paddle.nn.functional.margin_ranking_loss(
x, y, label, margin=0.1, reduction="reduction_mean")
self.assertRaises(ValueError, test_functional_margin_value_error)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -338,6 +338,10 @@ def margin_ranking_loss(input, ...@@ -338,6 +338,10 @@ def margin_ranking_loss(input,
loss = paddle.nn.functional.margin_ranking_loss(input, other, label) loss = paddle.nn.functional.margin_ranking_loss(input, other, label)
print(loss.numpy()) # [0.75] print(loss.numpy()) # [0.75]
""" """
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
if fluid.framework.in_dygraph_mode(): if fluid.framework.in_dygraph_mode():
out = core.ops.elementwise_sub(other, input) out = core.ops.elementwise_sub(other, input)
out = core.ops.elementwise_mul(out, label) out = core.ops.elementwise_mul(out, label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册