未验证 提交 422f55a5 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #122 from kuke/fix_tune

Decouple ext scorer init & inference & decoding for the convenience o…
......@@ -160,22 +160,30 @@ def start_server():
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
if args.decoding_method == "ctc_beam_search":
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
# prepare ASR inference handler
def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "")
result_transcript = ds2_model.infer_batch(
probs_split = ds2_model.infer_batch_probs(
infer_data=[feature],
decoding_method=args.decoding_method,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=1,
feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcript = ds2_model.decode_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
result_transcript = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
num_processes=1)
return result_transcript[0]
# warming up with utterrances sampled from Librispeech
......
......@@ -7,7 +7,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -u tools/tune.py \
--num_batches=-1 \
--batch_size=128 \
--trainer_count=8 \
--trainer_count=4 \
--beam_size=500 \
--num_proc_bsearch=12 \
--num_conv_layers=2 \
......
......@@ -90,18 +90,28 @@ def infer():
# decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
result_transcripts = ds2_model.infer_batch(
infer_data=infer_data,
decoding_method=args.decoding_method,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch,
feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
ds2_model.logger.info("start inference ...")
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
feeding_dict=data_generator.feeding)
result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
ds2_model.logger.info("start inference ...")
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
feeding_dict=data_generator.feeding)
result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
num_processes=args.num_proc_bsearch)
error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [data[1] for data in infer_data]
......
......@@ -173,43 +173,19 @@ class DeepSpeech2Model(object):
# run inference
return self._loss_inferer.infer(input=infer_data)
def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n, vocab_list,
language_model_path, num_processes, feeding_dict):
"""Model inference. Infer the transcription for a batch of speech
utterances.
def infer_batch_probs(self, infer_data, feeding_dict):
"""Infer the prob matrices for a batch of speech utterances.
:param infer_data: List of utterances to infer, with each utterance
consisting of a tuple of audio features and
transcription text (empty string).
:type infer_data: list
:param decoding_method: Decoding method name, 'ctc_greedy' or
'ctc_beam_search'.
:param decoding_method: string
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param beam_size: Width for Beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param language_model_path: Filepath for language model.
:type language_model_path: basestring|None
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:param feeding_dict: Feeding is a map of field name and tuple index
of the data that reader returns.
:type feeding_dict: dict|list
:return: List of transcription texts.
:rtype: List of basestring
:return: List of 2-D probability matrix, and each consists of prob
vectors for one speech utterancce.
:rtype: List of matrix
"""
# define inferer
if self._inferer == None:
......@@ -227,49 +203,102 @@ class DeepSpeech2Model(object):
infer_results[start_pos[i]:start_pos[i + 1]]
for i in xrange(0, len(adapted_infer_data))
]
# run decoder
return probs_split
def decode_batch_greedy(self, probs_split, vocab_list):
"""Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:return: List of transcription texts.
:rtype: List of basestring
"""
results = []
if decoding_method == "ctc_greedy":
# best path decode
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list)
results.append(output_transcription)
elif decoding_method == "ctc_beam_search":
# initialize external scorer
if self._ext_scorer == None:
self._loaded_lm_path = language_model_path
self.logger.info("begin to initialize the external scorer "
"for decoding")
self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path, vocab_list)
lm_char_based = self._ext_scorer.is_character_based()
lm_max_order = self._ext_scorer.get_max_order()
lm_dict_size = self._ext_scorer.get_dict_size()
self.logger.info("language model: "
"is_character_based = %d," % lm_char_based +
" max_order = %d," % lm_max_order +
" dict_size = %d" % lm_dict_size)
self.logger.info("end initializing scorer. Start decoding ...")
else:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
assert self._loaded_lm_path == language_model_path
# beam search decode
num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=beam_size,
num_processes=num_processes,
ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n)
results = [result[0][1] for result in beam_search_results]
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list)
results.append(output_transcription)
return results
def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
vocab_list):
"""Initialize the external scorer.
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param language_model_path: Filepath for language model. If it is
empty, the external scorer will be set to
None, and the decoding method will be pure
beam search without scorer.
:type language_model_path: basestring|None
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
"""
if language_model_path != '':
self.logger.info("begin to initialize the external scorer "
"for decoding")
self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path, vocab_list)
lm_char_based = self._ext_scorer.is_character_based()
lm_max_order = self._ext_scorer.get_max_order()
lm_dict_size = self._ext_scorer.get_dict_size()
self.logger.info("language model: "
"is_character_based = %d," % lm_char_based +
" max_order = %d," % lm_max_order +
" dict_size = %d" % lm_dict_size)
self.logger.info("end initializing scorer")
else:
raise ValueError("Decoding method [%s] is not supported." %
decoding_method)
self._ext_scorer = None
self.logger.info("no language model provided, "
"decoding by pure beam search without scorer.")
def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
vocab_list, num_processes):
"""Decode by beam search for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param beam_size: Width for Beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:return: List of transcription texts.
:rtype: List of basestring
"""
if self._ext_scorer != None:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
# beam search decode
num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=beam_size,
num_processes=num_processes,
ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n)
results = [result[0][1] for result in beam_search_results]
return results
def _adapt_feeding_dict(self, feeding_dict):
......
......@@ -90,22 +90,33 @@ def evaluate():
# decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
if args.decoding_method == "ctc_beam_search":
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
errors_sum, len_refs, num_ins = 0.0, 0, 0
ds2_model.logger.info("start evaluation ...")
for infer_data in batch_reader():
result_transcripts = ds2_model.infer_batch(
probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data,
decoding_method=args.decoding_method,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch,
feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
num_processes=args.num_proc_bsearch)
target_transcripts = [data[1] for data in infer_data]
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
......
......@@ -13,9 +13,7 @@ import logging
import paddle.v2 as paddle
import _init_paths
from data_utils.data import DataGenerator
from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from model_utils.model import deep_speech_v2_network
from model_utils.model import DeepSpeech2Model
from utils.error_rate import char_errors, word_errors
from utils.utility import add_arguments, print_arguments
......@@ -72,9 +70,6 @@ add_arg('specgram_type', str,
args = parser.parse_args()
logging.basicConfig(
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
def tune():
"""Tune parameters alpha and beta incrementally."""
if not args.num_alphas >= 0:
......@@ -88,40 +83,7 @@ def tune():
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data,
keep_transcription_text=True,
num_conv_layers=args.num_conv_layers)
audio_data = paddle.layer.data(
name="audio_spectrogram",
type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
seq_offset_data = paddle.layer.data(
name='sequence_offset',
type=paddle.data_type.integer_value_sequence(1))
seq_len_data = paddle.layer.data(
name='sequence_length',
type=paddle.data_type.integer_value_sequence(1))
index_range_datas = []
for i in xrange(args.num_rnn_layers):
index_range_datas.append(
paddle.layer.data(
name='conv%d_index_range' % i,
type=paddle.data_type.dense_vector(6)))
output_probs, _ = deep_speech_v2_network(
audio_data=audio_data,
text_data=text_data,
seq_offset_data=seq_offset_data,
seq_len_data=seq_len_data,
index_range_datas=index_range_datas,
dict_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
use_gru=args.use_gru,
share_rnn_weights=args.share_rnn_weights)
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.tune_manifest,
......@@ -129,35 +91,17 @@ def tune():
sortagrad=False,
shuffle_method=None)
# load parameters
if not os.path.isfile(args.model_path):
raise IOError("Invaid model path: %s" % args.model_path)
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.model_path))
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru,
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)
inferer = paddle.inference.Inference(
output_layer=output_probs, parameters=parameters)
# decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
# init logger
logger = logging.getLogger("")
logger.setLevel(level=logging.INFO)
# init external scorer
logger.info("begin to initialize the external scorer for tuning")
if not os.path.isfile(args.lang_model_path):
raise IOError("Invaid language model path: %s" % args.lang_model_path)
ext_scorer = Scorer(
alpha=args.alpha_from,
beta=args.beta_from,
model_path=args.lang_model_path,
vocabulary=vocab_list)
logger.info("language model: "
"is_character_based = %d," % ext_scorer.is_character_based() +
" max_order = %d," % ext_scorer.get_max_order() +
" dict_size = %d" % ext_scorer.get_dict_size())
logger.info("end initializing scorer. Start tuning ...")
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
# create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
......@@ -168,37 +112,32 @@ def tune():
err_sum = [0.0 for i in xrange(len(params_grid))]
err_ave = [0.0 for i in xrange(len(params_grid))]
num_ins, len_refs, cur_batch = 0, 0, 0
# initialize external scorer
ds2_model.init_ext_scorer(args.alpha_from, args.beta_from,
args.lang_model_path, vocab_list)
## incremental tuning parameters over multiple batches
ds2_model.logger.info("start tuning ...")
for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break
infer_results = inferer.infer(input=infer_data,
feeding=data_generator.feeding)
start_pos = [0] * (len(infer_data) + 1)
for i in xrange(len(infer_data)):
start_pos[i + 1] = start_pos[i] + infer_data[i][3][0]
probs_split = [
infer_results[start_pos[i]:start_pos[i + 1]]
for i in xrange(0, len(infer_data))
]
probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data,
feeding_dict=data_generator.feeding)
target_transcripts = [ data[1] for data in infer_data ]
num_ins += len(target_transcripts)
# grid search
for index, (alpha, beta) in enumerate(params_grid):
# reset alpha & beta
ext_scorer.reset_params(alpha, beta)
beam_search_results = ctc_beam_search_decoder_batch(
result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
vocabulary=vocab_list,
beam_alpha=alpha,
beam_beta=beta,
beam_size=args.beam_size,
num_processes=args.num_proc_bsearch,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
ext_scoring_func=ext_scorer, )
vocab_list=vocab_list,
num_processes=args.num_proc_bsearch)
result_transcripts = [res[0][1] for res in beam_search_results]
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
err_sum[index] += errors
......@@ -235,7 +174,7 @@ def tune():
% (cur_batch, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1]))
logger.info("finish tuning")
ds2_model.logger.info("finish tuning")
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册