未验证 提交 68203566 编写于 作者: W wawltor 提交者: GitHub

Add the loss of MarginRankingLoss for the paddle api2.0 (#26078)

add the api and doc for the margin_ranking_loss and MarginRankingLoss
上级 50a5bcfc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.static import Program, program_guard
def calc_margin_rank_loss(x, y, label, margin=0.0, reduction='none'):
result = (-1 * label) * (x - y) + margin
result = np.maximum(result, 0)
if reduction == 'none':
return result
elif reduction == 'sum':
return np.sum(result)
elif reduction == 'mean':
return np.mean(result)
def create_test_case(margin, reduction):
class MarginRankingLossCls(unittest.TestCase):
def setUp(self):
self.x_data = np.random.rand(10, 10).astype("float64")
self.y_data = np.random.rand(10, 10).astype("float64")
self.label_data = np.random.choice(
[-1, 1], size=[10, 10]).astype("float64")
self.places = []
self.places.append(fluid.CPUPlace())
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
def run_static_functional_api(self, place):
paddle.enable_static()
expected = calc_margin_rank_loss(
self.x_data,
self.y_data,
self.label_data,
margin=margin,
reduction=reduction)
with program_guard(Program(), Program()):
x = paddle.nn.data(name="x", shape=[10, 10], dtype="float64")
y = paddle.nn.data(name="y", shape=[10, 10], dtype="float64")
label = paddle.nn.data(
name="label", shape=[10, 10], dtype="float64")
result = paddle.nn.functional.margin_ranking_loss(
x, y, label, margin, reduction)
exe = paddle.static.Executor(place)
result_numpy, = exe.run(feed={
"x": self.x_data,
"y": self.y_data,
"label": self.label_data
},
fetch_list=[result])
self.assertTrue(np.allclose(result_numpy, expected))
def run_static_api(self, place):
paddle.enable_static()
expected = calc_margin_rank_loss(
self.x_data,
self.y_data,
self.label_data,
margin=margin,
reduction=reduction)
with program_guard(Program(), Program()):
x = paddle.nn.data(name="x", shape=[10, 10], dtype="float64")
y = paddle.nn.data(name="y", shape=[10, 10], dtype="float64")
label = paddle.nn.data(
name="label", shape=[10, 10], dtype="float64")
margin_rank_loss = paddle.nn.loss.MarginRankingLoss(
margin=margin, reduction=reduction)
result = margin_rank_loss(x, y, label)
exe = paddle.static.Executor(place)
result_numpy, = exe.run(feed={
"x": self.x_data,
"y": self.y_data,
"label": self.label_data
},
fetch_list=[result])
self.assertTrue(np.allclose(result_numpy, expected))
self.assertTrue('loss' in result.name)
def run_dynamic_functional_api(self, place):
paddle.disable_static(place)
x = paddle.to_variable(self.x_data)
y = paddle.to_variable(self.y_data)
label = paddle.to_variable(self.label_data)
result = paddle.nn.functional.margin_ranking_loss(x, y, label,
margin, reduction)
expected = calc_margin_rank_loss(
self.x_data,
self.y_data,
self.label_data,
margin=margin,
reduction=reduction)
self.assertTrue(np.allclose(result.numpy(), expected))
def run_dynamic_api(self, place):
paddle.disable_static(place)
x = paddle.to_variable(self.x_data)
y = paddle.to_variable(self.y_data)
label = paddle.to_variable(self.label_data)
margin_rank_loss = paddle.nn.loss.MarginRankingLoss(
margin=margin, reduction=reduction)
result = margin_rank_loss(x, y, label)
expected = calc_margin_rank_loss(
self.x_data,
self.y_data,
self.label_data,
margin=margin,
reduction=reduction)
self.assertTrue(np.allclose(result.numpy(), expected))
def run_dynamic_broadcast_api(self, place):
paddle.disable_static(place)
label_data = np.random.choice([-1, 1], size=[10]).astype("float64")
x = paddle.to_variable(self.x_data)
y = paddle.to_variable(self.y_data)
label = paddle.to_variable(label_data)
margin_rank_loss = paddle.nn.loss.MarginRankingLoss(
margin=margin, reduction=reduction)
result = margin_rank_loss(x, y, label)
expected = calc_margin_rank_loss(
self.x_data,
self.y_data,
label_data,
margin=margin,
reduction=reduction)
self.assertTrue(np.allclose(result.numpy(), expected))
def test_case(self):
for place in self.places:
self.run_static_api(place)
self.run_static_functional_api(place)
self.run_dynamic_api(place)
self.run_dynamic_functional_api(place)
self.run_dynamic_broadcast_api(place)
cls_name = "TestMarginRankLossCase_{}_{}".format(margin, reduction)
MarginRankingLossCls.__name__ = cls_name
globals()[cls_name] = MarginRankingLossCls
for margin in [0.0, 0.2]:
for reduction in ['none', 'mean', 'sum']:
create_test_case(margin, reduction)
# test case the raise message
class MarginRakingLossError(unittest.TestCase):
paddle.enable_static()
def test_errors(self):
def test_margin_value_error():
margin_rank_loss = paddle.nn.loss.MarginRankingLoss(
margin=0.1, reduction="reduce_mean")
self.assertRaises(ValueError, test_margin_value_error)
if __name__ == "__main__":
unittest.main()
......@@ -85,6 +85,7 @@ from .layer.loss import MSELoss #DEFINE_ALIAS
from .layer.loss import L1Loss #DEFINE_ALIAS
from .layer.loss import NLLLoss #DEFINE_ALIAS
from .layer.loss import BCELoss #DEFINE_ALIAS
from .layer.loss import MarginRankingLoss #DEFINE_ALIAS
from .layer.norm import BatchNorm #DEFINE_ALIAS
from .layer.norm import GroupNorm #DEFINE_ALIAS
from .layer.norm import LayerNorm #DEFINE_ALIAS
......
......@@ -129,7 +129,7 @@ from .loss import iou_similarity #DEFINE_ALIAS
from .loss import kldiv_loss #DEFINE_ALIAS
from .loss import l1_loss #DEFINE_ALIAS
from .loss import log_loss #DEFINE_ALIAS
from .loss import margin_rank_loss #DEFINE_ALIAS
from .loss import margin_ranking_loss #DEFINE_ALIAS
from .loss import mse_loss #DEFINE_ALIAS
from .loss import nll_loss #DEFINE_ALIAS
# from .loss import nce #DEFINE_ALIAS
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# TODO: define loss functions of neural network
import numpy as np
import paddle
import paddle.fluid as fluid
from ...fluid.framework import core, in_dygraph_mode
......@@ -38,10 +39,8 @@ from ...fluid.layers import teacher_student_sigmoid_loss #DEFINE_ALIAS
from ...fluid.layers import edit_distance #DEFINE_ALIAS
from ...fluid.layers import huber_loss #DEFINE_ALIAS
from ...fluid.layers import margin_rank_loss #DEFINE_ALIAS
from ...fluid.layers import sampled_softmax_with_cross_entropy #DEFINE_ALIAS
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode
from ...fluid.framework import Variable
__all__ = [
......@@ -55,8 +54,8 @@ __all__ = [
'kldiv_loss',
'l1_loss',
'log_loss',
'margin_rank_loss',
'mse_loss',
'margin_ranking_loss',
# 'nce',
'nll_loss',
'npair_loss',
......@@ -72,6 +71,110 @@ __all__ = [
]
def margin_ranking_loss(input,
other,
target,
margin=0.0,
reduction='mean',
name=None):
"""
This op the calcluate the the margin rank loss between the input x, y and target, use the math function as follows.
.. math::
margin\_rank\_loss = max(0, -target * (input - other) + margin)
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(margin\_rank\_loss)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(margin\_rank\_loss)
If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``.
Parameters:
input(Tensor): the first input tensor, it's data type should be float32, float64.
other(Tensor): the second input tensor, it's data type should be float32, float64.
target(Tensor): the target value corresponding to input, it's data type should be float32, float64.
margin (float, optional): The margin value to add, default value is 0;
reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Tensor, if :attr:`reduction` is ``'mean'`` or ``'sum'``, the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
x = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype('float32'))
y = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype('float32'))
target = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype('float32'))
loss = paddle.nn.functional.margin_ranking_loss(x, y, target)
print(loss.numpy()) # [0.75]
"""
if fluid.framework.in_dygraph_mode():
out = core.ops.elementwise_sub(other, input)
out = core.ops.elementwise_mul(out, target)
if margin != 0.0:
margin = fluid.dygraph.base.to_variable([margin], dtype=out.dtype)
out = core.ops.elementwise_add(out, margin)
out = core.ops.relu(out)
if reduction == 'sum':
return core.ops.reduce_sum(out, 'reduce_all', True)
elif reduction == 'mean':
return core.ops.mean(out)
return out
helper = LayerHelper("margin_ranking_loss", **locals())
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'margin_rank_loss')
fluid.data_feeder.check_variable_and_dtype(
other, 'other', ['float32', 'float64'], 'margin_rank_loss')
fluid.data_feeder.check_variable_and_dtype(
target, 'target', ['float32', 'float64'], 'margin_rank_loss')
out = paddle.elementwise_sub(other, input)
out = paddle.multiply(out, target)
if margin != 0.0:
margin_var = out.block.create_var(dtype=out.dtype)
paddle.fill_constant([1], out.dtype, margin, out=margin_var)
out = paddle.add(out, margin_var)
result_out = helper.create_variable_for_type_inference(input.dtype)
if reduction == 'none':
helper.append_op(
type="relu", inputs={"X": out}, outputs={"Out": result_out})
return result_out
elif reduction == 'sum':
out = paddle.nn.functional.relu(out)
attrs = {"dim": [0], "keep_dim": False, "reduce_all": True}
helper.append_op(
type="reduce_sum",
inputs={"X": out},
outputs={"Out": result_out},
attrs=attrs)
return result_out
elif reduction == 'mean':
out = paddle.nn.functional.relu(out)
helper.append_op(
type="mean",
inputs={"X": out},
outputs={"Out": result_out},
attrs={})
return result_out
def l1_loss(x, label, reduction='mean', name=None):
"""
This operator computes the L1 Loss of Tensor ``x`` and ``label`` as follows.
......
......@@ -62,6 +62,7 @@ from .loss import MSELoss #DEFINE_ALIAS
from .loss import L1Loss #DEFINE_ALIAS
from .loss import NLLLoss #DEFINE_ALIAS
from .loss import BCELoss #DEFINE_ALIAS
from .loss import MarginRankingLoss #DEFINE_ALIAS
from .norm import BatchNorm #DEFINE_ALIAS
from .norm import GroupNorm #DEFINE_ALIAS
from .norm import LayerNorm #DEFINE_ALIAS
......
......@@ -13,7 +13,9 @@
# limitations under the License.
# TODO: define loss functions of neural network
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle
from .. import functional as F
......@@ -23,7 +25,8 @@ __all__ = [
'MSELoss',
'L1Loss',
'NLLLoss',
'BCELoss'
'BCELoss',
'MarginRankingLoss'
]
......@@ -569,3 +572,72 @@ class NLLLoss(fluid.dygraph.Layer):
ignore_index=self._ignore_index,
reduction=self._reduction,
name=self._name)
class MarginRankingLoss(fluid.dygraph.Layer):
"""
This interface is used to construct a callable object of the ``MarginRankingLoss`` class.
The MarginRankingLoss layer calculates the margin rank loss between the input, other and target
, use the math function as follows.
.. math::
margin\_rank\_loss = max(0, -target * (input - other) + margin)
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(margin\_rank\_loss)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(margin\_rank\_loss)
If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``.
Parameters:
margin (float, optional): The margin value to add, default value is 0;
reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
input: N-D Tensor, the shape is [N, *], N is batch size and `*` means any number of additional dimensions., available dtype is float32, float64.
other: N-D Tensor, `other` have the same shape and dtype as `input`.
target: N-D Tensor, target have the same shape and dtype as `input`.
out: If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor.
Returns:
A callable object of MarginRankingLoss.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
input = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype("float32"))
other = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype("float32"))
target = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype("float32"))
margin_rank_loss = paddle.nn.MarginRankingLoss()
loss = margin_rank_loss(input, other, target)
print(loss.numpy()) # [0.75]
"""
def __init__(self, margin=0.0, reduction='mean', name=None):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
super(MarginRankingLoss, self).__init__()
self.margin = margin
self.reduction = reduction
self.name = name
def forward(self, input, other, target):
out = paddle.nn.functional.margin_ranking_loss(
input, other, target, self.margin, self.reduction, self.name)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册