From d03dd9d5f18b033c606906439c236f173b8b49ca Mon Sep 17 00:00:00 2001 From: wawltor Date: Mon, 17 Aug 2020 12:12:54 +0800 Subject: [PATCH] fix the margin ranking loss doc and api, test=develop (#26266) * upate the target to label, test=develop * Update the code for the margin_ranking_loss, test=develop --- python/paddle/nn/functional/loss.py | 22 +++++++++++----------- python/paddle/nn/layer/loss.py | 16 ++++++++-------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index dac1a6e2db3..19ffe572a9c 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -73,16 +73,16 @@ __all__ = [ def margin_ranking_loss(input, other, - target, + label, margin=0.0, reduction='mean', name=None): """ - This op the calcluate the the margin rank loss between the input x, y and target, use the math function as follows. + This op the calcluate the the margin rank loss between the input, other and label, use the math function as follows. .. math:: - margin\_rank\_loss = max(0, -target * (input - other) + margin) + margin\_rank\_loss = max(0, -label * (input - other) + margin) If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: @@ -99,7 +99,7 @@ def margin_ranking_loss(input, Parameters: input(Tensor): the first input tensor, it's data type should be float32, float64. other(Tensor): the second input tensor, it's data type should be float32, float64. - target(Tensor): the target value corresponding to input, it's data type should be float32, float64. + label(Tensor): the label value corresponding to input, it's data type should be float32, float64. margin (float, optional): The margin value to add, default value is 0; reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -115,15 +115,15 @@ def margin_ranking_loss(input, paddle.disable_static() - x = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype('float32')) - y = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype('float32')) - target = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype('float32')) - loss = paddle.nn.functional.margin_ranking_loss(x, y, target) + input = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype('float32')) + other = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype('float32')) + label = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype('float32')) + loss = paddle.nn.functional.margin_ranking_loss(input, other, label) print(loss.numpy()) # [0.75] """ if fluid.framework.in_dygraph_mode(): out = core.ops.elementwise_sub(other, input) - out = core.ops.elementwise_mul(out, target) + out = core.ops.elementwise_mul(out, label) if margin != 0.0: margin = fluid.dygraph.base.to_variable([margin], dtype=out.dtype) out = core.ops.elementwise_add(out, margin) @@ -140,10 +140,10 @@ def margin_ranking_loss(input, fluid.data_feeder.check_variable_and_dtype( other, 'other', ['float32', 'float64'], 'margin_rank_loss') fluid.data_feeder.check_variable_and_dtype( - target, 'target', ['float32', 'float64'], 'margin_rank_loss') + label, 'label', ['float32', 'float64'], 'margin_rank_loss') out = paddle.elementwise_sub(other, input) - out = paddle.multiply(out, target) + out = paddle.multiply(out, label) if margin != 0.0: margin_var = out.block.create_var(dtype=out.dtype) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 6a478e44fe8..bc4f32f9c31 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -648,11 +648,11 @@ class MarginRankingLoss(fluid.dygraph.Layer): """ This interface is used to construct a callable object of the ``MarginRankingLoss`` class. - The MarginRankingLoss layer calculates the margin rank loss between the input, other and target + The MarginRankingLoss layer calculates the margin rank loss between the input, other and label , use the math function as follows. .. math:: - margin\_rank\_loss = max(0, -target * (input - other) + margin) + margin\_rank\_loss = max(0, -label * (input - other) + margin) If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: @@ -674,8 +674,8 @@ class MarginRankingLoss(fluid.dygraph.Layer): Shape: input: N-D Tensor, the shape is [N, *], N is batch size and `*` means any number of additional dimensions., available dtype is float32, float64. other: N-D Tensor, `other` have the same shape and dtype as `input`. - target: N-D Tensor, target have the same shape and dtype as `input`. - out: If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor. + label: N-D Tensor, label have the same shape and dtype as `input`. + output: If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor. Returns: A callable object of MarginRankingLoss. @@ -691,9 +691,9 @@ class MarginRankingLoss(fluid.dygraph.Layer): input = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype("float32")) other = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype("float32")) - target = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype("float32")) + label = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype("float32")) margin_rank_loss = paddle.nn.MarginRankingLoss() - loss = margin_rank_loss(input, other, target) + loss = margin_rank_loss(input, other, label) print(loss.numpy()) # [0.75] """ @@ -707,7 +707,7 @@ class MarginRankingLoss(fluid.dygraph.Layer): self.reduction = reduction self.name = name - def forward(self, input, other, target): + def forward(self, input, other, label): out = paddle.nn.functional.margin_ranking_loss( - input, other, target, self.margin, self.reduction, self.name) + input, other, label, self.margin, self.reduction, self.name) return out -- GitLab