test_lookup_table_op.py 626 字节
Newer Older
1 2
import unittest
import numpy as np
Q
qijun 已提交
3
from op_test import OpTest
4 5


Q
qijun 已提交
6
class TestLookupTableOp(OpTest):
7
    def setUp(self):
Q
qijun 已提交
8 9
        self.op_type = "lookup_table"
        table = np.random.random((17, 31)).astype("float32")
10
        ids = np.random.randint(0, 17, 4).astype("int64")
11 12
        ids_expand = np.expand_dims(ids, axis=1)
        self.inputs = {'W': table, 'Ids': ids_expand}
13 14
        self.outputs = {'Out': table[ids]}

Q
qijun 已提交
15 16
    def test_check_output(self):
        self.check_output()
17

Q
qijun 已提交
18 19
    def test_check_grad(self):
        self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
20 21


Q
qijun 已提交
22
if __name__ == "__main__":
23
    unittest.main()