未验证 提交 4212d9ad 编写于 作者: H houj04 提交者: GitHub

[XPU] update numel/size op registration (#53094)

* [XPU] add numel op

* [XPU] update numel/size op registration
上级 d3f6e2d5
......@@ -508,13 +508,6 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"numel",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT16,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"one_hot_v2",
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
......@@ -628,6 +621,13 @@ XPUOpMap& get_kl2_ops() {
{"silu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"silu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"size",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT16,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"sigmoid_cross_entropy_with_logits_grad",
XPUKernelSet({phi::DataType::FLOAT32})},
{"sigmoid_cross_entropy_with_logits",
......
......@@ -27,12 +27,12 @@ import paddle
paddle.enable_static()
class XPUTestNumelOP(XPUOpTestWrapper):
class XPUTestSizeOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'size'
self.use_dynamic_create_class = False
class TestXPUNumelOp(XPUOpTest):
class TestXPUSizeOp(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_dtype()
......@@ -54,30 +54,30 @@ class XPUTestNumelOP(XPUOpTestWrapper):
def test_check_output(self):
self.check_output_with_place(self.place)
class TestNumel1(TestXPUNumelOp):
class TestSize1(TestXPUSizeOp):
def initTestCase(self):
self.shape = (11, 66)
class TestNumel2(TestXPUNumelOp):
class TestSize2(TestXPUSizeOp):
def initTestCase(self):
self.shape = (0,)
class TestNumel3(TestXPUNumelOp):
class TestSize3(TestXPUSizeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
class TestNumel4(TestXPUNumelOp):
class TestSize4(TestXPUSizeOp):
def initTestCase(self):
self.shape = (12, 24)
class TestNumel5(TestXPUNumelOp):
class TestSize5(TestXPUSizeOp):
def initTestCase(self):
self.shape = (1, 64, 16)
support_types = get_xpu_op_support_types('numel')
support_types = get_xpu_op_support_types('size')
for stype in support_types:
create_test_class(globals(), XPUTestNumelOP, stype)
create_test_class(globals(), XPUTestSizeOP, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册