未验证 提交 9bbcca2b 编写于 作者: Y yangguohao 提交者: GitHub

[Hackathon No.26] (#40487)

* 'triplet_margin_loss'

* 'test_file_corret'

* '2022_03_27'

* 2022_04_05

* 2022-04-17_1

* 2022-04-17

* 2022-04-17_2

* 2022-04-25

* 2022-05-02_V1

* 2022-05-06_V1

* 2022-05-07_V1

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update test_triplet_margin_loss.py

* Update loss.py

* 2022-06-01_pre-commit

* 2022-06-05

* 2022-06-06

* 2022-06-06

* code_style_check

* code_style_check

* Update loss.py

* 2022-06-07_V2

* Update loss.py

* Update loss.py
上级 6fe10181
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
def call_TripletMarginLoss_layer(
input,
positive,
negative,
p=2,
margin=0.3,
swap=False,
eps=1e-6,
reduction='mean',
):
triplet_margin_loss = paddle.nn.TripletMarginLoss(p=p,
epsilon=eps,
margin=margin,
swap=swap,
reduction=reduction)
res = triplet_margin_loss(
input=input,
positive=positive,
negative=negative,
)
return res
def call_TripletMarginLoss_functional(
input,
positive,
negative,
p=2,
margin=0.3,
swap=False,
eps=1e-6,
reduction='mean',
):
res = paddle.nn.functional.triplet_margin_loss(input=input,
positive=positive,
negative=negative,
p=p,
epsilon=eps,
margin=margin,
swap=swap,
reduction=reduction)
return res
def test_static(place,
input_np,
positive_np,
negative_np,
p=2,
margin=0.3,
swap=False,
eps=1e-6,
reduction='mean',
functional=False):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data(name='input',
shape=input_np.shape,
dtype='float64')
positive = paddle.static.data(name='positive',
shape=positive_np.shape,
dtype='float64')
negative = paddle.static.data(name='negative',
shape=negative_np.shape,
dtype='float64')
feed_dict = {
"input": input_np,
"positive": positive_np,
"negative": negative_np
}
if functional:
res = call_TripletMarginLoss_functional(input=input,
positive=positive,
negative=negative,
p=p,
eps=eps,
margin=margin,
swap=swap,
reduction=reduction)
else:
res = call_TripletMarginLoss_layer(input=input,
positive=positive,
negative=negative,
p=p,
eps=eps,
margin=margin,
swap=swap,
reduction=reduction)
exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result
def test_dygraph(place,
input,
positive,
negative,
p=2,
margin=0.3,
swap=False,
eps=1e-6,
reduction='mean',
functional=False):
paddle.disable_static()
input = paddle.to_tensor(input)
positive = paddle.to_tensor(positive)
negative = paddle.to_tensor(negative)
if functional:
dy_res = call_TripletMarginLoss_functional(input=input,
positive=positive,
negative=negative,
p=p,
eps=eps,
margin=margin,
swap=swap,
reduction=reduction)
else:
dy_res = call_TripletMarginLoss_layer(input=input,
positive=positive,
negative=negative,
p=p,
eps=eps,
margin=margin,
swap=swap,
reduction=reduction)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def calc_triplet_margin_loss(
input,
positive,
negative,
p=2,
margin=0.3,
swap=False,
reduction='mean',
):
positive_dist = np.linalg.norm((input - positive), p, axis=1)
negative_dist = np.linalg.norm((input - negative), p, axis=1)
if swap:
swap_dist = np.linalg.norm((positive - negative), p, axis=1)
negative_dist = np.minimum(negative_dist, swap_dist)
expected = np.maximum(positive_dist - negative_dist + margin, 0)
if reduction == 'mean':
expected = np.mean(expected)
elif reduction == 'sum':
expected = np.sum(expected)
else:
expected = expected
return expected
class TestTripletMarginLoss(unittest.TestCase):
def test_TripletMarginLoss(self):
shape = (2, 2)
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
positive = np.random.uniform(0, 2, size=shape).astype(np.float64)
negative = np.random.uniform(0, 2, size=shape).astype(np.float64)
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
expected = calc_triplet_margin_loss(input=input,
positive=positive,
negative=negative,
reduction=reduction)
dy_result = test_dygraph(
place=place,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
)
static_result = test_static(
place=place,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place=place,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_TripletMarginLoss_error(self):
paddle.disable_static()
self.assertRaises(ValueError,
paddle.nn.loss.TripletMarginLoss,
reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32')
self.assertRaises(ValueError,
paddle.nn.functional.triplet_margin_loss,
input=input,
positive=positive,
negative=negative,
reduction="unsupport reduction")
paddle.enable_static()
def test_TripletMarginLoss_dimension(self):
paddle.disable_static()
input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32')
positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32')
self.assertRaises(
ValueError,
paddle.nn.functional.triplet_margin_loss,
input=input,
positive=positive,
negative=negative,
)
TMLoss = paddle.nn.loss.TripletMarginLoss()
self.assertRaises(
ValueError,
TMLoss,
input=input,
positive=positive,
negative=negative,
)
paddle.enable_static()
def test_TripletMarginLoss_swap(self):
reduction = 'mean'
place = paddle.CPUPlace()
shape = (2, 2)
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
positive = np.random.uniform(0, 2, size=shape).astype(np.float64)
negative = np.random.uniform(0, 2, size=shape).astype(np.float64)
expected = calc_triplet_margin_loss(input=input,
swap=True,
positive=positive,
negative=negative,
reduction=reduction)
dy_result = test_dygraph(
place=place,
swap=True,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
)
static_result = test_static(
place=place,
swap=True,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place=place,
swap=True,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
swap=True,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_TripletMarginLoss_margin(self):
paddle.disable_static()
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32')
margin = -0.5
self.assertRaises(
ValueError,
paddle.nn.functional.triplet_margin_loss,
margin=margin,
input=input,
positive=positive,
negative=negative,
)
paddle.enable_static()
def test_TripletMarginLoss_p(self):
p = 3
shape = (2, 2)
reduction = 'mean'
place = paddle.CPUPlace()
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
positive = np.random.uniform(0, 2, size=shape).astype(np.float64)
negative = np.random.uniform(0, 2, size=shape).astype(np.float64)
expected = calc_triplet_margin_loss(input=input,
p=p,
positive=positive,
negative=negative,
reduction=reduction)
dy_result = test_dygraph(
place=place,
p=p,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
)
static_result = test_static(
place=place,
p=p,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place=place,
p=p,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
p=p,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
if __name__ == "__main__":
unittest.main()
......@@ -109,6 +109,7 @@ from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import HingeEmbeddingLoss # noqa: F401
from .layer.loss import CosineEmbeddingLoss # noqa: F401
from .layer.loss import TripletMarginWithDistanceLoss
from .layer.loss import TripletMarginLoss
from .layer.norm import BatchNorm # noqa: F401
from .layer.norm import SyncBatchNorm # noqa: F401
from .layer.norm import GroupNorm # noqa: F401
......@@ -316,4 +317,5 @@ __all__ = [ # noqa
'CosineEmbeddingLoss',
'RReLU',
'TripletMarginWithDistanceLoss',
'TripletMarginLoss',
]
......@@ -92,6 +92,7 @@ from .loss import ctc_loss # noqa: F401
from .loss import hinge_embedding_loss # noqa: F401
from .loss import cosine_embedding_loss # noqa: F401
from .loss import triplet_margin_with_distance_loss
from .loss import triplet_margin_loss
from .norm import batch_norm # noqa: F401
from .norm import instance_norm # noqa: F401
from .norm import layer_norm # noqa: F401
......@@ -234,4 +235,5 @@ __all__ = [ # noqa
'cosine_embedding_loss',
'rrelu',
'triplet_margin_with_distance_loss',
'triplet_margin_loss',
]
......@@ -28,7 +28,7 @@ from ...static import Variable
from paddle.utils import deprecated
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.framework import core
from paddle.framework import core, _non_static_mode
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode, _current_expected_place
__all__ = []
......@@ -2999,3 +2999,124 @@ def triplet_margin_with_distance_loss(input,
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss
def triplet_margin_loss(input,
positive,
negative,
margin=1.0,
p=2,
epsilon=1e-6,
swap=False,
reduction='mean',
name=None):
r"""
Measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
This is used for measuring a relative similarity between samples. A triplet
is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
examples` respectively). The shapes of all input tensors should be
:math:`(N, *)`.
The loss function for each sample in the mini-batch is:
.. math::
L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}
where
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
Parameters:
input (Tensor): Input tensor, the data type is float32 or float64.
the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.
positive (Tensor): Positive tensor, the data type is float32 or float64.
The shape of label is the same as the shape of input.
negative (Tensor): Negative tensor, the data type is float32 or float64.
The shape of label is the same as the shape of input.
margin (float, Optional): Default: :math:`1`.
p (int, Optional): The norm degree for pairwise distance. Default: :math:`2`.
epsilon (float, Optional): Add small value to avoid division by zero,
default value is 1e-6.
swap (bool,Optional): The distance swap change the negative distance to the distance between
positive sample and negative sample. For more details, see `Learning shallow convolutional feature descriptors with triplet losses`.
Default: ``False``.
reduction (str, Optional):Indicate how to average the loss by batch_size.
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, Optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Output: Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
loss = F.triplet_margin_loss(input, positive, negative, margin=1.0, reduction='none')
print(loss)
# Tensor([0. , 0.57496738, 0. ])
loss = F.triplet_margin_loss(input, positive, negative, margin=1.0, reduction='mean')
print(loss)
# Tensor([0.19165580])
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'triplet_margin_loss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
if margin < 0:
raise ValueError(
"The margin between positive samples and negative samples should be greater than 0."
)
if not _non_static_mode():
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'triplet_margin_loss')
check_variable_and_dtype(positive, 'positive', ['float32', 'float64'],
'triplet_margin_loss')
check_variable_and_dtype(negative, 'negative', ['float32', 'float64'],
'triplet_margin_loss')
if not (input.shape == positive.shape == negative.shape):
raise ValueError("input's shape must equal to "
"positive's shape and "
"negative's shape")
distance_function = paddle.nn.PairwiseDistance(p, epsilon=epsilon)
positive_dist = distance_function(input, positive)
negative_dist = distance_function(input, negative)
if swap:
swap_dist = distance_function(positive, negative)
negative_dist = paddle.minimum(negative_dist, swap_dist)
loss = paddle.clip(positive_dist - negative_dist + margin, min=0.0)
if reduction == 'mean':
return paddle.mean(loss, name=name)
elif reduction == 'sum':
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss
......@@ -80,6 +80,7 @@ from .loss import CTCLoss # noqa: F401
from .loss import SmoothL1Loss # noqa: F401
from .loss import HingeEmbeddingLoss # noqa: F401
from .loss import TripletMarginWithDistanceLoss
from .loss import TripletMarginLoss
from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401
......
......@@ -1507,3 +1507,109 @@ class TripletMarginWithDistanceLoss(Layer):
swap=self.swap,
reduction=self.reduction,
name=self.name)
class TripletMarginLoss(Layer):
r"""
Creates a criterion that measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
This is used for measuring a relative similarity between samples. A triplet
is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
examples` respectively). The shapes of all input tensors should be
:math:`(N, *)`.
The loss function for each sample in the mini-batch is:
.. math::
L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}
where
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
Parameters:
margin (float, Optional):Default: :math:`1`.
p (int, Optional):The norm degree for pairwise distance. Default: :math:`2`.
epsilon (float, Optional):Add small value to avoid division by zero,
default value is 1e-6.
swap (bool, Optional):The distance swap change the negative distance to the distance between
positive sample and negative sample. For more details, see `Learning shallow convolutional feature descriptors with triplet losses`.
Default: ``False``.
reduction (str, Optional):Indicate how to average the loss by batch_size.
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str,Optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Call Parameters:
input (Tensor):Input tensor, the data type is float32 or float64.
the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.
positive (Tensor):Positive tensor, the data type is float32 or float64.
The shape of label is the same as the shape of input.
negative (Tensor):Negative tensor, the data type is float32 or float64.
The shape of label is the same as the shape of input.
Returns:
Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative.
Examples:
.. code-block:: python
import paddle
input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
triplet_margin_loss = paddle.nn.TripletMarginLoss(reduction='none')
loss = triplet_margin_loss(input, positive, negative)
print(loss)
# Tensor([0. , 0.57496738, 0. ])
triplet_margin_loss = paddle.nn.TripletMarginLoss(margin=1.0, swap=True, reduction='mean', )
loss = triplet_margin_loss(input, positive, negative,)
print(loss)
# Tensor([0.19165580])
"""
def __init__(self,
margin=1.0,
p=2.,
epsilon=1e-6,
swap=False,
reduction='mean',
name=None):
super(TripletMarginLoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in TripletMarginLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
self.margin = margin
self.p = p
self.epsilon = epsilon
self.swap = swap
self.reduction = reduction
self.name = name
def forward(self, input, positive, negative):
return F.triplet_margin_loss(input,
positive,
negative,
margin=self.margin,
p=self.p,
epsilon=self.epsilon,
swap=self.swap,
reduction=self.reduction,
name=self.name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册