diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 4089119097dc183a9754643f9def9596cd4e1203..f9e66631c1d221cf05e3a3707b8565d74f6733df 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -306,6 +306,9 @@ XPUOpMap& get_kl2_ops() { {"huber_loss_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"huber_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"kldiv_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"kldiv_loss_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"iou_similarity", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"index_select", diff --git a/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc b/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5d2c750a4dfa331977f62fec1f2fdb4f985e6a59 --- /dev/null +++ b/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/softmax_kernel.h" + +namespace phi { + +template +void KLDivLossGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& label, + const DenseTensor& d_out, + const std::string& reduction, + DenseTensor* d_x) { + using XPUType = typename XPUTypeTrait::Type; + dev_ctx.template Alloc(d_x); + if (d_x->numel() == 0) { + return; + } + + int r = XPU_SUCCESS; + r = xpu::kldiv_loss_grad(dev_ctx.x_context(), + reinterpret_cast(label.data()), + reinterpret_cast(d_out.data()), + reinterpret_cast(d_x->data()), + d_x->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss_grad"); + if ("none" != reduction) { + PADDLE_THROW(phi::errors::Unavailable( + "Not supported reduction [%s] in kldiv_loss_grad", reduction)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + kldiv_loss_grad, XPU, ALL_LAYOUT, phi::KLDivLossGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc b/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ef917f008ab9e491cbc580b7b808122c612e938 --- /dev/null +++ b/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/softmax_kernel.h" + +namespace phi { + +template +void KLDivLossKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& label, + const std::string& reduction, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + dev_ctx.template Alloc(out); + if (out->numel() == 0) { + return; + } + + int r = XPU_SUCCESS; + r = xpu::kldiv_loss(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(label.data()), + reinterpret_cast(out->data()), + out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "kldiv_loss"); + if ("none" != reduction) { + PADDLE_THROW(phi::errors::Unavailable( + "Not supported reduction [%s] in kldiv_loss", reduction)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(kldiv_loss, XPU, ALL_LAYOUT, phi::KLDivLossKernel, float) {} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_kldiv_loss_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_kldiv_loss_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..86a4327d6ce5950f215e74134c8229c26d09e7c6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_kldiv_loss_op_xpu.py @@ -0,0 +1,147 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.append("..") +import paddle +import unittest +import numpy as np +from paddle.nn.functional import kl_div +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + create_test_class, + get_xpu_op_support_types, + XPUOpTestWrapper, +) + +paddle.enable_static() + + +def kldiv_loss(x, target, reduction): + output = target * (np.log(target) - x) + loss = np.where(target >= 0, output, np.zeros_like(x)) + + if reduction == "batchmean": + if len(x.shape) > 0: + return loss.sum() / x.shape[0] + else: + return loss.sum() + if reduction == "mean": + return loss.mean() + if reduction == "sum": + return loss.sum() + + return loss + + +class XPUTestKLDivLossOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'kldiv_loss' + self.use_dynamic_create_class = False + + class TestKLDivLossOp(XPUOpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'kldiv_loss' + self.dtype = np.float32 + self.__class__.use_xpu = True + self.python_api = kl_div + x = np.random.uniform(-10, 10, self.x_shape).astype('float32') + target = np.random.uniform(-10, 10, self.x_shape).astype('float32') + + self.attrs = {"reduction": self.reduction} + + self.inputs = { + 'X': x, + 'Target': target, + } + loss = kldiv_loss(x, target, self.reduction) + self.outputs = {'Loss': loss.astype('float32')} + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.XPUPlace(0), + ['X'], + 'Loss', + no_grad_set=set(["Target"]), + check_eager=True, + ) + + def initTestCase(self): + self.x_shape = (4, 5, 5) + self.reduction = 'none' + + class TestKLDivLossOp2(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (3, 2, 7, 7) + self.reduction = 'none' + + class TestKLDivLossOp3(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (2, 3, 5, 7, 9) + self.reduction = 'none' + + class TestKLDivLossOp4(TestKLDivLossOp): + def initTestCase(self): + self.x_shape = (5, 20) + self.reduction = 'none' + + class TestKLDivLossDygraph(unittest.TestCase): + def run_kl_loss(self, reduction, shape=(5, 20)): + x = np.random.uniform(-10, 10, shape).astype('float32') + target = np.random.uniform(-10, 10, shape).astype('float32') + gt_loss = kldiv_loss(x, target, reduction) + + with paddle.fluid.dygraph.guard(): + kldiv_criterion = paddle.nn.KLDivLoss(reduction) + pred_loss = kldiv_criterion( + paddle.to_tensor(x), paddle.to_tensor(target) + ) + np.testing.assert_allclose( + pred_loss.numpy(), gt_loss, rtol=1e-05 + ) + + def test_kl_loss_none(self): + self.run_kl_loss('none') + + def test_kl_loss_static_api(self): + input = paddle.fluid.data(name='input', shape=[5, 20]) + label = paddle.fluid.data(name='label', shape=[5, 20]) + + paddle.nn.functional.kl_div(input, label) + + class TestKLDivLossTypePromotion(unittest.TestCase): + def test_kl_div_promotion(self): + + with paddle.fluid.dygraph.guard(): + x1 = paddle.rand([5, 20], dtype='float32') + target1 = paddle.rand([5, 20], dtype='float32') + + kldiv_criterion = paddle.nn.KLDivLoss() + pred_loss1 = kldiv_criterion(x1, target1) + + x2 = paddle.rand([5, 20], dtype='float32') + target2 = paddle.rand([5, 20], dtype='float32') + pred_loss2 = paddle.nn.functional.kl_div(x2, target2) + + +support_types = get_xpu_op_support_types('kldiv_loss') +for stype in support_types: + create_test_class(globals(), XPUTestKLDivLossOp, stype) + +if __name__ == "__main__": + unittest.main()