未验证 提交 349a059d 编写于 作者: L LyndonKong 提交者: GitHub

【Hackathon No.16】add PoissonNLLLoss API (#51117)

* add PoissonNLLLoss API

* update unittests

* Fix poisson_nll_loss init and update data type support

* remove type comment

* Update doc string

* Fix doc string erro

* Fix doc string math equation format

* Add float16 and bfloat16 support
上级 4970dd65
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
np.random.seed(100)
def ref_poisson_nll_loss(
input,
label,
log_input=True,
full=False,
epsilon=1e-8,
reduction="mean",
):
if epsilon <= 0:
raise ValueError(
"The value of `epsilon` in PoissonNLLLoss should be positve, but received %f, which is not allowed"
% epsilon
)
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in SoftMarginLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction
)
loss_out = 0
if log_input:
loss_out = np.exp(input) - label * input
else:
loss_out = input - label * np.log(input + epsilon)
if full:
stirling_approx = (
label * np.log(label) - label + 0.5 * np.log(2 * np.pi * label)
)
loss_out += np.where(stirling_approx <= 1, 0, stirling_approx)
if reduction == 'none':
return loss_out
elif reduction == 'sum':
return [np.sum(loss_out)]
elif reduction == 'mean':
return [np.mean(loss_out)]
class TestPoissonNLLLossBasicCase(unittest.TestCase):
def setUp(self, dtype="float32"):
self.shape = [10, 2]
self.dtype = dtype
self.input_np = np.random.random(self.shape).astype(self.dtype)
self.label_np = np.random.random(self.shape).astype(self.dtype)
self.place = (
paddle.CUDAPlace(0)
if core.is_compiled_with_cuda()
else paddle.CPUPlace()
)
def test_static_case(
self,
dtype="float32",
log_input=True,
full=False,
epsilon=1e-8,
reduction="mean",
):
self.setUp(dtype)
paddle.enable_static()
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data('input', self.shape, dtype)
label = paddle.static.data('label', self.shape, dtype)
input.desc.set_need_check_feed(False)
label.desc.set_need_check_feed(False)
out1 = F.poisson_nll_loss(
input,
label,
log_input=log_input,
full=full,
epsilon=epsilon,
reduction=reduction,
)
poisson_nll_loss = paddle.nn.PoissonNLLLoss(
log_input=log_input,
full=full,
epsilon=epsilon,
reduction=reduction,
)
out2 = poisson_nll_loss(input, label)
exe = paddle.static.Executor(self.place)
exe.run(startup_prog)
res = exe.run(
prog,
feed={'input': self.input_np, 'label': self.label_np},
fetch_list=[out1, out2],
)
out_ref = ref_poisson_nll_loss(
self.input_np,
self.label_np,
log_input=log_input,
full=full,
epsilon=epsilon,
reduction=reduction,
)
for r in res:
np.allclose(out_ref, r, rtol=1e-5)
def test_dynamic_case(
self,
dtype="float32",
log_input=True,
full=False,
epsilon=1e-8,
reduction="mean",
type=None,
):
self.setUp(dtype)
paddle.disable_static(self.place)
input_x = paddle.to_tensor(self.input_np)
label = paddle.to_tensor(self.label_np)
out_ref = ref_poisson_nll_loss(
self.input_np,
self.label_np,
log_input=log_input,
full=full,
epsilon=epsilon,
reduction=reduction,
)
out1 = F.poisson_nll_loss(
input_x,
label,
log_input=log_input,
full=full,
epsilon=epsilon,
reduction=reduction,
)
if type == 'test_err_reduction':
self.assertRaises(
ValueError,
paddle.nn.functional.poisson_nll_loss,
input=input_x,
label=label,
log_input=log_input,
full=full,
epsilon=epsilon,
reduction="unsupport reduction",
)
elif type == 'test_err_epsilon':
self.assertRaises(
ValueError,
paddle.nn.functional.poisson_nll_loss,
input=input_x,
label=label,
log_input=log_input,
full=full,
epsilon=-1,
reduction="mean",
)
poisson_nll_loss = paddle.nn.PoissonNLLLoss(
log_input=log_input, full=full, epsilon=epsilon, reduction=reduction
)
out2 = poisson_nll_loss(input_x, label)
for r in [out1, out2]:
np.allclose(out_ref, r.numpy(), rtol=1e-5)
paddle.enable_static()
def test_api(self):
pass
class TestPoissonNLLLossErrCase(TestPoissonNLLLossBasicCase):
def test_err_reduction(self):
self.test_dynamic_case(type="test_err_reduction")
def test_err_epsilon(self):
self.test_dynamic_case(type="test_err_epsilon")
def test_api(self):
self.test_err_reduction()
self.test_err_epsilon()
class TestPoissonNLLLossFloat16Case(TestPoissonNLLLossBasicCase):
def test_api(self):
if core.is_compiled_with_cuda():
self.test_static_case(dtype="float16")
self.test_dynamic_case(dtype="float16")
class TestPoissonNLLLossBfloat16Case(TestPoissonNLLLossBasicCase):
def test_api(self):
if core.is_compiled_with_cuda():
self.test_static_case(dtype="uint16")
self.test_dynamic_case(dtype="uint16")
class TestPoissonNLLLossFloat32Case(TestPoissonNLLLossBasicCase):
def test_api(self):
self.test_static_case(dtype="float32")
self.test_dynamic_case(dtype="float32")
class TestPoissonNLLLossFloat64Case(TestPoissonNLLLossBasicCase):
def test_api(self):
self.test_static_case(dtype="float64")
self.test_dynamic_case(dtype="float64")
class TestPoissonNLLLossNoLoginputCase(TestPoissonNLLLossBasicCase):
def test_api(self):
self.test_static_case(log_input=False)
self.test_dynamic_case(log_input=False)
class TestPoissonNLLLossFulllossCase(TestPoissonNLLLossBasicCase):
def test_api(self):
self.test_static_case(full=True)
self.test_dynamic_case(full=True)
class TestPoissonNLLLossSumReductionCase(TestPoissonNLLLossBasicCase):
def test_api(self):
self.test_static_case(reduction="sum")
self.test_dynamic_case(reduction="sum")
if __name__ == "__main__":
unittest.main()
...@@ -100,6 +100,7 @@ from .layer.loss import HSigmoidLoss # noqa: F401 ...@@ -100,6 +100,7 @@ from .layer.loss import HSigmoidLoss # noqa: F401
from .layer.loss import MSELoss # noqa: F401 from .layer.loss import MSELoss # noqa: F401
from .layer.loss import L1Loss # noqa: F401 from .layer.loss import L1Loss # noqa: F401
from .layer.loss import NLLLoss # noqa: F401 from .layer.loss import NLLLoss # noqa: F401
from .layer.loss import PoissonNLLLoss # noqa: F401
from .layer.loss import BCELoss # noqa: F401 from .layer.loss import BCELoss # noqa: F401
from .layer.loss import KLDivLoss # noqa: F401 from .layer.loss import KLDivLoss # noqa: F401
from .layer.loss import MarginRankingLoss # noqa: F401 from .layer.loss import MarginRankingLoss # noqa: F401
...@@ -268,6 +269,7 @@ __all__ = [ # noqa ...@@ -268,6 +269,7 @@ __all__ = [ # noqa
'AdaptiveAvgPool3D', 'AdaptiveAvgPool3D',
'AdaptiveMaxPool3D', 'AdaptiveMaxPool3D',
'NLLLoss', 'NLLLoss',
'PoissonNLLLoss',
'Conv1D', 'Conv1D',
'Sequential', 'Sequential',
'Hardswish', 'Hardswish',
......
...@@ -83,6 +83,7 @@ from .loss import log_loss # noqa: F401 ...@@ -83,6 +83,7 @@ from .loss import log_loss # noqa: F401
from .loss import margin_ranking_loss # noqa: F401 from .loss import margin_ranking_loss # noqa: F401
from .loss import mse_loss # noqa: F401 from .loss import mse_loss # noqa: F401
from .loss import nll_loss # noqa: F401 from .loss import nll_loss # noqa: F401
from .loss import poisson_nll_loss # noqa: F401
from .loss import npair_loss # noqa: F401 from .loss import npair_loss # noqa: F401
from .loss import sigmoid_focal_loss # noqa: F401 from .loss import sigmoid_focal_loss # noqa: F401
from .loss import smooth_l1_loss # noqa: F401 from .loss import smooth_l1_loss # noqa: F401
...@@ -214,6 +215,7 @@ __all__ = [ # noqa ...@@ -214,6 +215,7 @@ __all__ = [ # noqa
'margin_ranking_loss', 'margin_ranking_loss',
'multi_label_soft_margin_loss', 'multi_label_soft_margin_loss',
'nll_loss', 'nll_loss',
'poisson_nll_loss',
'npair_loss', 'npair_loss',
'sigmoid_focal_loss', 'sigmoid_focal_loss',
'smooth_l1_loss', 'smooth_l1_loss',
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
# TODO: define loss functions of neural network # TODO: define loss functions of neural network
import paddle import paddle
from paddle import _C_ops, _legacy_C_ops, fluid, in_dynamic_mode from paddle import _C_ops, _legacy_C_ops, fluid, in_dynamic_mode
...@@ -1322,10 +1324,16 @@ def l1_loss(input, label, reduction='mean', name=None): ...@@ -1322,10 +1324,16 @@ def l1_loss(input, label, reduction='mean', name=None):
return unreduced return unreduced
else: else:
check_variable_and_dtype( check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'l1_loss' input,
'input',
['float32', 'float64', 'int32', 'int64'],
'l1_loss',
) )
check_variable_and_dtype( check_variable_and_dtype(
label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss' label,
'label',
['float32', 'float64', 'int32', 'int64'],
'l1_loss',
) )
if reduction == 'sum': if reduction == 'sum':
...@@ -1462,6 +1470,116 @@ def nll_loss( ...@@ -1462,6 +1470,116 @@ def nll_loss(
return out return out
def poisson_nll_loss(
input,
label,
log_input=True,
full=False,
epsilon=1e-8,
reduction="mean",
name=None,
):
r"""Poisson negative log likelihood loss.
See more detail in :ref:`PoissonNLLLoss <api_paddle_nn_PoissonNLLLoss>` .
Parameters:
input (Tensor):
Input tensor, expectation of underlying Poisson distribution.
The shape of input tensor should be `(N, *)` or `(*)` where `(*)` denotes any number of extra dimensions.
It's data type should be float16, bfloat16, float32, float64.
label (Tensor):
Label tensor, random sampled from Poisson distribution :math:`label \sim \text{Poisson}(input)`.
The shape of input tensor should be `(N, *)` or `(*)`, same shape as the input tensor.
It's data type should be float16, bfloat16, float32, float64.
log_input (bool, optional):
Whether to the treat input tensor as log input.
If ``True`` the loss is computed as, :math:`\exp(\text{input}) - \text{label} * \text{input}` .
If ``False`` then loss is :math:`\text{input} - \text{label} * \log(\text{input}+\text{epsilon})` .
Default: ``True``.
full (bool, optional):
Whether to compute full loss.
If ``True``, the Stirling approximation term is added.
If ``False``, the Stirling approximation is dropped.
Default: ``False``.
epsilon (float, optional):
A small value to avoid evaluation of :math:`\log(0)` when `log_input`\ =\ ``False``. ``epsilon > 0``.
Default: 1e-8.
reduction (str, optional):
Indicate how to reduce the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If `reduction` is ``'mean'``, the reduced mean loss is returned;
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be apllied.
Default is ``'mean'``.
name (str, optional):
Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.randn([5, 2], dtype=paddle.float32)
label = paddle.randn([5, 2], dtype=paddle.float32)
loss = F.poisson_nll_loss(input, label, log_input=True, reduction='None')
print(loss)
loss = F.poisson_nll_loss(input, label, reduction='mean')
print(loss)
"""
# check parameter values
if epsilon <= 0:
raise ValueError(
"The value of `epsilon` in poisson_nll_loss should be positve, but received %f, which is not allowed"
% epsilon
)
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in poisson_nll_loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction
)
# check input dtype and dimension
check_variable_and_dtype(
input,
'input',
['float16', 'uint16', 'float32', 'float64'],
'poisson_nll_loss',
)
check_variable_and_dtype(
label,
'label',
['float16', 'uint16', 'float32', 'float64'],
'poisson_nll_loss',
)
if not (input.shape == label.shape):
raise ValueError("input's shape must equal to label's shape")
label = paddle.cast(label, input.dtype)
loss_out = 0
if log_input:
loss_out = paddle.exp(input) - label * input
else:
loss_out = input - label * paddle.log(input + epsilon)
if full:
stirling_approx = (
label * paddle.log(label)
- label
+ 0.5 * paddle.log(2 * math.pi * label)
)
loss_out += paddle.where(
stirling_approx <= 1,
paddle.zeros_like(stirling_approx),
stirling_approx,
)
if reduction == 'mean':
loss_out = paddle.mean(loss_out)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
return loss_out
def kl_div(input, label, reduction='mean', name=None): def kl_div(input, label, reduction='mean', name=None):
r""" r"""
Calculate the Kullback-Leibler divergence loss Calculate the Kullback-Leibler divergence loss
......
...@@ -72,6 +72,7 @@ from .loss import CrossEntropyLoss # noqa: F401 ...@@ -72,6 +72,7 @@ from .loss import CrossEntropyLoss # noqa: F401
from .loss import MSELoss # noqa: F401 from .loss import MSELoss # noqa: F401
from .loss import L1Loss # noqa: F401 from .loss import L1Loss # noqa: F401
from .loss import NLLLoss # noqa: F401 from .loss import NLLLoss # noqa: F401
from .loss import PoissonNLLLoss # noqa: F401
from .loss import BCELoss # noqa: F401 from .loss import BCELoss # noqa: F401
from .loss import KLDivLoss # noqa: F401 from .loss import KLDivLoss # noqa: F401
from .loss import MarginRankingLoss # noqa: F401 from .loss import MarginRankingLoss # noqa: F401
......
...@@ -882,6 +882,99 @@ class NLLLoss(Layer): ...@@ -882,6 +882,99 @@ class NLLLoss(Layer):
) )
class PoissonNLLLoss(Layer):
r"""Generate a callable object of 'PoissonNLLLoss' to calculate the
Poisson negative log likelihood loss between Input(input) and
Input(label). Notes that Input(input) is the expectation of underlying
Poisson distribution and Input(label) is the random samples from the
Poisson distribution
Poisson negative log likelihood loss is calculated as follows:
.. math::
\text{loss}(\text{input}, \text{label}) = \text{input} - \text{label} * \log(\text{label}) + \log(\text{label!})
The last term can be approximated with Stirling formula. This approximation term is used when :attr:`full` is ``True``.
The approximation is added when label values are more than 1 and omitted when the labels are less than or equal to 1.
Parameters:
log_input (bool, optional):
Whether to the treat input tensor as log input.
If ``True`` the loss is computed as, :math:`\exp(\text{input}) - \text{label} * \text{input}` .
If ``False`` then loss is :math:`\text{input} - \text{label} * \log(\text{input}+\text{epsilon})` .
Default: ``True``.
full (bool, optional):
Whether to compute full loss.
If ``True``, the Stirling approximation term is added.
If ``False``, the Stirling approximation is dropped.
Default: ``False``.
epsilon (float, optional):
A small value to avoid evaluation of :math:`\log(0)` when ``log_input`` = ``False``. ``epsilon > 0``.
Default: 1e-8.
reduction (str, optional):
Indicate how to reduce the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If `reduction` is ``'mean'``, the reduced mean loss is returned;
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be apllied.
Default is ``'mean'``.
name (str, optional):
Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input (Tensor): The shape of input tensor should be `(N, *)` or `(*)` where `(*)` denotes any number of extra dimensions.
- label (Tensor): The shape of input tensor should be `(N, *)` or `(*)`, same shape as the input tensor.
- output (Tensor): scalar if :attr:`reduction` is ``'mean'`` (default) or ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input
Examples:
.. code-block:: python
import paddle
poisson_nll_loss = paddle.nn.loss.PoissonNLLLoss()
input = paddle.randn([5, 2], dtype=paddle.float32)
label = paddle.randn([5, 2], dtype=paddle.float32)
loss = poisson_nll_loss(input, label)
"""
def __init__(
self,
log_input=True,
full=False,
epsilon=1e-8,
reduction="mean",
name=None,
):
if epsilon <= 0:
raise ValueError(
"The value of `epsilon` in PoissonNLLLoss should be positve, but received %f, which is not allowed"
% epsilon
)
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in PoissonNLLLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction
)
super().__init__()
self._log_input = log_input
self._full = full
self._epsilon = epsilon
self._reduction = reduction
self._name = name
def forward(self, input, label):
return F.poisson_nll_loss(
input,
label,
log_input=self._log_input,
full=self._full,
epsilon=self._epsilon,
reduction=self._reduction,
name=self._name,
)
class KLDivLoss(Layer): class KLDivLoss(Layer):
r""" r"""
......
...@@ -159,7 +159,7 @@ def log(x, name=None): ...@@ -159,7 +159,7 @@ def log(x, name=None):
return _C_ops.log(x) return _C_ops.log(x)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], "log" x, 'x', ['uint16', 'float16', 'float32', 'float64'], "log"
) )
inputs = {'X': [x]} inputs = {'X': [x]}
helper = LayerHelper('log', **locals()) helper = LayerHelper('log', **locals())
......
...@@ -565,6 +565,7 @@ def exp(x, name=None): ...@@ -565,6 +565,7 @@ def exp(x, name=None):
[ [
'int32', 'int32',
'int64', 'int64',
'uint16',
'float16', 'float16',
'float32', 'float32',
'float64', 'float64',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册