提交 ff01d048 编写于 作者: Y Yibing Liu

final refining on old data provider: enable pruning & add evaluation & code cleanup

上级 a633eb9c
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import os import os
from itertools import groupby from itertools import groupby
import numpy as np import numpy as np
import copy
import kenlm import kenlm
import multiprocessing import multiprocessing
...@@ -73,11 +72,25 @@ class Scorer(object): ...@@ -73,11 +72,25 @@ class Scorer(object):
return len(words) return len(words)
# execute evaluation # execute evaluation
def __call__(self, sentence): def __call__(self, sentence, log=False):
"""
Evaluation function
:param sentence: The input sentence for evalutation
:type sentence: basestring
:param log: Whether return the score in log representation.
:type log: bool
:return: Evaluation score, in the decimal or log.
:rtype: float
"""
lm = self.language_model_score(sentence) lm = self.language_model_score(sentence)
word_cnt = self.word_count(sentence) word_cnt = self.word_count(sentence)
if log == False:
score = np.power(lm, self._alpha) \ score = np.power(lm, self._alpha) \
* np.power(word_cnt, self._beta) * np.power(word_cnt, self._beta)
else:
score = self._alpha * np.log(lm) \
+ self._beta * np.log(word_cnt)
return score return score
...@@ -85,13 +98,14 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -85,13 +98,14 @@ def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id=0, blank_id=0,
cutoff_prob=1.0,
ext_scoring_func=None, ext_scoring_func=None,
nproc=False): nproc=False):
''' '''
Beam search decoder for CTC-trained network, using beam search with width Beam search decoder for CTC-trained network, using beam search with width
beam_size to find many paths to one label, return beam_size labels in beam_size to find many paths to one label, return beam_size labels in
the order of probabilities. The implementation is based on Prefix Beam the descending order of probabilities. The implementation is based on Prefix
Search(https://arxiv.org/abs/1408.2873), and the unclear part is Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned, need to be verified. redesigned, need to be verified.
:param probs_seq: 2-D list with length num_time_steps, each element :param probs_seq: 2-D list with length num_time_steps, each element
...@@ -102,22 +116,25 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -102,22 +116,25 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for :param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count partially decoded sentence, e.g. word count
and language model. and language model.
:type external_scoring_function: function :type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param nproc: Whether the decoder used in multiprocesses. :param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool :type nproc: bool
:return: Decoding log probability and result string. :return: Decoding log probabilities and result sentences in descending order.
:rtype: list :rtype: list
''' '''
# dimension check # dimension check
for prob_list in probs_seq: for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1: if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary") raise ValueError("probs dimension mismatched with vocabulary")
num_time_steps = len(probs_seq) num_time_steps = len(probs_seq)
# blank_id check # blank_id check
...@@ -137,19 +154,35 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -137,19 +154,35 @@ def ctc_beam_search_decoder(probs_seq,
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
## extend prefix in loop ## extend prefix in loop
for time_step in range(num_time_steps): for time_step in xrange(num_time_steps):
# the set containing candidate prefixes # the set containing candidate prefixes
prefix_set_next = {} prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {} probs_b_cur, probs_nb_cur = {}, {}
for l in prefix_set_prev:
prob = probs_seq[time_step] prob = probs_seq[time_step]
prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
cutoff_len = len(prob_idx)
#If pruning is enabled
if (cutoff_prob < 1.0):
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len = 0
cum_prob = 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
prob_idx = prob_idx[0:cutoff_len]
for l in prefix_set_prev:
if not prefix_set_next.has_key(l): if not prefix_set_next.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
# extend prefix by travering vocabulary # extend prefix by travering prob_idx
for c in range(0, probs_dim): for index in xrange(cutoff_len):
c, prob_c = prob_idx[index][0], prob_idx[index][1]
if c == blank_id: if c == blank_id:
probs_b_cur[l] += prob[c] * ( probs_b_cur[l] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l]) probs_b_prev[l] + probs_nb_prev[l])
else: else:
last_char = l[-1] last_char = l[-1]
...@@ -159,18 +192,18 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -159,18 +192,18 @@ def ctc_beam_search_decoder(probs_seq,
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
if new_char == last_char: if new_char == last_char:
probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l] probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
probs_nb_cur[l] += prob[c] * probs_nb_prev[l] probs_nb_cur[l] += prob_c * probs_nb_prev[l]
elif new_char == ' ': elif new_char == ' ':
if (ext_scoring_func is None) or (len(l) == 1): if (ext_scoring_func is None) or (len(l) == 1):
score = 1.0 score = 1.0
else: else:
prefix = l[1:] prefix = l[1:]
score = ext_scoring_func(prefix) score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob[c] * ( probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l]) probs_b_prev[l] + probs_nb_prev[l])
else: else:
probs_nb_cur[l_plus] += prob[c] * ( probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l]) probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next # add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[ prefix_set_next[l_plus] = probs_nb_cur[
...@@ -203,6 +236,7 @@ def ctc_beam_search_decoder_nproc(probs_split, ...@@ -203,6 +236,7 @@ def ctc_beam_search_decoder_nproc(probs_split,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id=0, blank_id=0,
cutoff_prob=1.0,
ext_scoring_func=None, ext_scoring_func=None,
num_processes=None): num_processes=None):
''' '''
...@@ -216,16 +250,19 @@ def ctc_beam_search_decoder_nproc(probs_split, ...@@ -216,16 +250,19 @@ def ctc_beam_search_decoder_nproc(probs_split,
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for :param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count partially decoded sentence, e.g. word count
and language model. and language model.
:type external_scoring_function: function :type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param num_processes: Number of processes, default None, equal to the :param num_processes: Number of processes, default None, equal to the
number of CPUs. number of CPUs.
:type num_processes: int :type num_processes: int
:return: Decoding log probability and result string. :return: Decoding log probabilities and result sentences in descending order.
:rtype: list :rtype: list
''' '''
...@@ -243,7 +280,8 @@ def ctc_beam_search_decoder_nproc(probs_split, ...@@ -243,7 +280,8 @@ def ctc_beam_search_decoder_nproc(probs_split,
pool = multiprocessing.Pool(processes=num_processes) pool = multiprocessing.Pool(processes=num_processes)
results = [] results = []
for i, probs_list in enumerate(probs_split): for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, None, nproc) args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args)) results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close() pool.close()
......
"""
Evaluation for a simplifed version of Baidu DeepSpeech2 model.
"""
import paddle.v2 as paddle
import distutils.util
import argparse
import gzip
from audio_data_utils import DataGenerator
from model import deep_speech2
from decoder import *
from error_rate import wer
parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 evaluation.')
parser.add_argument(
"--num_samples",
default=100,
type=int,
help="Number of samples for evaluation. (default: %(default)s)")
parser.add_argument(
"--num_conv_layers",
default=2,
type=int,
help="Convolution layer number. (default: %(default)s)")
parser.add_argument(
"--num_rnn_layers",
default=3,
type=int,
help="RNN layer number. (default: %(default)s)")
parser.add_argument(
"--rnn_layer_size",
default=512,
type=int,
help="RNN layer cell number. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search_nproc',
type=str,
help="Method for ctc decoding, best_path, "
"beam_search or beam_search_nproc. (default: %(default)s)")
parser.add_argument(
"--language_model_path",
default="./data/1Billion.klm",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
"--alpha",
default=0.26,
type=float,
help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
"--beta",
default=0.1,
type=float,
help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=0.99,
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
parser.add_argument(
"--beam_size",
default=500,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--normalizer_manifest_path",
default='data/manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--decode_manifest_path",
default='data/manifest.libri.test-clean',
type=str,
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='./params.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
args = parser.parse_args()
def evaluate():
"""
Evaluate on whole test data for DeepSpeech2.
"""
# initialize data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
# create network config
dict_size = data_generator.vocabulary_size()
vocab_list = data_generator.vocabulary_list()
audio_data = paddle.layer.data(
name="audio_spectrogram",
height=161,
width=2000,
type=paddle.data_type.dense_vector(322000))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size))
output_probs = deep_speech2(
audio_data=audio_data,
text_data=text_data,
dict_size=dict_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
is_inference=True)
# load parameters
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.model_filepath))
# prepare infer data
feeding = data_generator.data_name_feeding()
test_batch_reader = data_generator.batch_reader_creator(
manifest_path=args.decode_manifest_path,
batch_size=args.num_samples,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
# define inferer
inferer = paddle.inference.Inference(
output_layer=output_probs, parameters=parameters)
# initialize external scorer for beam search decoding
if args.decode_method == 'beam_search' or \
args.decode_method == 'beam_search_nproc':
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
wer_counter, wer_sum = 0, 0.0
for infer_data in test_batch_reader():
# run inference
infer_results = inferer.infer(input=infer_data)
num_steps = len(infer_results) / len(infer_data)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(infer_data))
]
# decode and print
# best path decode
if args.decode_method == "best_path":
for i, probs in enumerate(probs_split):
output_transcription = ctc_decode(
probs_seq=probs, vocabulary=vocab_list, method="best_path")
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
wer_sum += wer(target_transcription, output_transcription)
wer_counter += 1
# beam search decode in single process
elif args.decode_method == "beam_search":
for i, probs in enumerate(probs_split):
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=vocab_list,
beam_size=args.beam_size,
blank_id=len(vocab_list),
ext_scoring_func=ext_scorer,
cutoff_prob=args.cutoff_prob, )
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1
# beam search using multiple processes
elif args.decode_method == "beam_search_nproc":
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=args.beam_size,
blank_id=len(vocab_list),
ext_scoring_func=ext_scorer,
cutoff_prob=args.cutoff_prob, )
for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1
else:
raise ValueError("Decoding method [%s] is not supported." % method)
print("Cur WER = %f" % (wer_sum / wer_counter))
print("Final WER = %f" % (wer_sum / wer_counter))
def main():
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
evaluate()
if __name__ == '__main__':
main()
...@@ -9,14 +9,14 @@ import gzip ...@@ -9,14 +9,14 @@ import gzip
from audio_data_utils import DataGenerator from audio_data_utils import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import * from decoder import *
import kenlm
from error_rate import wer from error_rate import wer
import time
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 inference.') description='Simplified version of DeepSpeech2 inference.')
parser.add_argument( parser.add_argument(
"--num_samples", "--num_samples",
default=10, default=100,
type=int, type=int,
help="Number of samples for inference. (default: %(default)s)") help="Number of samples for inference. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -46,7 +46,7 @@ parser.add_argument( ...@@ -46,7 +46,7 @@ parser.add_argument(
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_manifest_path", "--decode_manifest_path",
default='data/manifest.libri.test-clean', default='data/manifest.libri.test-100sample',
type=str, type=str,
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -63,11 +63,13 @@ parser.add_argument( ...@@ -63,11 +63,13 @@ parser.add_argument(
"--decode_method", "--decode_method",
default='beam_search_nproc', default='beam_search_nproc',
type=str, type=str,
help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)" help="Method for ctc decoding:"
) " best_path,"
" beam_search, "
" or beam_search_nproc. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--beam_size", "--beam_size",
default=50, default=500,
type=int, type=int,
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument( parser.add_argument(
...@@ -82,14 +84,20 @@ parser.add_argument( ...@@ -82,14 +84,20 @@ parser.add_argument(
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--alpha", "--alpha",
default=0.0, default=0.26,
type=float, type=float,
help="Parameter associated with language model. (default: %(default)f)") help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument( parser.add_argument(
"--beta", "--beta",
default=0.0, default=0.1,
type=float, type=float,
help="Parameter associated with word count. (default: %(default)f)") help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=0.99,
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
args = parser.parse_args() args = parser.parse_args()
...@@ -154,6 +162,7 @@ def infer(): ...@@ -154,6 +162,7 @@ def infer():
## decode and print ## decode and print
# best path decode # best path decode
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
total_time = 0.0
if args.decode_method == "best_path": if args.decode_method == "best_path":
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
target_transcription = ''.join( target_transcription = ''.join(
...@@ -177,11 +186,12 @@ def infer(): ...@@ -177,11 +186,12 @@ def infer():
probs_seq=probs, probs_seq=probs,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
ext_scoring_func=ext_scorer, blank_id=len(vocab_list),
blank_id=len(vocab_list)) cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
print("\nTarget Transcription:\t%s" % target_transcription) print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample): for index in xrange(args.num_results_per_sample):
result = beam_search_result[index] result = beam_search_result[index]
#output: index, log prob, beam result #output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1])) print("Beam %d: %f \t%s" % (index, result[0], result[1]))
...@@ -190,21 +200,21 @@ def infer(): ...@@ -190,21 +200,21 @@ def infer():
wer_counter += 1 wer_counter += 1
print("cur wer = %f , average wer = %f" % print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter)) (wer_cur, wer_sum / wer_counter))
# beam search using multiple processes
elif args.decode_method == "beam_search_nproc": elif args.decode_method == "beam_search_nproc":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_nproc_results = ctc_beam_search_decoder_nproc( beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split, probs_split=probs_split,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
ext_scoring_func=ext_scorer, blank_id=len(vocab_list),
blank_id=len(vocab_list)) cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
for i, beam_search_result in enumerate(beam_search_nproc_results): for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join( target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]]) [vocab_list[index] for index in infer_data[i][1]])
print("\nTarget Transcription:\t%s" % target_transcription) print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample): for index in xrange(args.num_results_per_sample):
result = beam_search_result[index] result = beam_search_result[index]
#output: index, log prob, beam result #output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1])) print("Beam %d: %f \t%s" % (index, result[0], result[1]))
......
""" """
Tune parameters for beam search decoder in Deep Speech 2. Parameters tuning for beam search decoder in Deep Speech 2.
""" """
import paddle.v2 as paddle import paddle.v2 as paddle
...@@ -12,7 +12,7 @@ from decoder import * ...@@ -12,7 +12,7 @@ from decoder import *
from error_rate import wer from error_rate import wer
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Parameters tuning script for ctc beam search decoder in Deep Speech 2.' description='Parameters tuning for ctc beam search decoder in Deep Speech 2.'
) )
parser.add_argument( parser.add_argument(
"--num_samples", "--num_samples",
...@@ -82,34 +82,40 @@ parser.add_argument( ...@@ -82,34 +82,40 @@ parser.add_argument(
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--alpha_from", "--alpha_from",
default=0.0, default=0.1,
type=float, type=float,
help="Where alpha starts from, <= alpha_to. (default: %(default)f)") help="Where alpha starts from. (default: %(default)f)")
parser.add_argument( parser.add_argument(
"--alpha_stride", "--num_alphas",
default=0.001, default=14,
type=float, type=int,
help="Step length for varying alpha. (default: %(default)f)") help="Number of candidate alphas. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--alpha_to", "--alpha_to",
default=0.01, default=0.36,
type=float, type=float,
help="Where alpha ends with, >= alpha_from. (default: %(default)f)") help="Where alpha ends with. (default: %(default)f)")
parser.add_argument( parser.add_argument(
"--beta_from", "--beta_from",
default=0.0, default=0.05,
type=float, type=float,
help="Where beta starts from, <= beta_to. (default: %(default)f)") help="Where beta starts from. (default: %(default)f)")
parser.add_argument( parser.add_argument(
"--beta_stride", "--num_betas",
default=0.01, default=20,
type=float, type=float,
help="Step length for varying beta. (default: %(default)f)") help="Number of candidate betas. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--beta_to", "--beta_to",
default=0.0, default=1.0,
type=float,
help="Where beta ends with. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=0.99,
type=float, type=float,
help="Where beta ends with, >= beta_from. (default: %(default)f)") help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
args = parser.parse_args() args = parser.parse_args()
...@@ -118,15 +124,11 @@ def tune(): ...@@ -118,15 +124,11 @@ def tune():
Tune parameters alpha and beta on one minibatch. Tune parameters alpha and beta on one minibatch.
""" """
if not args.alpha_from <= args.alpha_to: if not args.num_alphas >= 0:
raise ValueError("alpha_from <= alpha_to doesn't satisfy!") raise ValueError("num_alphas must be non-negative!")
if not args.alpha_stride > 0:
raise ValueError("alpha_stride shouldn't be negative!")
if not args.beta_from <= args.beta_to: if not args.num_betas >= 0:
raise ValueError("beta_from <= beta_to doesn't satisfy!") raise ValueError("num_betas must be non-negative!")
if not args.beta_stride > 0:
raise ValueError("beta_stride shouldn't be negative!")
# initialize data generator # initialize data generator
data_generator = DataGenerator( data_generator = DataGenerator(
...@@ -171,6 +173,7 @@ def tune(): ...@@ -171,6 +173,7 @@ def tune():
flatten=True, flatten=True,
sort_by_duration=False, sort_by_duration=False,
shuffle=False) shuffle=False)
# get one batch data for tuning
infer_data = test_batch_reader().next() infer_data = test_batch_reader().next()
# run inference # run inference
...@@ -182,11 +185,12 @@ def tune(): ...@@ -182,11 +185,12 @@ def tune():
for i in xrange(0, len(infer_data)) for i in xrange(0, len(infer_data))
] ]
cand_alpha = np.arange(args.alpha_from, args.alpha_to + args.alpha_stride, # create grid for search
args.alpha_stride) cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
cand_beta = np.arange(args.beta_from, args.beta_to + args.beta_stride, cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
args.beta_stride) params_grid = [(alpha, beta) for alpha in cand_alphas
params_grid = [(alpha, beta) for alpha in cand_alpha for beta in cand_beta] for beta in cand_betas]
## tune parameters in loop ## tune parameters in loop
for (alpha, beta) in params_grid: for (alpha, beta) in params_grid:
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
...@@ -200,8 +204,9 @@ def tune(): ...@@ -200,8 +204,9 @@ def tune():
probs_seq=probs, probs_seq=probs,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
ext_scoring_func=ext_scorer, blank_id=len(vocab_list),
blank_id=len(vocab_list)) cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
wer_sum += wer(target_transcription, beam_search_result[0][1]) wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1 wer_counter += 1
# beam search using multiple processes # beam search using multiple processes
...@@ -210,9 +215,9 @@ def tune(): ...@@ -210,9 +215,9 @@ def tune():
probs_split=probs_split, probs_split=probs_split,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
ext_scoring_func=ext_scorer, cutoff_prob=args.cutoff_prob,
blank_id=len(vocab_list), blank_id=len(vocab_list),
num_processes=1) ext_scoring_func=ext_scorer, )
for i, beam_search_result in enumerate(beam_search_nproc_results): for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join( target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]]) [vocab_list[index] for index in infer_data[i][1]])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册