From 0a06140f21d771a0941de9ef423923f923f74cfc Mon Sep 17 00:00:00 2001 From: lijin23 <41257772+lj970926@users.noreply.github.com> Date: Wed, 24 May 2023 12:43:29 +0800 Subject: [PATCH] [XPU][PHI Kernels] bind bitwise_add kernel & add int32/int64 support to scatter_nd_add kernel for xpu (#54066) * bind new kernels to xpu * refine code * fix bugs in unittest --- paddle/phi/backends/xpu/xpu2_op_list.cc | 6 +++- paddle/phi/kernels/xpu/bitwise.cc | 16 +++++++++- .../phi/kernels/xpu/scatter_nd_add_kernel.cc | 22 ++++++++------ test/xpu/test_bitwise_op_xpu.py | 29 +++++++++++-------- 4 files changed, 50 insertions(+), 23 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 5b7c847d76d..155b70260ea 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -81,6 +81,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"bitwise_not", XPUKernelSet({phi::DataType::BOOL})}, + {"bitwise_and", XPUKernelSet({phi::DataType::BOOL})}, {"broadcast", XPUKernelSet({phi::DataType::FLOAT32})}, {"c_allgather", XPUKernelSet({phi::DataType::FLOAT16, @@ -644,7 +645,10 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT32})}, {"scatter_grad", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, - {"scatter_nd_add", XPUKernelSet({phi::DataType::FLOAT32})}, + {"scatter_nd_add", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"sampling_id", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, {"set_value", diff --git a/paddle/phi/kernels/xpu/bitwise.cc b/paddle/phi/kernels/xpu/bitwise.cc index 019acf52f82..80cdd87c5a1 100644 --- a/paddle/phi/kernels/xpu/bitwise.cc +++ b/paddle/phi/kernels/xpu/bitwise.cc @@ -16,6 +16,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/logical_kernel.h" namespace phi { @@ -29,8 +30,21 @@ void BitwiseNotKernel(const Context& ctx, reinterpret_cast(x.data()), reinterpret_cast(out->data()), x.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise not"); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "logical_not"); +} + +template +void BitwiseAndKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + // XPU api do not support bitwise operation now. + // However, because biwise and logical operation is identical for bool type, + // we can implement bitwise_and_bool kernel by calling their logical + // counterpart. Need to be changed when adding support to other types. + LogicalAndKernel(ctx, x, y, out); } } // namespace phi PD_REGISTER_KERNEL(bitwise_not, XPU, ALL_LAYOUT, phi::BitwiseNotKernel, bool) {} +PD_REGISTER_KERNEL(bitwise_and, XPU, ALL_LAYOUT, phi::BitwiseAndKernel, bool) {} diff --git a/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc b/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc index 8df0eb5a4d2..c760a2d0166 100644 --- a/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc @@ -38,13 +38,12 @@ void ScatterNdAddKernel(const Context &ctx, static_cast(index.dims().size() == 0 ? 1 : index.dims()[0]); for (int i = 0; i < loop_time; i++) { - // xpu::add only support float or float16 template typename - // now, register this op only with float type - r = xpu::add(ctx.x_context(), - updates_ptr + out->numel() * i, - out_ptr, - out_ptr, - out->numel()); + r = xpu::broadcast_add(ctx.x_context(), + updates_ptr + out->numel() * i, + out_ptr, + out_ptr, + {out->numel()}, + {out->numel()}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); } return; @@ -100,5 +99,10 @@ void ScatterNdAddKernel(const Context &ctx, } } // namespace phi -PD_REGISTER_KERNEL( - scatter_nd_add, XPU, ALL_LAYOUT, phi::ScatterNdAddKernel, float) {} +PD_REGISTER_KERNEL(scatter_nd_add, + XPU, + ALL_LAYOUT, + phi::ScatterNdAddKernel, + float, + int64_t, + int) {} diff --git a/test/xpu/test_bitwise_op_xpu.py b/test/xpu/test_bitwise_op_xpu.py index 1d21108bf8c..ff493b79a7f 100644 --- a/test/xpu/test_bitwise_op_xpu.py +++ b/test/xpu/test_bitwise_op_xpu.py @@ -39,12 +39,18 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper): class XPUTestBitwiseAndBase(XPUOpTest): def setUp(self): self.place = paddle.XPUPlace(0) + self.dtype = self.in_type self.init_case() self.set_case() def set_case(self): self.op_type = 'bitwise_and' + # special range for bool dtype + if self.dtype == np.bool_: + self.low = 0 + self.high = 2 + x = np.random.randint( self.low, self.high, self.x_shape, dtype=self.dtype ) @@ -61,7 +67,6 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper): self.outputs = {'Out': out} def init_case(self): - self.dtype = np.int32 self.x_shape = [2, 3, 4, 5] self.y_shape = [2, 3, 4, 5] self.low = -100 @@ -75,7 +80,6 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper): class XPUTestBitwiseAndCase1(XPUTestBitwiseAndBase): def init_case(self): - self.dtype = np.int32 self.x_shape = [4, 5] self.y_shape = [2, 3, 4, 5] self.low = -100 @@ -83,7 +87,6 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper): class XPUTestBitwiseAndCase2(XPUTestBitwiseAndBase): def init_case(self): - self.dtype = np.int32 self.x_shape = [2, 3, 4, 5] self.y_shape = [4, 1] self.low = -100 @@ -91,7 +94,6 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper): class XPUTestBitwiseAndCase3(XPUTestBitwiseAndBase): def init_case(self): - self.dtype = np.int32 self.x_shape = [2, 3, 4, 5] self.y_shape = [2, 3, 4, 5] self.low = 0 @@ -111,12 +113,18 @@ class XPUTestBitwiseOr(XPUOpTestWrapper): class XPUTestBitwiseOrBase(XPUOpTest): def setUp(self): self.place = paddle.XPUPlace(0) + self.dtype = self.in_type self.init_case() self.set_case() def set_case(self): self.op_type = 'bitwise_or' + # special range for bool dtype + if self.dtype == np.bool_: + self.low = 0 + self.high = 2 + x = np.random.randint( self.low, self.high, self.x_shape, dtype=self.dtype ) @@ -133,7 +141,6 @@ class XPUTestBitwiseOr(XPUOpTestWrapper): self.outputs = {'Out': out} def init_case(self): - self.dtype = np.int32 self.x_shape = [2, 3, 4, 5] self.y_shape = [2, 3, 4, 5] self.low = -100 @@ -147,7 +154,6 @@ class XPUTestBitwiseOr(XPUOpTestWrapper): class XPUTestBitwiseOrCase1(XPUTestBitwiseOrBase): def init_case(self): - self.dtype = np.int32 self.x_shape = [4, 5] self.y_shape = [2, 3, 4, 5] self.low = -100 @@ -155,7 +161,6 @@ class XPUTestBitwiseOr(XPUOpTestWrapper): class XPUTestBitwiseOrCase2(XPUTestBitwiseOrBase): def init_case(self): - self.dtype = np.int32 self.x_shape = [2, 3, 4, 5] self.y_shape = [4, 1] self.low = -100 @@ -163,7 +168,6 @@ class XPUTestBitwiseOr(XPUOpTestWrapper): class XPUTestBitwiseOrCase3(XPUTestBitwiseOrBase): def init_case(self): - self.dtype = np.int32 self.x_shape = [2, 3, 4, 5] self.y_shape = [2, 3, 4, 5] self.low = 0 @@ -183,11 +187,16 @@ class XPUTestBitwiseXor(XPUOpTestWrapper): class XPUTestBitwiseXorBase(XPUOpTest): def setUp(self): self.place = paddle.XPUPlace(0) + self.dtype = self.in_type self.init_case() self.set_case() def set_case(self): self.op_type = 'bitwise_xor' + # special case for bool dtype + if self.dtype == np.bool_: + self.low = 0 + self.high = 2 x = np.random.randint( self.low, self.high, self.x_shape, dtype=self.dtype @@ -205,7 +214,6 @@ class XPUTestBitwiseXor(XPUOpTestWrapper): self.outputs = {'Out': out} def init_case(self): - self.dtype = np.int32 self.x_shape = [2, 3, 4, 5] self.y_shape = [2, 3, 4, 5] self.low = -100 @@ -219,7 +227,6 @@ class XPUTestBitwiseXor(XPUOpTestWrapper): class XPUTestBitwiseXorCase1(XPUTestBitwiseXorBase): def init_case(self): - self.dtype = np.int32 self.x_shape = [4, 5] self.y_shape = [2, 3, 4, 5] self.low = -100 @@ -227,7 +234,6 @@ class XPUTestBitwiseXor(XPUOpTestWrapper): class XPUTestBitwiseXorCase2(XPUTestBitwiseXorBase): def init_case(self): - self.dtype = np.int32 self.x_shape = [2, 3, 4, 5] self.y_shape = [4, 1] self.low = -100 @@ -235,7 +241,6 @@ class XPUTestBitwiseXor(XPUOpTestWrapper): class XPUTestBitwiseXorCase3(XPUTestBitwiseXorBase): def init_case(self): - self.dtype = np.int32 self.x_shape = [2, 3, 4, 5] self.y_shape = [2, 3, 4, 5] self.low = 0 -- GitLab