From a51ed5108b7016e0629499736259fcb607214673 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 2 Jun 2017 09:30:30 +0800 Subject: [PATCH] update code & add test --- .../ctc_beam_search_decoder.py | 162 ++++++++++++++++++ .../test_ctc_beam_search_decoder.py | 69 ++++++++ 2 files changed, 231 insertions(+) create mode 100644 ctc_beam_search_decoder/ctc_beam_search_decoder.py create mode 100644 ctc_beam_search_decoder/test_ctc_beam_search_decoder.py diff --git a/ctc_beam_search_decoder/ctc_beam_search_decoder.py b/ctc_beam_search_decoder/ctc_beam_search_decoder.py new file mode 100644 index 00000000..873121b1 --- /dev/null +++ b/ctc_beam_search_decoder/ctc_beam_search_decoder.py @@ -0,0 +1,162 @@ +## This is a prototype of ctc beam search decoder + +import copy +import random +import numpy as np + +# vocab = blank + space + English characters +#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)] + +vocab = ['-', '_', 'a'] + + +def ids_str2list(ids_str): + ids_str = ids_str.split(' ') + ids_list = [int(elem) for elem in ids_str] + return ids_list + + +def ids_list2str(ids_list): + ids_str = [str(elem) for elem in ids_list] + ids_str = ' '.join(ids_str) + return ids_str + + +def ids_id2token(ids_list): + ids_str = '' + for ids in ids_list: + ids_str += vocab[ids] + return ids_str + + +def ctc_beam_search_decoder(input_probs_matrix, + beam_size, + max_time_steps=None, + lang_model=None, + alpha=1.0, + beta=1.0, + blank_id=0, + space_id=1, + num_results_per_sample=None): + ''' + beam search decoder for CTC-trained network, called outside of the recurrent group. + adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873. + + param input_probs_matrix: probs matrix for input sequence, row major + type input_probs_matrix: 2D matrix. + param beam_size: width for beam search + type beam_size: int + max_time_steps: maximum steps' number for input sequence, <=len(input_probs_matrix) + type max_time_steps: int + lang_model: language model for scoring + type lang_model: function + + ...... + + ''' + if num_results_per_sample is None: + num_results_per_sample = beam_size + assert num_results_per_sample <= beam_size + + if max_time_steps is None: + max_time_steps = len(input_probs_matrix) + else: + max_time_steps = min(max_time_steps, len(input_probs_matrix)) + assert max_time_steps > 0 + + vocab_dim = len(input_probs_matrix[0]) + assert blank_id < vocab_dim + assert space_id < vocab_dim + + ## initialize + start_id = -1 + # the set containing selected prefixes + prefix_set_prev = {str(start_id): 1.0} + probs_b, probs_nb = {str(start_id): 1.0}, {str(start_id): 0.0} + + ## extend prefix in loop + for time_step in range(max_time_steps): + # the set containing candidate prefixes + prefix_set_next = {} + probs_b_cur, probs_nb_cur = {}, {} + for l in prefix_set_prev: + prob = input_probs_matrix[time_step] + + # convert ids in string to list + ids_list = ids_str2list(l) + end_id = ids_list[-1] + if not prefix_set_next.has_key(l): + probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 + + # extend prefix by travering vocabulary + for c in range(0, vocab_dim): + if c == blank_id: + probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) + else: + l_plus = l + ' ' + str(c) + if not prefix_set_next.has_key(l_plus): + probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 + + if c == end_id: + probs_nb_cur[l_plus] += prob[c] * probs_b[l] + probs_nb_cur[l] += prob[c] * probs_nb[l] + elif c == space_id: + lm = 1.0 if lang_model is None \ + else np.power(lang_model(ids_list), alpha) + probs_nb_cur[l_plus] += lm * prob[c] * ( + probs_b[l] + probs_nb[l]) + else: + probs_nb_cur[l_plus] += prob[c] * ( + probs_b[l] + probs_nb[l]) + # add l_plus into prefix_set_next + prefix_set_next[l_plus] = probs_nb_cur[ + l_plus] + probs_b_cur[l_plus] + # add l into prefix_set_next + prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] + # update probs + probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy( + probs_nb_cur) + + ## store top beam_size prefixes + prefix_set_prev = sorted( + prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True) + if beam_size < len(prefix_set_prev): + prefix_set_prev = prefix_set_prev[:beam_size] + prefix_set_prev = dict(prefix_set_prev) + + beam_result = [] + for (seq, prob) in prefix_set_prev.items(): + if prob > 0.0: + ids_list = ids_str2list(seq) + log_prob = np.log(prob) + beam_result.append([log_prob, ids_list[1:]]) + + ## output top beam_size decoding results + beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) + if num_results_per_sample < beam_size: + beam_result = beam_result[:num_results_per_sample] + return beam_result + + +def language_model(input): + # TODO + return random.uniform(0, 1) + + +def simple_test(): + + input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]] + + beam_result = ctc_beam_search_decoder( + input_probs_matrix=input_probs_matrix, + beam_size=20, + blank_id=0, + space_id=1, ) + + print "\nbeam search output:" + for result in beam_result: + print("%6f\t%s" % (result[0], ids_id2token(result[1]))) + + +if __name__ == '__main__': + simple_test() diff --git a/ctc_beam_search_decoder/test_ctc_beam_search_decoder.py b/ctc_beam_search_decoder/test_ctc_beam_search_decoder.py new file mode 100644 index 00000000..f7970444 --- /dev/null +++ b/ctc_beam_search_decoder/test_ctc_beam_search_decoder.py @@ -0,0 +1,69 @@ +from __future__ import absolute_import +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +import ctc_beam_search_decoder as tested_decoder + + +def test_beam_search_decoder(): + max_time_steps = 6 + beam_size = 20 + num_results_per_sample = 20 + + input_prob_matrix_0 = np.asarray( + [ + [0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908], + [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517], + [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763], + [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655], + [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878], + # Random entry added in at time=5 + [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671] + ], + dtype=np.float32) + + # Add arbitrary offset - this is fine + input_log_prob_matrix_0 = np.log(input_prob_matrix_0) #+ 2.0 + + # len max_time_steps array of batch_size x depth matrices + inputs = ([ + input_log_prob_matrix_0[t, :][np.newaxis, :] + for t in range(max_time_steps) + ]) + + inputs_t = [ops.convert_to_tensor(x) for x in inputs] + inputs_t = array_ops.stack(inputs_t) + + # run CTC beam search decoder in tensorflow + with tf.Session() as sess: + decoded, log_probabilities = tf.nn.ctc_beam_search_decoder( + inputs_t, [max_time_steps], + beam_width=beam_size, + top_paths=num_results_per_sample, + merge_repeated=False) + tf_decoded = sess.run(decoded) + tf_log_probs = sess.run(log_probabilities) + + # run tested CTC beam search decoder + beam_result = tested_decoder.ctc_beam_search_decoder( + input_probs_matrix=input_prob_matrix_0, + beam_size=beam_size, + blank_id=5, # default blank_id in tensorflow decoder is (num classes-1) + space_id=4, # doesn't matter + max_time_steps=max_time_steps, + num_results_per_sample=num_results_per_sample) + + # compare decoding result + print( + "{tf_decoder log probs} \t {tested_decoder log probs}: {tf_decoder result} {tested_decoder result}" + ) + for index in range(len(beam_result)): + print(('%6f\t%6f: ') % (tf_log_probs[0][index], beam_result[index][0]), + tf_decoded[index].values, ' ', beam_result[index][1]) + + +if __name__ == '__main__': + test_beam_search_decoder() -- GitLab