From b6fd262951838ef2fd7f6f097f9d38f6ee6d0bb6 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Thu, 31 Dec 2020 14:29:28 +0800 Subject: [PATCH] fix gather nd for untest (#30037) --- .../fluid/tests/unittests/test_gather_nd_op.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py index a2955c12fc0..1dbc1c05612 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py @@ -114,9 +114,9 @@ class TestGatherNdOpWithHighRankSame(OpTest): def setUp(self): self.op_type = "gather_nd" - shape = (20, 9, 8, 1, 31) + shape = (5, 2, 3, 1, 10) xnp = np.random.rand(*shape).astype("float64") - index = np.vstack([np.random.randint(0, s, size=150) for s in shape]).T + index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T self.inputs = {'X': xnp, 'Index': index.astype("int32")} self.outputs = {'Out': xnp[tuple(index.T)]} @@ -133,13 +133,13 @@ class TestGatherNdOpWithHighRankDiff(OpTest): def setUp(self): self.op_type = "gather_nd" - shape = (20, 9, 8, 1, 31) + shape = (2, 3, 4, 1, 10) xnp = np.random.rand(*shape).astype("float64") - index = np.vstack([np.random.randint(0, s, size=1000) for s in shape]).T - index_re = index.reshape([10, 5, 20, 5]) + index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T + index_re = index.reshape([20, 5, 2, 5]) self.inputs = {'X': xnp, 'Index': index_re.astype("int32")} - self.outputs = {'Out': xnp[tuple(index.T)].reshape([10, 5, 20])} + self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])} def test_check_output(self): self.check_output() -- GitLab