From ecfb89e1338fb0203af79b806575fbf1fedc3bd9 Mon Sep 17 00:00:00 2001 From: wawltor Date: Fri, 21 Aug 2020 14:28:04 +0800 Subject: [PATCH] Update the error message for the margin_ranking_loss Update the error message for the margin_ranking_loss --- .../fluid/tests/unittests/test_nn_margin_rank_loss.py | 10 ++++++++++ python/paddle/nn/functional/loss.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py b/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py index 2e74ffa88f..0ebe769fb9 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py +++ b/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py @@ -177,6 +177,16 @@ class MarginRakingLossError(unittest.TestCase): self.assertRaises(ValueError, test_margin_value_error) + def test_functional_margin_value_error(): + x = paddle.static.data(name="x", shape=[10, 10], dtype="float64") + y = paddle.static.data(name="y", shape=[10, 10], dtype="float64") + label = paddle.static.data( + name="label", shape=[10, 10], dtype="float64") + result = paddle.nn.functional.margin_ranking_loss( + x, y, label, margin=0.1, reduction="reduction_mean") + + self.assertRaises(ValueError, test_functional_margin_value_error) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 9a214d3982..357a8f8e84 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -338,6 +338,10 @@ def margin_ranking_loss(input, loss = paddle.nn.functional.margin_ranking_loss(input, other, label) print(loss.numpy()) # [0.75] """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) if fluid.framework.in_dygraph_mode(): out = core.ops.elementwise_sub(other, input) out = core.ops.elementwise_mul(out, label) -- GitLab