test_lookup_table.py 570 字节
Newer Older
1 2
import unittest
import numpy as np
Q
qijun 已提交
3
from op_test import OpTest
4 5


Q
qijun 已提交
6
class TestLookupTableOp(OpTest):
7
    def setUp(self):
Q
qijun 已提交
8 9 10
        self.op_type = "lookup_table"
        table = np.random.random((17, 31)).astype("float32")
        ids = np.random.randint(0, 17, 4).astype("int32")
11 12 13
        self.inputs = {'W': table, 'Ids': ids}
        self.outputs = {'Out': table[ids]}

Q
qijun 已提交
14 15
    def test_check_output(self):
        self.check_output()
16

Q
qijun 已提交
17 18
    def test_check_grad(self):
        self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
19 20


Q
qijun 已提交
21
if __name__ == "__main__":
22
    unittest.main()