From 4212d9ad28cfe2c16430c69bfcc7b89036393b22 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Thu, 20 Apr 2023 22:09:10 +0800 Subject: [PATCH] [XPU] update numel/size op registration (#53094) * [XPU] add numel op * [XPU] update numel/size op registration --- paddle/phi/backends/xpu/xpu2_op_list.cc | 14 +++++++------- ...est_numel_op_xpu.py => test_size_op_xpu.py} | 18 +++++++++--------- 2 files changed, 16 insertions(+), 16 deletions(-) rename test/xpu/{test_numel_op_xpu.py => test_size_op_xpu.py} (83%) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 7af60173275..43461e696da 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -508,13 +508,6 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32, phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, - {"numel", - XPUKernelSet({phi::DataType::INT64, - phi::DataType::INT32, - phi::DataType::INT16, - phi::DataType::BOOL, - phi::DataType::FLOAT16, - phi::DataType::FLOAT32})}, {"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, {"one_hot_v2", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, @@ -628,6 +621,13 @@ XPUOpMap& get_kl2_ops() { {"silu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"silu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"size", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT16, + phi::DataType::BOOL, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, {"sigmoid_cross_entropy_with_logits_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"sigmoid_cross_entropy_with_logits", diff --git a/test/xpu/test_numel_op_xpu.py b/test/xpu/test_size_op_xpu.py similarity index 83% rename from test/xpu/test_numel_op_xpu.py rename to test/xpu/test_size_op_xpu.py index e7776d25ce7..d7a0a5a1e54 100644 --- a/test/xpu/test_numel_op_xpu.py +++ b/test/xpu/test_size_op_xpu.py @@ -27,12 +27,12 @@ import paddle paddle.enable_static() -class XPUTestNumelOP(XPUOpTestWrapper): +class XPUTestSizeOP(XPUOpTestWrapper): def __init__(self): self.op_name = 'size' self.use_dynamic_create_class = False - class TestXPUNumelOp(XPUOpTest): + class TestXPUSizeOp(XPUOpTest): def setUp(self): self.place = paddle.XPUPlace(0) self.init_dtype() @@ -54,30 +54,30 @@ class XPUTestNumelOP(XPUOpTestWrapper): def test_check_output(self): self.check_output_with_place(self.place) - class TestNumel1(TestXPUNumelOp): + class TestSize1(TestXPUSizeOp): def initTestCase(self): self.shape = (11, 66) - class TestNumel2(TestXPUNumelOp): + class TestSize2(TestXPUSizeOp): def initTestCase(self): self.shape = (0,) - class TestNumel3(TestXPUNumelOp): + class TestSize3(TestXPUSizeOp): def initTestCase(self): self.shape = (2, 3, 4, 5, 6) - class TestNumel4(TestXPUNumelOp): + class TestSize4(TestXPUSizeOp): def initTestCase(self): self.shape = (12, 24) - class TestNumel5(TestXPUNumelOp): + class TestSize5(TestXPUSizeOp): def initTestCase(self): self.shape = (1, 64, 16) -support_types = get_xpu_op_support_types('numel') +support_types = get_xpu_op_support_types('size') for stype in support_types: - create_test_class(globals(), XPUTestNumelOP, stype) + create_test_class(globals(), XPUTestSizeOP, stype) if __name__ == '__main__': unittest.main() -- GitLab