diff --git a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py index a2955c12fc0c4671822aad99730ec7baf38a5531..1dbc1c056128cf0abee1aa4bde30e4d9b3b98ffd 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py @@ -114,9 +114,9 @@ class TestGatherNdOpWithHighRankSame(OpTest): def setUp(self): self.op_type = "gather_nd" - shape = (20, 9, 8, 1, 31) + shape = (5, 2, 3, 1, 10) xnp = np.random.rand(*shape).astype("float64") - index = np.vstack([np.random.randint(0, s, size=150) for s in shape]).T + index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T self.inputs = {'X': xnp, 'Index': index.astype("int32")} self.outputs = {'Out': xnp[tuple(index.T)]} @@ -133,13 +133,13 @@ class TestGatherNdOpWithHighRankDiff(OpTest): def setUp(self): self.op_type = "gather_nd" - shape = (20, 9, 8, 1, 31) + shape = (2, 3, 4, 1, 10) xnp = np.random.rand(*shape).astype("float64") - index = np.vstack([np.random.randint(0, s, size=1000) for s in shape]).T - index_re = index.reshape([10, 5, 20, 5]) + index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T + index_re = index.reshape([20, 5, 2, 5]) self.inputs = {'X': xnp, 'Index': index_re.astype("int32")} - self.outputs = {'Out': xnp[tuple(index.T)].reshape([10, 5, 20])} + self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])} def test_check_output(self): self.check_output()