未验证 提交 8474392d 编写于 作者: Y yangguohao 提交者: GitHub

【Paddle Hackathon No.11】 (#45595)

* 2022-08-30_update nn.layer.loss nn.functional.loss, test_file

* 2022-08-30_update nn.layer.loss nn.functional.loss, test_file

* fix: test_file

* fix: test_file, docs, multi_margin_loss

* fix: doc weight function

* fix: test_multi_margin_loss

* fix: weight np.testing.assert_allclose

* fix: test_file

* fix: en_doc

* 2022-10-10
上级 450af30c
# -*- coding: utf-8 -*
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
def call_MultiMarginLoss_layer(
input,
label,
p=1,
margin=1.0,
weight=None,
reduction='mean',
):
triplet_margin_loss = paddle.nn.MultiMarginLoss(p=p,
margin=margin,
weight=weight,
reduction=reduction)
res = triplet_margin_loss(
input=input,
label=label,
)
return res
def call_MultiMarginLoss_functional(
input,
label,
p=1,
margin=1.0,
weight=None,
reduction='mean',
):
res = paddle.nn.functional.multi_margin_loss(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
return res
def test_static(place,
input_np,
label_np,
p=1,
margin=1.0,
weight_np=None,
reduction='mean',
functional=False):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data(name='input',
shape=input_np.shape,
dtype=input_np.dtype)
label = paddle.static.data(name='label',
shape=label_np.shape,
dtype=label_np.dtype)
feed_dict = {
"input": input_np,
"label": label_np,
}
weight = None
if weight_np is not None:
weight = paddle.static.data(name='weight',
shape=weight_np.shape,
dtype=weight_np.dtype)
feed_dict['weight'] = weight_np
if functional:
res = call_MultiMarginLoss_functional(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
else:
res = call_MultiMarginLoss_layer(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result[0]
def test_static_data_shape(place,
input_np,
label_np,
wrong_label_shape=None,
weight_np=None,
wrong_weight_shape=None,
functional=False):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data(name='input',
shape=input_np.shape,
dtype=input_np.dtype)
if wrong_label_shape is None:
label_shape = label_np.shape
else:
label_shape = wrong_label_shape
label = paddle.static.data(name='label',
shape=label_shape,
dtype=label_np.dtype)
feed_dict = {
"input": input_np,
"label": label_np,
}
weight = None
if weight_np is not None:
if wrong_weight_shape is None:
weight_shape = weight_np.shape
else:
weight_shape = wrong_weight_shape
weight = paddle.static.data(name='weight',
shape=weight_shape,
dtype=weight_np.dtype)
feed_dict['weight'] = weight_np
if functional:
res = call_MultiMarginLoss_functional(
input=input,
label=label,
weight=weight,
)
else:
res = call_MultiMarginLoss_layer(
input=input,
label=label,
weight=weight,
)
exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result
def test_dygraph(place,
input,
label,
p=1,
margin=1.0,
weight=None,
reduction='mean',
functional=False):
paddle.disable_static()
input = paddle.to_tensor(input)
label = paddle.to_tensor(label)
if weight is not None:
weight = paddle.to_tensor(weight)
if functional:
dy_res = call_MultiMarginLoss_functional(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
else:
dy_res = call_MultiMarginLoss_layer(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def calc_multi_margin_loss(
input,
label,
p=1,
margin=1.0,
weight=None,
reduction='mean',
):
index_sample = np.array([input[i, label[i]]
for i in range(label.size)]).reshape(-1, 1)
if weight is None:
expected = np.mean(np.maximum(margin + input - index_sample, 0.0)**p,
axis=1) - margin**p / input.shape[1]
else:
weight = np.array([weight[label[i]]
for i in range(label.size)]).reshape(-1, 1)
expected = np.mean(np.maximum(weight * (margin + input - index_sample), 0.0) ** p, axis=1) - weight*(margin ** p / \
input.shape[1])
if reduction == 'mean':
expected = np.mean(expected)
elif reduction == 'sum':
expected = np.sum(expected)
else:
expected = expected
return expected
class TestMultiMarginLoss(unittest.TestCase):
def test_MultiMarginLoss(self):
batch_size = 5
num_classes = 2
shape = (batch_size, num_classes)
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
label = np.random.uniform(0, input.shape[1],
size=(batch_size, )).astype(np.int64)
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
expected = calc_multi_margin_loss(input=input,
label=label,
reduction=reduction)
dy_result = test_dygraph(
place=place,
input=input,
label=label,
reduction=reduction,
)
static_result = test_static(
place=place,
input_np=input,
label_np=label,
reduction=reduction,
)
np.testing.assert_allclose(static_result, expected)
np.testing.assert_allclose(static_result, dy_result)
np.testing.assert_allclose(dy_result, expected)
static_functional = test_static(place=place,
input_np=input,
label_np=label,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
input=input,
label=label,
reduction=reduction,
functional=True)
np.testing.assert_allclose(static_functional, expected)
np.testing.assert_allclose(static_functional, dy_functional)
np.testing.assert_allclose(dy_functional, expected)
def test_MultiMarginLoss_error(self):
paddle.disable_static()
self.assertRaises(ValueError,
paddle.nn.MultiMarginLoss,
reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
label = paddle.to_tensor([0], dtype='int32')
self.assertRaises(ValueError,
paddle.nn.functional.multi_margin_loss,
input=input,
label=label,
reduction="unsupport reduction")
paddle.enable_static()
def test_MultiMarginLoss_dimension(self):
paddle.disable_static()
input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32')
label = paddle.to_tensor([0, 1, 1], dtype='int32')
self.assertRaises(
ValueError,
paddle.nn.functional.multi_margin_loss,
input=input,
label=label,
)
MMLoss = paddle.nn.MultiMarginLoss()
self.assertRaises(
ValueError,
MMLoss,
input=input,
label=label,
)
paddle.enable_static()
def test_MultiMarginLoss_p(self):
p = 2
batch_size = 5
num_classes = 2
shape = (batch_size, num_classes)
reduction = 'mean'
place = paddle.CPUPlace()
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
label = np.random.uniform(0, input.shape[1],
size=(batch_size, )).astype(np.int64)
expected = calc_multi_margin_loss(input=input,
p=p,
label=label,
reduction=reduction)
dy_result = test_dygraph(
place=place,
p=p,
input=input,
label=label,
reduction=reduction,
)
static_result = test_static(
place=place,
p=p,
input_np=input,
label_np=label,
reduction=reduction,
)
np.testing.assert_allclose(static_result, expected)
np.testing.assert_allclose(static_result, dy_result)
np.testing.assert_allclose(dy_result, expected)
static_functional = test_static(place=place,
p=p,
input_np=input,
label_np=label,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
p=p,
input=input,
label=label,
reduction=reduction,
functional=True)
np.testing.assert_allclose(static_functional, expected)
np.testing.assert_allclose(static_functional, dy_functional)
np.testing.assert_allclose(dy_functional, expected)
def test_MultiMarginLoss_weight(self):
batch_size = 5
num_classes = 2
shape = (batch_size, num_classes)
reduction = 'mean'
place = paddle.CPUPlace()
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
label = np.random.uniform(0, input.shape[1],
size=(batch_size, )).astype(np.int64)
weight = np.random.uniform(0, 2,
size=(num_classes, )).astype(np.float64)
expected = calc_multi_margin_loss(input=input,
label=label,
weight=weight,
reduction=reduction)
dy_result = test_dygraph(
place=place,
input=input,
label=label,
weight=weight,
reduction=reduction,
)
static_result = test_static(
place=place,
input_np=input,
label_np=label,
weight_np=weight,
reduction=reduction,
)
np.testing.assert_allclose(static_result, expected)
np.testing.assert_allclose(static_result, dy_result)
np.testing.assert_allclose(dy_result, expected)
static_functional = test_static(place=place,
input_np=input,
label_np=label,
weight_np=weight,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
input=input,
label=label,
weight=weight,
reduction=reduction,
functional=True)
np.testing.assert_allclose(static_functional, expected)
np.testing.assert_allclose(static_functional, dy_functional)
np.testing.assert_allclose(dy_functional, expected)
def test_MultiMarginLoss_static_data_shape(self):
batch_size = 5
num_classes = 2
shape = (batch_size, num_classes)
place = paddle.CPUPlace()
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
label = np.random.uniform(0, input.shape[1],
size=(batch_size, )).astype(np.int64)
weight = np.random.uniform(0, 2,
size=(num_classes, )).astype(np.float64)
self.assertRaises(
ValueError,
test_static_data_shape,
place=place,
input_np=input,
label_np=label,
wrong_label_shape=(10, ),
functional=True,
)
self.assertRaises(
ValueError,
test_static_data_shape,
place=place,
input_np=input,
label_np=label,
wrong_label_shape=(10, ),
functional=False,
)
self.assertRaises(
ValueError,
test_static_data_shape,
place=place,
input_np=input,
label_np=label,
weight_np=weight,
wrong_weight_shape=(3, ),
functional=True,
)
self.assertRaises(
ValueError,
test_static_data_shape,
place=place,
input_np=input,
label_np=label,
weight_np=weight,
wrong_weight_shape=(3, ),
functional=False,
)
if __name__ == "__main__":
unittest.main()
......@@ -109,6 +109,7 @@ from .layer.loss import CTCLoss # noqa: F401
from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import HingeEmbeddingLoss # noqa: F401
from .layer.loss import CosineEmbeddingLoss # noqa: F401
from .layer.loss import MultiMarginLoss
from .layer.loss import TripletMarginWithDistanceLoss
from .layer.loss import TripletMarginLoss
from .layer.loss import SoftMarginLoss
......@@ -319,6 +320,7 @@ __all__ = [ # noqa
'Identity',
'CosineEmbeddingLoss',
'RReLU',
'MultiMarginLoss',
'TripletMarginWithDistanceLoss',
'TripletMarginLoss',
'SoftMarginLoss',
......
......@@ -92,6 +92,7 @@ from .loss import square_error_cost # noqa: F401
from .loss import ctc_loss # noqa: F401
from .loss import hinge_embedding_loss # noqa: F401
from .loss import cosine_embedding_loss # noqa: F401
from .loss import multi_margin_loss
from .loss import multi_label_soft_margin_loss
from .loss import triplet_margin_with_distance_loss
from .loss import triplet_margin_loss
......@@ -241,5 +242,6 @@ __all__ = [ # noqa
'rrelu',
'triplet_margin_with_distance_loss',
'triplet_margin_loss',
'multi_margin_loss',
'soft_margin_loss',
]
......@@ -3457,6 +3457,117 @@ def triplet_margin_loss(input,
return loss
def multi_margin_loss(input,
label,
p: int = 1,
margin: float = 1.0,
weight=None,
reduction='mean',
name=None):
r"""
Measures a multi-class classification hinge loss between input :math:`input` and label :math:`label`:
For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar
output :math:`label_i` is:
.. math::
\text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}}
where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`.
Optionally, you can give non-equal weighting on the classes by passing
a 1D :attr:`weight` tensor into the constructor.
The loss function for i-th sample then becomes:
.. math::
\text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}}
Parameters:
input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes.
label (Tensor): Label tensor, the data type is int32 or int64. The shape of label is (N,)
p (int, Optional): The power num. Default: :math:`1`.
margin (float, Optional): Default: :math:`1`.
weight (Tensor,optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of shape (C,) and the data type is float32, float64.
Default is ``'None'`` .
reduction (str, Optional):Indicate how to calculate the loss by batch_size.
the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, Optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Output: Tensor. The tensor variable storing the multi_margin_loss of input and label.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
label = paddle.to_tensor([1, 2, 1], dtype=paddle.int32)
loss = F.multi_margin_loss(input, label, margin=1.0, reduction='none')
print(loss)
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'multi_margin_loss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
if not _non_static_mode():
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'multi_margin_loss')
check_variable_and_dtype(label, 'label', ['int32', 'int64'],
'multi_margin_loss')
if not (input.shape[0] == label.shape[0]):
raise ValueError(
"The label's shape[0] should be equal to input's shape[0], "
"but received input's shape[0] {} and label's shape[0]:{}. ".format(
input.shape[0], label.shape[0]))
label = label.reshape((-1, 1))
index_sample = paddle.index_sample(input, label)
if weight is not None:
if not _non_static_mode():
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'],
'multi_margin_loss')
if not (input.shape[1] == weight.shape[0]):
raise ValueError(
"The weight's shape[0] should be equal to input's shape[1]"
"but received weight's shape[0]: {} and input's shape[1]: {}".
format(weight.shape[0], input.shape[1]))
weight = paddle.gather(weight, label, axis=0).reshape((-1, 1))
loss = paddle.mean(
paddle.pow(
paddle.clip(weight *
(margin - index_sample + input), min=0.0), p),
axis=1) - weight * (margin**p / paddle.shape(input)[1])
else:
loss = paddle.mean(paddle.pow(
paddle.clip(margin - index_sample + input, min=0.0), p),
axis=1) - margin**p / paddle.shape(input)[1]
if reduction == 'mean':
return paddle.mean(loss, name=name)
elif reduction == 'sum':
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss
def soft_margin_loss(input, label, reduction='mean', name=None):
"""
The API measures the soft margin loss between input predictions ``input``
......
......@@ -83,6 +83,7 @@ from .loss import HingeEmbeddingLoss # noqa: F401
from .loss import TripletMarginWithDistanceLoss
from .loss import TripletMarginLoss
from .loss import SoftMarginLoss
from .loss import MultiMarginLoss
from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401
......
......@@ -1674,6 +1674,103 @@ class TripletMarginLoss(Layer):
name=self.name)
class MultiMarginLoss(Layer):
r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between
input :math:`input` and label :math:`label`:
For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar
output :math:`label_i` is:
.. math::
\text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}}
where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`.
Optionally, you can give non-equal weighting on the classes by passing
a 1D :attr:`weight` tensor into the constructor.
The loss function for i-th sample then becomes:
.. math::
\text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}}
Parameters:
p (int, Optional):The norm degree for pairwise distance. Default: :math:`1`.
margin (float, Optional):Default: :math:`1`.
weight (Tensor,optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of shape (C,) and the data type is float32, float64.
Default is ``'None'`` .
reduction (str, optional): Indicate how to calculate the loss by batch_size,
the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Call parameters:
input (Tensor): Input tensor, the data type is float32 or float64.
label (Tensor): Label tensor, 0<= label < input.shape[1], the data type is int32 or int64.
Shape:
input: 2-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes.
label: 1-D Tensor, the shape is [N,].
output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the label.
Returns:
A callable object of MultiMarginLoss.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
label = paddle.to_tensor([0, 1, 2], dtype=paddle.int32)
multi_margin_loss = nn.MultiMarginLoss(reduction='mean')
loss = multi_margin_loss(input, label)
print(loss)
"""
def __init__(self,
p: int = 1,
margin: float = 1.0,
weight=None,
reduction="mean",
name=None):
super(MultiMarginLoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'MultiMarginLoss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
self.p = p
self.margin = margin
self.weight = weight
self.reduction = reduction
self.name = name
def forward(self, input, label):
return F.multi_margin_loss(input,
label,
p=self.p,
margin=self.margin,
weight=self.weight,
reduction=self.reduction,
name=self.name)
class SoftMarginLoss(Layer):
r"""
Creates a criterion that measures a two-class soft margin loss between input predictions ``input``
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册