diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index d472c6062e94ce499c915a5f2f9559d94cd695d7..0ea9cb625bb32601798009f6e3cb1f72de9aeb47 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -230,6 +230,7 @@ paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '1546136806fef5c08f6918544bd9151d')) paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '2f6ff96864054a31aa4bb659c6722c99')) paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '431a4301c35032166ec029f7432c80a7')) +paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '776d536cac47c89073abc7ee524d5aec')) paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607')) paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329')) paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393')) diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a43f22c0496f89943d2fd5110446f1aae6a99315 --- /dev/null +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -0,0 +1,171 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/kldiv_loss_op.h" +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class KLDivLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of KLDivLossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Target"), + "Input(Target) of KLDivLossOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Loss"), + "Output(Loss) of KLDivLossOp should not be null."); + + auto dim_x = ctx->GetInputDim("X"); + auto dim_target = ctx->GetInputDim("Target"); + PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(), + "Input(X) rank and Input(Target) rank should be same."); + for (int i = 0; i < dim_x.size(); i++) { + PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], + "Input(X) and Input(Target) should in same shape."); + } + + auto reduction = ctx->Attrs().Get("reduction"); + + PADDLE_ENFORCE( + "mean" == reduction || "sum" == reduction || "batchmean" == reduction || + "none" == reduction, + "Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'."); + + if ("none" == reduction) { + ctx->SetOutputDim("Loss", dim_x); + } else { + ctx->SetOutputDim("Loss", {1}); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class KLDivLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of KL divergence loss operator. " + "This is a tensor with shape of [N, *], where N is the " + "batch size, * means any number of additional dimensions."); + AddInput("Target", + "The tensor of KL divergence loss operator. " + "This is a tensor with shape of Input(X)."); + AddOutput( + "Loss", + "The output KL divergence loss tensor. if Attr(reduction) is " + "'none', this tensor should be in same shape of of Input(X), else " + "this tensor should be in shape of [1]."); + + AddAttr( + "reduction", + "The reduction type to apply to the output, available types " + "are 'none' | 'batchmean' | 'mean' | 'sum', 'none' for no " + "reduction, 'batchmean' for the sum of output divided by " + "batch size, 'mean' for the average value of all output, " + "'sum' for the sum of the output.") + .SetDefault("mean"); + + AddComment(R"DOC( + This operator calculates the Kullback-Leibler divergence loss + between Input(X) and Input(Target). + + KL divergence loss is calculated as follows: + + $$l(x, y) = y * (\log(y) - x)$$ + + While :math:`x` is Input(X) and :math:`y` is Input(Target). + + While :attr:`reduction` is :attr:`none`, output loss is in + the same shape as Input(X), loss in each point is calculated + seperately and no reduction is applied. + + While :attr:`reduction` is :attr:`mean`, output loss is in + shape of [1] and loss value is the mean value of all losses. + + While :attr:`reduction` is :attr:`sum`, output loss is in + shape of [1] and loss value is the sum value of all losses. + + While :attr:`reduction` is :attr:`batchmean`, output loss is + in shape of [1] and loss value is the sum value of all losses + divided by batch size. + + )DOC"); + } +}; + +class KLDivLossOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Target"), "Input(Target) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), + "Input(Loss@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), dim_x); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class KLDivLossOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("kldiv_loss_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Target", Input("Target")); + op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker, + ops::KLDivLossOpGradMaker); +REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad); +REGISTER_OP_CPU_KERNEL( + kldiv_loss, ops::KLDivLossKernel, + ops::KLDivLossKernel); +REGISTER_OP_CPU_KERNEL( + kldiv_loss_grad, + ops::KLDivLossGradKernel, + ops::KLDivLossGradKernel); diff --git a/paddle/fluid/operators/kldiv_loss_op.cu b/paddle/fluid/operators/kldiv_loss_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5226cb8c08e3db4a0bfbbe4440c27264903f06e3 --- /dev/null +++ b/paddle/fluid/operators/kldiv_loss_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/kldiv_loss_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + kldiv_loss, + ops::KLDivLossKernel, + ops::KLDivLossKernel); +REGISTER_OP_CUDA_KERNEL( + kldiv_loss_grad, + ops::KLDivLossGradKernel, + ops::KLDivLossGradKernel); diff --git a/paddle/fluid/operators/kldiv_loss_op.h b/paddle/fluid/operators/kldiv_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..625e16e298d9f842fa621aca727c6df2cb045301 --- /dev/null +++ b/paddle/fluid/operators/kldiv_loss_op.h @@ -0,0 +1,119 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +using Array1 = Eigen::DSizes; + +template +struct KLDivLossForward { + HOSTDEVICE KLDivLossForward() {} + + HOSTDEVICE T operator()(const T& target, const T& input) const { + if (target <= 0) { + return 0; + } else { + return target * (std::log(target) - input); + } + } +}; + +template +struct KLDivLossBackward { + HOSTDEVICE KLDivLossBackward() {} + + HOSTDEVICE T operator()(const T& target, const T& grad) const { + if (target <= 0) { + return 0; + } else { + return static_cast(-1.) * grad; + } + } +}; + +template +class KLDivLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto* input = ctx.Input("X"); + auto* target = ctx.Input("Target"); + auto* loss = ctx.Output("Loss"); + auto reduction = ctx.Attr("reduction"); + + const int n = input->dims()[0]; + + loss->mutable_data(ctx.GetPlace()); + auto input_t = EigenVector::Flatten(*input); + auto target_t = EigenVector::Flatten(*target); + auto loss_t = EigenVector::Flatten(*loss); + auto output = target_t.binaryExpr(input_t, KLDivLossForward()); + if ("none" == reduction) { + loss_t.device(place) = output; + } else if ("batchmean" == reduction) { + auto output_sum = output.sum().eval(); + loss_t.device(place) = output_sum / output_sum.constant(n); + } else if ("mean" == reduction) { + loss_t.device(place) = output.mean(); + } else if ("sum" == reduction) { + loss_t.device(place) = output.sum(); + } + } +}; + +template +class KLDivLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto* target = ctx.Input("Target"); + auto reduction = ctx.Attr("reduction"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + + const int n = input_grad->dims()[0]; + const int numel = input_grad->numel(); + const int expand = numel / loss_grad->numel(); + + input_grad->mutable_data(ctx.GetPlace()); + + auto target_t = EigenVector::Flatten(*target); + + auto input_grad_t = EigenVector::Flatten(*input_grad); + auto loss_grad_t = EigenVector::Flatten(*loss_grad); + + auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand)); + auto grad_t = target_t * loss_grad_expand; + input_grad_t.device(place) = + target_t.binaryExpr(grad_t, KLDivLossBackward()); + + if ("mean" == reduction) { + input_grad_t.device(place) = input_grad_t / static_cast(numel); + } else if ("batchmean" == reduction) { + input_grad_t.device(place) = input_grad_t / static_cast(n); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index baa7d93cbcf1c85e026a559ae026ac35b5249a23..91414fdeb207781afd5e28afa5a3fa6e1018efb1 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -188,6 +188,7 @@ __all__ = [ 'psroi_pool', 'teacher_student_sigmoid_loss', 'huber_loss', + 'kldiv_loss', 'tree_conv', 'npair_loss', 'fsp_matrix', @@ -10762,6 +10763,38 @@ def huber_loss(input, label, delta): return out +@templatedoc() +def kldiv_loss(x, target, reduction='mean', name=None): + """ + ${comment} + + Args: + x (Variable): ${x_comment} + target (Variable): ${target_comment} + reduction (Variable): ${reduction_comment} + name (str, default None): The name of this layer. + + Returns: + kldiv\_loss (Variable): The KL divergence loss. + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[4,2,2], dtype='float32') + target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32') + loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean') + """ + helper = LayerHelper('kldiv_loss', **locals()) + loss = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='kldiv_loss', + inputs={'X': x, + 'Target': target}, + outputs={'Loss': loss}, + attrs={'reduction': reduction}) + return loss + + @templatedoc() def tree_conv(nodes_vector, edge_set, diff --git a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d0212d177e6f1c60b916a0cb0eef7cd7f54a3585 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py @@ -0,0 +1,82 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np +from op_test import OpTest + + +def kldiv_loss(x, target, reduction): + output = target * (np.log(target) - x) + loss = np.where(target >= 0, output, np.zeros_like(x)) + + if reduction == "batchmean": + return loss.sum() / x.shape[0] + if reduction == "mean": + return loss.mean() + if reduction == "sum": + return loss.sum() + + return loss + + +class TestKLDivLossOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'kldiv_loss' + x = np.random.uniform(-10, 10, self.x_shape).astype('float32') + target = np.random.uniform(-10, 10, self.x_shape).astype('float32') + + self.attrs = {"reduction": self.reduction} + + self.inputs = { + 'X': x, + 'Target': target, + } + loss = kldiv_loss(x, target, self.reduction) + self.outputs = {'Loss': loss.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['X'], 'Loss', no_grad_set=set(["Target"]), max_relative_error=0.06) + + def initTestCase(self): + self.x_shape = (2, 5, 5) + self.reduction = 'batchmean' + + +class TestKLDivLossOp2(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (3, 2, 7, 7) + self.reduction = 'none' + + +class TestKLDivLossOp3(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (2, 3, 5, 7, 9) + self.reduction = 'mean' + + +class TestKLDivLossOp4(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (5, 7) + self.reduction = 'sum' + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index d5041f800b58883d2a457350e921ed32860dc854..e92ece7acb41b5a63adaae8edba78486ca3adcf8 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1591,6 +1591,15 @@ class TestBook(unittest.TestCase): out = layers.spectral_norm(weight, dim=1, power_iters=1) self.assertIsNotNone(out) + def test_kldiv_loss(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[32, 128, 128], dtype="float32") + target = layers.data( + name='target', shape=[32, 128, 128], dtype="float32") + loss = layers.kldiv_loss(x=x, target=target, reduction='batchmean') + self.assertIsNotNone(loss) + print(str(program)) def test_temporal_shift(self):