diff --git a/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py b/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py index 2e74ffa88f5c5dd6d31001aaa21c1105d6f8ab27..0ebe769fb9bce1aee8412ccebc216c2c85e97775 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py +++ b/python/paddle/fluid/tests/unittests/test_nn_margin_rank_loss.py @@ -177,6 +177,16 @@ class MarginRakingLossError(unittest.TestCase): self.assertRaises(ValueError, test_margin_value_error) + def test_functional_margin_value_error(): + x = paddle.static.data(name="x", shape=[10, 10], dtype="float64") + y = paddle.static.data(name="y", shape=[10, 10], dtype="float64") + label = paddle.static.data( + name="label", shape=[10, 10], dtype="float64") + result = paddle.nn.functional.margin_ranking_loss( + x, y, label, margin=0.1, reduction="reduction_mean") + + self.assertRaises(ValueError, test_functional_margin_value_error) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 9a214d3982a8837ea810e3560fc0b603965d5810..357a8f8e84fb8af3985487957089ccd35a65abf9 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -338,6 +338,10 @@ def margin_ranking_loss(input, loss = paddle.nn.functional.margin_ranking_loss(input, other, label) print(loss.numpy()) # [0.75] """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) if fluid.framework.in_dygraph_mode(): out = core.ops.elementwise_sub(other, input) out = core.ops.elementwise_mul(out, label)