提交 2a834865 编写于 作者: X Xinghai Sun

Refactor decoder interfaces and add ./data directory.

上级 8313895e
......@@ -16,7 +16,9 @@ For some machines, we also need to install libsndfile1. Details to be added.
### Preparing Dataset(s)
```
cd data
python librispeech.py
cd ..
```
More help for arguments:
......
"""
CTC-like decoder utilitis.
"""
from itertools import groupby
import numpy as np
def ctc_best_path_decode(probs_seq, vocabulary):
"""
Best path decoding, also called argmax decoding or greedy decoding.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: baseline
"""
# dimension verification
for probs in probs_seq:
if not len(probs) == len(vocabulary) + 1:
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
# argmax to get the best index for each time step
max_index_list = list(np.array(probs_seq).argmax(axis=1))
# remove consecutive duplicate indexes
index_list = [index_group[0] for index_group in groupby(max_index_list)]
# remove blank indexes
blank_index = len(vocabulary)
index_list = [index for index in index_list if index != blank_index]
# convert index list to string
return ''.join([vocabulary[index] for index in index_list])
def ctc_decode(probs_seq, vocabulary, method):
"""
CTC-like sequence decoding from a sequence of likelihood probablilites.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param method: Decoding method name, with options: "best_path".
:type method: basestring
:return: Decoding result string.
:rtype: baseline
"""
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary")
if method == "best_path":
return ctc_best_path_decode(probs_seq, vocabulary)
else:
raise ValueError("Decoding method [%s] is not supported.")
......@@ -3,12 +3,12 @@
"""
import paddle.v2 as paddle
from itertools import groupby
import distutils.util
import argparse
import gzip
from audio_data_utils import DataGenerator
from model import deep_speech2
from decoder import ctc_decode
parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 inference.')
......@@ -39,12 +39,12 @@ parser.add_argument(
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--normalizer_manifest_path",
default='./manifest.libri.train-clean-100',
default='data/manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--decode_manifest_path",
default='./manifest.libri.test-clean',
default='data/manifest.libri.test-clean',
type=str,
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
......@@ -52,34 +52,28 @@ parser.add_argument(
default='./params.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
args = parser.parse_args()
def remove_duplicate_and_blank(id_list, blank_id):
"""
Postprocessing for max-ctc-decoder.
- remove consecutive duplicate tokens.
- remove blanks.
"""
# remove consecutive duplicate tokens
id_list = [x[0] for x in groupby(id_list)]
# remove blanks
return [id for id in id_list if id != blank_id]
def best_path_decode():
def infer():
"""
Max-ctc-decoding for DeepSpeech2.
"""
# initialize data generator
data_generator = DataGenerator(
vocab_filepath='eng_vocab.txt',
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
# create network config
dict_size = data_generator.vocabulary_size()
vocab_list = data_generator.vocabulary_list()
......@@ -91,13 +85,14 @@ def best_path_decode():
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size))
_, max_id = deep_speech2(
output_probs = deep_speech2(
audio_data=audio_data,
text_data=text_data,
dict_size=dict_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size)
rnn_size=args.rnn_layer_size,
is_inference=True)
# load parameters
parameters = paddle.parameters.Parameters.from_tar(
......@@ -114,30 +109,28 @@ def best_path_decode():
shuffle=False)
infer_data = test_batch_reader().next()
# run max-ctc-decoding
max_id_results = paddle.infer(
output_layer=max_id,
parameters=parameters,
input=infer_data,
field=['id'])
# postprocess
instance_length = len(max_id_results) / args.num_samples
instance_list = [
max_id_results[i * instance_length:(i + 1) * instance_length]
for i in xrange(0, args.num_samples)
# run inference
infer_results = paddle.infer(
output_layer=output_probs, parameters=parameters, input=infer_data)
num_steps = len(infer_results) / len(infer_data)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(infer_data))
]
for i, instance in enumerate(instance_list):
id_list = remove_duplicate_and_blank(instance, dict_size)
output_transcript = ''.join([vocab_list[id] for id in id_list])
target_transcript = ''.join([vocab_list[id] for id in infer_data[i][1]])
print("Target Transcript: %s \nOutput Transcript: %s \n" %
(target_transcript, output_transcript))
# decode and print
for i, probs in enumerate(probs_split):
output_transcription = ctc_decode(
probs_seq=probs, vocabulary=vocab_list, method="best_path")
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
print("Target Transcription: %s \nOutput Transcription: %s \n" %
(target_transcription, output_transcription))
def main():
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
best_path_decode()
infer()
if __name__ == '__main__':
......
......@@ -85,7 +85,8 @@ def deep_speech2(audio_data,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=256):
rnn_size=256,
is_inference=False):
"""
The whole DeepSpeech2 model structure (a simplified version).
......@@ -101,7 +102,12 @@ def deep_speech2(audio_data,
:type num_rnn_layers: int
:param rnn_size: RNN layer size (number of RNN cells).
:type rnn_size: int
:return: Tuple of the cost layer and the max_id decoder layer.
:param is_inference: False in the training mode, and True in the
inferene mode.
:type is_inference: bool
:return: If is_inference set False, return a ctc cost layer;
if is_inference set True, return a sequence layer of output
probability distribution.
:rtype: tuple of LayerOutput
"""
# convolution group
......@@ -118,19 +124,21 @@ def deep_speech2(audio_data,
# rnn group
rnn_group_output = rnn_group(
input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers)
# output token distribution
fc = paddle.layer.fc(
input=rnn_group_output,
size=dict_size + 1,
act=paddle.activation.Linear(),
bias_attr=True)
# ctc cost
cost = paddle.layer.warp_ctc(
input=fc,
label=text_data,
size=dict_size + 1,
blank=dict_size,
norm_by_times=True)
# max decoder
max_id = paddle.layer.max_id(input=fc)
return cost, max_id
if is_inference:
# probability distribution with softmax
return paddle.layer.mixed(
input=paddle.layer.identity_projection(input=fc),
act=paddle.activation.Softmax())
else:
# ctc cost
return paddle.layer.warp_ctc(
input=fc,
label=text_data,
size=dict_size + 1,
blank=dict_size,
norm_by_times=True)
......@@ -60,19 +60,24 @@ parser.add_argument(
help="Trainer number. (default: %(default)s)")
parser.add_argument(
"--normalizer_manifest_path",
default='./manifest.libri.train-clean-100',
default='data/manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--train_manifest_path",
default='./manifest.libri.train-clean-100',
default='data/manifest.libri.train-clean-100',
type=str,
help="Manifest path for training. (default: %(default)s)")
parser.add_argument(
"--dev_manifest_path",
default='./manifest.libri.dev-clean',
default='data/manifest.libri.dev-clean',
type=str,
help="Manifest path for validation. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
args = parser.parse_args()
......@@ -82,7 +87,7 @@ def train():
"""
# initialize data generator
data_generator = DataGenerator(
vocab_filepath='eng_vocab.txt',
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
......@@ -100,13 +105,14 @@ def train():
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size))
cost, _ = deep_speech2(
cost = deep_speech2(
audio_data=audio_data,
text_data=text_data,
dict_size=dict_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size)
rnn_size=args.rnn_layer_size,
is_inference=False)
# create parameters and optimizer
parameters = paddle.parameters.create(cost)
......@@ -118,21 +124,21 @@ def train():
# prepare data reader
train_batch_reader_sortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size // args.trainer_count,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size // args.trainer_count,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=True)
test_batch_reader = data_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path,
batch_size=args.batch_size // args.trainer_count,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=False,
......@@ -141,9 +147,7 @@ def train():
# create event handler
def event_handler(event):
global start_time
global cost_sum
global cost_counter
global start_time, cost_sum, cost_counter
if isinstance(event, paddle.event.EndIteration):
cost_sum += event.cost
cost_counter += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册