未验证 提交 8dab2690 编写于 作者: N Netpunk 提交者: GitHub

【Hackathon No.17】为 Paddle 新增 paddle.nn.CosineEmbeddingLoss 和...

【Hackathon No.17】为 Paddle 新增 paddle.nn.CosineEmbeddingLoss 和 paddle.nn.functional.cosine_embedding_loss API (#41680)

* add cosine embedding loss API

* new version

* new version

* new version

* set label to int32

* new version

* new version-test

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* new version

* aligning to Chinese document

* add name parameter

* activate CI

* fix format error

* unit test code format

* format code
上级 5413fd79
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from __future__ import print_function
import paddle
import paddle.static as static
import numpy as np
import unittest
def cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='mean'):
z = (input1 * input2).sum(axis=-1)
mag_square1 = np.square(input1).sum(axis=-1) + 10e-12
mag_square2 = np.square(input2).sum(axis=-1) + 10e-12
denom = np.sqrt(mag_square1 * mag_square2)
cos = z / denom
zeros = np.zeros_like(cos)
pos = 1 - cos
neg = np.clip(cos - margin, a_min=0, a_max=np.inf)
out_pos = np.where(label == 1, pos, zeros)
out_neg = np.where(label == -1, neg, zeros)
out = out_pos + out_neg
if reduction == 'none':
return out
if reduction == 'mean':
return np.mean(out)
elif reduction == 'sum':
return np.sum(out)
class TestFunctionCosineEmbeddingLoss(unittest.TestCase):
def setUp(self):
self.input1_np = np.random.random(size=(5, 3)).astype(np.float64)
self.input2_np = np.random.random(size=(5, 3)).astype(np.float64)
a = np.array([-1, -1, -1]).astype(np.int32)
b = np.array([1, 1]).astype(np.int32)
self.label_np = np.concatenate((a, b), axis=0)
np.random.shuffle(self.label_np)
def run_dynamic(self):
input1 = paddle.to_tensor(self.input1_np)
input2 = paddle.to_tensor(self.input2_np)
label = paddle.to_tensor(self.label_np)
dy_result = paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='mean')
expected1 = cosine_embedding_loss(self.input1_np,
self.input2_np,
self.label_np,
margin=0.5,
reduction='mean')
self.assertTrue(np.allclose(dy_result.numpy(), expected1))
self.assertTrue(dy_result.shape, [1])
dy_result = paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='sum')
expected2 = cosine_embedding_loss(self.input1_np,
self.input2_np,
self.label_np,
margin=0.5,
reduction='sum')
self.assertTrue(np.allclose(dy_result.numpy(), expected2))
self.assertTrue(dy_result.shape, [1])
dy_result = paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='none')
expected3 = cosine_embedding_loss(self.input1_np,
self.input2_np,
self.label_np,
margin=0.5,
reduction='none')
self.assertTrue(np.allclose(dy_result.numpy(), expected3))
self.assertTrue(dy_result.shape, [5])
def run_static(self, use_gpu=False):
input1 = static.data(name='input1', shape=[5, 3], dtype='float64')
input2 = static.data(name='input2', shape=[5, 3], dtype='float64')
label = static.data(name='label', shape=[5], dtype='int32')
result0 = paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='none')
result1 = paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='sum')
result2 = paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='mean')
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
exe = static.Executor(place)
exe.run(static.default_startup_program())
static_result = exe.run(feed={
"input1": self.input1_np,
"input2": self.input2_np,
"label": self.label_np
},
fetch_list=[result0, result1, result2])
expected = cosine_embedding_loss(self.input1_np,
self.input2_np,
self.label_np,
margin=0.5,
reduction='none')
self.assertTrue(np.allclose(static_result[0], expected))
expected = cosine_embedding_loss(self.input1_np,
self.input2_np,
self.label_np,
margin=0.5,
reduction='sum')
self.assertTrue(np.allclose(static_result[1], expected))
expected = cosine_embedding_loss(self.input1_np,
self.input2_np,
self.label_np,
margin=0.5,
reduction='mean')
self.assertTrue(np.allclose(static_result[2], expected))
def test_cpu(self):
paddle.disable_static(place=paddle.CPUPlace())
self.run_dynamic()
paddle.enable_static()
with static.program_guard(static.Program()):
self.run_static()
def test_gpu(self):
if not paddle.is_compiled_with_cuda():
return
paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_dynamic()
paddle.enable_static()
with static.program_guard(static.Program()):
self.run_static(use_gpu=True)
def test_errors(self):
paddle.disable_static()
input1 = paddle.to_tensor(self.input1_np)
input2 = paddle.to_tensor(self.input2_np)
label = paddle.to_tensor(self.label_np)
def test_label_shape_error():
label = paddle.to_tensor(
np.random.randint(low=0, high=2, size=(2, 3)))
paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='mean')
self.assertRaises(ValueError, test_label_shape_error)
def test_input_different_shape_error():
input1 = paddle.to_tensor(self.input1_np[0])
label = paddle.to_tensor(np.ndarray([1]))
paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='mean')
self.assertRaises(ValueError, test_input_different_shape_error)
def test_input_shape2D_error():
input1 = paddle.to_tensor(
np.random.random(size=(2, 3, 4)).astype(np.float64))
input2 = paddle.to_tensor(
np.random.random(size=(2, 3, 4)).astype(np.float64))
paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='mean')
self.assertRaises(ValueError, test_input_shape2D_error)
def test_label_value_error():
label = paddle.to_tensor(np.ndarray([-1, -2]))
paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='mean')
self.assertRaises(ValueError, test_label_value_error)
def test_input_type_error():
input1 = paddle.to_tensor(self.input1_np.astype(np.int64))
paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='mean')
self.assertRaises(ValueError, test_input_type_error)
def test_label_type_error():
label = paddle.to_tensor(self.label_np.astype(np.int16))
paddle.nn.functional.cosine_embedding_loss(input1,
input2,
label,
margin=0.5,
reduction='mean')
self.assertRaises(ValueError, test_label_type_error)
class TestClassCosineEmbeddingLoss(unittest.TestCase):
def setUp(self):
self.input1_np = np.random.random(size=(10, 3)).astype(np.float32)
self.input2_np = np.random.random(size=(10, 3)).astype(np.float32)
a = np.array([-1, -1, -1, -1, -1]).astype(np.int64)
b = np.array([1, 1, 1, 1, 1]).astype(np.int64)
self.label_np = np.concatenate((a, b), axis=0)
np.random.shuffle(self.label_np)
self.input1_np_1D = np.random.random(size=10).astype(np.float32)
self.input2_np_1D = np.random.random(size=10).astype(np.float32)
self.label_np_1D = np.array([1]).astype(np.int64)
def run_dynamic(self):
input1 = paddle.to_tensor(self.input1_np)
input2 = paddle.to_tensor(self.input2_np)
label = paddle.to_tensor(self.label_np)
CosineEmbeddingLoss = paddle.nn.CosineEmbeddingLoss(margin=0.5,
reduction='mean')
dy_result = CosineEmbeddingLoss(input1, input2, label)
expected1 = cosine_embedding_loss(self.input1_np,
self.input2_np,
self.label_np,
margin=0.5,
reduction='mean')
self.assertTrue(np.allclose(dy_result.numpy(), expected1))
self.assertTrue(dy_result.shape, [1])
input1_1D = paddle.to_tensor(self.input1_np_1D)
input2_1D = paddle.to_tensor(self.input2_np_1D)
label_1D = paddle.to_tensor(self.label_np_1D)
dy_result = CosineEmbeddingLoss(input1_1D, input2_1D, label_1D)
expected2 = cosine_embedding_loss(self.input1_np_1D,
self.input2_np_1D,
self.label_np_1D,
margin=0.5,
reduction='mean')
self.assertTrue(np.allclose(dy_result.numpy(), expected2))
def run_static(self):
input1 = static.data(name='input1', shape=[10, 3], dtype='float32')
input2 = static.data(name='input2', shape=[10, 3], dtype='float32')
label = static.data(name='label', shape=[10], dtype='int64')
CosineEmbeddingLoss = paddle.nn.CosineEmbeddingLoss(margin=0.5,
reduction='mean')
result = CosineEmbeddingLoss(input1, input2, label)
place = paddle.CPUPlace()
exe = static.Executor(place)
exe.run(static.default_startup_program())
static_result = exe.run(feed={
"input1": self.input1_np,
"input2": self.input2_np,
"label": self.label_np
},
fetch_list=[result])
expected = cosine_embedding_loss(self.input1_np,
self.input2_np,
self.label_np,
margin=0.5,
reduction='mean')
self.assertTrue(np.allclose(static_result[0], expected))
def test_cpu(self):
paddle.disable_static(place=paddle.CPUPlace())
self.run_dynamic()
paddle.enable_static()
with static.program_guard(static.Program()):
self.run_static()
def test_errors(self):
def test_margin_error():
CosineEmbeddingLoss = paddle.nn.CosineEmbeddingLoss(
margin=2, reduction='mean')
self.assertRaises(ValueError, test_margin_error)
def test_reduction_error():
CosineEmbeddingLoss = paddle.nn.CosineEmbeddingLoss(
margin=2, reduction='reduce_mean')
self.assertRaises(ValueError, test_reduction_error)
if __name__ == "__main__":
unittest.main()
......@@ -107,6 +107,7 @@ from .layer.loss import MarginRankingLoss # noqa: F401
from .layer.loss import CTCLoss # noqa: F401
from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import HingeEmbeddingLoss # noqa: F401
from .layer.loss import CosineEmbeddingLoss # noqa: F401
from .layer.norm import BatchNorm # noqa: F401
from .layer.norm import SyncBatchNorm # noqa: F401
from .layer.norm import GroupNorm # noqa: F401
......@@ -311,5 +312,6 @@ __all__ = [ #noqa
'MaxUnPool3D',
'HingeEmbeddingLoss',
'Identity',
'CosineEmbeddingLoss',
'RReLU',
]
......@@ -90,6 +90,7 @@ from .loss import margin_cross_entropy # noqa: F401
from .loss import square_error_cost # noqa: F401
from .loss import ctc_loss # noqa: F401
from .loss import hinge_embedding_loss # noqa: F401
from .loss import cosine_embedding_loss # noqa: F401
from .norm import batch_norm # noqa: F401
from .norm import instance_norm # noqa: F401
from .norm import layer_norm # noqa: F401
......@@ -229,5 +230,6 @@ __all__ = [ #noqa
'class_center_sample',
'sparse_attention',
'fold',
'cosine_embedding_loss',
'rrelu',
]
......@@ -2763,3 +2763,112 @@ def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None):
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss
def cosine_embedding_loss(input1,
input2,
label,
margin=0,
reduction='mean',
name=None):
r"""
This operator computes the cosine embedding loss of Tensor ``input1``, ``input2`` and ``label`` as follows.
If label = 1, then the loss value can be calculated as follow:
.. math::
Out = 1 - cos(input1, input2)
If label = -1, then the loss value can be calculated as follow:
.. math::
Out = max(0, cos(input1, input2)) - margin
The operator cos can be described as follow:
.. math::
cos(x1, x2) = \frac{x1 \cdot{} x2}{\Vert x1 \Vert_2 * \Vert x2 \Vert_2}
Parameters:
input1 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array.
Available dtypes are float32, float64.
input2 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array.
Available dtypes are float32, float64.
label (Tensor): tensor with shape: [N] or [1]. The target labels values should be -1 or 1.
Available dtypes are int32, int64, float32, float64.
margin (float, optional): Should be a number from :math:`-1` to :math:`1`,
:math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the
default value is :math:`0`.
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of elements in the output
``'sum'``: the output will be summed.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, the cosine embedding Loss of Tensor ``input1`` ``input2`` and ``label``.
If `reduction` is ``'none'``, the shape of output loss is [N], the same as ``input`` .
If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
Examples:
.. code-block:: python
:name: code-example1
import paddle
input1 = paddle.to_tensor([[1.6, 1.2, -0.5], [3.2, 2.6, -5.8]], 'float32')
input2 = paddle.to_tensor([[0.5, 0.5, -1.8], [2.3, -1.4, 1.1]], 'float32')
label = paddle.to_tensor([1, -1], 'int64')
output = paddle.nn.functional.cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='mean')
print(output) # [0.21155193]
output = paddle.nn.functional.cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='sum')
print(output) # [0.42310387]
output = paddle.nn.functional.cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='none')
print(output) # [0.42310387, 0. ]
"""
if len(label.shape) != 1:
raise ValueError(
"1D target tensor expected, multi-target not supported")
if input1.shape != input2.shape:
raise ValueError(
"the shape of input tensor 1 should be equal to input tensor 2, but found inputs with "
"different sizes")
if len(input1.shape) > 2:
raise ValueError(
"1D target tensor expects 1D or 2D input tensors, but found inputs with different sizes"
)
if input1.dtype not in [paddle.float32, paddle.float64]:
raise ValueError(
"The data type of input Variable must be 'float32' or 'float64'")
if label.dtype not in [
paddle.int32, paddle.int64, paddle.float32, paddle.float64
]:
raise ValueError(
"The data type of label Variable must be 'int32', 'int64', 'float32', 'float64'"
)
prod_sum = (input1 * input2).sum(axis=-1)
mag_square1 = paddle.square(input1).sum(axis=-1) + 10e-12
mag_square2 = paddle.square(input2).sum(axis=-1) + 10e-12
denom = paddle.sqrt(mag_square1 * mag_square2)
cos = prod_sum / denom
zeros = paddle.zeros_like(cos)
pos = 1 - cos
neg = paddle.clip(cos - margin, min=0)
out_pos = paddle.where(label == 1, pos, zeros)
out_neg = paddle.where(label == -1, neg, zeros)
out = out_pos + out_neg
if reduction == 'none':
return out
if reduction == 'mean':
return paddle.mean(out, name=name)
elif reduction == 'sum':
return paddle.sum(out, name=name)
......@@ -1309,3 +1309,94 @@ class HingeEmbeddingLoss(Layer):
reduction=self.reduction,
margin=self.margin,
name=self.name)
class CosineEmbeddingLoss(Layer):
r"""
This interface is used to construct a callable object of the ``CosineEmbeddingLoss`` class.
The CosineEmbeddingLoss layer measures the cosine_embedding loss between input predictions ``input1``, ``input2``
and target labels ``label`` with values 1 or 0. This is used for measuring whether two inputs are similar or
dissimilar and is typically used for learning nonlinear embeddings or semi-supervised learning.
The cosine embedding loss can be described as:
If label = 1, then the loss value can be calculated as follow:
.. math::
Out = 1 - cos(input1, input2)
If label = -1, then the loss value can be calculated as follow:
.. math::
Out = max(0, cos(input1, input2)) - margin
The operator cos can be described as follow:
.. math::
cos(x1, x2) = \frac{x1 \cdot{} x2}{\Vert x1 \Vert_2 * \Vert x2 \Vert_2}
Parameters:
margin (float, optional): Should be a number from :math:`-1` to :math:`1`,
:math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the
default value is :math:`0`.
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
input1 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array.
Available dtypes are float32, float64.
input2 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array.
Available dtypes are float32, float64.
label (Tensor): tensor with shape: [N] or [1]. The target labels values should be -1 or 1.
Available dtypes are int32, int64, float32, float64.
output (Tensor): Tensor, the cosine embedding Loss of Tensor ``input1`` ``input2`` and ``label``.
If `reduction` is ``'none'``, the shape of output loss is [N], the same as ``input`` .
If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
Examples:
.. code-block:: python
:name: code-example1
import paddle
input1 = paddle.to_tensor([[1.6, 1.2, -0.5], [3.2, 2.6, -5.8]], 'float32')
input2 = paddle.to_tensor([[0.5, 0.5, -1.8], [2.3, -1.4, 1.1]], 'float32')
label = paddle.to_tensor([1, -1], 'int64')
cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='mean')
output = cosine_embedding_loss(input1, input2, label)
print(output) # [0.21155193]
cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='sum')
output = cosine_embedding_loss(input1, input2, label)
print(output) # [0.42310387]
cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='none')
output = cosine_embedding_loss(input1, input2, label)
print(output) # [0.42310387, 0. ]
"""
def __init__(self, margin=0, reduction='mean', name=None):
if margin > 1 or margin < -1:
raise ValueError(
"The value of 'margin' should be in the interval of [-1, 1], but received %f, which is not allowed."
% margin)
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' should be 'sum', 'mean' or "
"'none', but received %s, which is not allowed." % reduction)
super(CosineEmbeddingLoss, self).__init__()
self.margin = margin
self.reduction = reduction
self.name = name
def forward(self, input1, input2, label):
return F.cosine_embedding_loss(input1,
input2,
label,
margin=self.margin,
reduction=self.reduction,
name=self.name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册