test_gather_op.py 535 字节
Newer Older
Z
zchen0211 已提交
1
import unittest
Q
qijun 已提交
2 3
import numpy as np
from op_test import OpTest
Z
zchen0211 已提交
4 5


Q
qijun 已提交
6
class TestGatherOp(OpTest):
Z
zchen0211 已提交
7
    def setUp(self):
Q
qijun 已提交
8 9 10 11
        self.op_type = "gather"
        xnp = np.random.random((10, 20)).astype("float32")
        self.inputs = {'X': xnp, 'Index': np.array([1, 3, 5]).astype("int32")}
        self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
Z
zchen0211 已提交
12

Q
qijun 已提交
13 14
    def test_check_output(self):
        self.check_output()
Z
zchen0211 已提交
15

Q
qijun 已提交
16 17
    def test_check_grad(self):
        self.check_grad(['X'], 'Out')
Z
zchen0211 已提交
18 19 20 21


if __name__ == "__main__":
    unittest.main()