From 3e962aecc1626976dda7a7de33841a1c404bd786 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 27 Apr 2020 17:43:17 +0800 Subject: [PATCH] fix kldiv_loss sample code diff. test=develop test=document_fix (#23660) --- python/paddle/fluid/layers/loss.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index 0e6dcf39bc..18790317e5 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -1592,9 +1592,27 @@ def kldiv_loss(x, target, reduction='mean', name=None): .. code-block:: python import paddle.fluid as fluid - x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32') + + # 'batchmean' reduction, loss shape will be [N] + x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32') # shape=[-1, 4, 2, 2] + target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32') + loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean') # shape=[-1] + + # 'mean' reduction, loss shape will be [1] + x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32') # shape=[-1, 4, 2, 2] + target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32') + loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='mean') # shape=[1] + + # 'sum' reduction, loss shape will be [1] + x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32') # shape=[-1, 4, 2, 2] + target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32') + loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='sum') # shape=[1] + + # 'none' reduction, loss shape is same with X shape + x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32') # shape=[-1, 4, 2, 2] target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32') - loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean') + loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='none') # shape=[-1, 4, 2, 2] + """ helper = LayerHelper('kldiv_loss', **locals()) check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'kldiv_loss') -- GitLab