import unittest import numpy as np from op_test import OpTest class TestLookupTableOp(OpTest): def setUp(self): self.op_type = "lookup_table" table = np.random.random((17, 31)).astype("float32") ids = np.random.randint(0, 17, 4).astype("int64") ids_expand = np.expand_dims(ids, axis=1) self.inputs = {'W': table, 'Ids': ids_expand} self.outputs = {'Out': table[ids]} def test_check_output(self): self.check_output() def test_check_grad(self): self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) if __name__ == "__main__": unittest.main()