提交 1db71425 编写于 作者: Y Yibing Liu

avoid repeated infer for same batch in the tuning of DS2

上级 6cac602b
...@@ -7,10 +7,14 @@ import sys ...@@ -7,10 +7,14 @@ import sys
import numpy as np import numpy as np
import argparse import argparse
import functools import functools
import gzip
import logging
import paddle.v2 as paddle import paddle.v2 as paddle
import _init_paths import _init_paths
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from model_utils.model import deep_speech_v2_network
from utils.error_rate import wer, cer from utils.error_rate import wer, cer
from utils.utility import add_arguments, print_arguments from utils.utility import add_arguments, print_arguments
...@@ -66,6 +70,9 @@ add_arg('specgram_type', str, ...@@ -66,6 +70,9 @@ add_arg('specgram_type', str,
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
def tune(): def tune():
"""Tune parameters alpha and beta incrementally.""" """Tune parameters alpha and beta incrementally."""
if not args.num_alphas >= 0: if not args.num_alphas >= 0:
...@@ -79,29 +86,55 @@ def tune(): ...@@ -79,29 +86,55 @@ def tune():
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=1) num_threads=1)
audio_data = paddle.layer.data(
name="audio_spectrogram",
type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
output_probs, _ = deep_speech_v2_network(
audio_data=audio_data,
text_data=text_data,
dict_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
use_gru=args.use_gru,
share_rnn_weights=args.share_rnn_weights)
batch_reader = data_generator.batch_reader_creator( batch_reader = data_generator.batch_reader_creator(
manifest_path=args.tune_manifest, manifest_path=args.tune_manifest,
batch_size=args.batch_size, batch_size=args.batch_size,
sortagrad=False, sortagrad=False,
shuffle_method=None) shuffle_method=None)
tune_data = batch_reader().next()
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in tune_data
]
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru,
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)
# load parameters
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.model_path))
inferer = paddle.inference.Inference(
output_layer=output_probs, parameters=parameters)
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
# init logger
logger = logging.getLogger("")
logger.setLevel(level=logging.INFO)
# init external scorer
logger.info("begin to initialize the external scorer for tuning")
ext_scorer = Scorer(
alpha=args.alpha_from,
beta=args.beta_from,
model_path=args.lang_model_path,
vocabulary=vocab_list)
logger.info("language model: "
"is_character_based = %d," % ext_scorer.is_character_based() +
" max_order = %d," % ext_scorer.get_max_order() +
" dict_size = %d" % ext_scorer.get_dict_size())
logger.info("end initializing scorer. Start tuning ...")
error_rate_func = cer if args.error_rate_type == 'cer' else wer error_rate_func = cer if args.error_rate_type == 'cer' else wer
# create grid for search # create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
...@@ -116,6 +149,13 @@ def tune(): ...@@ -116,6 +149,13 @@ def tune():
for infer_data in batch_reader(): for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches): if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break break
infer_results = inferer.infer(input=infer_data)
num_steps = len(infer_results) // len(infer_data)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(len(infer_data))
]
target_transcripts = [ target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript]) ''.join([data_generator.vocab_list[token] for token in transcript])
...@@ -125,18 +165,18 @@ def tune(): ...@@ -125,18 +165,18 @@ def tune():
num_ins += len(target_transcripts) num_ins += len(target_transcripts)
# grid search # grid search
for index, (alpha, beta) in enumerate(params_grid): for index, (alpha, beta) in enumerate(params_grid):
result_transcripts = ds2_model.infer_batch( # reset alpha & beta
infer_data=infer_data, ext_scorer.reset_params(alpha, beta)
decoding_method='ctc_beam_search', beam_search_results = ctc_beam_search_decoder_batch(
beam_alpha=alpha, probs_split=probs_split,
beam_beta=beta, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, num_processes=args.num_proc_bsearch,
cutoff_top_n=args.cutoff_top_n, cutoff_prob=args.cutoff_prob,
vocab_list=vocab_list, cutoff_top_n=args.cutoff_top_n,
language_model_path=args.lang_model_path, ext_scoring_func=ext_scorer, )
num_processes=args.num_proc_bsearch)
result_transcripts = [res[0][1] for res in beam_search_results]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
err_sum[index] += error_rate_func(target, result) err_sum[index] += error_rate_func(target, result)
err_ave[index] = err_sum[index] / num_ins err_ave[index] = err_sum[index] / num_ins
...@@ -167,7 +207,7 @@ def tune(): ...@@ -167,7 +207,7 @@ def tune():
% (args.num_batches, "%.3f" % params_grid[min_index][0], % (args.num_batches, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1])) "%.3f" % params_grid[min_index][1]))
ds2_model.logger.info("finish inference") logger.info("finish tuning")
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册