test_decoders.py 3.3 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
"""Test decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import unittest
from decoder import *


class TestDecoders(unittest.TestCase):
    def setUp(self):
        self.vocab_list = ["\'", ' ', 'a', 'b', 'c', 'd']
        self.beam_size = 20
        self.probs_seq1 = [[
            0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254,
            0.18184413, 0.16493624
        ], [
            0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462,
            0.0094893, 0.06890021
        ], [
            0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535,
            0.08424043, 0.08120984
        ], [
            0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305,
            0.05206269, 0.09772094
        ], [
            0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985,
            0.41317442, 0.01946335
        ], [
            0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937,
            0.04377724, 0.01457421
        ]]
        self.probs_seq2 = [[
            0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441,
            0.04468023, 0.10903471
        ], [
            0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123,
            0.10219457, 0.20640612
        ], [
            0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316,
            0.12298585, 0.01654384
        ], [
            0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055,
            0.22538587, 0.13483174
        ], [
            0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313,
            0.07113197, 0.04139363
        ], [
            0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306,
            0.05294827, 0.22298418
        ]]
        self.best_path_result = ["ac'bdc", "b'da"]
        self.beam_search_result = ['acdc', "b'a"]

    def test_best_path_decoder_1(self):
Y
Yibing Liu 已提交
56
        bst_result = ctc_best_path_decoder(self.probs_seq1, self.vocab_list)
Y
Yibing Liu 已提交
57 58 59
        self.assertEqual(bst_result, self.best_path_result[0])

    def test_best_path_decoder_2(self):
Y
Yibing Liu 已提交
60
        bst_result = ctc_best_path_decoder(self.probs_seq2, self.vocab_list)
Y
Yibing Liu 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        self.assertEqual(bst_result, self.best_path_result[1])

    def test_beam_search_decoder_1(self):
        beam_result = ctc_beam_search_decoder(
            probs_seq=self.probs_seq1,
            beam_size=self.beam_size,
            vocabulary=self.vocab_list,
            blank_id=len(self.vocab_list))
        self.assertEqual(beam_result[0][1], self.beam_search_result[0])

    def test_beam_search_decoder_2(self):
        beam_result = ctc_beam_search_decoder(
            probs_seq=self.probs_seq2,
            beam_size=self.beam_size,
            vocabulary=self.vocab_list,
            blank_id=len(self.vocab_list))
        self.assertEqual(beam_result[0][1], self.beam_search_result[1])

Y
Yibing Liu 已提交
79
    def test_beam_search_decoder_batch(self):
Y
Yibing Liu 已提交
80
        beam_results = ctc_beam_search_decoder_batch(
Y
Yibing Liu 已提交
81 82 83
            probs_split=[self.probs_seq1, self.probs_seq2],
            beam_size=self.beam_size,
            vocabulary=self.vocab_list,
Y
Yibing Liu 已提交
84 85
            blank_id=len(self.vocab_list),
            num_processes=24)
Y
Yibing Liu 已提交
86 87 88 89 90 91
        self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
        self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])


if __name__ == '__main__':
    unittest.main()