未验证 提交 8474392d 编写于 作者: Y yangguohao 提交者: GitHub

【Paddle Hackathon No.11】 (#45595)

* 2022-08-30_update nn.layer.loss nn.functional.loss, test_file

* 2022-08-30_update nn.layer.loss nn.functional.loss, test_file

* fix: test_file

* fix: test_file, docs, multi_margin_loss

* fix: doc weight function

* fix: test_multi_margin_loss

* fix: weight np.testing.assert_allclose

* fix: test_file

* fix: en_doc

* 2022-10-10
上级 450af30c
# -*- coding: utf-8 -*
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
def call_MultiMarginLoss_layer(
input,
label,
p=1,
margin=1.0,
weight=None,
reduction='mean',
):
triplet_margin_loss = paddle.nn.MultiMarginLoss(p=p,
margin=margin,
weight=weight,
reduction=reduction)
res = triplet_margin_loss(
input=input,
label=label,
)
return res
def call_MultiMarginLoss_functional(
input,
label,
p=1,
margin=1.0,
weight=None,
reduction='mean',
):
res = paddle.nn.functional.multi_margin_loss(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
return res
def test_static(place,
input_np,
label_np,
p=1,
margin=1.0,
weight_np=None,
reduction='mean',
functional=False):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data(name='input',
shape=input_np.shape,
dtype=input_np.dtype)
label = paddle.static.data(name='label',
shape=label_np.shape,
dtype=label_np.dtype)
feed_dict = {
"input": input_np,
"label": label_np,
}
weight = None
if weight_np is not None:
weight = paddle.static.data(name='weight',
shape=weight_np.shape,
dtype=weight_np.dtype)
feed_dict['weight'] = weight_np
if functional:
res = call_MultiMarginLoss_functional(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
else:
res = call_MultiMarginLoss_layer(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result[0]
def test_static_data_shape(place,
input_np,
label_np,
wrong_label_shape=None,
weight_np=None,
wrong_weight_shape=None,
functional=False):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data(name='input',
shape=input_np.shape,
dtype=input_np.dtype)
if wrong_label_shape is None:
label_shape = label_np.shape
else:
label_shape = wrong_label_shape
label = paddle.static.data(name='label',
shape=label_shape,
dtype=label_np.dtype)
feed_dict = {
"input": input_np,
"label": label_np,
}
weight = None
if weight_np is not None:
if wrong_weight_shape is None:
weight_shape = weight_np.shape
else:
weight_shape = wrong_weight_shape
weight = paddle.static.data(name='weight',
shape=weight_shape,
dtype=weight_np.dtype)
feed_dict['weight'] = weight_np
if functional:
res = call_MultiMarginLoss_functional(
input=input,
label=label,
weight=weight,
)
else:
res = call_MultiMarginLoss_layer(
input=input,
label=label,
weight=weight,
)
exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result
def test_dygraph(place,
input,
label,
p=1,
margin=1.0,
weight=None,
reduction='mean',
functional=False):
paddle.disable_static()
input = paddle.to_tensor(input)
label = paddle.to_tensor(label)
if weight is not None:
weight = paddle.to_tensor(weight)
if functional:
dy_res = call_MultiMarginLoss_functional(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
else:
dy_res = call_MultiMarginLoss_layer(input=input,
label=label,
p=p,
margin=margin,
weight=weight,
reduction=reduction)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def calc_multi_margin_loss(
input,
label,
p=1,
margin=1.0,
weight=None,
reduction='mean',
):
index_sample = np.array([input[i, label[i]]
for i in range(label.size)]).reshape(-1, 1)
if weight is None:
expected = np.mean(np.maximum(margin + input - index_sample, 0.0)**p,
axis=1) - margin**p / input.shape[1]
else:
weight = np.array([weight[label[i]]
for i in range(label.size)]).reshape(-1, 1)
expected = np.mean(np.maximum(weight * (margin + input - index_sample), 0.0) ** p, axis=1) - weight*(margin ** p / \
input.shape[1])
if reduction == 'mean':
expected = np.mean(expected)
elif reduction == 'sum':
expected = np.sum(expected)
else:
expected = expected
return expected
class TestMultiMarginLoss(unittest.TestCase):
def test_MultiMarginLoss(self):
batch_size = 5
num_classes = 2
shape = (batch_size, num_classes)
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
label = np.random.uniform(0, input.shape[1],
size=(batch_size, )).astype(np.int64)
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
expected = calc_multi_margin_loss(input=input,
label=label,
reduction=reduction)
dy_result = test_dygraph(
place=place,
input=input,
label=label,
reduction=reduction,
)
static_result = test_static(
place=place,
input_np=input,
label_np=label,
reduction=reduction,
)
np.testing.assert_allclose(static_result, expected)
np.testing.assert_allclose(static_result, dy_result)
np.testing.assert_allclose(dy_result, expected)
static_functional = test_static(place=place,
input_np=input,
label_np=label,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
input=input,
label=label,
reduction=reduction,
functional=True)
np.testing.assert_allclose(static_functional, expected)
np.testing.assert_allclose(static_functional, dy_functional)
np.testing.assert_allclose(dy_functional, expected)
def test_MultiMarginLoss_error(self):
paddle.disable_static()
self.assertRaises(ValueError,
paddle.nn.MultiMarginLoss,
reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
label = paddle.to_tensor([0], dtype='int32')
self.assertRaises(ValueError,
paddle.nn.functional.multi_margin_loss,
input=input,
label=label,
reduction="unsupport reduction")
paddle.enable_static()
def test_MultiMarginLoss_dimension(self):
paddle.disable_static()
input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32')
label = paddle.to_tensor([0, 1, 1], dtype='int32')
self.assertRaises(
ValueError,
paddle.nn.functional.multi_margin_loss,
input=input,
label=label,
)
MMLoss = paddle.nn.MultiMarginLoss()
self.assertRaises(
ValueError,
MMLoss,
input=input,
label=label,
)
paddle.enable_static()
def test_MultiMarginLoss_p(self):
p = 2
batch_size = 5
num_classes = 2
shape = (batch_size, num_classes)
reduction = 'mean'
place = paddle.CPUPlace()
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
label = np.random.uniform(0, input.shape[1],
size=(batch_size, )).astype(np.int64)
expected = calc_multi_margin_loss(input=input,
p=p,
label=label,
reduction=reduction)
dy_result = test_dygraph(
place=place,
p=p,
input=input,
label=label,
reduction=reduction,
)
static_result = test_static(
place=place,
p=p,
input_np=input,
label_np=label,
reduction=reduction,
)
np.testing.assert_allclose(static_result, expected)
np.testing.assert_allclose(static_result, dy_result)
np.testing.assert_allclose(dy_result, expected)
static_functional = test_static(place=place,
p=p,
input_np=input,
label_np=label,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
p=p,
input=input,
label=label,
reduction=reduction,
functional=True)
np.testing.assert_allclose(static_functional, expected)
np.testing.assert_allclose(static_functional, dy_functional)
np.testing.assert_allclose(dy_functional, expected)
def test_MultiMarginLoss_weight(self):
batch_size = 5
num_classes = 2
shape = (batch_size, num_classes)
reduction = 'mean'
place = paddle.CPUPlace()
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
label = np.random.uniform(0, input.shape[1],
size=(batch_size, )).astype(np.int64)
weight = np.random.uniform(0, 2,
size=(num_classes, )).astype(np.float64)
expected = calc_multi_margin_loss(input=input,
label=label,
weight=weight,
reduction=reduction)
dy_result = test_dygraph(
place=place,
input=input,
label=label,
weight=weight,
reduction=reduction,
)
static_result = test_static(
place=place,
input_np=input,
label_np=label,
weight_np=weight,
reduction=reduction,
)
np.testing.assert_allclose(static_result, expected)
np.testing.assert_allclose(static_result, dy_result)
np.testing.assert_allclose(dy_result, expected)
static_functional = test_static(place=place,
input_np=input,
label_np=label,
weight_np=weight,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
input=input,
label=label,
weight=weight,
reduction=reduction,
functional=True)
np.testing.assert_allclose(static_functional, expected)
np.testing.assert_allclose(static_functional, dy_functional)
np.testing.assert_allclose(dy_functional, expected)
def test_MultiMarginLoss_static_data_shape(self):
batch_size = 5
num_classes = 2
shape = (batch_size, num_classes)
place = paddle.CPUPlace()
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
label = np.random.uniform(0, input.shape[1],
size=(batch_size, )).astype(np.int64)
weight = np.random.uniform(0, 2,
size=(num_classes, )).astype(np.float64)
self.assertRaises(
ValueError,
test_static_data_shape,
place=place,
input_np=input,
label_np=label,
wrong_label_shape=(10, ),
functional=True,
)
self.assertRaises(
ValueError,
test_static_data_shape,
place=place,
input_np=input,
label_np=label,
wrong_label_shape=(10, ),
functional=False,
)
self.assertRaises(
ValueError,
test_static_data_shape,
place=place,
input_np=input,
label_np=label,
weight_np=weight,
wrong_weight_shape=(3, ),
functional=True,
)
self.assertRaises(
ValueError,
test_static_data_shape,
place=place,
input_np=input,
label_np=label,
weight_np=weight,
wrong_weight_shape=(3, ),
functional=False,
)
if __name__ == "__main__":
unittest.main()
...@@ -109,6 +109,7 @@ from .layer.loss import CTCLoss # noqa: F401 ...@@ -109,6 +109,7 @@ from .layer.loss import CTCLoss # noqa: F401
from .layer.loss import SmoothL1Loss # noqa: F401 from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import HingeEmbeddingLoss # noqa: F401 from .layer.loss import HingeEmbeddingLoss # noqa: F401
from .layer.loss import CosineEmbeddingLoss # noqa: F401 from .layer.loss import CosineEmbeddingLoss # noqa: F401
from .layer.loss import MultiMarginLoss
from .layer.loss import TripletMarginWithDistanceLoss from .layer.loss import TripletMarginWithDistanceLoss
from .layer.loss import TripletMarginLoss from .layer.loss import TripletMarginLoss
from .layer.loss import SoftMarginLoss from .layer.loss import SoftMarginLoss
...@@ -319,6 +320,7 @@ __all__ = [ # noqa ...@@ -319,6 +320,7 @@ __all__ = [ # noqa
'Identity', 'Identity',
'CosineEmbeddingLoss', 'CosineEmbeddingLoss',
'RReLU', 'RReLU',
'MultiMarginLoss',
'TripletMarginWithDistanceLoss', 'TripletMarginWithDistanceLoss',
'TripletMarginLoss', 'TripletMarginLoss',
'SoftMarginLoss', 'SoftMarginLoss',
......
...@@ -92,6 +92,7 @@ from .loss import square_error_cost # noqa: F401 ...@@ -92,6 +92,7 @@ from .loss import square_error_cost # noqa: F401
from .loss import ctc_loss # noqa: F401 from .loss import ctc_loss # noqa: F401
from .loss import hinge_embedding_loss # noqa: F401 from .loss import hinge_embedding_loss # noqa: F401
from .loss import cosine_embedding_loss # noqa: F401 from .loss import cosine_embedding_loss # noqa: F401
from .loss import multi_margin_loss
from .loss import multi_label_soft_margin_loss from .loss import multi_label_soft_margin_loss
from .loss import triplet_margin_with_distance_loss from .loss import triplet_margin_with_distance_loss
from .loss import triplet_margin_loss from .loss import triplet_margin_loss
...@@ -241,5 +242,6 @@ __all__ = [ # noqa ...@@ -241,5 +242,6 @@ __all__ = [ # noqa
'rrelu', 'rrelu',
'triplet_margin_with_distance_loss', 'triplet_margin_with_distance_loss',
'triplet_margin_loss', 'triplet_margin_loss',
'multi_margin_loss',
'soft_margin_loss', 'soft_margin_loss',
] ]
...@@ -3457,6 +3457,117 @@ def triplet_margin_loss(input, ...@@ -3457,6 +3457,117 @@ def triplet_margin_loss(input,
return loss return loss
def multi_margin_loss(input,
label,
p: int = 1,
margin: float = 1.0,
weight=None,
reduction='mean',
name=None):
r"""
Measures a multi-class classification hinge loss between input :math:`input` and label :math:`label`:
For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar
output :math:`label_i` is:
.. math::
\text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}}
where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`.
Optionally, you can give non-equal weighting on the classes by passing
a 1D :attr:`weight` tensor into the constructor.
The loss function for i-th sample then becomes:
.. math::
\text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}}
Parameters:
input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes.
label (Tensor): Label tensor, the data type is int32 or int64. The shape of label is (N,)
p (int, Optional): The power num. Default: :math:`1`.
margin (float, Optional): Default: :math:`1`.
weight (Tensor,optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of shape (C,) and the data type is float32, float64.
Default is ``'None'`` .
reduction (str, Optional):Indicate how to calculate the loss by batch_size.
the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, Optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Output: Tensor. The tensor variable storing the multi_margin_loss of input and label.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
label = paddle.to_tensor([1, 2, 1], dtype=paddle.int32)
loss = F.multi_margin_loss(input, label, margin=1.0, reduction='none')
print(loss)
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'multi_margin_loss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
if not _non_static_mode():
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'multi_margin_loss')
check_variable_and_dtype(label, 'label', ['int32', 'int64'],
'multi_margin_loss')
if not (input.shape[0] == label.shape[0]):
raise ValueError(
"The label's shape[0] should be equal to input's shape[0], "
"but received input's shape[0] {} and label's shape[0]:{}. ".format(
input.shape[0], label.shape[0]))
label = label.reshape((-1, 1))
index_sample = paddle.index_sample(input, label)
if weight is not None:
if not _non_static_mode():
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'],
'multi_margin_loss')
if not (input.shape[1] == weight.shape[0]):
raise ValueError(
"The weight's shape[0] should be equal to input's shape[1]"
"but received weight's shape[0]: {} and input's shape[1]: {}".
format(weight.shape[0], input.shape[1]))
weight = paddle.gather(weight, label, axis=0).reshape((-1, 1))
loss = paddle.mean(
paddle.pow(
paddle.clip(weight *
(margin - index_sample + input), min=0.0), p),
axis=1) - weight * (margin**p / paddle.shape(input)[1])
else:
loss = paddle.mean(paddle.pow(
paddle.clip(margin - index_sample + input, min=0.0), p),
axis=1) - margin**p / paddle.shape(input)[1]
if reduction == 'mean':
return paddle.mean(loss, name=name)
elif reduction == 'sum':
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss
def soft_margin_loss(input, label, reduction='mean', name=None): def soft_margin_loss(input, label, reduction='mean', name=None):
""" """
The API measures the soft margin loss between input predictions ``input`` The API measures the soft margin loss between input predictions ``input``
......
...@@ -83,6 +83,7 @@ from .loss import HingeEmbeddingLoss # noqa: F401 ...@@ -83,6 +83,7 @@ from .loss import HingeEmbeddingLoss # noqa: F401
from .loss import TripletMarginWithDistanceLoss from .loss import TripletMarginWithDistanceLoss
from .loss import TripletMarginLoss from .loss import TripletMarginLoss
from .loss import SoftMarginLoss from .loss import SoftMarginLoss
from .loss import MultiMarginLoss
from .norm import BatchNorm1D # noqa: F401 from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401 from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401 from .norm import BatchNorm3D # noqa: F401
......
...@@ -1674,6 +1674,103 @@ class TripletMarginLoss(Layer): ...@@ -1674,6 +1674,103 @@ class TripletMarginLoss(Layer):
name=self.name) name=self.name)
class MultiMarginLoss(Layer):
r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between
input :math:`input` and label :math:`label`:
For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar
output :math:`label_i` is:
.. math::
\text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}}
where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`.
Optionally, you can give non-equal weighting on the classes by passing
a 1D :attr:`weight` tensor into the constructor.
The loss function for i-th sample then becomes:
.. math::
\text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}}
Parameters:
p (int, Optional):The norm degree for pairwise distance. Default: :math:`1`.
margin (float, Optional):Default: :math:`1`.
weight (Tensor,optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of shape (C,) and the data type is float32, float64.
Default is ``'None'`` .
reduction (str, optional): Indicate how to calculate the loss by batch_size,
the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Call parameters:
input (Tensor): Input tensor, the data type is float32 or float64.
label (Tensor): Label tensor, 0<= label < input.shape[1], the data type is int32 or int64.
Shape:
input: 2-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes.
label: 1-D Tensor, the shape is [N,].
output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the label.
Returns:
A callable object of MultiMarginLoss.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
label = paddle.to_tensor([0, 1, 2], dtype=paddle.int32)
multi_margin_loss = nn.MultiMarginLoss(reduction='mean')
loss = multi_margin_loss(input, label)
print(loss)
"""
def __init__(self,
p: int = 1,
margin: float = 1.0,
weight=None,
reduction="mean",
name=None):
super(MultiMarginLoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'MultiMarginLoss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
self.p = p
self.margin = margin
self.weight = weight
self.reduction = reduction
self.name = name
def forward(self, input, label):
return F.multi_margin_loss(input,
label,
p=self.p,
margin=self.margin,
weight=self.weight,
reduction=self.reduction,
name=self.name)
class SoftMarginLoss(Layer): class SoftMarginLoss(Layer):
r""" r"""
Creates a criterion that measures a two-class soft margin loss between input predictions ``input`` Creates a criterion that measures a two-class soft margin loss between input predictions ``input``
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册