未验证 提交 d01109fc 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support input 0D Tensor for xpu compare kernel, test=kunlun (#47812)

上级 03f976d6
......@@ -36,6 +36,13 @@ void XPUCompareKernelImpl(const Context& dev_ctx,
auto x_shape = vectorize<int>(x.dims());
auto y_shape = vectorize<int>(y.dims());
if (x.dims().size() == 0) {
x_shape = std::vector<int>({1});
}
if (y.dims().size() == 0) {
y_shape = std::vector<int>({1});
}
auto x_data = reinterpret_cast<const XPUType*>(x.data<T>());
auto y_data = reinterpret_cast<const XPUType*>(y.data<T>());
auto* out_data = dev_ctx.template Alloc<bool>(out);
......
......@@ -283,54 +283,6 @@ def create_paddle_case(op_type, callback):
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_zero_dim_api_1(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.randint(-3, 3, shape=[], dtype='int32')
y = paddle.randint(-3, 3, shape=[], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
(
x_np,
y_np,
res,
) = exe.run(fetch_list=[x, y, out])
real_result = callback(x_np, y_np)
self.assertEqual((res == real_result).all(), True)
def test_zero_dim_api_2(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32')
y = paddle.randint(-3, 3, shape=[], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
(
x_np,
y_np,
res,
) = exe.run(fetch_list=[x, y, out])
real_result = callback(x_np, y_np)
self.assertEqual((res == real_result).all(), True)
def test_zero_dim_api_3(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.randint(-3, 3, shape=[], dtype='int32')
y = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
(
x_np,
y_np,
res,
) = exe.run(fetch_list=[x, y, out])
real_result = callback(x_np, y_np)
self.assertEqual((res == real_result).all(), True)
def test_broadcast_api_1(self):
paddle.enable_static()
with program_guard(Program(), Program()):
......@@ -383,6 +335,54 @@ def create_paddle_case(op_type, callback):
)
self.assertEqual((res == real_result).all(), True)
def test_zero_dim_api_1(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.randint(-3, 3, shape=[], dtype='int32')
y = paddle.randint(-3, 3, shape=[], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
(
x_np,
y_np,
res,
) = exe.run(fetch_list=[x, y, out])
real_result = callback(x_np, y_np)
self.assertEqual((res == real_result).all(), True)
def test_zero_dim_api_2(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32')
y = paddle.randint(-3, 3, shape=[], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
(
x_np,
y_np,
res,
) = exe.run(fetch_list=[x, y, out])
real_result = callback(x_np, y_np)
self.assertEqual((res == real_result).all(), True)
def test_zero_dim_api_3(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.randint(-3, 3, shape=[], dtype='int32')
y = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
(
x_np,
y_np,
res,
) = exe.run(fetch_list=[x, y, out])
real_result = callback(x_np, y_np)
self.assertEqual((res == real_result).all(), True)
def test_bool_api_4(self):
paddle.enable_static()
with program_guard(Program(), Program()):
......
......@@ -100,6 +100,27 @@ class XPUTestLessThanOP(XPUOpTestWrapper):
self.x_shape = [128, 128, 512]
self.y_shape = [128, 128, 512]
class LessThanOpTestCase_ZeroDim1(LessThanOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = []
class LessThanOpTestCase_ZeroDim2(LessThanOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = [11, 17]
class LessThanOpTestCase_ZeroDim3(LessThanOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = [11, 17]
self.y_shape = []
support_types = get_xpu_op_support_types('less_than')
for stype in support_types:
......@@ -152,6 +173,27 @@ class XPUTestLessEqualOp(XPUOpTestWrapper):
self.x_shape = [128, 128, 512]
self.y_shape = [128, 128, 512]
class LessEqualOpTestCase_ZeroDim1(LessEqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = []
class LessEqualOpTestCase_ZeroDim2(LessEqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = [11, 17]
class LessEqualOpTestCase_ZeroDim3(LessEqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = [11, 17]
self.y_shape = []
support_types = get_xpu_op_support_types('less_equal')
for stype in support_types:
......@@ -204,6 +246,27 @@ class XPUTestGreaterThanOp(XPUOpTestWrapper):
self.x_shape = [10, 10, 20, 20]
self.y_shape = [10, 10, 20, 20]
class GreaterThanOpTestCase_ZeroDim1(GreaterThanOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = []
class GreaterThanOpTestCase_ZeroDim2(GreaterThanOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = [11, 17]
class GreaterThanOpTestCase_ZeroDim3(GreaterThanOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = [11, 17]
self.y_shape = []
support_types = get_xpu_op_support_types('greater_than')
for stype in support_types:
......@@ -256,6 +319,27 @@ class XPUTestGreaterEqualOp(XPUOpTestWrapper):
self.x_shape = [10, 30, 15]
self.y_shape = [10, 30, 15]
class GreaterEqualOpTestCase_ZeroDim1(GreaterEqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = []
class GreaterEqualOpTestCase_ZeroDim2(GreaterEqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = [11, 17]
class GreaterEqualOpTestCase_ZeroDim3(GreaterEqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = [11, 17]
self.y_shape = []
support_types = get_xpu_op_support_types('greater_equal')
for stype in support_types:
......@@ -308,6 +392,27 @@ class XPUTestEqualOp(XPUOpTestWrapper):
self.x_shape = [11, 17]
self.y_shape = [1]
class EqualOpTestCase_ZeroDim1(EqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = []
class EqualOpTestCase_ZeroDim2(EqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = [11, 17]
class EqualOpTestCase_ZeroDim3(EqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = [11, 17]
self.y_shape = []
support_types = get_xpu_op_support_types('equal')
for stype in support_types:
......@@ -360,6 +465,27 @@ class XPUTestNotEqualOp(XPUOpTestWrapper):
self.x_shape = [512, 128]
self.y_shape = [512, 128]
class NotEqualOpTestCase_ZeroDim1(NotEqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = []
class NotEqualOpTestCase_ZeroDim2(NotEqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = []
self.y_shape = [11, 17]
class NotEqualOpTestCase_ZeroDim3(NotEqualOpTestCase1):
def set_data(self):
self.lbound = -100
self.hbound = 100
self.x_shape = [11, 17]
self.y_shape = []
support_types = get_xpu_op_support_types('not_equal')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册