未验证 提交 e4033a06 编写于 作者: L LielinJiang 提交者: GitHub

add Class KLDivLoss and function kl_div (#25977)

* add Class KLDivLoss and function kl_div
上级 57e83ad7
......@@ -16,6 +16,7 @@ from __future__ import print_function
import numpy as np
from functools import partial, reduce
from paddle.utils import deprecated
from . import nn
from .layer_function_generator import templatedoc
from ..layer_helper import LayerHelper
......@@ -1619,6 +1620,7 @@ def huber_loss(input, label, delta):
return out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.kl_div")
@templatedoc()
def kldiv_loss(x, target, reduction='mean', name=None):
"""
......
......@@ -13,6 +13,7 @@
from __future__ import division
import paddle
import unittest
import numpy as np
from op_test import OpTest
......@@ -77,5 +78,36 @@ class TestKLDivLossOp4(TestKLDivLossOp):
self.reduction = 'sum'
class TestKLDivLossDygraph(unittest.TestCase):
def run_kl_loss(self, reduction, shape=(5, 20)):
x = np.random.uniform(-10, 10, shape).astype('float64')
target = np.random.uniform(-10, 10, shape).astype('float64')
gt_loss = kldiv_loss(x, target, reduction)
with paddle.fluid.dygraph.guard():
kldiv_criterion = paddle.nn.KLDivLoss(reduction)
pred_loss = kldiv_criterion(
paddle.to_variable(x), paddle.to_variable(target))
self.assertTrue(np.allclose(pred_loss.numpy(), gt_loss))
def test_kl_loss_batchmean(self):
self.run_kl_loss('batchmean')
def test_kl_loss_mean(self):
self.run_kl_loss('mean')
def test_kl_loss_sum(self):
self.run_kl_loss('sum')
def test_kl_loss_none(self):
self.run_kl_loss('none')
def test_kl_loss_static_api(self):
input = paddle.fluid.data(name='input', shape=[5, 20])
label = paddle.fluid.data(name='label', shape=[5, 20])
pred_loss = paddle.nn.functional.kl_div(input, label)
if __name__ == "__main__":
unittest.main()
......@@ -86,6 +86,7 @@ from .layer.loss import MSELoss #DEFINE_ALIAS
from .layer.loss import L1Loss #DEFINE_ALIAS
from .layer.loss import NLLLoss #DEFINE_ALIAS
from .layer.loss import BCELoss #DEFINE_ALIAS
from .layer.loss import KLDivLoss #DEFINE_ALIAS
from .layer.loss import MarginRankingLoss #DEFINE_ALIAS
from .layer.norm import BatchNorm #DEFINE_ALIAS
from .layer.norm import GroupNorm #DEFINE_ALIAS
......
......@@ -126,7 +126,7 @@ from .loss import dice_loss #DEFINE_ALIAS
from .loss import edit_distance #DEFINE_ALIAS
from .loss import huber_loss #DEFINE_ALIAS
from .loss import iou_similarity #DEFINE_ALIAS
from .loss import kldiv_loss #DEFINE_ALIAS
from .loss import kl_div #DEFINE_ALIAS
from .loss import l1_loss #DEFINE_ALIAS
from .loss import log_loss #DEFINE_ALIAS
from .loss import margin_ranking_loss #DEFINE_ALIAS
......
......@@ -25,7 +25,6 @@ from ...fluid.layers import center_loss #DEFINE_ALIAS
from ...fluid.layers import cross_entropy #DEFINE_ALIAS
from ...fluid.layers import dice_loss #DEFINE_ALIAS
from ...fluid.layers import iou_similarity #DEFINE_ALIAS
from ...fluid.layers import kldiv_loss #DEFINE_ALIAS
from ...fluid.layers import log_loss #DEFINE_ALIAS
from ...fluid.layers import npair_loss #DEFINE_ALIAS
from ...fluid.layers import rank_loss #DEFINE_ALIAS
......@@ -52,7 +51,7 @@ __all__ = [
'edit_distance',
'huber_loss',
'iou_similarity',
'kldiv_loss',
'kl_div',
'l1_loss',
'log_loss',
'mse_loss',
......@@ -374,6 +373,105 @@ def nll_loss(input,
return out
def kl_div(input, label, reduction='mean', name=None):
"""
This operator calculates the Kullback-Leibler divergence loss
between Input(X) and Input(Target). Notes that Input(X) is the
log-probability and Input(Target) is the probability.
KL divergence loss is calculated as follows:
$$l(x, y) = y * (\log(y) - x)$$
While :math:`x` is input and :math:`y` is label.
While :attr:`reduction` is :attr:`none`, output loss is in
the same shape as input, loss in each point is calculated
seperately and no reduction is applied.
While :attr:`reduction` is :attr:`mean`, output loss is in
shape of [1] and loss value is the mean value of all losses.
While :attr:`reduction` is :attr:`sum`, output loss is in
shape of [1] and loss value is the sum value of all losses.
While :attr:`reduction` is :attr:`batchmean`, output loss is
in shape of [1] and loss value is the sum value of all losses
divided by batch size.
Args:
input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means
any number of additional dimensions. It's data type should be float32, float64.
label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64.
reduction (Tensor): Indicate how to average the loss,
the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
If `reduction` is ``'mean'``, the reduced mean loss is returned;
If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be apllied.
Default is ``'mean'``.
name(str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The KL divergence loss. The data type is same as input tensor
Examples:
.. code-block:: python
import paddle
import numpy as np
import paddle.nn.functional as F
paddle.enable_imperative()
shape = (5, 20)
input = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N]
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='batchmean')
# shape=[5]
# 'mean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='mean')
# shape=[1]
# 'sum' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='sum')
# shape=[1]
# 'none' reduction, loss shape is same with input shape
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='none')
# shape=[5, 20]
"""
if paddle.in_dynamic_mode():
out = core.ops.kldiv_loss(input, label, 'reduction', reduction)
return out
helper = LayerHelper('kl_div', **locals())
fluid.data_feeder.check_variable_and_dtype(input, 'input',
['float32', 'float64'], 'kl_div')
fluid.data_feeder.check_variable_and_dtype(label, 'label',
['float32', 'float64'], 'kl_div')
fluid.data_feeder.check_type(reduction, 'reduction', str, 'kl_div')
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='kldiv_loss',
inputs={'X': input,
'Target': label},
outputs={'Loss': loss},
attrs={'reduction': reduction})
return loss
def mse_loss(input, label, reduction='mean', name=None):
"""
This op accepts input predications and label and returns the mean square error.
......
......@@ -62,6 +62,7 @@ from .loss import MSELoss #DEFINE_ALIAS
from .loss import L1Loss #DEFINE_ALIAS
from .loss import NLLLoss #DEFINE_ALIAS
from .loss import BCELoss #DEFINE_ALIAS
from .loss import KLDivLoss #DEFINE_ALIAS
from .loss import MarginRankingLoss #DEFINE_ALIAS
from .norm import BatchNorm #DEFINE_ALIAS
from .norm import GroupNorm #DEFINE_ALIAS
......
......@@ -26,6 +26,7 @@ __all__ = [
'L1Loss',
'NLLLoss',
'BCELoss',
'KLDivLoss',
'MarginRankingLoss'
]
......@@ -574,6 +575,75 @@ class NLLLoss(fluid.dygraph.Layer):
name=self._name)
class KLDivLoss(fluid.dygraph.Layer):
"""
This interface calculates the Kullback-Leibler divergence loss
between Input(X) and Input(Target). Notes that Input(X) is the
log-probability and Input(Target) is the probability.
KL divergence loss is calculated as follows:
$$l(x, y) = y * (\log(y) - x)$$
Parameters:
reduction (str, optional): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
Default is ``'mean'``.
Shape:
- input: (N, *) where * means, any number of additional dimensions.
- label: (N, *), same shape as input
- output: tensor with shape: (1) by default.
Examples:
.. code-block:: python
import paddle
import numpy as np
import paddle.nn as nn
paddle.enable_imperative()
shape = (5, 20)
x = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N]
kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
pred_loss = kldiv_criterion(paddle.to_variable(x),
paddle.to_variable(target))
# shape=[5]
# 'mean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='mean')
pred_loss = kldiv_criterion(paddle.to_variable(x),
paddle.to_variable(target))
# shape=[1]
# 'sum' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='sum')
pred_loss = kldiv_criterion(paddle.to_variable(x),
paddle.to_variable(target))
# shape=[1]
# 'none' reduction, loss shape is same with X shape
kldiv_criterion = nn.KLDivLoss(reduction='none')
pred_loss = kldiv_criterion(paddle.to_variable(x),
paddle.to_variable(target))
# shape=[5, 20]
"""
def __init__(self, reduction='mean'):
super(KLDivLoss, self).__init__()
self.reduction = reduction
def forward(self, input, label):
out = paddle.nn.functional.kl_div(input, label, self.reduction)
return out
class MarginRankingLoss(fluid.dygraph.Layer):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册