提交 1043ea51 编写于 作者: X Xinghai Sun

Refactor decoder interfaces and add ./data directory.

上级 ec9cce9e
...@@ -16,7 +16,9 @@ For some machines, we also need to install libsndfile1. Details to be added. ...@@ -16,7 +16,9 @@ For some machines, we also need to install libsndfile1. Details to be added.
### Preparing Dataset(s) ### Preparing Dataset(s)
``` ```
cd data
python librispeech.py python librispeech.py
cd ..
``` ```
More help for arguments: More help for arguments:
......
"""
CTC-like decoder utilitis.
"""
from itertools import groupby
import numpy as np
def ctc_best_path_decode(probs_seq, vocabulary):
"""
Best path decoding, also called argmax decoding or greedy decoding.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: baseline
"""
# dimension verification
for probs in probs_seq:
if not len(probs) == len(vocabulary) + 1:
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
# argmax to get the best index for each time step
max_index_list = list(np.array(probs_seq).argmax(axis=1))
# remove consecutive duplicate indexes
index_list = [index_group[0] for index_group in groupby(max_index_list)]
# remove blank indexes
blank_index = len(vocabulary)
index_list = [index for index in index_list if index != blank_index]
# convert index list to string
return ''.join([vocabulary[index] for index in index_list])
def ctc_decode(probs_seq, vocabulary, method):
"""
CTC-like sequence decoding from a sequence of likelihood probablilites.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param method: Decoding method name, with options: "best_path".
:type method: basestring
:return: Decoding result string.
:rtype: baseline
"""
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary")
if method == "best_path":
return ctc_best_path_decode(probs_seq, vocabulary)
else:
raise ValueError("Decoding method [%s] is not supported.")
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
""" """
import paddle.v2 as paddle import paddle.v2 as paddle
from itertools import groupby
import distutils.util import distutils.util
import argparse import argparse
import gzip import gzip
from audio_data_utils import DataGenerator from audio_data_utils import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import ctc_decode
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 inference.') description='Simplified version of DeepSpeech2 inference.')
...@@ -39,12 +39,12 @@ parser.add_argument( ...@@ -39,12 +39,12 @@ parser.add_argument(
help="Use gpu or not. (default: %(default)s)") help="Use gpu or not. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--normalizer_manifest_path", "--normalizer_manifest_path",
default='./manifest.libri.train-clean-100', default='data/manifest.libri.train-clean-100',
type=str, type=str,
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_manifest_path", "--decode_manifest_path",
default='./manifest.libri.test-clean', default='data/manifest.libri.test-clean',
type=str, type=str,
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -52,34 +52,28 @@ parser.add_argument( ...@@ -52,34 +52,28 @@ parser.add_argument(
default='./params.tar.gz', default='./params.tar.gz',
type=str, type=str,
help="Model filepath. (default: %(default)s)") help="Model filepath. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
def remove_duplicate_and_blank(id_list, blank_id): def infer():
"""
Postprocessing for max-ctc-decoder.
- remove consecutive duplicate tokens.
- remove blanks.
"""
# remove consecutive duplicate tokens
id_list = [x[0] for x in groupby(id_list)]
# remove blanks
return [id for id in id_list if id != blank_id]
def best_path_decode():
""" """
Max-ctc-decoding for DeepSpeech2. Max-ctc-decoding for DeepSpeech2.
""" """
# initialize data generator # initialize data generator
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath='eng_vocab.txt', vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path, normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200, normalizer_num_samples=200,
max_duration=20.0, max_duration=20.0,
min_duration=0.0, min_duration=0.0,
stride_ms=10, stride_ms=10,
window_ms=20) window_ms=20)
# create network config # create network config
dict_size = data_generator.vocabulary_size() dict_size = data_generator.vocabulary_size()
vocab_list = data_generator.vocabulary_list() vocab_list = data_generator.vocabulary_list()
...@@ -91,13 +85,14 @@ def best_path_decode(): ...@@ -91,13 +85,14 @@ def best_path_decode():
text_data = paddle.layer.data( text_data = paddle.layer.data(
name="transcript_text", name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size)) type=paddle.data_type.integer_value_sequence(dict_size))
_, max_id = deep_speech2( output_probs = deep_speech2(
audio_data=audio_data, audio_data=audio_data,
text_data=text_data, text_data=text_data,
dict_size=dict_size, dict_size=dict_size,
num_conv_layers=args.num_conv_layers, num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers, num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size) rnn_size=args.rnn_layer_size,
is_inference=True)
# load parameters # load parameters
parameters = paddle.parameters.Parameters.from_tar( parameters = paddle.parameters.Parameters.from_tar(
...@@ -114,30 +109,28 @@ def best_path_decode(): ...@@ -114,30 +109,28 @@ def best_path_decode():
shuffle=False) shuffle=False)
infer_data = test_batch_reader().next() infer_data = test_batch_reader().next()
# run max-ctc-decoding # run inference
max_id_results = paddle.infer( infer_results = paddle.infer(
output_layer=max_id, output_layer=output_probs, parameters=parameters, input=infer_data)
parameters=parameters, num_steps = len(infer_results) / len(infer_data)
input=infer_data, probs_split = [
field=['id']) infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(infer_data))
# postprocess
instance_length = len(max_id_results) / args.num_samples
instance_list = [
max_id_results[i * instance_length:(i + 1) * instance_length]
for i in xrange(0, args.num_samples)
] ]
for i, instance in enumerate(instance_list):
id_list = remove_duplicate_and_blank(instance, dict_size) # decode and print
output_transcript = ''.join([vocab_list[id] for id in id_list]) for i, probs in enumerate(probs_split):
target_transcript = ''.join([vocab_list[id] for id in infer_data[i][1]]) output_transcription = ctc_decode(
print("Target Transcript: %s \nOutput Transcript: %s \n" % probs_seq=probs, vocabulary=vocab_list, method="best_path")
(target_transcript, output_transcript)) target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
print("Target Transcription: %s \nOutput Transcription: %s \n" %
(target_transcription, output_transcription))
def main(): def main():
paddle.init(use_gpu=args.use_gpu, trainer_count=1) paddle.init(use_gpu=args.use_gpu, trainer_count=1)
best_path_decode() infer()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -85,7 +85,8 @@ def deep_speech2(audio_data, ...@@ -85,7 +85,8 @@ def deep_speech2(audio_data,
dict_size, dict_size,
num_conv_layers=2, num_conv_layers=2,
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=256): rnn_size=256,
is_inference=False):
""" """
The whole DeepSpeech2 model structure (a simplified version). The whole DeepSpeech2 model structure (a simplified version).
...@@ -101,7 +102,12 @@ def deep_speech2(audio_data, ...@@ -101,7 +102,12 @@ def deep_speech2(audio_data,
:type num_rnn_layers: int :type num_rnn_layers: int
:param rnn_size: RNN layer size (number of RNN cells). :param rnn_size: RNN layer size (number of RNN cells).
:type rnn_size: int :type rnn_size: int
:return: Tuple of the cost layer and the max_id decoder layer. :param is_inference: False in the training mode, and True in the
inferene mode.
:type is_inference: bool
:return: If is_inference set False, return a ctc cost layer;
if is_inference set True, return a sequence layer of output
probability distribution.
:rtype: tuple of LayerOutput :rtype: tuple of LayerOutput
""" """
# convolution group # convolution group
...@@ -118,19 +124,21 @@ def deep_speech2(audio_data, ...@@ -118,19 +124,21 @@ def deep_speech2(audio_data,
# rnn group # rnn group
rnn_group_output = rnn_group( rnn_group_output = rnn_group(
input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers) input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers)
# output token distribution
fc = paddle.layer.fc( fc = paddle.layer.fc(
input=rnn_group_output, input=rnn_group_output,
size=dict_size + 1, size=dict_size + 1,
act=paddle.activation.Linear(), act=paddle.activation.Linear(),
bias_attr=True) bias_attr=True)
# ctc cost if is_inference:
cost = paddle.layer.warp_ctc( # probability distribution with softmax
input=fc, return paddle.layer.mixed(
label=text_data, input=paddle.layer.identity_projection(input=fc),
size=dict_size + 1, act=paddle.activation.Softmax())
blank=dict_size, else:
norm_by_times=True) # ctc cost
# max decoder return paddle.layer.warp_ctc(
max_id = paddle.layer.max_id(input=fc) input=fc,
return cost, max_id label=text_data,
size=dict_size + 1,
blank=dict_size,
norm_by_times=True)
...@@ -60,19 +60,24 @@ parser.add_argument( ...@@ -60,19 +60,24 @@ parser.add_argument(
help="Trainer number. (default: %(default)s)") help="Trainer number. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--normalizer_manifest_path", "--normalizer_manifest_path",
default='./manifest.libri.train-clean-100', default='data/manifest.libri.train-clean-100',
type=str, type=str,
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--train_manifest_path", "--train_manifest_path",
default='./manifest.libri.train-clean-100', default='data/manifest.libri.train-clean-100',
type=str, type=str,
help="Manifest path for training. (default: %(default)s)") help="Manifest path for training. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--dev_manifest_path", "--dev_manifest_path",
default='./manifest.libri.dev-clean', default='data/manifest.libri.dev-clean',
type=str, type=str,
help="Manifest path for validation. (default: %(default)s)") help="Manifest path for validation. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
...@@ -82,7 +87,7 @@ def train(): ...@@ -82,7 +87,7 @@ def train():
""" """
# initialize data generator # initialize data generator
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath='eng_vocab.txt', vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path, normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200, normalizer_num_samples=200,
max_duration=20.0, max_duration=20.0,
...@@ -100,13 +105,14 @@ def train(): ...@@ -100,13 +105,14 @@ def train():
text_data = paddle.layer.data( text_data = paddle.layer.data(
name="transcript_text", name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size)) type=paddle.data_type.integer_value_sequence(dict_size))
cost, _ = deep_speech2( cost = deep_speech2(
audio_data=audio_data, audio_data=audio_data,
text_data=text_data, text_data=text_data,
dict_size=dict_size, dict_size=dict_size,
num_conv_layers=args.num_conv_layers, num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers, num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size) rnn_size=args.rnn_layer_size,
is_inference=False)
# create parameters and optimizer # create parameters and optimizer
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
...@@ -118,21 +124,21 @@ def train(): ...@@ -118,21 +124,21 @@ def train():
# prepare data reader # prepare data reader
train_batch_reader_sortagrad = data_generator.batch_reader_creator( train_batch_reader_sortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path, manifest_path=args.train_manifest_path,
batch_size=args.batch_size // args.trainer_count, batch_size=args.batch_size,
padding_to=2000, padding_to=2000,
flatten=True, flatten=True,
sort_by_duration=True, sort_by_duration=True,
shuffle=False) shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator( train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path, manifest_path=args.train_manifest_path,
batch_size=args.batch_size // args.trainer_count, batch_size=args.batch_size,
padding_to=2000, padding_to=2000,
flatten=True, flatten=True,
sort_by_duration=False, sort_by_duration=False,
shuffle=True) shuffle=True)
test_batch_reader = data_generator.batch_reader_creator( test_batch_reader = data_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path, manifest_path=args.dev_manifest_path,
batch_size=args.batch_size // args.trainer_count, batch_size=args.batch_size,
padding_to=2000, padding_to=2000,
flatten=True, flatten=True,
sort_by_duration=False, sort_by_duration=False,
...@@ -141,9 +147,7 @@ def train(): ...@@ -141,9 +147,7 @@ def train():
# create event handler # create event handler
def event_handler(event): def event_handler(event):
global start_time global start_time, cost_sum, cost_counter
global cost_sum
global cost_counter
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
cost_sum += event.cost cost_sum += event.cost
cost_counter += 1 cost_counter += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册