提交 d7f5ee66 编写于 作者: Y Yibing Liu

follow comments in code format

上级 eec7cb48
...@@ -5,6 +5,7 @@ from __future__ import print_function ...@@ -5,6 +5,7 @@ from __future__ import print_function
from itertools import groupby from itertools import groupby
import numpy as np import numpy as np
from math import log
import multiprocessing import multiprocessing
...@@ -97,13 +98,8 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -97,13 +98,8 @@ def ctc_beam_search_decoder(probs_seq,
# prefix_set_prev: the set containing selected prefixes # prefix_set_prev: the set containing selected prefixes
# probs_b_prev: prefixes' probability ending with blank in previous step # probs_b_prev: prefixes' probability ending with blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous step # probs_nb_prev: prefixes' probability ending with non-blank in previous step
prefix_set_prev, probs_b_prev, probs_nb_prev = { prefix_set_prev = {'\t': 1.0}
'\t': 1.0 probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
}, {
'\t': 1.0
}, {
'\t': 0.0
}
## extend prefix in loop ## extend prefix in loop
for time_step in xrange(len(probs_seq)): for time_step in xrange(len(probs_seq)):
...@@ -179,7 +175,7 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -179,7 +175,7 @@ def ctc_beam_search_decoder(probs_seq,
# score last word by external scorer # score last word by external scorer
if (ext_scoring_func is not None) and (result[-1] != ' '): if (ext_scoring_func is not None) and (result[-1] != ' '):
prob = prob * ext_scoring_func(result) prob = prob * ext_scoring_func(result)
log_prob = np.log(prob) log_prob = log(prob)
beam_result.append((log_prob, result)) beam_result.append((log_prob, result))
## output top beam_size decoding results ## output top beam_size decoding results
......
...@@ -62,7 +62,7 @@ parser.add_argument( ...@@ -62,7 +62,7 @@ parser.add_argument(
) )
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="data/en.00.UNKNOWN.klm", default="lm/data/1Billion.klm",
type=str, type=str,
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -88,7 +88,7 @@ parser.add_argument( ...@@ -88,7 +88,7 @@ parser.add_argument(
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--decode_manifest_path", "--decode_manifest_path",
default='data/manifest.libri.test-clean', default='datasets/manifest.test',
type=str, type=str,
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
......
...@@ -89,7 +89,7 @@ parser.add_argument( ...@@ -89,7 +89,7 @@ parser.add_argument(
help="Number of output per sample in beam search. (default: %(default)d)") help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="lm/data/en.00.UNKNOWN.klm", default="lm/data/1Billion.klm",
type=str, type=str,
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
......
...@@ -62,9 +62,7 @@ class LmScorer(object): ...@@ -62,9 +62,7 @@ class LmScorer(object):
lm = self._language_model_score(sentence) lm = self._language_model_score(sentence)
word_cnt = self._word_count(sentence) word_cnt = self._word_count(sentence)
if log == False: if log == False:
score = np.power(lm, self._alpha) \ score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta)
* np.power(word_cnt, self._beta)
else: else:
score = self._alpha * np.log(lm) \ score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt)
+ self._beta * np.log(word_cnt)
return score return score
...@@ -77,7 +77,7 @@ parser.add_argument( ...@@ -77,7 +77,7 @@ parser.add_argument(
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="lm/data/en.00.UNKNOWN.klm", default="lm/data/1Billion.klm",
type=str, type=str,
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册