提交 63a72c1e 编写于 作者: Y Yibing Liu

refine ctc_beam_search_decoder

上级 accaf924
......@@ -8,8 +8,8 @@ import numpy as np
import multiprocessing
def ctc_best_path_decode(probs_seq, vocabulary):
"""Best path decoding, also called argmax decoding or greedy decoding.
def ctc_best_path_decoder(probs_seq, vocabulary):
"""Best path decoder, also called argmax decoder or greedy decoder.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
......@@ -40,73 +40,84 @@ def ctc_best_path_decode(probs_seq, vocabulary):
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
blank_id=0,
blank_id,
cutoff_prob=1.0,
ext_scoring_func=None,
nproc=False):
'''Beam search decoder for CTC-trained network, using beam search with width
beam_size to find many paths to one label, return beam_size labels in
the descending order of probabilities. The implementation is based on Prefix
Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned.
:param probs_seq: 2-D list with length num_time_steps, each element
is a list of normalized probabilities over vocabulary
and blank for one time step.
"""Beam search decoder for CTC-trained network. It utilizes beam search
to approximately select top best decoding labels and returning results
in the descending order. The implementation is based on Prefix
Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
one prefix may comes from different paths; 2) the if condition "if l^+ not
in A_prev then" after probabilities' computation is deprecated for it is
hard to understand and seems unnecessary.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank, default 0.
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
or language model.
:type external_scoring_func: callable
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:return: Decoding log probabilities and result sentences in descending order.
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
'''
"""
# dimension check
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatched with vocabulary")
num_time_steps = len(probs_seq)
raise ValueError("The shape of prob_seq does not match with the "
"shape of the vocabulary.")
# blank_id check
probs_dim = len(probs_seq[0])
if not blank_id < probs_dim:
if not blank_id < len(probs_seq[0]):
raise ValueError("blank_id shouldn't be greater than probs dimension")
# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_nproc().
# instantiated in ctc_beam_search_decoder_batch().
if nproc is True:
global ext_nproc_scorer
ext_scoring_func = ext_nproc_scorer
## initialize
# the set containing selected prefixes
prefix_set_prev = {'\t': 1.0}
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
# prefix_set_prev: the set containing selected prefixes
# probs_b_prev: prefixes' probability ending with blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
prefix_set_prev, probs_b_prev, probs_nb_prev = {
'\t': 1.0
}, {
'\t': 1.0
}, {
'\t': 0.0
}
## extend prefix in loop
for time_step in xrange(num_time_steps):
# the set containing candidate prefixes
prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {}
prob = probs_seq[time_step]
prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
for time_step in xrange(len(probs_seq)):
# prefix_set_next: the set containing candidate prefixes
# probs_b_cur: prefixes' probability ending with blank in current step
# probs_nb_cur: prefixes' probability ending with non-blank in current step
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}
prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx)
#If pruning is enabled
if (cutoff_prob < 1.0):
if cutoff_prob < 1.0:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len = 0
cum_prob = 0.0
cutoff_len, cum_prob = 0, 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
......@@ -162,54 +173,53 @@ def ctc_beam_search_decoder(probs_seq,
prefix_set_prev = dict(prefix_set_prev)
beam_result = []
for (seq, prob) in prefix_set_prev.items():
for seq, prob in prefix_set_prev.items():
if prob > 0.0 and len(seq) > 1:
result = seq[1:]
# score last word by external scorer
if (ext_scoring_func is not None) and (result[-1] != ' '):
prob = prob * ext_scoring_func(result)
log_prob = np.log(prob)
beam_result.append([log_prob, result])
beam_result.append((log_prob, result))
## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result
def ctc_beam_search_decoder_nproc(probs_split,
def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
blank_id=0,
blank_id,
num_processes,
cutoff_prob=1.0,
ext_scoring_func=None,
num_processes=None):
'''Beam search decoder using multiple processes.
ext_scoring_func=None):
"""CTC beam search decoder using multiple processes.
:param probs_seq: 3-D list with length batch_size, each element
is a 2-D list of probabilities can be used by
ctc_beam_search_decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank, default 0.
:param blank_id: ID of blank.
:type blank_id: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
default 0, no pruning.
default 1.0, no pruning.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param num_processes: Number of processes, default None, equal to the
number of CPUs.
:type num_processes: int
:return: Decoding log probabilities and result sentences in descending order.
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
'''
if num_processes is None:
num_processes = multiprocessing.cpu_count()
"""
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")
......@@ -227,7 +237,5 @@ def ctc_beam_search_decoder_nproc(probs_split,
pool.close()
pool.join()
beam_search_results = []
for result in results:
beam_search_results.append(result.get())
beam_search_results = [result.get() for result in results]
return beam_search_results
......@@ -3,22 +3,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.v2 as paddle
import distutils.util
import argparse
import gzip
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2
from decoder import *
from scorer import Scorer
from lm.lm_scorer import LmScorer
from error_rate import wer
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num_samples",
"--batch_size",
default=100,
type=int,
help="Number of samples for evaluation. (default: %(default)s)")
help="Minibatch size for evaluation. (default: %(default)s)")
parser.add_argument(
"--num_conv_layers",
default=2,
......@@ -39,6 +39,16 @@ parser.add_argument(
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--num_processes_beam_search",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu processes for beam search. (default: %(default)s)")
parser.add_argument(
"--mean_std_filepath",
default='mean_std.npz',
......@@ -46,10 +56,10 @@ parser.add_argument(
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search_nproc',
default='beam_search',
type=str,
help="Method for ctc decoding, best_path, "
"beam_search or beam_search_nproc. (default: %(default)s)")
help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
)
parser.add_argument(
"--language_model_path",
default="data/en.00.UNKNOWN.klm",
......@@ -76,11 +86,6 @@ parser.add_argument(
default=500,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--normalizer_manifest_path",
default='data/manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--decode_manifest_path",
default='data/manifest.libri.test-clean',
......@@ -88,7 +93,7 @@ parser.add_argument(
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='./params.tar.gz',
default='checkpoints/params.latest.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
......@@ -101,12 +106,12 @@ args = parser.parse_args()
def evaluate():
"""Evaluate on whole test data for DeepSpeech2."""
# initialize data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}')
augmentation_config='{}',
num_threads=args.num_threads_data)
# create network config
# paddle.data_type.dense_array is used for variable batch input.
......@@ -133,7 +138,7 @@ def evaluate():
# prepare infer data
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.decode_manifest_path,
batch_size=args.num_samples,
batch_size=args.batch_size,
sortagrad=False,
shuffle_method=None)
......@@ -142,9 +147,8 @@ def evaluate():
output_layer=output_probs, parameters=parameters)
# initialize external scorer for beam search decoding
if args.decode_method == 'beam_search' or \
args.decode_method == 'beam_search_nproc':
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
if args.decode_method == 'beam_search':
ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path)
wer_counter, wer_sum = 0, 0.0
for infer_data in batch_reader():
......@@ -155,56 +159,39 @@ def evaluate():
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(infer_data))
]
# target transcription
target_transcription = [
''.join([
data_generator.vocab_list[index] for index in infer_data[i][1]
]) for i, probs in enumerate(probs_split)
]
# decode and print
# best path decode
if args.decode_method == "best_path":
for i, probs in enumerate(probs_split):
output_transcription = ctc_best_path_decode(
output_transcription = ctc_best_path_decoder(
probs_seq=probs, vocabulary=data_generator.vocab_list)
target_transcription = ''.join([
data_generator.vocab_list[index]
for index in infer_data[i][1]
])
wer_sum += wer(target_transcription, output_transcription)
wer_sum += wer(target_transcription[i], output_transcription)
wer_counter += 1
# beam search decode in single process
# beam search decode
elif args.decode_method == "beam_search":
for i, probs in enumerate(probs_split):
target_transcription = ''.join([
data_generator.vocab_list[index]
for index in infer_data[i][1]
])
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=data_generator.vocab_list,
beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list),
ext_scoring_func=ext_scorer,
cutoff_prob=args.cutoff_prob, )
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1
# beam search using multiple processes
elif args.decode_method == "beam_search_nproc":
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
vocabulary=data_generator.vocab_list,
beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list),
num_processes=args.num_processes_beam_search,
ext_scoring_func=ext_scorer,
cutoff_prob=args.cutoff_prob, )
for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join([
data_generator.vocab_list[index]
for index in infer_data[i][1]
])
wer_sum += wer(target_transcription, beam_search_result[0][1])
for i, beam_search_result in enumerate(beam_search_results):
wer_sum += wer(target_transcription[i],
beam_search_result[0][1])
wer_counter += 1
else:
raise ValueError("Decoding method [%s] is not supported." %
decode_method)
print("Cur WER = %f" % (wer_sum / wer_counter))
print("Final WER = %f" % (wer_sum / wer_counter))
......
......@@ -11,14 +11,14 @@ import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2
from decoder import *
from scorer import Scorer
from lm.lm_scorer import LmScorer
from error_rate import wer
import utils
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num_samples",
default=100,
default=10,
type=int,
help="Number of samples for inference. (default: %(default)s)")
parser.add_argument(
......@@ -46,6 +46,11 @@ parser.add_argument(
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--num_processes_beam_search",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu processes for beam search. (default: %(default)s)")
parser.add_argument(
"--mean_std_filepath",
default='mean_std.npz',
......@@ -53,12 +58,12 @@ parser.add_argument(
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--decode_manifest_path",
default='data/manifest.libri.test-100sample',
default='datasets/manifest.test',
type=str,
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='checkpoints/params.latest.tar.gz',
default='checkpoints/params.tar.gz.41',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
......@@ -68,12 +73,10 @@ parser.add_argument(
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search_nproc',
default='beam_search',
type=str,
help="Method for ctc decoding:"
" best_path,"
" beam_search, "
" or beam_search_nproc. (default: %(default)s)")
help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)"
)
parser.add_argument(
"--beam_size",
default=500,
......@@ -86,7 +89,7 @@ parser.add_argument(
help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument(
"--language_model_path",
default="data/en.00.UNKNOWN.klm",
default="lm/data/en.00.UNKNOWN.klm",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
......@@ -143,6 +146,7 @@ def infer():
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.decode_manifest_path,
batch_size=args.num_samples,
min_batch_size=1,
sortagrad=False,
shuffle_method=None)
infer_data = batch_reader().next()
......@@ -156,68 +160,45 @@ def infer():
for i in xrange(len(infer_data))
]
# targe transcription
target_transcription = [
''.join(
[data_generator.vocab_list[index] for index in infer_data[i][1]])
for i, probs in enumerate(probs_split)
]
## decode and print
# best path decode
wer_sum, wer_counter = 0, 0
if args.decode_method == "best_path":
for i, probs in enumerate(probs_split):
target_transcription = ''.join([
data_generator.vocab_list[index] for index in infer_data[i][1]
])
best_path_transcription = ctc_best_path_decode(
best_path_transcription = ctc_best_path_decoder(
probs_seq=probs, vocabulary=data_generator.vocab_list)
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target_transcription, best_path_transcription))
wer_cur = wer(target_transcription, best_path_transcription)
(target_transcription[i], best_path_transcription))
wer_cur = wer(target_transcription[i], best_path_transcription)
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f, average wer = %f" %
(wer_cur, wer_sum / wer_counter))
# beam search decode
elif args.decode_method == "beam_search":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
for i, probs in enumerate(probs_split):
target_transcription = ''.join([
data_generator.vocab_list[index] for index in infer_data[i][1]
])
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=data_generator.vocab_list,
beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list),
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
print("\nTarget Transcription:\t%s" % target_transcription)
for index in xrange(args.num_results_per_sample):
result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
elif args.decode_method == "beam_search_nproc":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path)
beam_search_batch_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
vocabulary=data_generator.vocab_list,
beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list),
num_processes=args.num_processes_beam_search,
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join([
data_generator.vocab_list[index] for index in infer_data[i][1]
])
print("\nTarget Transcription:\t%s" % target_transcription)
for i, beam_search_result in enumerate(beam_search_batch_results):
print("\nTarget Transcription:\t%s" % target_transcription[i])
for index in xrange(args.num_results_per_sample):
result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1])
wer_cur = wer(target_transcription[i], beam_search_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
......
......@@ -8,13 +8,16 @@ import kenlm
import numpy as np
class Scorer(object):
"""External defined scorer to evaluate a sentence in beam search
decoding, consisting of language model and word count.
class LmScorer(object):
"""External scorer to evaluate a prefix or whole sentence in
beam search decoding, including the score from n-gram language
model and word count.
:param alpha: Parameter associated with language model.
:param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count.
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float
:model_path: Path to load language model.
:type model_path: basestring
......@@ -28,14 +31,14 @@ class Scorer(object):
self._language_model = kenlm.LanguageModel(model_path)
# n-gram language model scoring
def language_model_score(self, sentence):
def _language_model_score(self, sentence):
#log10 prob of last word
log_cond_prob = list(
self._language_model.full_scores(sentence, eos=False))[-1][0]
return np.power(10, log_cond_prob)
# word insertion term
def word_count(self, sentence):
def _word_count(self, sentence):
words = sentence.strip().split(' ')
return len(words)
......@@ -51,8 +54,8 @@ class Scorer(object):
:return: Evaluation score, in the decimal or log.
:rtype: float
"""
lm = self.language_model_score(sentence)
word_cnt = self.word_count(sentence)
lm = self._language_model_score(sentence)
word_cnt = self._word_count(sentence)
if log == False:
score = np.power(lm, self._alpha) \
* np.power(word_cnt, self._beta)
......
echo "Downloading language model."
wget -c ftp://xxx/xxx/en.00.UNKNOWN.klm -P ./data
SoundFile==0.9.0.post1
wget==3.2
scipy==0.13.1
https://github.com/kpu/kenlm/archive/master.zip
......@@ -53,11 +53,11 @@ class TestDecoders(unittest.TestCase):
self.beam_search_result = ['acdc', "b'a"]
def test_best_path_decoder_1(self):
bst_result = ctc_best_path_decode(self.probs_seq1, self.vocab_list)
bst_result = ctc_best_path_decoder(self.probs_seq1, self.vocab_list)
self.assertEqual(bst_result, self.best_path_result[0])
def test_best_path_decoder_2(self):
bst_result = ctc_best_path_decode(self.probs_seq2, self.vocab_list)
bst_result = ctc_best_path_decoder(self.probs_seq2, self.vocab_list)
self.assertEqual(bst_result, self.best_path_result[1])
def test_beam_search_decoder_1(self):
......@@ -77,7 +77,7 @@ class TestDecoders(unittest.TestCase):
self.assertEqual(beam_result[0][1], self.beam_search_result[1])
def test_beam_search_nproc_decoder(self):
beam_results = ctc_beam_search_decoder_nproc(
beam_results = ctc_beam_search_decoder_batch(
probs_split=[self.probs_seq1, self.probs_seq2],
beam_size=self.beam_size,
vocabulary=self.vocab_list,
......
......@@ -3,14 +3,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.v2 as paddle
import distutils.util
import argparse
import gzip
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2
from decoder import *
from scorer import Scorer
from lm.lm_scorer import LmScorer
from error_rate import wer
parser = argparse.ArgumentParser(description=__doc__)
......@@ -39,24 +39,29 @@ parser.add_argument(
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--num_processes_beam_search",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu processes for beam search. (default: %(default)s)")
parser.add_argument(
"--mean_std_filepath",
default='mean_std.npz',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--normalizer_manifest_path",
default='data/manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--decode_manifest_path",
default='data/manifest.libri.test-100sample',
default='datasets/manifest.test',
type=str,
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='./params.tar.gz',
default='checkpoints/params.latest.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
......@@ -64,25 +69,14 @@ parser.add_argument(
default='datasets/vocab/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search_nproc',
type=str,
help="Method for decoding, beam_search or beam_search_nproc. (default: %(default)s)"
)
parser.add_argument(
"--beam_size",
default=500,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--num_results_per_sample",
default=1,
type=int,
help="Number of outputs per sample in beam search. (default: %(default)d)")
parser.add_argument(
"--language_model_path",
default="data/en.00.UNKNOWN.klm",
default="lm/data/en.00.UNKNOWN.klm",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
......@@ -137,7 +131,8 @@ def tune():
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}')
augmentation_config='{}',
num_threads=args.num_threads_data)
# create network config
# paddle.data_type.dense_array is used for variable batch input.
......@@ -188,42 +183,22 @@ def tune():
## tune parameters in loop
for (alpha, beta) in params_grid:
wer_sum, wer_counter = 0, 0
ext_scorer = Scorer(alpha, beta, args.language_model_path)
# beam search decode
if args.decode_method == "beam_search":
for i, probs in enumerate(probs_split):
target_transcription = ''.join([
data_generator.vocab_list[index]
for index in infer_data[i][1]
])
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=data_generator.vocab_list,
beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list),
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1
ext_scorer = LmScorer(alpha, beta, args.language_model_path)
# beam search using multiple processes
elif args.decode_method == "beam_search_nproc":
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
vocabulary=data_generator.vocab_list,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
blank_id=len(data_generator.vocab_list),
num_processes=args.num_processes_beam_search,
ext_scoring_func=ext_scorer, )
for i, beam_search_result in enumerate(beam_search_nproc_results):
for i, beam_search_result in enumerate(beam_search_results):
target_transcription = ''.join([
data_generator.vocab_list[index]
for index in infer_data[i][1]
data_generator.vocab_list[index] for index in infer_data[i][1]
])
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1
else:
raise ValueError("Decoding method [%s] is not supported." %
decode_method)
print("alpha = %f\tbeta = %f\tWER = %f" %
(alpha, beta, wer_sum / wer_counter))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册