未验证 提交 0a06140f 编写于 作者: L lijin23 提交者: GitHub

[XPU][PHI Kernels] bind bitwise_add kernel & add int32/int64 support to...

[XPU][PHI Kernels] bind bitwise_add kernel & add int32/int64 support to scatter_nd_add kernel for xpu (#54066)

* bind new kernels to xpu

* refine code

* fix bugs in unittest
上级 e419f434
......@@ -81,6 +81,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"bitwise_not", XPUKernelSet({phi::DataType::BOOL})},
{"bitwise_and", XPUKernelSet({phi::DataType::BOOL})},
{"broadcast", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_allgather",
XPUKernelSet({phi::DataType::FLOAT16,
......@@ -644,7 +645,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT32})},
{"scatter_grad",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"scatter_nd_add", XPUKernelSet({phi::DataType::FLOAT32})},
{"scatter_nd_add",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"sampling_id",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
{"set_value",
......
......@@ -16,6 +16,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/logical_kernel.h"
namespace phi {
......@@ -29,8 +30,21 @@ void BitwiseNotKernel(const Context& ctx,
reinterpret_cast<const XPUDataType*>(x.data<T>()),
reinterpret_cast<XPUDataType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise not");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "logical_not");
}
template <typename T, typename Context>
void BitwiseAndKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
// XPU api do not support bitwise operation now.
// However, because biwise and logical operation is identical for bool type,
// we can implement bitwise_and_bool kernel by calling their logical
// counterpart. Need to be changed when adding support to other types.
LogicalAndKernel<T, Context>(ctx, x, y, out);
}
} // namespace phi
PD_REGISTER_KERNEL(bitwise_not, XPU, ALL_LAYOUT, phi::BitwiseNotKernel, bool) {}
PD_REGISTER_KERNEL(bitwise_and, XPU, ALL_LAYOUT, phi::BitwiseAndKernel, bool) {}
......@@ -38,13 +38,12 @@ void ScatterNdAddKernel(const Context &ctx,
static_cast<int>(index.dims().size() == 0 ? 1 : index.dims()[0]);
for (int i = 0; i < loop_time; i++) {
// xpu::add only support float or float16 template typename
// now, register this op only with float type
r = xpu::add<T>(ctx.x_context(),
updates_ptr + out->numel() * i,
out_ptr,
out_ptr,
out->numel());
r = xpu::broadcast_add<T>(ctx.x_context(),
updates_ptr + out->numel() * i,
out_ptr,
out_ptr,
{out->numel()},
{out->numel()});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
return;
......@@ -100,5 +99,10 @@ void ScatterNdAddKernel(const Context &ctx,
}
} // namespace phi
PD_REGISTER_KERNEL(
scatter_nd_add, XPU, ALL_LAYOUT, phi::ScatterNdAddKernel, float) {}
PD_REGISTER_KERNEL(scatter_nd_add,
XPU,
ALL_LAYOUT,
phi::ScatterNdAddKernel,
float,
int64_t,
int) {}
......@@ -39,12 +39,18 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper):
class XPUTestBitwiseAndBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.dtype = self.in_type
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'bitwise_and'
# special range for bool dtype
if self.dtype == np.bool_:
self.low = 0
self.high = 2
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype
)
......@@ -61,7 +67,6 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper):
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
......@@ -75,7 +80,6 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper):
class XPUTestBitwiseAndCase1(XPUTestBitwiseAndBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
......@@ -83,7 +87,6 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper):
class XPUTestBitwiseAndCase2(XPUTestBitwiseAndBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
self.low = -100
......@@ -91,7 +94,6 @@ class XPUTestBitwiseAnd(XPUOpTestWrapper):
class XPUTestBitwiseAndCase3(XPUTestBitwiseAndBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = 0
......@@ -111,12 +113,18 @@ class XPUTestBitwiseOr(XPUOpTestWrapper):
class XPUTestBitwiseOrBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.dtype = self.in_type
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'bitwise_or'
# special range for bool dtype
if self.dtype == np.bool_:
self.low = 0
self.high = 2
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype
)
......@@ -133,7 +141,6 @@ class XPUTestBitwiseOr(XPUOpTestWrapper):
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
......@@ -147,7 +154,6 @@ class XPUTestBitwiseOr(XPUOpTestWrapper):
class XPUTestBitwiseOrCase1(XPUTestBitwiseOrBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
......@@ -155,7 +161,6 @@ class XPUTestBitwiseOr(XPUOpTestWrapper):
class XPUTestBitwiseOrCase2(XPUTestBitwiseOrBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
self.low = -100
......@@ -163,7 +168,6 @@ class XPUTestBitwiseOr(XPUOpTestWrapper):
class XPUTestBitwiseOrCase3(XPUTestBitwiseOrBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = 0
......@@ -183,11 +187,16 @@ class XPUTestBitwiseXor(XPUOpTestWrapper):
class XPUTestBitwiseXorBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.dtype = self.in_type
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'bitwise_xor'
# special case for bool dtype
if self.dtype == np.bool_:
self.low = 0
self.high = 2
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype
......@@ -205,7 +214,6 @@ class XPUTestBitwiseXor(XPUOpTestWrapper):
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
......@@ -219,7 +227,6 @@ class XPUTestBitwiseXor(XPUOpTestWrapper):
class XPUTestBitwiseXorCase1(XPUTestBitwiseXorBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
......@@ -227,7 +234,6 @@ class XPUTestBitwiseXor(XPUOpTestWrapper):
class XPUTestBitwiseXorCase2(XPUTestBitwiseXorBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
self.low = -100
......@@ -235,7 +241,6 @@ class XPUTestBitwiseXor(XPUOpTestWrapper):
class XPUTestBitwiseXorCase3(XPUTestBitwiseXorBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册