提交 80338456 编写于 作者: Y Yibing Liu

add unit test for decoders

上级 f20c6d64
...@@ -3,10 +3,8 @@ from __future__ import absolute_import ...@@ -3,10 +3,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from itertools import groupby from itertools import groupby
import numpy as np import numpy as np
import kenlm
import multiprocessing import multiprocessing
...@@ -39,59 +37,6 @@ def ctc_best_path_decode(probs_seq, vocabulary): ...@@ -39,59 +37,6 @@ def ctc_best_path_decode(probs_seq, vocabulary):
return ''.join([vocabulary[index] for index in index_list]) return ''.join([vocabulary[index] for index in index_list])
class Scorer(object):
"""External defined scorer to evaluate a sentence in beam search
decoding, consisting of language model and word count.
:param alpha: Parameter associated with language model.
:type alpha: float
:param beta: Parameter associated with word count.
:type beta: float
:model_path: Path to load language model.
:type model_path: basestring
"""
def __init__(self, alpha, beta, model_path):
self._alpha = alpha
self._beta = beta
if not os.path.isfile(model_path):
raise IOError("Invaid language model path: %s" % model_path)
self._language_model = kenlm.LanguageModel(model_path)
# n-gram language model scoring
def language_model_score(self, sentence):
#log prob of last word
log_cond_prob = list(
self._language_model.full_scores(sentence, eos=False))[-1][0]
return np.power(10, log_cond_prob)
# word insertion term
def word_count(self, sentence):
words = sentence.strip().split(' ')
return len(words)
# execute evaluation
def __call__(self, sentence, log=False):
"""Evaluation function, gathering all the scores.
:param sentence: The input sentence for evalutation
:type sentence: basestring
:param log: Whether return the score in log representation.
:type log: bool
:return: Evaluation score, in the decimal or log.
:rtype: float
"""
lm = self.language_model_score(sentence)
word_cnt = self.word_count(sentence)
if log == False:
score = np.power(lm, self._alpha) \
* np.power(word_cnt, self._beta)
else:
score = self._alpha * np.log(lm) \
+ self._beta * np.log(word_cnt)
return score
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
vocabulary, vocabulary,
......
...@@ -10,6 +10,7 @@ import gzip ...@@ -10,6 +10,7 @@ import gzip
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import * from decoder import *
from scorer import Scorer
from error_rate import wer from error_rate import wer
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -51,7 +52,7 @@ parser.add_argument( ...@@ -51,7 +52,7 @@ parser.add_argument(
"beam_search or beam_search_nproc. (default: %(default)s)") "beam_search or beam_search_nproc. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="data/1Billion.klm", default="data/en.00.UNKNOWN.klm",
type=str, type=str,
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
......
...@@ -11,6 +11,7 @@ import paddle.v2 as paddle ...@@ -11,6 +11,7 @@ import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import * from decoder import *
from scorer import Scorer
from error_rate import wer from error_rate import wer
import utils import utils
...@@ -67,7 +68,7 @@ parser.add_argument( ...@@ -67,7 +68,7 @@ parser.add_argument(
help="Vocabulary filepath. (default: %(default)s)") help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_method", "--decode_method",
default='best_path', default='beam_search_nproc',
type=str, type=str,
help="Method for ctc decoding:" help="Method for ctc decoding:"
" best_path," " best_path,"
...@@ -85,7 +86,7 @@ parser.add_argument( ...@@ -85,7 +86,7 @@ parser.add_argument(
help="Number of output per sample in beam search. (default: %(default)d)") help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="data/1Billion.klm", default="data/en.00.UNKNOWN.klm",
type=str, type=str,
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
......
"""External Scorer for Beam Search Decoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import kenlm
import numpy as np
class Scorer(object):
"""External defined scorer to evaluate a sentence in beam search
decoding, consisting of language model and word count.
:param alpha: Parameter associated with language model.
:type alpha: float
:param beta: Parameter associated with word count.
:type beta: float
:model_path: Path to load language model.
:type model_path: basestring
"""
def __init__(self, alpha, beta, model_path):
self._alpha = alpha
self._beta = beta
if not os.path.isfile(model_path):
raise IOError("Invaid language model path: %s" % model_path)
self._language_model = kenlm.LanguageModel(model_path)
# n-gram language model scoring
def language_model_score(self, sentence):
#log10 prob of last word
log_cond_prob = list(
self._language_model.full_scores(sentence, eos=False))[-1][0]
return np.power(10, log_cond_prob)
# word insertion term
def word_count(self, sentence):
words = sentence.strip().split(' ')
return len(words)
# execute evaluation
def __call__(self, sentence, log=False):
"""Evaluation function, gathering all the different scores
and return the final one.
:param sentence: The input sentence for evalutation
:type sentence: basestring
:param log: Whether return the score in log representation.
:type log: bool
:return: Evaluation score, in the decimal or log.
:rtype: float
"""
lm = self.language_model_score(sentence)
word_cnt = self.word_count(sentence)
if log == False:
score = np.power(lm, self._alpha) \
* np.power(word_cnt, self._beta)
else:
score = self._alpha * np.log(lm) \
+ self._beta * np.log(word_cnt)
return score
"""Test decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from decoder import *
class TestDecoders(unittest.TestCase):
def setUp(self):
self.vocab_list = ["\'", ' ', 'a', 'b', 'c', 'd']
self.beam_size = 20
self.probs_seq1 = [[
0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254,
0.18184413, 0.16493624
], [
0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462,
0.0094893, 0.06890021
], [
0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535,
0.08424043, 0.08120984
], [
0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305,
0.05206269, 0.09772094
], [
0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985,
0.41317442, 0.01946335
], [
0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937,
0.04377724, 0.01457421
]]
self.probs_seq2 = [[
0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441,
0.04468023, 0.10903471
], [
0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123,
0.10219457, 0.20640612
], [
0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316,
0.12298585, 0.01654384
], [
0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055,
0.22538587, 0.13483174
], [
0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313,
0.07113197, 0.04139363
], [
0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306,
0.05294827, 0.22298418
]]
self.best_path_result = ["ac'bdc", "b'da"]
self.beam_search_result = ['acdc', "b'a"]
def test_best_path_decoder_1(self):
bst_result = ctc_best_path_decode(self.probs_seq1, self.vocab_list)
self.assertEqual(bst_result, self.best_path_result[0])
def test_best_path_decoder_2(self):
bst_result = ctc_best_path_decode(self.probs_seq2, self.vocab_list)
self.assertEqual(bst_result, self.best_path_result[1])
def test_beam_search_decoder_1(self):
beam_result = ctc_beam_search_decoder(
probs_seq=self.probs_seq1,
beam_size=self.beam_size,
vocabulary=self.vocab_list,
blank_id=len(self.vocab_list))
self.assertEqual(beam_result[0][1], self.beam_search_result[0])
def test_beam_search_decoder_2(self):
beam_result = ctc_beam_search_decoder(
probs_seq=self.probs_seq2,
beam_size=self.beam_size,
vocabulary=self.vocab_list,
blank_id=len(self.vocab_list))
self.assertEqual(beam_result[0][1], self.beam_search_result[1])
def test_beam_search_nproc_decoder(self):
beam_results = ctc_beam_search_decoder_nproc(
probs_split=[self.probs_seq1, self.probs_seq2],
beam_size=self.beam_size,
vocabulary=self.vocab_list,
blank_id=len(self.vocab_list))
self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])
if __name__ == '__main__':
unittest.main()
...@@ -10,6 +10,7 @@ import gzip ...@@ -10,6 +10,7 @@ import gzip
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import * from decoder import *
from scorer import Scorer
from error_rate import wer from error_rate import wer
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -81,7 +82,7 @@ parser.add_argument( ...@@ -81,7 +82,7 @@ parser.add_argument(
help="Number of outputs per sample in beam search. (default: %(default)d)") help="Number of outputs per sample in beam search. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="data/1Billion.klm", default="data/en.00.UNKNOWN.klm",
type=str, type=str,
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册