diff --git a/paddle/fluid/operators/kldiv_loss_op_npu.cc b/paddle/fluid/operators/kldiv_loss_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d7cdd4c786712adb73d67cf8f8027f5cba06263 --- /dev/null +++ b/paddle/fluid/operators/kldiv_loss_op_npu.cc @@ -0,0 +1,163 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the Licnse. */ + +#include "paddle/fluid/operators/kldiv_loss_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class KLDivLossNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* target = ctx.Input("Target"); + auto* loss = ctx.Output("Loss"); + auto reduction = ctx.Attr("reduction"); + loss->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + auto stream = dev_ctx.stream(); + + if ("none" == reduction) { + // log(label) + auto ones_tensor = ctx.AllocateTmpTensor( + target->dims(), dev_ctx); + const auto& ones_runner = + NpuOpRunner("OnesLike", {*target}, {ones_tensor}, {}); + ones_runner.Run(stream); + + auto sub_tensor = ctx.AllocateTmpTensor( + target->dims(), dev_ctx); + const auto& sub_runner = + NpuOpRunner("Sub", {*target, ones_tensor}, {sub_tensor}, {}); + sub_runner.Run(stream); + + auto log_target = ctx.AllocateTmpTensor( + target->dims(), dev_ctx); + const auto& log_runner = + NpuOpRunner("Log1p", {sub_tensor}, {log_target}, {}); + log_runner.Run(stream); + + // log(label) - input + const auto& sub_runner2 = + NpuOpRunner("Sub", {log_target, *input}, {*loss}, {}); + sub_runner2.Run(stream); + + // label * (log(label) - input) + auto min_value = + ctx.AllocateTmpTensor({1}, dev_ctx); + auto max_value = + ctx.AllocateTmpTensor({1}, dev_ctx); + FillNpuTensorWithConstant(&min_value, static_cast(0)); + FillNpuTensorWithConstant(&max_value, std::numeric_limits::max()); + + auto cliped_target = ctx.AllocateTmpTensor( + target->dims(), dev_ctx); + const auto& clip_runner = NpuOpRunner( + "ClipByValue", {*target, min_value, max_value}, {cliped_target}, {}); + clip_runner.Run(stream); + + const auto& mul_runner = + NpuOpRunner("Mul", {*loss, cliped_target}, {*loss}, {}); + mul_runner.Run(stream); + } else if ("batchmean" == reduction || "sum" == reduction) { + const auto& runner = NpuOpRunner("KLDiv", {*input, *target}, {*loss}, + {{"reduction", reduction}}); + runner.Run(stream); + } else if ("mean" == reduction) { + const auto& runner = NpuOpRunner("KLDiv", {*input, *target}, {*loss}, + {{"reduction", std::string("sum")}}); + runner.Run(stream); + + const int numel = input->numel(); + const auto& muls_runner = + NpuOpRunner("Muls", {*loss}, {*loss}, + {{"value", static_cast(1.0 / numel)}}); + muls_runner.Run(stream); + } + } +}; + +template +class KLDivLossGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* target = ctx.Input("Target"); + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto reduction = ctx.Attr("reduction"); + input_grad->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + auto stream = dev_ctx.stream(); + + Tensor loss_grad_transformed; + if ("none" == reduction) { + loss_grad_transformed.ShareDataWith(*loss_grad); + } else { + loss_grad_transformed.mutable_data(input_grad->dims(), ctx.GetPlace()); + + NpuOpRunner broadcast_runner; + broadcast_runner.SetType("BroadcastTo"); + broadcast_runner.AddInput(*loss_grad); + broadcast_runner.AddInput(framework::vectorize(input_grad->dims())); + broadcast_runner.AddOutput(loss_grad_transformed); + broadcast_runner.Run(stream); + } + auto min_value = + ctx.AllocateTmpTensor({1}, dev_ctx); + auto max_value = + ctx.AllocateTmpTensor({1}, dev_ctx); + FillNpuTensorWithConstant(&min_value, static_cast(0)); + FillNpuTensorWithConstant(&max_value, std::numeric_limits::max()); + + auto cliped_target = ctx.AllocateTmpTensor( + target->dims(), dev_ctx); + const auto& clip_runner = NpuOpRunner( + "ClipByValue", {*target, min_value, max_value}, {cliped_target}, {}); + clip_runner.Run(stream); + + const auto& mul_runner = NpuOpRunner( + "Mul", {cliped_target, loss_grad_transformed}, {*input_grad}, {}); + mul_runner.Run(stream); + + float k = -1.0f; + + if ("mean" == reduction) { + k = static_cast(-1.0 / input_grad->numel()); + } else if ("batchmean" == reduction) { + k = static_cast(-1.0 / input_grad->dims()[0]); + } + + const auto& muls_runner = + NpuOpRunner("Muls", {*input_grad}, {*input_grad}, {{"value", k}}); + muls_runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(kldiv_loss, ops::KLDivLossNPUKernel, + ops::KLDivLossNPUKernel); + +REGISTER_OP_NPU_KERNEL(kldiv_loss_grad, ops::KLDivLossGradNPUKernel, + ops::KLDivLossGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_kldiv_loss_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_kldiv_loss_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed1775fa5e6dbd5cb7809cd687d1600695924da --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_kldiv_loss_op_npu.py @@ -0,0 +1,154 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from test_kldiv_loss_op import kldiv_loss + +paddle.enable_static() + + +class TestKLDivLossOp(OpTest): + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def init_dtype(self): + self.dtype = 'float32' + + def setUp(self): + self.set_npu() + self.init_dtype() + self.initTestCase() + self.op_type = 'kldiv_loss' + x = np.random.uniform(-10, 10, self.x_shape).astype(self.dtype) + target = np.random.uniform(-10, 10, self.x_shape).astype(self.dtype) + + self.attrs = {"reduction": self.reduction} + + self.inputs = { + 'X': x, + 'Target': target, + } + loss = kldiv_loss(x, target, self.reduction) + self.outputs = {'Loss': loss.astype(self.dtype)} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X'], + 'Loss', + no_grad_set=set(["Target"]), + max_relative_error=0.15) + + def initTestCase(self): + self.x_shape = (4, 5, 5) + self.reduction = 'batchmean' + + +class TestKLDivLossOp2(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (3, 2, 7, 7) + self.reduction = 'none' + + +class TestKLDivLossOp3(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (2, 3, 5, 7, 9) + self.reduction = 'mean' + + +class TestKLDivLossOp4(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (5, 20) + self.reduction = 'sum' + + +class TestKLDivLossOp_fp16(TestKLDivLossOp): + def init_dtype(self): + self.dtype = 'float16' + + def test_check_output(self): + self.check_output_with_place(self.place, atol=3e-1) + + def test_check_grad(self): + input_grad = -self.inputs['Target'] * ( + self.inputs['Target'] > 0) / self.inputs['Target'].shape[0] + self.check_grad_with_place( + self.place, ['X'], + 'Loss', + no_grad_set=set(["Target"]), + max_relative_error=0.2, + user_defined_grads=[input_grad]) + + +class TestKLDivLossDygraph(unittest.TestCase): + def run_kl_loss(self, reduction, shape=(5, 20)): + x = np.random.uniform(-10, 10, shape).astype('float32') + target = np.random.uniform(-10, 10, shape).astype('float32') + gt_loss = kldiv_loss(x, target, reduction) + + with paddle.fluid.dygraph.guard(paddle.NPUPlace(0)): + kldiv_criterion = paddle.nn.KLDivLoss(reduction) + pred_loss = kldiv_criterion( + paddle.to_tensor(x), paddle.to_tensor(target)) + self.assertTrue(np.allclose(pred_loss.numpy(), gt_loss)) + + def test_kl_loss_batchmean(self): + self.run_kl_loss('batchmean') + + def test_kl_loss_batchmean_shape(self): + self.run_kl_loss('batchmean', ()) + + def test_kl_loss_mean(self): + self.run_kl_loss('mean') + + def test_kl_loss_sum(self): + self.run_kl_loss('sum') + + def test_kl_loss_none(self): + self.run_kl_loss('none') + + def test_kl_loss_static_api(self): + input = paddle.fluid.data(name='input', shape=[5, 20]) + label = paddle.fluid.data(name='label', shape=[5, 20]) + + pred_loss = paddle.nn.functional.kl_div(input, label) + + +class TestKLDivLossTypePromotion(unittest.TestCase): + def test_kl_div_promotion(self): + with paddle.fluid.dygraph.guard(paddle.NPUPlace(0)): + x1 = paddle.rand([5, 20], dtype='float32') + target1 = paddle.rand([5, 20], dtype='float32') + + kldiv_criterion = paddle.nn.KLDivLoss() + pred_loss1 = kldiv_criterion(x1, target1) + + x2 = paddle.rand([5, 20], dtype='float32') + target2 = paddle.rand([5, 20], dtype='float32') + pred_loss2 = paddle.nn.functional.kl_div(x2, target2) + + +if __name__ == "__main__": + unittest.main()