未验证 提交 923f2458 编写于 作者: X xysheng-baidu 提交者: GitHub

[Zero_Dim][unittest] add repeat_interleave unittest for zero_dim (#49596)

上级 35fa30d0
...@@ -104,6 +104,7 @@ class TestRepeatInterleaveOp2(OpTest): ...@@ -104,6 +104,7 @@ class TestRepeatInterleaveOp2(OpTest):
class TestIndexSelectAPI(unittest.TestCase): class TestIndexSelectAPI(unittest.TestCase):
def input_data(self): def input_data(self):
self.data_zero_dim_x = np.array(0.5)
self.data_x = np.array( self.data_x = np.array(
[ [
[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0],
...@@ -170,6 +171,19 @@ class TestIndexSelectAPI(unittest.TestCase): ...@@ -170,6 +171,19 @@ class TestIndexSelectAPI(unittest.TestCase):
expect_out = np.repeat(self.data_x, repeats, axis=0) expect_out = np.repeat(self.data_x, repeats, axis=0)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
# case 3 zero_dim:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[])
z = paddle.repeat_interleave(x, repeats)
exe = fluid.Executor(fluid.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_zero_dim_x},
fetch_list=[z.name],
return_numpy=False,
)
expect_out = np.repeat(self.data_zero_dim_x, repeats)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
def test_dygraph_api(self): def test_dygraph_api(self):
self.input_data() self.input_data()
# case axis none # case axis none
...@@ -220,6 +234,15 @@ class TestIndexSelectAPI(unittest.TestCase): ...@@ -220,6 +234,15 @@ class TestIndexSelectAPI(unittest.TestCase):
expect_out = np.repeat(self.data_x, index, axis=0) expect_out = np.repeat(self.data_x, index, axis=0)
np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) np.testing.assert_allclose(expect_out, np_z, rtol=1e-05)
# case 3 zero_dim:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_zero_dim_x)
index = 2
z = paddle.repeat_interleave(x, index, None)
np_z = z.numpy()
expect_out = np.repeat(self.data_zero_dim_x, index, axis=None)
np.testing.assert_allclose(expect_out, np_z, rtol=1e-05)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -966,6 +966,34 @@ class TestSundryAPI(unittest.TestCase): ...@@ -966,6 +966,34 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x1.grad.numpy(), 0) self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 0) self.assertEqual(x2.grad.numpy(), 0)
def test_repeat_interleave(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
for place in places:
paddle.set_device(place)
x = paddle.randn(())
x.stop_gradient = False
out = paddle.repeat_interleave(x, 2, None)
out.backward()
# check shape of output
self.assertEqual(out.shape, [2])
# check grad shape
self.assertEqual(x.grad.shape, [])
repeats = paddle.to_tensor([3], dtype='int32')
out = paddle.repeat_interleave(x, repeats, None)
# check shape of output with 1D repeats
self.assertEqual(out.shape, [3])
# check grad shape with 1D repeats
self.assertEqual(x.grad.shape, [])
class TestSundryAPIStatic(unittest.TestCase): class TestSundryAPIStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -1380,6 +1408,24 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1380,6 +1408,24 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
@prog_scope()
def test_repeat_interleave(self):
x = paddle.full([], 1.0, 'float32')
out = paddle.repeat_interleave(x, 2, None)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, (2,))
repeats = paddle.to_tensor([3], dtype='int32')
out = paddle.repeat_interleave(x, repeats, None)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, (3,))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册