From 73d706ce25e2d5063e7c82c5ae0eae4709cf9a32 Mon Sep 17 00:00:00 2001 From: RuohengMa <120699764+RuohengMa@users.noreply.github.com> Date: Tue, 23 May 2023 16:49:49 +0800 Subject: [PATCH] [PHI] bind nll_loss xpu kernel (#54043) --- cmake/external/xpu.cmake | 2 +- paddle/phi/backends/xpu/xpu2_op_list.cc | 2 + .../phi/kernels/xpu/nll_loss_grad_kernel.cc | 95 ++++++ paddle/phi/kernels/xpu/nll_loss_kernel.cc | 93 ++++++ test/xpu/test_nll_loss_op_xpu.py | 288 ++++++++++++++++++ 5 files changed, 479 insertions(+), 1 deletion(-) create mode 100644 paddle/phi/kernels/xpu/nll_loss_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/nll_loss_kernel.cc create mode 100644 test/xpu/test_nll_loss_op_xpu.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 61188ae383a..1ba00fe42c6 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so") set(XPU_RT_LIB_NAME "libxpurt.so") set(XPU_XFT_LIB_NAME "libxft.so") -set(XPU_BASE_DATE "20230519") +set(XPU_BASE_DATE "20230523") set(XPU_XCCL_BASE_VERSION "1.0.49.2") set(XPU_XFT_BASE_VERSION "latest") diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index a8bf526cf87..5b7c847d76d 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -525,6 +525,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::INT64})}, {"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"nll_loss", XPUKernelSet({phi::DataType::FLOAT32})}, + {"nll_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"not_equal", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/kernels/xpu/nll_loss_grad_kernel.cc b/paddle/phi/kernels/xpu/nll_loss_grad_kernel.cc new file mode 100644 index 00000000000..1dbe679e674 --- /dev/null +++ b/paddle/phi/kernels/xpu/nll_loss_grad_kernel.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nll_loss_grad_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void NllLossGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& label, + const paddle::optional& weight, + const DenseTensor& total_weight, + const DenseTensor& d_out, + int64_t ignore_index, + const std::string& reduction, + DenseTensor* d_x) { + using XPUType = typename XPUTypeTrait::Type; + const auto& label_type = label.dtype(); + bool label_type_match = + label_type == phi::DataType::INT32 || label_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(label_type_match, + true, + phi::errors::InvalidArgument( + "Input(Label) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + label_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + auto d_out_data = d_out.data(); + auto d_x_data = dev_ctx.template Alloc(d_x); + + auto d_x_dims = d_x->dims(); + std::vector d_x_shape = phi::vectorize(d_x_dims); + + auto weight_data = + weight.get_ptr() ? weight.get_ptr()->data() : nullptr; + + int64_t reduction_id = 0; + if (reduction == "none") { + reduction_id = 0; + } else if (reduction == "mean") { + reduction_id = 1; + } else if (reduction == "sum") { + reduction_id = 2; + } + + auto total_weight_data = total_weight.data(); + + int r; + if (label_type == phi::DataType::INT32) { + const int* label_data = label.data(); + r = xpu::nll_loss_grad(dev_ctx.x_context(), + d_out_data, + d_x_data, + d_x_shape, + label_data, + weight_data, + reduction_id, + ignore_index, + total_weight_data); + } else if (label_type == phi::DataType::INT64) { + const int64_t* label_data = label.data(); + r = xpu::nll_loss_grad(dev_ctx.x_context(), + d_out_data, + d_x_data, + d_x_shape, + label_data, + weight_data, + reduction_id, + ignore_index, + total_weight_data); + } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "nll_loss_grad"); +} + +} // namespace phi + +// TODO(xiongkun): add the non-raw kernel register here. +PD_REGISTER_KERNEL( + nll_loss_grad, XPU, ALL_LAYOUT, phi::NllLossGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/nll_loss_kernel.cc b/paddle/phi/kernels/xpu/nll_loss_kernel.cc new file mode 100644 index 00000000000..2d9bf5baf57 --- /dev/null +++ b/paddle/phi/kernels/xpu/nll_loss_kernel.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nll_loss_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void NllLossRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& label, + const paddle::optional& weight, + int64_t ignore_index, + const std::string& reduction, + DenseTensor* out, + DenseTensor* total_weight) { + using XPUType = typename XPUTypeTrait::Type; + const auto& label_type = label.dtype(); + bool label_type_match = + label_type == phi::DataType::INT32 || label_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(label_type_match, + true, + phi::errors::InvalidArgument( + "Input(Label) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + label_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + auto x_data = x.data(); + auto out_data = dev_ctx.template Alloc(out); + + auto weight_data = + weight.get_ptr() ? weight.get_ptr()->data() : nullptr; + + auto total_weight_data = dev_ctx.template Alloc(total_weight); + + auto x_dims = x.dims(); + std::vector x_shape = phi::vectorize(x_dims); + + int64_t reduction_id = 0; + if (reduction == "none") { + reduction_id = 0; + } else if (reduction == "mean") { + reduction_id = 1; + } else if (reduction == "sum") { + reduction_id = 2; + } + + int r; + if (label_type == phi::DataType::INT32) { + const int* label_data = label.data(); + r = xpu::nll_loss(dev_ctx.x_context(), + x_data, + out_data, + total_weight_data, + x_shape, + label_data, + weight_data, + reduction_id, + ignore_index); + } else if (label_type == phi::DataType::INT64) { + const int64_t* label_data = label.data(); + r = xpu::nll_loss(dev_ctx.x_context(), + x_data, + out_data, + total_weight_data, + x_shape, + label_data, + weight_data, + reduction_id, + ignore_index); + } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "nll_loss"); +} + +} // namespace phi + +// TODO(xiongkun): add the non-raw kernel register here. +PD_REGISTER_KERNEL(nll_loss, XPU, ALL_LAYOUT, phi::NllLossRawKernel, float) {} diff --git a/test/xpu/test_nll_loss_op_xpu.py b/test/xpu/test_nll_loss_op_xpu.py new file mode 100644 index 00000000000..71ce3829334 --- /dev/null +++ b/test/xpu/test_nll_loss_op_xpu.py @@ -0,0 +1,288 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) +from op_test_xpu import XPUOpTest + +import paddle + +paddle.enable_static() + + +def nll_loss_1d( + logs, dtype, targets, weight=None, reduction='mean', ignore_index=-100 +): + input_shape = logs.shape + N = input_shape[0] + C = input_shape[1] + out = np.zeros_like(targets).astype(dtype) + total_weight = 0 + for i in range(N): + cur_target = targets[i] + if cur_target == ignore_index: + out[i] = 0 + continue + cur_weight = weight[cur_target] if weight is not None else 1 + total_weight += cur_weight + out[i] = -logs[i][cur_target] * cur_weight + if reduction == 'sum': + out = np.sum(out) + total_weight = np.array([total_weight]).astype(dtype) + return {'Out': out, 'Total_weight': total_weight} + elif reduction == 'mean': + out = np.sum(out) + if total_weight != 0: + out /= total_weight + total_weight = np.array([total_weight]).astype(dtype) + return {'Out': out, 'Total_weight': total_weight} + elif reduction == 'none': + total_weight = np.array([0]).astype(dtype) + return {'Out': out, 'Total_weight': total_weight} + + +def nll_loss_2d( + logs, dtype, targets, weight=None, reduction='mean', ignore_index=-100 +): + input_shape = logs.shape + N = input_shape[0] + H = input_shape[2] + W = input_shape[3] + out = np.zeros_like(targets).astype(dtype) + total_weight = 0 + for i in range(N): + for h in range(H): + for w in range(W): + cur_target = targets[i][h][w] + if cur_target == ignore_index: + out[i][h][w] = 0 + continue + cur_weight = weight[cur_target] if weight is not None else 1 + total_weight += cur_weight + out[i][h][w] = -logs[i][cur_target][h][w] * cur_weight + if reduction == 'sum': + out = np.sum(out) + total_weight = np.array([total_weight]).astype(dtype) + return {'Out': out, 'Total_weight': total_weight} + elif reduction == 'mean': + out = np.sum(out) + if total_weight != 0: + out /= total_weight + total_weight = np.array([total_weight]).astype(dtype) + return {'Out': out, 'Total_weight': total_weight} + elif reduction == 'none': + total_weight = np.array([0]).astype(dtype) + return {'Out': out, 'Total_weight': total_weight} + + +class XPUTestNLLLossOP(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'nll_loss' + self.use_dynamic_create_class = False + + class TestNLLLossOpBase1D(XPUOpTest): + op_type = 'nll_loss' + + def setUp(self): + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + self.set_attrs() + self.set_inputs() + self.inputs = { + 'X': self.x, + 'Label': self.label, + } + if self.weight is not None: + self.inputs['Weight'] = self.weight + self.outputs = nll_loss_1d( + self.x, + self.dtype, + self.label, + self.weight, + self.attrs['reduction'], + ) + + def set_attrs(self): + self.attrs = {'reduction': 'none'} + + def set_inputs(self): + self.class_num = 3 + x_shape = [5, self.class_num] + label_shape = [5] + self.x = np.random.random(x_shape).astype(self.dtype) + self.label = np.random.randint( + low=0, high=self.class_num, size=label_shape + ).astype(np.int64) + self.weight = np.random.random(self.class_num).astype(self.dtype) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + class TestNLLLossOpWithWeightMean1D(TestNLLLossOpBase1D): + def set_attrs(self): + self.attrs = {'reduction': 'mean'} + + class TestNLLLossOpWithWeightSum1D(TestNLLLossOpBase1D): + def set_attrs(self): + self.attrs = {'reduction': 'sum'} + + class TestNLLLossOpWithoutWeightNone1D(TestNLLLossOpBase1D): + def set_inputs(self): + self.class_num = 3 + x_shape = [5, self.class_num] + label_shape = [5] + self.x = np.random.random(x_shape).astype(self.dtype) + self.label = np.random.randint( + low=0, high=self.class_num, size=label_shape + ).astype(np.int64) + self.weight = None + + def set_attrs(self): + self.attrs = {'reduction': 'none'} + + class TestNLLLossOpWithoutWeightMean1D(TestNLLLossOpBase1D): + def set_inputs(self): + self.class_num = 3 + x_shape = [5, self.class_num] + label_shape = [5] + self.x = np.random.random(x_shape).astype(self.dtype) + self.label = np.random.randint( + low=0, high=self.class_num, size=label_shape + ).astype(np.int64) + self.weight = None + + def set_attrs(self): + self.attrs = {'reduction': 'mean'} + + class TestNLLLossOpWithoutWeightSum1D(TestNLLLossOpBase1D): + def set_inputs(self): + self.class_num = 3 + x_shape = [5, self.class_num] + label_shape = [5] + self.x = np.random.random(x_shape).astype(self.dtype) + self.label = np.random.randint( + low=0, high=self.class_num, size=label_shape + ).astype(np.int64) + self.weight = None + + def set_attrs(self): + self.attrs = {'reduction': 'sum'} + + class TestNLLLossOpBase2D(XPUOpTest): + op_type = 'nll_loss' + + def setUp(self): + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + self.set_attrs() + self.set_inputs() + self.inputs = {'X': self.x, 'Label': self.label} + if self.weight is not None: + self.inputs['Weight'] = self.weight + self.outputs = nll_loss_2d( + self.x, + self.dtype, + self.label, + self.weight, + self.attrs['reduction'], + ) + + def set_attrs(self): + self.attrs = {'reduction': 'none'} + + def set_inputs(self): + self.class_num = 3 + x_shape = [5, self.class_num, 7, 11] + label_shape = [5, 7, 11] + self.x = np.random.random(x_shape).astype(self.dtype) + self.label = np.random.randint( + low=0, high=self.class_num, size=label_shape + ).astype(np.int64) + self.weight = np.random.random(self.class_num).astype(self.dtype) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + class TestNLLLossOpWithWeightMean2D(TestNLLLossOpBase2D): + def set_attrs(self): + self.attrs = {'reduction': 'mean'} + + class TestNLLLossOpWithWeightSum2D(TestNLLLossOpBase2D): + def set_attrs(self): + self.attrs = {'reduction': 'sum'} + + class TestNLLLossOpWithoutWeightNone2D(TestNLLLossOpBase2D): + def set_inputs(self): + self.dtype = self.in_type + self.class_num = 3 + x_shape = [5, self.class_num, 7, 11] + label_shape = [5, 7, 11] + self.x = np.random.random(x_shape).astype(self.dtype) + self.label = np.random.randint( + low=0, high=self.class_num, size=label_shape + ).astype(np.int64) + self.weight = None + + def set_attrs(self): + self.attrs = {'reduction': 'none'} + + class TestNLLLossOpWithoutWeightMean2D(TestNLLLossOpBase2D): + def set_inputs(self): + self.dtype = self.in_type + self.class_num = 3 + x_shape = [5, self.class_num, 7, 11] + label_shape = [5, 7, 11] + self.x = np.random.random(x_shape).astype(self.dtype) + self.label = np.random.randint( + low=0, high=self.class_num, size=label_shape + ).astype(np.int64) + self.weight = None + + def set_attrs(self): + self.attrs = {'reduction': 'mean'} + + class TestNLLLossOpWithoutWeightSum2D(TestNLLLossOpBase2D): + def set_inputs(self): + self.dtype = self.in_type + self.class_num = 3 + x_shape = [5, self.class_num, 7, 11] + label_shape = [5, 7, 11] + self.x = np.random.random(x_shape).astype(self.dtype) + self.label = np.random.randint( + low=0, high=self.class_num, size=label_shape + ).astype(np.int64) + self.weight = None + + def set_attrs(self): + self.attrs = {'reduction': 'sum'} + + +support_types = get_xpu_op_support_types('nll_loss') +for stype in support_types: + create_test_class(globals(), XPUTestNLLLossOP, stype) + +if __name__ == '__main__': + unittest.main() -- GitLab