未验证 提交 4212d9ad 编写于 作者: H houj04 提交者: GitHub

[XPU] update numel/size op registration (#53094)

* [XPU] add numel op

* [XPU] update numel/size op registration
上级 d3f6e2d5
...@@ -508,13 +508,6 @@ XPUOpMap& get_kl2_ops() { ...@@ -508,13 +508,6 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"numel",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT16,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, {"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"one_hot_v2", {"one_hot_v2",
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
...@@ -628,6 +621,13 @@ XPUOpMap& get_kl2_ops() { ...@@ -628,6 +621,13 @@ XPUOpMap& get_kl2_ops() {
{"silu_grad", {"silu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"silu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"silu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"size",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT16,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"sigmoid_cross_entropy_with_logits_grad", {"sigmoid_cross_entropy_with_logits_grad",
XPUKernelSet({phi::DataType::FLOAT32})}, XPUKernelSet({phi::DataType::FLOAT32})},
{"sigmoid_cross_entropy_with_logits", {"sigmoid_cross_entropy_with_logits",
......
...@@ -27,12 +27,12 @@ import paddle ...@@ -27,12 +27,12 @@ import paddle
paddle.enable_static() paddle.enable_static()
class XPUTestNumelOP(XPUOpTestWrapper): class XPUTestSizeOP(XPUOpTestWrapper):
def __init__(self): def __init__(self):
self.op_name = 'size' self.op_name = 'size'
self.use_dynamic_create_class = False self.use_dynamic_create_class = False
class TestXPUNumelOp(XPUOpTest): class TestXPUSizeOp(XPUOpTest):
def setUp(self): def setUp(self):
self.place = paddle.XPUPlace(0) self.place = paddle.XPUPlace(0)
self.init_dtype() self.init_dtype()
...@@ -54,30 +54,30 @@ class XPUTestNumelOP(XPUOpTestWrapper): ...@@ -54,30 +54,30 @@ class XPUTestNumelOP(XPUOpTestWrapper):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
class TestNumel1(TestXPUNumelOp): class TestSize1(TestXPUSizeOp):
def initTestCase(self): def initTestCase(self):
self.shape = (11, 66) self.shape = (11, 66)
class TestNumel2(TestXPUNumelOp): class TestSize2(TestXPUSizeOp):
def initTestCase(self): def initTestCase(self):
self.shape = (0,) self.shape = (0,)
class TestNumel3(TestXPUNumelOp): class TestSize3(TestXPUSizeOp):
def initTestCase(self): def initTestCase(self):
self.shape = (2, 3, 4, 5, 6) self.shape = (2, 3, 4, 5, 6)
class TestNumel4(TestXPUNumelOp): class TestSize4(TestXPUSizeOp):
def initTestCase(self): def initTestCase(self):
self.shape = (12, 24) self.shape = (12, 24)
class TestNumel5(TestXPUNumelOp): class TestSize5(TestXPUSizeOp):
def initTestCase(self): def initTestCase(self):
self.shape = (1, 64, 16) self.shape = (1, 64, 16)
support_types = get_xpu_op_support_types('numel') support_types = get_xpu_op_support_types('size')
for stype in support_types: for stype in support_types:
create_test_class(globals(), XPUTestNumelOP, stype) create_test_class(globals(), XPUTestSizeOP, stype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册