未验证 提交 b6fd2629 编写于 作者: S ShenLiang 提交者: GitHub

fix gather nd for untest (#30037)

上级 a253a78a
...@@ -114,9 +114,9 @@ class TestGatherNdOpWithHighRankSame(OpTest): ...@@ -114,9 +114,9 @@ class TestGatherNdOpWithHighRankSame(OpTest):
def setUp(self): def setUp(self):
self.op_type = "gather_nd" self.op_type = "gather_nd"
shape = (20, 9, 8, 1, 31) shape = (5, 2, 3, 1, 10)
xnp = np.random.rand(*shape).astype("float64") xnp = np.random.rand(*shape).astype("float64")
index = np.vstack([np.random.randint(0, s, size=150) for s in shape]).T index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T
self.inputs = {'X': xnp, 'Index': index.astype("int32")} self.inputs = {'X': xnp, 'Index': index.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)]} self.outputs = {'Out': xnp[tuple(index.T)]}
...@@ -133,13 +133,13 @@ class TestGatherNdOpWithHighRankDiff(OpTest): ...@@ -133,13 +133,13 @@ class TestGatherNdOpWithHighRankDiff(OpTest):
def setUp(self): def setUp(self):
self.op_type = "gather_nd" self.op_type = "gather_nd"
shape = (20, 9, 8, 1, 31) shape = (2, 3, 4, 1, 10)
xnp = np.random.rand(*shape).astype("float64") xnp = np.random.rand(*shape).astype("float64")
index = np.vstack([np.random.randint(0, s, size=1000) for s in shape]).T index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T
index_re = index.reshape([10, 5, 20, 5]) index_re = index.reshape([20, 5, 2, 5])
self.inputs = {'X': xnp, 'Index': index_re.astype("int32")} self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)].reshape([10, 5, 20])} self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册