Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e4033a06
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e4033a06
编写于
8月 15, 2020
作者:
L
LielinJiang
提交者:
GitHub
8月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add Class KLDivLoss and function kl_div (#25977)
* add Class KLDivLoss and function kl_div
上级
57e83ad7
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
207 addition
and
3 deletion
+207
-3
python/paddle/fluid/layers/loss.py
python/paddle/fluid/layers/loss.py
+2
-0
python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py
python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py
+32
-0
python/paddle/nn/__init__.py
python/paddle/nn/__init__.py
+1
-0
python/paddle/nn/functional/__init__.py
python/paddle/nn/functional/__init__.py
+1
-1
python/paddle/nn/functional/loss.py
python/paddle/nn/functional/loss.py
+100
-2
python/paddle/nn/layer/__init__.py
python/paddle/nn/layer/__init__.py
+1
-0
python/paddle/nn/layer/loss.py
python/paddle/nn/layer/loss.py
+70
-0
未找到文件。
python/paddle/fluid/layers/loss.py
浏览文件 @
e4033a06
...
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
numpy
as
np
from
functools
import
partial
,
reduce
from
paddle.utils
import
deprecated
from
.
import
nn
from
.layer_function_generator
import
templatedoc
from
..layer_helper
import
LayerHelper
...
...
@@ -1619,6 +1620,7 @@ def huber_loss(input, label, delta):
return
out
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.nn.functional.kl_div"
)
@
templatedoc
()
def
kldiv_loss
(
x
,
target
,
reduction
=
'mean'
,
name
=
None
):
"""
...
...
python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py
浏览文件 @
e4033a06
...
...
@@ -13,6 +13,7 @@
from
__future__
import
division
import
paddle
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
...
...
@@ -77,5 +78,36 @@ class TestKLDivLossOp4(TestKLDivLossOp):
self
.
reduction
=
'sum'
class
TestKLDivLossDygraph
(
unittest
.
TestCase
):
def
run_kl_loss
(
self
,
reduction
,
shape
=
(
5
,
20
)):
x
=
np
.
random
.
uniform
(
-
10
,
10
,
shape
).
astype
(
'float64'
)
target
=
np
.
random
.
uniform
(
-
10
,
10
,
shape
).
astype
(
'float64'
)
gt_loss
=
kldiv_loss
(
x
,
target
,
reduction
)
with
paddle
.
fluid
.
dygraph
.
guard
():
kldiv_criterion
=
paddle
.
nn
.
KLDivLoss
(
reduction
)
pred_loss
=
kldiv_criterion
(
paddle
.
to_variable
(
x
),
paddle
.
to_variable
(
target
))
self
.
assertTrue
(
np
.
allclose
(
pred_loss
.
numpy
(),
gt_loss
))
def
test_kl_loss_batchmean
(
self
):
self
.
run_kl_loss
(
'batchmean'
)
def
test_kl_loss_mean
(
self
):
self
.
run_kl_loss
(
'mean'
)
def
test_kl_loss_sum
(
self
):
self
.
run_kl_loss
(
'sum'
)
def
test_kl_loss_none
(
self
):
self
.
run_kl_loss
(
'none'
)
def
test_kl_loss_static_api
(
self
):
input
=
paddle
.
fluid
.
data
(
name
=
'input'
,
shape
=
[
5
,
20
])
label
=
paddle
.
fluid
.
data
(
name
=
'label'
,
shape
=
[
5
,
20
])
pred_loss
=
paddle
.
nn
.
functional
.
kl_div
(
input
,
label
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/nn/__init__.py
浏览文件 @
e4033a06
...
...
@@ -86,6 +86,7 @@ from .layer.loss import MSELoss #DEFINE_ALIAS
from
.layer.loss
import
L1Loss
#DEFINE_ALIAS
from
.layer.loss
import
NLLLoss
#DEFINE_ALIAS
from
.layer.loss
import
BCELoss
#DEFINE_ALIAS
from
.layer.loss
import
KLDivLoss
#DEFINE_ALIAS
from
.layer.loss
import
MarginRankingLoss
#DEFINE_ALIAS
from
.layer.norm
import
BatchNorm
#DEFINE_ALIAS
from
.layer.norm
import
GroupNorm
#DEFINE_ALIAS
...
...
python/paddle/nn/functional/__init__.py
浏览文件 @
e4033a06
...
...
@@ -126,7 +126,7 @@ from .loss import dice_loss #DEFINE_ALIAS
from
.loss
import
edit_distance
#DEFINE_ALIAS
from
.loss
import
huber_loss
#DEFINE_ALIAS
from
.loss
import
iou_similarity
#DEFINE_ALIAS
from
.loss
import
kl
div_loss
#DEFINE_ALIAS
from
.loss
import
kl
_div
#DEFINE_ALIAS
from
.loss
import
l1_loss
#DEFINE_ALIAS
from
.loss
import
log_loss
#DEFINE_ALIAS
from
.loss
import
margin_ranking_loss
#DEFINE_ALIAS
...
...
python/paddle/nn/functional/loss.py
浏览文件 @
e4033a06
...
...
@@ -25,7 +25,6 @@ from ...fluid.layers import center_loss #DEFINE_ALIAS
from
...fluid.layers
import
cross_entropy
#DEFINE_ALIAS
from
...fluid.layers
import
dice_loss
#DEFINE_ALIAS
from
...fluid.layers
import
iou_similarity
#DEFINE_ALIAS
from
...fluid.layers
import
kldiv_loss
#DEFINE_ALIAS
from
...fluid.layers
import
log_loss
#DEFINE_ALIAS
from
...fluid.layers
import
npair_loss
#DEFINE_ALIAS
from
...fluid.layers
import
rank_loss
#DEFINE_ALIAS
...
...
@@ -52,7 +51,7 @@ __all__ = [
'edit_distance'
,
'huber_loss'
,
'iou_similarity'
,
'kl
div_loss
'
,
'kl
_div
'
,
'l1_loss'
,
'log_loss'
,
'mse_loss'
,
...
...
@@ -374,6 +373,105 @@ def nll_loss(input,
return
out
def
kl_div
(
input
,
label
,
reduction
=
'mean'
,
name
=
None
):
"""
This operator calculates the Kullback-Leibler divergence loss
between Input(X) and Input(Target). Notes that Input(X) is the
log-probability and Input(Target) is the probability.
KL divergence loss is calculated as follows:
$$l(x, y) = y * (\log(y) - x)$$
While :math:`x` is input and :math:`y` is label.
While :attr:`reduction` is :attr:`none`, output loss is in
the same shape as input, loss in each point is calculated
seperately and no reduction is applied.
While :attr:`reduction` is :attr:`mean`, output loss is in
shape of [1] and loss value is the mean value of all losses.
While :attr:`reduction` is :attr:`sum`, output loss is in
shape of [1] and loss value is the sum value of all losses.
While :attr:`reduction` is :attr:`batchmean`, output loss is
in shape of [1] and loss value is the sum value of all losses
divided by batch size.
Args:
input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means
any number of additional dimensions. It's data type should be float32, float64.
label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64.
reduction (Tensor): Indicate how to average the loss,
the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
If `reduction` is ``'mean'``, the reduced mean loss is returned;
If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be apllied.
Default is ``'mean'``.
name(str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The KL divergence loss. The data type is same as input tensor
Examples:
.. code-block:: python
import paddle
import numpy as np
import paddle.nn.functional as F
paddle.enable_imperative()
shape = (5, 20)
input = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N]
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='batchmean')
# shape=[5]
# 'mean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='mean')
# shape=[1]
# 'sum' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='sum')
# shape=[1]
# 'none' reduction, loss shape is same with input shape
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='none')
# shape=[5, 20]
"""
if
paddle
.
in_dynamic_mode
():
out
=
core
.
ops
.
kldiv_loss
(
input
,
label
,
'reduction'
,
reduction
)
return
out
helper
=
LayerHelper
(
'kl_div'
,
**
locals
())
fluid
.
data_feeder
.
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
,
'float64'
],
'kl_div'
)
fluid
.
data_feeder
.
check_variable_and_dtype
(
label
,
'label'
,
[
'float32'
,
'float64'
],
'kl_div'
)
fluid
.
data_feeder
.
check_type
(
reduction
,
'reduction'
,
str
,
'kl_div'
)
loss
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
helper
.
append_op
(
type
=
'kldiv_loss'
,
inputs
=
{
'X'
:
input
,
'Target'
:
label
},
outputs
=
{
'Loss'
:
loss
},
attrs
=
{
'reduction'
:
reduction
})
return
loss
def
mse_loss
(
input
,
label
,
reduction
=
'mean'
,
name
=
None
):
"""
This op accepts input predications and label and returns the mean square error.
...
...
python/paddle/nn/layer/__init__.py
浏览文件 @
e4033a06
...
...
@@ -62,6 +62,7 @@ from .loss import MSELoss #DEFINE_ALIAS
from
.loss
import
L1Loss
#DEFINE_ALIAS
from
.loss
import
NLLLoss
#DEFINE_ALIAS
from
.loss
import
BCELoss
#DEFINE_ALIAS
from
.loss
import
KLDivLoss
#DEFINE_ALIAS
from
.loss
import
MarginRankingLoss
#DEFINE_ALIAS
from
.norm
import
BatchNorm
#DEFINE_ALIAS
from
.norm
import
GroupNorm
#DEFINE_ALIAS
...
...
python/paddle/nn/layer/loss.py
浏览文件 @
e4033a06
...
...
@@ -26,6 +26,7 @@ __all__ = [
'L1Loss'
,
'NLLLoss'
,
'BCELoss'
,
'KLDivLoss'
,
'MarginRankingLoss'
]
...
...
@@ -574,6 +575,75 @@ class NLLLoss(fluid.dygraph.Layer):
name
=
self
.
_name
)
class
KLDivLoss
(
fluid
.
dygraph
.
Layer
):
"""
This interface calculates the Kullback-Leibler divergence loss
between Input(X) and Input(Target). Notes that Input(X) is the
log-probability and Input(Target) is the probability.
KL divergence loss is calculated as follows:
$$l(x, y) = y * (\log(y) - x)$$
Parameters:
reduction (str, optional): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
Default is ``'mean'``.
Shape:
- input: (N, *) where * means, any number of additional dimensions.
- label: (N, *), same shape as input
- output: tensor with shape: (1) by default.
Examples:
.. code-block:: python
import paddle
import numpy as np
import paddle.nn as nn
paddle.enable_imperative()
shape = (5, 20)
x = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N]
kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
pred_loss = kldiv_criterion(paddle.to_variable(x),
paddle.to_variable(target))
# shape=[5]
# 'mean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='mean')
pred_loss = kldiv_criterion(paddle.to_variable(x),
paddle.to_variable(target))
# shape=[1]
# 'sum' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='sum')
pred_loss = kldiv_criterion(paddle.to_variable(x),
paddle.to_variable(target))
# shape=[1]
# 'none' reduction, loss shape is same with X shape
kldiv_criterion = nn.KLDivLoss(reduction='none')
pred_loss = kldiv_criterion(paddle.to_variable(x),
paddle.to_variable(target))
# shape=[5, 20]
"""
def
__init__
(
self
,
reduction
=
'mean'
):
super
(
KLDivLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
def
forward
(
self
,
input
,
label
):
out
=
paddle
.
nn
.
functional
.
kl_div
(
input
,
label
,
self
.
reduction
)
return
out
class
MarginRankingLoss
(
fluid
.
dygraph
.
Layer
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录