未验证 提交 802129b3 编写于 作者: Z Zman 提交者: GitHub

Add GaussianNLLLoss API. (#50843)

* Add GaussianNLLLoss API.

* Change `rotl` `atol`.Check `var` in dynamic graph

* remove assertTrue

* update unittest

* update unittest for ci-covarage.add broadcast with same dim.

* Supply static err print.

* Repair note and example.

* Split unitest.

* empty commit.

* for standard commit.

* for standard commit.

* Add int dynamic graph test.

* Repair parameters name.

* Repair unitest parameters name.

* Repair unitest parameters name

* Repair unitest parameters name

* Repair unitest parameters name

* add square in code-block

* fit few notes.

* fit few notes.

* fit few notes.

* fit few notes.

* add few interpretations.

* add few interpretations.

* add few interpretations.

* fix import.

* fix space.

* empty commit for ci.
上级 e05df020
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
np.random.seed(10)
def ref_gaussian_nll_loss(
input, label, variance, full=False, eps=1e-6, reduction='none'
):
if variance.shape != input.shape:
if input.shape[:-1] == variance.shape:
variance = np.expand_dims(variance, -1)
elif (
input.shape[:-1] == variance.shape[:-1] and variance.shape[-1] == 1
):
pass
else:
raise ValueError("variance is of incorrect size")
if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
raise ValueError(reduction + " is not valid")
if np.any(variance < 0):
raise ValueError("var has negative entry/entries")
variance = variance.copy()
variance = np.clip(variance, a_min=eps, a_max=None)
loss = 0.5 * (np.log(variance) + (input - label) ** 2 / variance)
if full:
loss += 0.5 * np.log(2 * np.pi)
if reduction == 'none':
return loss
elif reduction == 'sum':
return [np.sum(loss)]
elif reduction == 'mean':
return [np.mean(loss)]
class TestGaussianNLLLossAPI(unittest.TestCase):
# test paddle.nn.functional.gaussian_nll_loss, paddle.nn.gaussian_nll_loss
def setUp(self, type=None):
self.shape = [10, 2]
if type in ['float16', 'float64', 'int32', 'int64']:
dtype = np.dtype(type)
self.input_np = np.random.random(self.shape).astype(dtype)
self.label_np = np.random.random(self.shape).astype(dtype)
self.variance_np = np.ones(self.shape).astype(dtype)
elif type == 'broadcast1':
self.shape = [10, 2, 3]
self.broadcast_shape = [10, 2]
self.input_np = np.random.random(self.shape).astype(np.float32)
self.label_np = np.random.random(self.shape).astype(np.float32)
self.variance_np = np.ones(self.broadcast_shape).astype(np.float32)
elif type == 'broadcast2':
self.shape = [10, 2, 3]
self.broadcast_shape = [10, 2, 1]
self.input_np = np.random.random(self.shape).astype(np.float32)
self.label_np = np.random.random(self.shape).astype(np.float32)
self.variance_np = np.ones(self.broadcast_shape).astype(np.float32)
else:
dtype = np.dtype('float32')
self.input_np = np.random.random(self.shape).astype(dtype)
self.label_np = np.random.random(self.shape).astype(dtype)
self.variance_np = np.ones(self.shape).astype(dtype)
if type == 'test_err':
self.variance_np = -np.ones(self.shape).astype(np.float32)
self.place = (
paddle.CUDAPlace(0)
if core.is_compiled_with_cuda()
else paddle.CPUPlace()
)
def test_dynamic_case(self, type=None, full=False, reduction='none'):
self.setUp(type)
paddle.disable_static(self.place)
input_x = paddle.to_tensor(self.input_np)
label = paddle.to_tensor(self.label_np)
variance = paddle.to_tensor(self.variance_np)
if type in ['test_err', 'int32', 'int64']:
self.assertRaises(
ValueError,
paddle.nn.functional.gaussian_nll_loss,
input=input_x,
label=label,
variance=variance,
)
else:
out_ref = ref_gaussian_nll_loss(
self.input_np,
self.label_np,
self.variance_np,
full=full,
reduction=reduction,
)
out1 = F.gaussian_nll_loss(
input_x, label, variance, full=full, reduction=reduction
)
gaussian_nll_loss = paddle.nn.GaussianNLLLoss(
full, reduction=reduction
)
out2 = gaussian_nll_loss(input_x, label, variance)
for r in [out1, out2]:
np.allclose(out_ref, r.numpy(), rtol=1e-5, atol=1e-5)
paddle.enable_static()
def test_static_case(self, type=None, full=False, reduction='none'):
self.setUp(type)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
if type in ['int32', 'int64', 'float64']:
input_x = paddle.static.data('Input_x', self.shape, type)
label = paddle.static.data('Label', self.shape, type)
variance = paddle.static.data('Variance', self.shape, type)
elif type in ['broadcast1', 'broadcast2']:
input_x = paddle.static.data('Input_x', self.shape)
label = paddle.static.data('Label', self.shape)
variance = paddle.static.data('Variance', self.broadcast_shape)
else:
input_x = paddle.static.data('Input_x', self.shape, 'float32')
label = paddle.static.data('Label', self.shape, 'float32')
variance = paddle.static.data('Variance', self.shape, 'float32')
out1 = F.gaussian_nll_loss(
input_x, label, variance, full=full, reduction=reduction
)
gaussian_nll_loss = paddle.nn.GaussianNLLLoss(
full, reduction=reduction
)
out2 = gaussian_nll_loss(input_x, label, variance)
exe = paddle.static.Executor(self.place)
if type not in ['test_err', 'int32', 'int64']:
out_ref = ref_gaussian_nll_loss(
self.input_np,
self.label_np,
self.variance_np,
full=full,
reduction=reduction,
)
res = exe.run(
feed={
'Input_x': self.input_np,
'Label': self.label_np,
'Variance': self.variance_np,
},
fetch_list=[out1, out2],
)
for r in res:
np.allclose(out_ref, r, rtol=1e-5, atol=1e-5)
else:
try:
res = exe.run(
feed={
'Input_x': self.input_np,
'Label': self.label_np,
'Variance': self.variance_np,
},
fetch_list=[out1, out2],
)
except ValueError:
pass
def test_api(self):
self.test_dynamic_case()
self.test_static_case()
def test_float64(self):
self.test_dynamic_case('float64')
self.test_static_case('float64')
def test_broadcast(self):
self.test_dynamic_case('broadcast1')
self.test_static_case('broadcast1')
def test_broadcast_with_same_dim(self):
self.test_dynamic_case('broadcast2')
self.test_static_case('broadcast2')
def test_reduction(self):
self.test_dynamic_case(full=True, reduction='mean')
self.test_dynamic_case(full=True, reduction='sum')
self.test_static_case(full=True, reduction='mean')
def test_error(self):
self.test_dynamic_case('test_err')
self.test_static_case('test_err')
def test_int(self):
self.test_dynamic_case('int64')
self.test_dynamic_case('int32')
if __name__ == "__main__":
unittest.main()
......@@ -114,6 +114,8 @@ from .layer.loss import MultiMarginLoss
from .layer.loss import TripletMarginWithDistanceLoss
from .layer.loss import TripletMarginLoss
from .layer.loss import SoftMarginLoss
from .layer.loss import GaussianNLLLoss
from .layer.norm import BatchNorm # noqa: F401
from .layer.norm import SyncBatchNorm # noqa: F401
from .layer.norm import GroupNorm # noqa: F401
......@@ -335,4 +337,5 @@ __all__ = [ # noqa
'TripletMarginWithDistanceLoss',
'TripletMarginLoss',
'SoftMarginLoss',
'GaussianNLLLoss',
]
......@@ -99,6 +99,8 @@ from .loss import multi_label_soft_margin_loss
from .loss import triplet_margin_with_distance_loss
from .loss import triplet_margin_loss
from .loss import soft_margin_loss
from .loss import gaussian_nll_loss
from .norm import batch_norm # noqa: F401
from .norm import instance_norm # noqa: F401
from .norm import layer_norm # noqa: F401
......@@ -248,4 +250,5 @@ __all__ = [ # noqa
'triplet_margin_loss',
'multi_margin_loss',
'soft_margin_loss',
'gaussian_nll_loss',
]
......@@ -18,6 +18,7 @@ import math
import paddle
from paddle import _C_ops, _legacy_C_ops, fluid, in_dynamic_mode
from paddle.framework import core
from paddle.static.nn.control_flow import Assert
from paddle.utils import deprecated
from ...common_ops_import import Variable
......@@ -4007,3 +4008,163 @@ def soft_margin_loss(input, label, reduction='mean', name=None):
return paddle.mean(out, name=name)
else:
return out
def gaussian_nll_loss(
input,
label,
variance,
full=False,
epsilon=1e-6,
reduction='mean',
name=None,
):
r"""Gaussian negative log likelihood loss.
Gaussian negative log likelihood loss among ``input``, ``variance`` and
``label``. Note that the ``label`` is treated as samples from Gaussian distributions.
This function is used to train a neural network predicts
the ``input`` and ``variance`` of a gaussian distribution that ``label`` are supposed to
be coming from. This means ``input`` and ``variance`` should be functions(the neural network) of some inputs.
For a ``label`` having Gaussian distribution with ``input`` and ``variance`` predicted by neural network
the loss is calculated as follows:
.. math::
\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
\ \text{epsilon}\right)\right) + \frac{\left(\text{input} - \text{label}\right)^2}
{\text{max}\left(\text{var}, \ \text{epsilon}\right)}\right) + \text{const.}
where :attr:`epsilon` is used for stability. By default, the constant term of
the loss function is omitted unless :attr:`full` is ``True``. If ``variance`` is not the same
size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting.
Args:
input (Tensor): input tensor, :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional
dimensions. Expectation of the Gaussian distribution, available dtype is float32, float64.
label (Tensor): target label tensor, :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input
but with one dimension equal to 1 (to allow for broadcasting). Sample from the Gaussian distribution, available dtype is float32, float64.
variance (Tensor): tensor of positive variance(s), :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but
with one dimension equal to 1, or same shape as the input but with one fewer
dimension (to allow for broadcasting). One for each of the expectations
in the input (heteroscedastic), or a single one (homoscedastic), available dtype is float32, float64.
full (bool, optional): include the constant term in the loss
calculation. Default: ``False``.
epsilon (float, optional): value used to clamp ``variance`` (see note below), for
stability. Default: 1e-6.
reduction (str, optional): specifies the reduction to apply to the
output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
will be applied, ``'mean'``: the output is the average of all batch
member losses, ``'sum'``: the output is the sum of all batch member
losses. Default: ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is [1].
Examples::
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.randn([5, 2], dtype=paddle.float32)
label = paddle.randn([5, 2], dtype=paddle.float32)
variance = paddle.ones([5, 2], dtype=paddle.float32)
loss = F.gaussian_nll_loss(input, label, variance, reduction='none')
print(loss)
loss = F.gaussian_nll_loss(input, label, variance, reduction='mean')
print(loss)
Note:
The clamping of ``variance`` is ignored with respect to autograd, and so the
gradients are unaffected by it.
"""
# Check variance shape
# If variance.shape == input.shape, the case is heteroscedastic and no further checks are needed.
# Otherwise:
if variance.shape != input.shape:
# If variance is one dimension short of input, but the shape match otherwise, then this is a homoscedastic case.
# e.g. input.shape = (10, 2, 3), variance.shape = (10, 2)
# -> unsqueeze variance so that variance.shape = (10, 2, 1)
# this is done so that broadcasting can happen in the loss calculation
if input.shape[:-1] == variance.shape:
variance = paddle.unsqueeze(variance, -1)
# This checks if the shape match up to the final dimension, and the final dimension of variance is of shape 1.
# This is also a homoscedastic case.
# e.g. input.shape = (10, 2, 3), variance.shape = (10, 2, 1)
elif (
input.shape[:-1] == variance.shape[:-1] and variance.shape[-1] == 1
): # Heteroscedastic case
pass
# If none of the above pass, then the shape of variance is incorrect.
else:
raise ValueError("variance is of incorrect shape")
# Check validity of reduction mode
if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
raise ValueError(reduction + " is not valid")
check_variable_and_dtype(
input,
'Input',
['float32', 'float64'],
'gaussian_nll_loss',
)
check_variable_and_dtype(
label,
'Label',
['float32', 'float64'],
'gaussian_nll_loss',
)
check_variable_and_dtype(
variance,
'Variance',
['float32', 'float64'],
'gaussian_nll_loss',
)
# Entries of variance must be non-negative
if not in_dygraph_mode():
condition = paddle.all(variance > 0)
Assert(condition, [variance], 6)
else:
if input.dtype not in [paddle.float32, paddle.float64]:
raise ValueError(
"The data type of input Variable must be 'float32' or 'float64'"
)
if label.dtype not in [
paddle.float32,
paddle.float64,
]:
raise ValueError(
"The data type of label Variable must be 'float32', 'float64'"
)
if variance.dtype not in [paddle.float32, paddle.float64]:
raise ValueError(
"The data type of variance Variable must be 'float32', 'float64'"
)
if paddle.any(variance < 0):
raise ValueError("variance has negative entry/entries")
# Clamp for stability
variance = variance.clone()
with paddle.no_grad():
variance = paddle.clip(variance, min=epsilon)
# Calculate the loss
loss = 0.5 * (
paddle.log(variance) + paddle.square(input - label) / variance
)
if full:
loss += 0.5 * math.log(2 * math.pi)
if reduction == 'mean':
return paddle.mean(loss, name=name)
elif reduction == 'sum':
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss
......@@ -85,6 +85,8 @@ from .loss import TripletMarginWithDistanceLoss
from .loss import TripletMarginLoss
from .loss import SoftMarginLoss
from .loss import MultiMarginLoss
from .loss import GaussianNLLLoss
from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401
......
......@@ -2046,3 +2046,97 @@ class SoftMarginLoss(Layer):
input, label, self.reduction, self.name
)
return out
class GaussianNLLLoss(Layer):
r"""Create a callable object of 'GaussianNLLLoss' to calculate Gaussian negative log likelihood loss.
This class create a callable object of Gaussian negative log likelihood loss among ``input``, ``variance`` and
``label``. Note that the ``label`` is treated as samples from Gaussian distributions.
This class is used to train a neural network predicts
the ``input`` and ``variance`` of a gaussian distribution that ``label`` are supposed to
be coming from. This means ``input`` and ``variance`` should be functions(the neural network) of some inputs.
For a ``label`` having Gaussian distribution with ``input`` and ``variance`` predicted by neural network
the loss is calculated as follows:
.. math::
\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
\ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{label}\right)^2}
{\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}
where :attr:`epsilon` is used for stability. By default, the constant term of
the loss function is omitted unless :attr:`full` is ``True``. If ``variance`` is not the same
size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting.
Args:
full (bool, optional): include the constant term in the loss
calculation. Default: ``False``, means omit the constant term.
epsilon (float, optional): value used to clamp ``variance`` (see note below), for
stability. Default: 1e-6.
reduction (str, optional): specifies the reduction to apply to the
output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
will be applied, ``'mean'``: the output is the average of all batch
member losses, ``'sum'``: the output is the sum of all batch member
losses. Default: ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
- Input(Tensor): :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional
dimensions. Available dtype is float32, float64.
- Label(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input
but with one dimension equal to 1 (to allow for broadcasting). Available dtype is float32, float64.
- Variance(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but
with one dimension equal to 1, or same shape as the input but with one fewer
dimension (to allow for broadcasting). Available dtype is float32, float64.
- Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same
shape as the input
Returns:
A callable object of GaussianNLLLoss.
Examples::
.. code-block:: python
import paddle
import paddle.nn as nn
input = paddle.randn([5, 2], dtype=paddle.float32)
label = paddle.randn([5, 2], dtype=paddle.float32)
variance = paddle.ones([5, 2], dtype=paddle.float32)
gs_nll_loss = nn.GaussianNLLLoss(full=False, epsilon=1e-6, reduction='none')
loss = gs_nll_loss(input, label, variance)
print(loss)
Note:
The clamping of ``variance`` is ignored with respect to autograd, and so the
gradients are unaffected by it.
"""
def __init__(self, full=False, epsilon=1e-6, reduction='mean', name=None):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in GaussianNLLLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction
)
super().__init__()
self.full = full
self.epsilon = epsilon
self.reduction = reduction
self.name = name
def forward(self, input, label, variance):
out = F.gaussian_nll_loss(
input,
label,
variance,
self.full,
self.epsilon,
self.reduction,
self.name,
)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册