From 9a1855ffd2291498e69ee439444c75caa4632cf3 Mon Sep 17 00:00:00 2001 From: Leo Guo <58431564+ZibinGuo@users.noreply.github.com> Date: Thu, 29 Sep 2022 14:05:48 +0800 Subject: [PATCH] Add index_select, index_select_grad, reduce_min kernel and their unittests for kunlun. Add registers of index_select, index_select_grad, reduce_min, sqrt, sqrt_grad to xpu2_op_list.test=kunlun. (#46557) --- .../fluid/platform/device/xpu/xpu2_op_list.h | 7 + paddle/phi/kernels/reduce_min_kernel.cc | 6 +- paddle/phi/kernels/xpu/index_select_kernel.cc | 79 +++++++++ paddle/phi/kernels/xpu/reduce_min_kernel.cc | 43 +++++ .../unittests/xpu/test_index_select_op_xpu.py | 165 ++++++++++++++++++ 5 files changed, 299 insertions(+), 1 deletion(-) create mode 100644 paddle/phi/kernels/xpu/index_select_kernel.cc create mode 100644 paddle/phi/kernels/xpu/reduce_min_kernel.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_index_select_op_xpu.py diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index ba03a3b53db..01fd563f8e3 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -301,6 +301,10 @@ XPUOpMap& get_kl2_ops() { {"huber_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"iou_similarity", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"index_select", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, {"instance_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"instance_norm_grad", @@ -422,6 +426,7 @@ XPUOpMap& get_kl2_ops() { {"reduce_mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_min", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_sum_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, @@ -510,6 +515,8 @@ XPUOpMap& get_kl2_ops() { {"split", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, + {"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"square_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"squeeze2_grad", diff --git a/paddle/phi/kernels/reduce_min_kernel.cc b/paddle/phi/kernels/reduce_min_kernel.cc index 981c7afa621..1df454e2976 100644 --- a/paddle/phi/kernels/reduce_min_kernel.cc +++ b/paddle/phi/kernels/reduce_min_kernel.cc @@ -42,7 +42,7 @@ PD_REGISTER_KERNEL( min, GPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {} #endif -#if defined(PADDLE_WITH_XPU_KP) +#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU) PD_REGISTER_KERNEL(min, KPS, ALL_LAYOUT, phi::MinKernel, float) {} #endif @@ -50,3 +50,7 @@ PD_REGISTER_KERNEL(min, KPS, ALL_LAYOUT, phi::MinKernel, float) {} PD_REGISTER_KERNEL( min, OneDNN, ALL_LAYOUT, phi::MinKernel, float, phi::dtype::bfloat16) {} #endif + +#if defined(PADDLE_WITH_XPU) +PD_REGISTER_KERNEL(min, XPU, ALL_LAYOUT, phi::MinKernel, float) {} +#endif diff --git a/paddle/phi/kernels/xpu/index_select_kernel.cc b/paddle/phi/kernels/xpu/index_select_kernel.cc new file mode 100644 index 00000000000..cbe6e99c43a --- /dev/null +++ b/paddle/phi/kernels/xpu/index_select_kernel.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_select_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { + +template +void IndexSelectKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + int dim, + DenseTensor* output) { + auto input_dim = x.dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + const auto& index_type = index.dtype(); + + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, + true, + phi::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + auto* in_data = x.data(); + std::vector in_shape = phi::vectorize(input_dim); + int index_len = output->dims()[dim]; + T* out_data = ctx.template Alloc(output); + int r = 0; + if (index_type == phi::DataType::INT64) { + const int64_t* index_data = index.data(); + r = xpu::gather(ctx.x_context(), + in_data, + index_data, + out_data, + in_shape, + index_len, + dim); + } else { + const int* index_data = index.data(); + r = xpu::gather(ctx.x_context(), + in_data, + index_data, + out_data, + in_shape, + index_len, + dim); + } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(index_select, + XPU, + ALL_LAYOUT, + phi::IndexSelectKernel, + float, + int, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/reduce_min_kernel.cc b/paddle/phi/kernels/xpu/reduce_min_kernel.cc new file mode 100644 index 00000000000..c54aca1830b --- /dev/null +++ b/paddle/phi/kernels/xpu/reduce_min_kernel.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_min_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/xpu/reduce.h" + +namespace phi { + +template +void MinRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { + int r = XPUReduce(dev_ctx, + x, + dims.GetData(), + keep_dim, + reduce_all, + out, + xpu::reduce_min); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_min"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(min_raw, XPU, ALL_LAYOUT, phi::MinRawKernel, float) {} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_index_select_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_index_select_op_xpu.py new file mode 100644 index 00000000000..766ebdd2567 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_index_select_op_xpu.py @@ -0,0 +1,165 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import sys +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +sys.path.append("..") + +import numpy as np + +import paddle +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestIndexSelect(XPUOpTestWrapper): + + def __init__(self): + self.op_name = 'index_select' + + class TestXPUIndexSelectOp(XPUOpTest): + + def setUp(self): + self.op_type = "index_select" + self.place = paddle.XPUPlace(0) + self.dtype = self.in_type + + self.init_dtype_type() + index_np = np.random.randint(low=0, + high=self.x_shape[self.dim], + size=self.index_size).astype( + self.index_type) + x_np = np.random.random(self.x_shape).astype(self.dtype) + self.inputs = {'X': x_np, 'Index': index_np} + self.attrs = {'dim': self.dim} + outer_loop = np.prod(self.x_shape[:self.dim]) + x_reshape = [outer_loop] + list(self.x_shape[self.dim:]) + x_np_reshape = np.reshape(x_np, tuple(x_reshape)) + out_list = [] + for i in range(outer_loop): + for j in range(self.index_size): + out_list.append(x_np_reshape[i, index_np[j]]) + self.out_shape = list(self.x_shape) + self.out_shape[self.dim] = self.index_size + self.out_shape = tuple(self.out_shape) + + out = np.reshape(out_list, self.out_shape) + self.outputs = {'Out': out} + + def init_dtype_type(self): + self.dim = 1 + self.index_type = np.int64 + self.x_shape = (100, 4, 5) + self.index_size = 100 + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + self.check_grad_with_place(self.place, ['X'], 'Out') + + class TestXPUIndexSelectOpCase2(TestXPUIndexSelectOp): + + def init_dtype_type(self): + self.index_type = np.int32 + self.dim = -2 + self.x_shape = (10, 10, 4, 10) + self.index_size = 10 + + +class TestIndexSelectAPI(unittest.TestCase): + + def input_data(self): + self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0]]) + self.data_index = np.array([0, 1, 1]).astype('int32') + + def test_index_select_api(self): + self.input_data() + + # case 1: + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[-1, 4]) + index = fluid.layers.data(name='index', + shape=[3], + dtype='int32', + append_batch_size=False) + z = paddle.index_select(x, index, axis=1) + exe = fluid.Executor(fluid.XPUPlace(0)) + res, = exe.run(feed={ + 'x': self.data_x, + 'index': self.data_index + }, + fetch_list=[z.name], + return_numpy=False) + expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0], + [9.0, 10.0, 10.0]]) + np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) + + # case 2: + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[-1, 4]) + index = fluid.layers.data(name='index', + shape=[3], + dtype='int32', + append_batch_size=False) + z = paddle.index_select(x, index) + exe = fluid.Executor(fluid.XPUPlace(0)) + res, = exe.run(feed={ + 'x': self.data_x, + 'index': self.data_index + }, + fetch_list=[z.name], + return_numpy=False) + expect_out = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], + [5.0, 6.0, 7.0, 8.0]]) + np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) + + def test_dygraph_api(self): + self.input_data() + # case 1: + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(self.data_x) + index = fluid.dygraph.to_variable(self.data_index) + z = paddle.index_select(x, index) + np_z = z.numpy() + expect_out = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], + [5.0, 6.0, 7.0, 8.0]]) + np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + + # case 2: + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(self.data_x) + index = fluid.dygraph.to_variable(self.data_index) + z = paddle.index_select(x, index, axis=1) + np_z = z.numpy() + expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0], + [9.0, 10.0, 10.0]]) + np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + + +support_types = get_xpu_op_support_types('index_select') +for stype in support_types: + create_test_class(globals(), XPUTestIndexSelect, stype) + +if __name__ == "__main__": + unittest.main() -- GitLab