未验证 提交 5d48528f 编写于 作者: Y yangguohao 提交者: GitHub

【Hachathon No.30】 (#40545)

* 'TripletMarginDistanceLoss'

* 'test_file'

* '2022_03_27'

* 2022-03-31

* 2022-04-05

* 2

* 2022-04-17

* 2022-04-17_2

* 2022-04-17_3

* 2022-04-17_4

* 2022-04-25

* 2022-05-02_V1

* 2022-05-06_V1

* 2022-05-07_V1

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* Update loss.py

* 2022-06-01_pre-commit

* 2022-06-05

* 2022-06-06

* 2022-06-07

* 2022-06-07_V2
上级 0d0258f8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
def call_TripletMarginDistanceLoss_layer(
input,
positive,
negative,
distance_function=None,
margin=0.3,
swap=False,
reduction='mean',
):
triplet_margin_with_distance_loss = paddle.nn.TripletMarginWithDistanceLoss(
distance_function=distance_function,
margin=margin,
swap=swap,
reduction=reduction)
res = triplet_margin_with_distance_loss(
input=input,
positive=positive,
negative=negative,
)
return res
def call_TripletMaginDistanceLoss_functional(
input,
positive,
negative,
distance_function=None,
margin=0.3,
swap=False,
reduction='mean',
):
res = paddle.nn.functional.triplet_margin_with_distance_loss(
input=input,
positive=positive,
negative=negative,
distance_function=distance_function,
margin=margin,
swap=swap,
reduction=reduction)
return res
def test_static(place,
input_np,
positive_np,
negative_np,
distance_function=None,
margin=0.3,
swap=False,
reduction='mean',
functional=False):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data(name='input',
shape=input_np.shape,
dtype='float64')
positive = paddle.static.data(name='positive',
shape=positive_np.shape,
dtype='float64')
negative = paddle.static.data(name='negative',
shape=negative_np.shape,
dtype='float64')
feed_dict = {
"input": input_np,
"positive": positive_np,
"negative": negative_np
}
if functional:
res = call_TripletMaginDistanceLoss_functional(
input=input,
positive=positive,
negative=negative,
distance_function=distance_function,
margin=margin,
swap=swap,
reduction=reduction)
else:
res = call_TripletMarginDistanceLoss_layer(
input=input,
positive=positive,
negative=negative,
distance_function=distance_function,
margin=margin,
swap=swap,
reduction=reduction)
exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result
def test_dygraph(place,
input,
positive,
negative,
distance_function=None,
margin=0.3,
swap=False,
reduction='mean',
functional=False):
paddle.disable_static()
input = paddle.to_tensor(input)
positive = paddle.to_tensor(positive)
negative = paddle.to_tensor(negative)
if functional:
dy_res = call_TripletMaginDistanceLoss_functional(
input=input,
positive=positive,
negative=negative,
distance_function=distance_function,
margin=margin,
swap=swap,
reduction=reduction)
else:
dy_res = call_TripletMarginDistanceLoss_layer(
input=input,
positive=positive,
negative=negative,
distance_function=distance_function,
margin=margin,
swap=swap,
reduction=reduction)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def calc_triplet_margin_distance_loss(
input,
positive,
negative,
distance_function=None,
margin=0.3,
swap=False,
reduction='mean',
):
distance_function = np.linalg.norm
positive_dist = distance_function((input - positive), 2, axis=1)
negative_dist = distance_function((input - negative), 2, axis=1)
if swap:
swap_dist = np.linalg.norm((positive - negative), 2, axis=1)
negative_dist = np.minimum(negative_dist, swap_dist)
expected = np.maximum(positive_dist - negative_dist + margin, 0)
if reduction == 'mean':
expected = np.mean(expected)
elif reduction == 'sum':
expected = np.sum(expected)
else:
expected = expected
return expected
class TestTripletMarginWithDistanceLoss(unittest.TestCase):
def test_TripletMarginDistanceLoss(self):
shape = (5, 5)
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
positive = np.random.uniform(0, 2, size=shape).astype(np.float64)
negative = np.random.uniform(0, 2, size=shape).astype(np.float64)
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
expected = calc_triplet_margin_distance_loss(
input=input,
positive=positive,
negative=negative,
reduction=reduction)
dy_result = test_dygraph(
place=place,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
)
static_result = test_static(
place=place,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place=place,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_TripletMarginDistanceLoss_error(self):
paddle.disable_static()
self.assertRaises(ValueError,
paddle.nn.TripletMarginWithDistanceLoss,
reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32')
self.assertRaises(
ValueError,
paddle.nn.functional.triplet_margin_with_distance_loss,
input=input,
positive=positive,
negative=negative,
reduction="unsupport reduction")
paddle.enable_static()
def test_TripletMarginDistanceLoss_distance_function(self):
def distance_function_1(x1, x2):
return 1.0 - paddle.nn.functional.cosine_similarity(x1, x2)
def distance_function_2(x1, x2):
return paddle.max(paddle.abs(x1 - x2), axis=1)
distance_function_list = [distance_function_1, distance_function_2]
shape = (5, 5)
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
positive = np.random.uniform(0, 2, size=shape).astype(np.float64)
negative = np.random.uniform(0, 2, size=shape).astype(np.float64)
place = paddle.CPUPlace()
reduction = 'mean'
for distance_function in distance_function_list:
dy_result = test_dygraph(
place=place,
input=input,
positive=positive,
negative=negative,
distance_function=distance_function,
reduction=reduction,
)
static_result = test_static(
place=place,
input_np=input,
positive_np=positive,
negative_np=negative,
distance_function=distance_function,
reduction=reduction,
)
self.assertTrue(np.allclose(static_result, dy_result))
static_functional = test_static(place=place,
input_np=input,
positive_np=positive,
negative_np=negative,
distance_function=distance_function,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
input=input,
positive=positive,
negative=negative,
distance_function=distance_function,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, dy_functional))
def test_TripletMarginWithDistanceLoss_distance_funtion_error(self):
paddle.disable_static()
def distance_function(x1, x2):
return -1.0 - paddle.nn.functional.cosine_similarity(x1, x2)
func = distance_function
shape = (5, 5)
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
positive = np.random.uniform(0, 2, size=shape).astype(np.float64)
negative = np.random.uniform(0, 2, size=shape).astype(np.float64)
self.assertRaises(
ValueError,
paddle.nn.functional.triplet_margin_with_distance_loss,
input=input,
positive=positive,
negative=negative,
distance_function=func,
)
paddle.enable_static()
def test_TripletMarginDistanceLoss_dimension(self):
paddle.disable_static()
input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32')
positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32')
self.assertRaises(
ValueError,
paddle.nn.functional.triplet_margin_with_distance_loss,
input=input,
positive=positive,
negative=negative,
)
triplet_margin_with_distance_loss = paddle.nn.loss.TripletMarginWithDistanceLoss(
)
self.assertRaises(
ValueError,
triplet_margin_with_distance_loss,
input=input,
positive=positive,
negative=negative,
)
paddle.enable_static()
def test_TripletMarginWithDistanceLoss_swap(self):
reduction = 'mean'
place = paddle.CPUPlace()
shape = (5, 5)
input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64)
positive = np.random.uniform(0, 2, size=shape).astype(np.float64)
negative = np.random.uniform(0, 2, size=shape).astype(np.float64)
expected = calc_triplet_margin_distance_loss(input=input,
swap=True,
positive=positive,
negative=negative,
reduction=reduction)
dy_result = test_dygraph(
place=place,
swap=True,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
)
static_result = test_static(
place=place,
swap=True,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place=place,
swap=True,
input_np=input,
positive_np=positive,
negative_np=negative,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
swap=True,
input=input,
positive=positive,
negative=negative,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_TripletMarginWithDistanceLoss_margin(self):
paddle.disable_static()
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32')
margin = -0.5
self.assertRaises(
ValueError,
paddle.nn.functional.triplet_margin_with_distance_loss,
margin=margin,
input=input,
positive=positive,
negative=negative,
)
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -108,6 +108,7 @@ from .layer.loss import CTCLoss # noqa: F401
from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import HingeEmbeddingLoss # noqa: F401
from .layer.loss import CosineEmbeddingLoss # noqa: F401
from .layer.loss import TripletMarginWithDistanceLoss
from .layer.norm import BatchNorm # noqa: F401
from .layer.norm import SyncBatchNorm # noqa: F401
from .layer.norm import GroupNorm # noqa: F401
......@@ -154,7 +155,7 @@ from . import functional # noqa: F401
from . import initializer # noqa: F401
from . import quant # noqa: F401
#TODO: remove 'diag_embed', 'remove_weight_norm', 'weight_norm' months later.
# TODO: remove 'diag_embed', 'remove_weight_norm', 'weight_norm' months later.
import paddle.utils.deprecated as deprecated
......@@ -191,7 +192,7 @@ def weight_norm(*args):
return utils.weight_norm(*args)
__all__ = [ #noqa
__all__ = [ # noqa
'BatchNorm',
'CELU',
'GroupNorm',
......@@ -314,4 +315,5 @@ __all__ = [ #noqa
'Identity',
'CosineEmbeddingLoss',
'RReLU',
'TripletMarginWithDistanceLoss',
]
......@@ -91,6 +91,7 @@ from .loss import square_error_cost # noqa: F401
from .loss import ctc_loss # noqa: F401
from .loss import hinge_embedding_loss # noqa: F401
from .loss import cosine_embedding_loss # noqa: F401
from .loss import triplet_margin_with_distance_loss
from .norm import batch_norm # noqa: F401
from .norm import instance_norm # noqa: F401
from .norm import layer_norm # noqa: F401
......@@ -125,7 +126,7 @@ from .extension import temporal_shift # noqa: F401
from .sparse_attention import sparse_attention
__all__ = [ #noqa
__all__ = [ # noqa
'celu',
'conv1d',
'conv1d_transpose',
......@@ -232,4 +233,5 @@ __all__ = [ #noqa
'fold',
'cosine_embedding_loss',
'rrelu',
'triplet_margin_with_distance_loss',
]
......@@ -2872,3 +2872,130 @@ def cosine_embedding_loss(input1,
return paddle.mean(out, name=name)
elif reduction == 'sum':
return paddle.sum(out, name=name)
def triplet_margin_with_distance_loss(input,
positive,
negative,
distance_function=None,
margin=1.0,
swap=False,
reduction='mean',
name=None):
r"""
Measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
This is used for measuring a relative similarity between samples. A triplet
is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
examples` respectively). The shapes of all input tensors should be
:math:`(N, D)`.
The loss function for each sample in the mini-batch is:
.. math::
L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}
where the default distance function
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
or user can defined their own distance functions. `margin` is a nonnegative margin representing the minimum difference
between the positive and negative distances that is required for the loss to be 0. If `swap` is true, it will compare distance of (input, negative) with
distance of (negative, positive) and change it to the smaller one. For more details see http://www.bmva.org/bmvc/2016/papers/paper119/paper119.pdf.
Parameters:
input (Tensor):Input tensor, the data type is float32 or float64.
the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.
positive (Tensor):Positive tensor, the data type is float32 or float64.
The shape of label is the same as the shape of input.
negative (Tensor):Negative tensor, the data type is float32 or float64.
The shape of label is the same as the shape of input.
distance_function (callable, optional): Quantifies the distance between two tensors. if not specified, 2 norm functions will be used.
margin (float, optional):Default: :math:`1`.A nonnegative margin representing the minimum difference
between the positive and negative distances required for the loss to be 0.
swap (bool, optional):The distance swap changes the negative distance to the swap distance (distance between positive samples
and negative samples) if swap distance smaller than negative distance. Default: ``False``.
reduction (str, optional):Indicate how to average the loss by batch_size.
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Output: Tensor. The tensor variable storing the triplet_margin_with_distance_loss of input and positive and negative.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
loss = F.triplet_margin_with_distance_loss(input, positive, negative, margin=1.0, reduction='none')
print(loss)
# Tensor([0. , 0.57496738, 0. ])
loss = F.triplet_margin_with_distance_loss(input, positive, negative, margin=1.0, reduction='mean')
print(loss)
# Tensor([0.19165580])
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError("'reduction' in 'triplet_margin_with_distance_loss' "
"should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
if margin < 0:
raise ValueError(
"The margin between positive samples and negative samples should be greater than 0."
)
if not _non_static_mode():
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'triplet_margin_with_distance_loss')
check_variable_and_dtype(positive, 'positive', ['float32', 'float64'],
'triplet_margin_with_distance_loss')
check_variable_and_dtype(negative, 'negative', ['float32', 'float64'],
'triplet_margin_with_distance_loss')
if not (input.shape == positive.shape == negative.shape):
raise ValueError("input's shape must equal to "
"positive's shape and "
"negative's shape")
distance_function = distance_function if distance_function is not None \
else paddle.nn.PairwiseDistance(2)
positive_dist = distance_function(input, positive)
negative_dist = distance_function(input, negative)
if swap:
swap_dist = distance_function(positive, negative)
negative_dist = paddle.minimum(negative_dist, swap_dist)
if not paddle.all(positive_dist > 0) or not paddle.all(negative_dist > 0):
raise ValueError(
"The positive distance or negative distance should be greater than 0, "
"The distance functions should be checked.")
loss = paddle.clip(positive_dist - negative_dist + margin, min=0.0)
if reduction == 'mean':
return paddle.mean(loss, name=name)
elif reduction == 'sum':
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss
......@@ -79,6 +79,7 @@ from .loss import MarginRankingLoss # noqa: F401
from .loss import CTCLoss # noqa: F401
from .loss import SmoothL1Loss # noqa: F401
from .loss import HingeEmbeddingLoss # noqa: F401
from .loss import TripletMarginWithDistanceLoss
from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401
......
......@@ -1400,3 +1400,110 @@ class CosineEmbeddingLoss(Layer):
margin=self.margin,
reduction=self.reduction,
name=self.name)
class TripletMarginWithDistanceLoss(Layer):
r"""
Creates a criterion that measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
This is used for measuring a relative similarity between samples. A triplet
is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
examples` respectively). The shapes of all input tensors should be
:math:`(N, D)`.
The loss function for each sample in the mini-batch is:
.. math::
L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}
where the default `distance_function`
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_2
or user can define their own distance function. `margin` is a nonnegative margin representing the minimum difference
between the positive and negative distances that is required for the loss to be 0. If `swap` is true, it will compare distance of (input, negative) with
distance of (negative, positive) and change it to the smaller one. For more details see http://www.bmva.org/bmvc/2016/papers/paper119/paper119.pdf.
Parameters:
distance_function (Callable, Optional): Quantifies the distance between two tensors. if not specified, 2 norm functions will be used.
margin (float, Optional):Default: :math:`1`.A nonnegative margin representing the minimum difference
between the positive and negative distances required for the loss to be 0. Larger
margins penalize cases where the negative examples are not distant enough from the
anchors, relative to the positives.
swap (bool, Optional):The distance swap changes the negative distance to the swap distance (distance between positive samples
and negative samples) if swap distance smaller than negative distance. Default: ``False``.
reduction (str, Optional):Indicate how to average the loss by batch_size.
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shapes:
input (Tensor):Input tensor, the data type is float32 or float64.
the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.
positive (Tensor):Positive tensor, the data type is float32 or float64.
The shape of label is the same as the shape of input.
negative (Tensor):Negative tensor, the data type is float32 or float64.
The shape of label is the same as the shape of input.
output(Tensor): The tensor variable storing the triplet_margin_with_distance_loss of input and positive and negative.
Return:
A callable object of TripletMarginWithDistanceLoss
Examples:
.. code-block:: python
import paddle
from paddle.nn import TripletMarginWithDistanceLoss
input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='none')
loss = triplet_margin_with_distance_loss(input, positive, negative,)
print(loss)
# Tensor([0. , 0.57496738, 0. ])
triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='mean')
loss = triplet_margin_with_distance_loss(input, positive, negative,)
print(loss)
# Tensor([0.19165580])
"""
def __init__(self,
distance_function=None,
margin=1.0,
swap=False,
reduction: str = 'mean',
name=None):
super(TripletMarginWithDistanceLoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in TripletMarginWithDistanceLoss "
"should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
self.margin = margin
self.swap = swap
self.reduction = reduction
self.distance_function = distance_function
self.name = name
def forward(self, input, positive, negative):
return F.triplet_margin_with_distance_loss(input,
positive,
negative,
margin=self.margin,
swap=self.swap,
reduction=self.reduction,
name=self.name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册