From 3e3a983a6902572049046f38b5ead4097cad969e Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sat, 2 Mar 2019 13:52:32 +0800 Subject: [PATCH] add kldiv_loss op. test=develop --- paddle/fluid/operators/kldiv_loss_op.cc | 150 ++++++++++++++++++ paddle/fluid/operators/kldiv_loss_op.cu | 21 +++ paddle/fluid/operators/kldiv_loss_op.h | 117 ++++++++++++++ .../tests/unittests/test_kldiv_loss_op.py | 82 ++++++++++ 4 files changed, 370 insertions(+) create mode 100644 paddle/fluid/operators/kldiv_loss_op.cc create mode 100644 paddle/fluid/operators/kldiv_loss_op.cu create mode 100644 paddle/fluid/operators/kldiv_loss_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc new file mode 100644 index 000000000..d04221054 --- /dev/null +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -0,0 +1,150 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/kldiv_loss_op.h" +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class KLDivLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of KLDivLossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Target"), + "Input(Target) of KLDivLossOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Loss"), + "Output(Loss) of KLDivLossOp should not be null."); + + auto dim_x = ctx->GetInputDim("X"); + auto dim_target = ctx->GetInputDim("Target"); + PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(), + "Input(X) rank and Input(Target) rank should be same."); + for (size_t i = 0; i < dim_x.size(); i++) { + PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], + "Input(X) and Input(Target) should in same shape."); + } + + auto reduction = ctx->Attrs().Get("reduction"); + + PADDLE_ENFORCE( + "mean" == reduction || "sum" == reduction || "batchmean" == reduction || + "none" == reduction, + "Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'."); + + if ("none" == reduction) { + ctx->SetOutputDim("Loss", dim_x); + } else { + ctx->SetOutputDim("Loss", framework::make_ddim({1})); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class KLDivLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of KL divergence loss operator, " + "This is a tensor with shape of [N, *], where N is the" + "batch size, * means any number of additional dimensions."); + AddInput("Target", + "The tensor of KL divergence loss operator, " + "This is a tensor with shape of Input(X)."); + AddOutput( + "Loss", + "The output KL divergence loss tensor. if Attr(reduction) is " + "'none', this tensor should be in same shape of of Input(X), else " + "this tensor should be in shape of [1]."); + + AddAttr( + "reduction", + "The reduction type to apply to the output, available types " + "are 'none' | 'batchmean' | 'mean' | 'sum', 'none' for no " + "reduction, 'batchmean' for the sum of output divided by " + "batch size, 'mean' for the average valud of all output, " + "'sum' for the sum of the output.") + .SetDefault("mean"); + + AddComment(R"DOC( + This operator calculates the Kullback-Leibler divergence loss + between Input(X) and Input(Target). + + )DOC"); + } +}; + +class KLDivLossOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Target"), "Input(Target) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), + "Input(Loss@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), dim_x); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class KLDivLossOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("kldiv_loss_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Target", Input("Target")); + op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker, + ops::KLDivLossOpGradMaker); +REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad); +REGISTER_OP_CPU_KERNEL( + kldiv_loss, ops::KLDivLossKernel, + ops::KLDivLossKernel); +REGISTER_OP_CPU_KERNEL( + kldiv_loss_grad, + ops::KLDivLossGradKernel, + ops::KLDivLossGradKernel); diff --git a/paddle/fluid/operators/kldiv_loss_op.cu b/paddle/fluid/operators/kldiv_loss_op.cu new file mode 100644 index 000000000..ef394feb6 --- /dev/null +++ b/paddle/fluid/operators/kldiv_loss_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/kldiv_loss_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + sum, ops::KLDivLossKernel, + ops::KLDivLossKernel); +REGISTER_OP_CUDA_KERNEL( + sum_grad, + ops::KLDivLossGradKernel, + ops::KLDivLossGradKernel); diff --git a/paddle/fluid/operators/kldiv_loss_op.h b/paddle/fluid/operators/kldiv_loss_op.h new file mode 100644 index 000000000..2867e44e7 --- /dev/null +++ b/paddle/fluid/operators/kldiv_loss_op.h @@ -0,0 +1,117 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +using Array1 = Eigen::DSizes; + +template +struct KLDivLossForward { + HOSTDEVICE KLDivLossForward() {} + + HOSTDEVICE T operator()(const T& target, const T& input) const { + if (target < 0) { + return 0; + } else { + return target * (std::log(target) - input); + } + } +}; + +template +class KLDivLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto* input = ctx.Input("X"); + auto* target = ctx.Input("Target"); + auto* loss = ctx.Output("Loss"); + auto reduction = ctx.Attr("reduction"); + + const int n = input->dims()[0]; + + loss->mutable_data(ctx.GetPlace()); + auto input_t = EigenVector::Flatten(*input); + auto target_t = EigenVector::Flatten(*target); + auto loss_t = EigenVector::Flatten(*loss); + // auto target_mask = (target_t > target_t.constant(0)).template cast(); + // auto output = (target_t * (target_t.log() - input_t)) * target_mask; + auto output = target_t.binaryExpr(input_t, KLDivLossForward()); + if ("none" == reduction) { + loss_t.device(place) = output; + } else if ("batchmean" == reduction) { + loss_t.device(place) = output.sum() / static_cast(n); + } else if ("mean" == reduction) { + loss_t.device(place) = output.mean(); + } else if ("sum" == reduction) { + loss_t.device(place) = output.sum(); + } + } +}; + +template +class KLDivLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto* input = ctx.Input("X"); + auto* target = ctx.Input("Target"); + auto reduction = ctx.Attr("reduction"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + + const int n = input->dims()[0]; + const int numel = input->numel(); + const int expand = numel / loss_grad->numel(); + + input_grad->mutable_data(ctx.GetPlace()); + + auto input_t = EigenVector::Flatten(*input); + auto target_t = EigenVector::Flatten(*target); + + auto input_grad_t = EigenVector::Flatten(*input_grad); + auto loss_grad_t = EigenVector::Flatten(*loss_grad); + auto target_mask = (target_t > target_t.constant(0)).template cast(); + + auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand)); + input_grad_t.device(place) = + target_t * target_t.constant(-1.0) * loss_grad_expand * target_mask; + // if (reduction == "none") { + // input_grad_t.device(place) = + // target_t * loss_grad_t * target_t.constant(-1.0); + // } else { + // auto loss_grad_expand = loss_grad_t.broadcast(Array1(numel)); + // input_grad_t.device(place) = + // target_t * loss_grad_expand * target_t.constant(-1.0); + // } + + if ("mean" == reduction) { + input_grad_t.device(place) = input_grad_t / static_cast(numel); + } else if ("batchmean" == reduction) { + input_grad_t.device(place) = input_grad_t / static_cast(n); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py new file mode 100644 index 000000000..21bac6732 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py @@ -0,0 +1,82 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np +from op_test import OpTest + + +def kldiv_loss(x, target, reduction): + output = target * (np.log(target) - x) + loss = np.where(target > 0, output, np.zeros_like(x)) + + if reduction == "batchmean": + return loss.sum() / x.shape[0] + if reduction == "mean": + return loss.mean() + if reduction == "sum": + return loss.sum() + + return loss + + +class TestKLDivLossOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'kldiv_loss' + x = np.random.uniform(-10, 10, self.x_shape).astype('float32') + target = np.random.uniform(-10, 10, self.x_shape).astype('float32') + + self.attrs = {"reduction": self.reduction} + + self.inputs = { + 'X': x, + 'Target': target, + } + loss = kldiv_loss(x, target, self.reduction) + self.outputs = {'Loss': loss} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['X'], 'Loss', no_grad_set=set(["Target"]), max_relative_error=0.1) + + def initTestCase(self): + self.x_shape = (2, 3, 5, 5) + self.reduction = 'batchmean' + + +# class TestKLDivLossOp2(TestKLDivLossOp): +# def initTestCase(self): +# self.x_shape = (3, 7, 7) +# self.reduction = 'batchmean' +# +# +# class TestKLDivLossOp3(TestKLDivLossOp): +# def initTestCase(self): +# self.x_shape = (2, 3, 5, 7, 9) +# self.reduction = 'mean' +# +# +# class TestKLDivLossOp4(TestKLDivLossOp): +# def initTestCase(self): +# self.x_shape = (5, 7) +# self.reduction = 'sum' + +if __name__ == "__main__": + unittest.main() -- GitLab