提交 51f35a53 编写于 作者: Y Yibing Liu

code clean & add external scorer

上级 dedbfb26
## This is a prototype of ctc beam search decoder
import copy
import random
import numpy as np
# vocab = blank + space + English characters
#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)]
vocab = ['-', '_', 'a']
def ids_list2str(ids_list):
ids_str = [str(elem) for elem in ids_list]
ids_str = ' '.join(ids_str)
return ids_str
def ids_id2token(ids_list):
ids_str = ''
for ids in ids_list:
ids_str += vocab[ids]
return ids_str
def language_model(ids_list, vocabulary):
# lookup ptb vocabulary
ptb_vocab_path = "./data/ptb_vocab.txt"
sentence = ''.join([vocabulary[ids] for ids in ids_list])
words = sentence.split(' ')
last_word = words[-1]
with open(ptb_vocab_path, 'r') as ptb_vocab:
f = ptb_vocab.readline()
while f:
if f == last_word:
return 1.0
f = ptb_vocab.readline()
return 0.0
def ctc_beam_search_decoder(input_probs_matrix,
beam_size,
vocabulary,
max_time_steps=None,
lang_model=language_model,
alpha=1.0,
beta=1.0,
blank_id=0,
space_id=1,
num_results_per_sample=None):
'''
Beam search decoder for CTC-trained network, adapted from Algorithm 1
in https://arxiv.org/abs/1408.2873.
:param input_probs_matrix: probs matrix for input sequence, row major
:type input_probs_matrix: 2D matrix.
:param beam_size: width for beam search
:type beam_size: int
:max_time_steps: maximum steps' number for input sequence,
<=len(input_probs_matrix)
:type max_time_steps: int
:lang_model: language model for scoring
:type lang_model: function
:param alpha: parameter associated with language model.
:type alpha: float
:param beta: parameter associated with word count
:type beta: float
:param blank_id: id of blank, default 0.
:type blank_id: int
:param space_id: id of space, default 1.
:type space_id: int
:param num_result_per_sample: the number of output decoding results
per given sample, <=beam_size.
:type num_result_per_sample: int
'''
# function to convert ids in string to list
def ids_str2list(ids_str):
ids_str = ids_str.split(' ')
ids_list = [int(elem) for elem in ids_str]
return ids_list
# counting words in a character list
def word_count(ids_list):
cnt = 0
for elem in ids_list:
if elem == space_id:
cnt += 1
return cnt
if num_results_per_sample is None:
num_results_per_sample = beam_size
assert num_results_per_sample <= beam_size
if max_time_steps is None:
max_time_steps = len(input_probs_matrix)
else:
max_time_steps = min(max_time_steps, len(input_probs_matrix))
assert max_time_steps > 0
vocab_dim = len(input_probs_matrix[0])
assert blank_id < vocab_dim
assert space_id < vocab_dim
## initialize
start_id = -1
# the set containing selected prefixes
prefix_set_prev = {str(start_id): 1.0}
probs_b, probs_nb = {str(start_id): 1.0}, {str(start_id): 0.0}
## extend prefix in loop
for time_step in range(max_time_steps):
# the set containing candidate prefixes
prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {}
for l in prefix_set_prev:
prob = input_probs_matrix[time_step]
# convert ids in string to list
ids_list = ids_str2list(l)
end_id = ids_list[-1]
if not prefix_set_next.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
# extend prefix by travering vocabulary
for c in range(0, vocab_dim):
if c == blank_id:
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l])
else:
l_plus = l + ' ' + str(c)
if not prefix_set_next.has_key(l_plus):
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
if c == end_id:
probs_nb_cur[l_plus] += prob[c] * probs_b[l]
probs_nb_cur[l] += prob[c] * probs_nb[l]
elif c == space_id:
lm = 1.0 if lang_model is None \
else np.power(lang_model(ids_list, vocabulary), alpha)
probs_nb_cur[l_plus] += lm * prob[c] * (
probs_b[l] + probs_nb[l])
else:
probs_nb_cur[l_plus] += prob[c] * (
probs_b[l] + probs_nb[l])
# add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[
l_plus] + probs_b_cur[l_plus]
# add l into prefix_set_next
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
# update probs
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(
probs_nb_cur)
## store top beam_size prefixes
prefix_set_prev = sorted(
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
if beam_size < len(prefix_set_prev):
prefix_set_prev = prefix_set_prev[:beam_size]
prefix_set_prev = dict(prefix_set_prev)
beam_result = []
for (seq, prob) in prefix_set_prev.items():
if prob > 0.0:
ids_list = ids_str2list(seq)[1:]
result = ''.join([vocabulary[ids] for ids in ids_list])
log_prob = np.log(prob)
beam_result.append([log_prob, result])
## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
if num_results_per_sample < beam_size:
beam_result = beam_result[:num_results_per_sample]
return beam_result
def simple_test():
input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]]
beam_result = ctc_beam_search_decoder(
input_probs_matrix=input_probs_matrix,
beam_size=20,
blank_id=0,
space_id=1, )
print "\nbeam search output:"
for result in beam_result:
print("%6f\t%s" % (result[0], ids_id2token(result[1])))
if __name__ == '__main__':
simple_test()
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
from itertools import groupby from itertools import groupby
import numpy as np import numpy as np
from ctc_beam_search_decoder import * import copy
import kenlm
def ctc_best_path_decode(probs_seq, vocabulary): def ctc_best_path_decode(probs_seq, vocabulary):
...@@ -37,36 +38,165 @@ def ctc_best_path_decode(probs_seq, vocabulary): ...@@ -37,36 +38,165 @@ def ctc_best_path_decode(probs_seq, vocabulary):
return ''.join([vocabulary[index] for index in index_list]) return ''.join([vocabulary[index] for index in index_list])
def ctc_decode(probs_seq, class Scorer(object):
vocabulary,
method,
beam_size=None,
num_results_per_sample=None):
""" """
CTC-like sequence decoding from a sequence of likelihood probablilites. External defined scorer to evaluate a sentence in beam search
decoding, consisting of language model and word count.
:param probs_seq: 2-D list of probabilities over the vocabulary for each :param alpha: Parameter associated with language model.
character. Each element is a list of float probabilities :type alpha: float
for one character. :param beta: Parameter associated with word count.
:type probs_seq: list :type beta: float
:model_path: Path to load language model.
:type model_path: basestring
"""
def __init__(self, alpha, beta, model_path):
self._alpha = alpha
self._beta = beta
self._language_model = kenlm.LanguageModel(model_path)
def language_model_score(self, sentence, bos=True, eos=False):
log_prob = self._language_model.score(sentence, bos, eos)
return np.power(10, log_prob)
def word_count(self, sentence):
words = sentence.strip().split(' ')
return len(words)
# execute evaluation
def evaluate(self, sentence, bos=True, eos=False):
lm = self.language_model_score(sentence, bos, eos)
word_count = self.word_count(sentence)
score = np.power(lm, self._alpha) \
* np.power(word_count, self._beta)
return score
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
ext_scoring_func=None,
blank_id=0):
'''
Beam search decoder for CTC-trained network, using beam search with width
beam_size to find many paths to one label, return beam_size labels in
the order of probabilities. The implementation is based on Prefix Beam
Search(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned, need to be verified.
:param probs_seq: 2-D list with length max_time_steps, each element
is a list of normalized probabilities over vocabulary
and blank for one time step.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param method: Decoding method name, with options: "best_path". :param ext_scoring_func: External defined scoring function for
:type method: basestring partially decoded sentence, e.g. word count
:return: Decoding result string. and language model.
:rtype: baseline :type external_scoring_function: function
""" :param blank_id: id of blank, default 0.
:type blank_id: int
:return: Decoding log probability and result string.
:rtype: list
'''
for prob_list in probs_seq: for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1: if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary") raise ValueError("probs dimension mismatchedd with vocabulary")
if method == "best_path":
return ctc_best_path_decode(probs_seq, vocabulary) max_time_steps = len(probs_seq)
elif method == "beam_search": if not max_time_steps > 0:
return ctc_beam_search_decoder( raise ValueError("probs_seq shouldn't be empty")
input_probs_matrix=probs_seq,
vocabulary=vocabulary, probs_dim = len(probs_seq[0])
beam_size=beam_size, if not blank_id < probs_dim:
blank_id=len(vocabulary), raise ValueError("blank_id shouldn't be greater than probs dimension")
num_results_per_sample=num_results_per_sample)
else: if ' ' not in vocabulary:
raise ValueError("Decoding method [%s] is not supported." % method) raise ValueError("space doesn't exist in vocabulary")
space_id = vocabulary.index(' ')
# function to convert ids in string to list
def ids_str2list(ids_str):
ids_str = ids_str.split(' ')
ids_list = [int(elem) for elem in ids_str]
return ids_list
# function to convert ids list to sentence
def ids2sentence(ids_list, vocab):
return ''.join([vocab[ids] for ids in ids_list])
## initialize
# the set containing selected prefixes
prefix_set_prev = {'-1': 1.0}
probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0}
## extend prefix in loop
for time_step in range(max_time_steps):
# the set containing candidate prefixes
prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {}
for l in prefix_set_prev:
prob = probs_seq[time_step]
# convert ids in string to list
ids_list = ids_str2list(l)
end_id = ids_list[-1]
if not prefix_set_next.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
# extend prefix by travering vocabulary
for c in range(0, probs_dim):
if c == blank_id:
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l])
else:
l_plus = l + ' ' + str(c)
if not prefix_set_next.has_key(l_plus):
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
if c == end_id:
probs_nb_cur[l_plus] += prob[c] * probs_b[l]
probs_nb_cur[l] += prob[c] * probs_nb[l]
elif c == space_id:
if ext_scoring_func is None:
score = 1.0
else:
prefix_sent = ids2sentence(ids_list, vocabulary)
score = ext_scoring_func(prefix_sent)
probs_nb_cur[l_plus] += score * prob[c] * (
probs_b[l] + probs_nb[l])
else:
probs_nb_cur[l_plus] += prob[c] * (
probs_b[l] + probs_nb[l])
# add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[
l_plus] + probs_b_cur[l_plus]
# add l into prefix_set_next
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
# update probs
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(
probs_nb_cur)
## store top beam_size prefixes
prefix_set_prev = sorted(
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
if beam_size < len(prefix_set_prev):
prefix_set_prev = prefix_set_prev[:beam_size]
prefix_set_prev = dict(prefix_set_prev)
beam_result = []
for (seq, prob) in prefix_set_prev.items():
if prob > 0.0:
ids_list = ids_str2list(seq)[1:]
result = ids2sentence(ids_list, vocabulary)
log_prob = np.log(prob)
beam_result.append([log_prob, result])
## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result
...@@ -8,7 +8,7 @@ import argparse ...@@ -8,7 +8,7 @@ import argparse
import gzip import gzip
from audio_data_utils import DataGenerator from audio_data_utils import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import ctc_decode from decoder import *
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 inference.') description='Simplified version of DeepSpeech2 inference.')
...@@ -59,7 +59,7 @@ parser.add_argument( ...@@ -59,7 +59,7 @@ parser.add_argument(
help="Vocabulary filepath. (default: %(default)s)") help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_method", "--decode_method",
default='best_path', default='beam_search',
type=str, type=str,
help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
) )
...@@ -69,11 +69,25 @@ parser.add_argument( ...@@ -69,11 +69,25 @@ parser.add_argument(
type=int, type=int,
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--num_result_per_sample", "--num_results_per_sample",
default=2, default=1,
type=int, type=int,
help="Number of results per given sample in beam search. (default: %(default)d)" help="Number of output per sample in beam search. (default: %(default)d)")
) parser.add_argument(
"--language_model_path",
default="./data/1Billion.klm",
type=str,
help="Path for language model. (default: %(default)d)")
parser.add_argument(
"--alpha",
default=0.0,
type=float,
help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
"--beta",
default=0.0,
type=float,
help="Parameter associated with word count. (default: %(default)f)")
args = parser.parse_args() args = parser.parse_args()
...@@ -135,24 +149,34 @@ def infer(): ...@@ -135,24 +149,34 @@ def infer():
for i in xrange(0, len(infer_data)) for i in xrange(0, len(infer_data))
] ]
# decode and print ## decode and print
for i, probs in enumerate(probs_split): # best path decode
best_path_transcription = ctc_decode( if args.decode_method == "best_path":
probs_seq=probs, vocabulary=vocab_list, method="best_path") for i, probs in enumerate(probs_split):
target_transcription = ''.join( target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]]) [vocab_list[index] for index in infer_data[i][1]])
print("\nTarget Transcription: %s \nBst_path Transcription: %s" % best_path_transcription = ctc_best_path_decode(
(target_transcription, best_path_transcription)) probs_seq=probs, vocabulary=vocab_list)
beam_search_transcription = ctc_decode( print("\nTarget Transcription: %s\nOutput Transcription: %s" %
probs_seq=probs, (target_transcription, best_path_transcription))
vocabulary=vocab_list, # beam search decode
method="beam_search", elif args.decode_method == "beam_search":
beam_size=args.beam_size, for i, probs in enumerate(probs_split):
num_results_per_sample=args.num_result_per_sample) target_transcription = ''.join(
for index in range(len(beam_search_transcription)): [vocab_list[index] for index in infer_data[i][1]])
print("LM No, %d - %4f: %s " % ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
(index, beam_search_transcription[index][0], beam_search_result = ctc_beam_search_decoder(
beam_search_transcription[index][1])) probs_seq=probs,
vocabulary=vocab_list,
beam_size=args.beam_size,
ext_scoring_func=ext_scorer.evaluate,
blank_id=len(vocab_list))
print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample):
result = beam_search_result[index]
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
else:
raise ValueError("Decoding method [%s] is not supported." % method)
def main(): def main():
......
from __future__ import absolute_import
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
import ctc_beam_search_decoder as tested_decoder
def test_beam_search_decoder():
max_time_steps = 6
beam_size = 20
num_results_per_sample = 20
input_prob_matrix_0 = np.asarray(
[
[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
[0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
[0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
[0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
[0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
# Random entry added in at time=5
[0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]
],
dtype=np.float32)
# Add arbitrary offset - this is fine
input_log_prob_matrix_0 = np.log(input_prob_matrix_0) #+ 2.0
# len max_time_steps array of batch_size x depth matrices
inputs = ([
input_log_prob_matrix_0[t, :][np.newaxis, :]
for t in range(max_time_steps)
])
inputs_t = [ops.convert_to_tensor(x) for x in inputs]
inputs_t = array_ops.stack(inputs_t)
# run CTC beam search decoder in tensorflow
with tf.Session() as sess:
decoded, log_probabilities = tf.nn.ctc_beam_search_decoder(
inputs_t, [max_time_steps],
beam_width=beam_size,
top_paths=num_results_per_sample,
merge_repeated=False)
tf_decoded = sess.run(decoded)
tf_log_probs = sess.run(log_probabilities)
# run tested CTC beam search decoder
beam_result = tested_decoder.ctc_beam_search_decoder(
input_probs_matrix=input_prob_matrix_0,
beam_size=beam_size,
blank_id=5, # default blank_id in tensorflow decoder is (num classes-1)
space_id=4, # doesn't matter
max_time_steps=max_time_steps,
num_results_per_sample=num_results_per_sample)
# compare decoding result
print(
"{tf_decoder log probs} \t {tested_decoder log probs}: {tf_decoder result} {tested_decoder result}"
)
for index in range(len(beam_result)):
print(('%6f\t%6f: ') % (tf_log_probs[0][index], beam_result[index][0]),
tf_decoded[index].values, ' ', beam_result[index][1])
if __name__ == '__main__':
test_beam_search_decoder()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册