提交 1da8f7a6 编写于 作者: X Xinghai Sun 提交者: GitHub

Merge pull request #186 from xinghai-sun/demo2_backup

Add a realtime ASR demo (both server and client) for DS2 users to try with own voice.
......@@ -143,3 +143,28 @@ python tune.py --help
```
Then reset parameters with the tuning result before inference or evaluating.
### Playing with the ASR Demo
A real-time ASR demo is built for users to try out the ASR model with their own voice. Please do the following installation on the machine you'd like to run the demo's client (no need for the machine running the demo's server).
For example, on MAC OS X:
```
brew install portaudio
pip install pyaudio
pip install pynput
```
After a model and language model is prepared, we can first start the demo's server:
```
CUDA_VISIBLE_DEVICES=0 python demo_server.py
```
And then in another console, start the demo's client:
```
python demo_client.py
```
On the client console, press and hold the "white-space" key on the keyboard to start talking, until you finish your speech and then release the "white-space" key. The decoding results (infered transcription) will be displayed.
It could be possible to start the server and the client in two seperate machines, e.g. `demo_client.py` is usually started in a machine with a microphone hardware, while `demo_server.py` is usually started in a remote server with powerful GPUs. Please first make sure that these two machines have network access to each other, and then use `--host_ip` and `--host_port` to indicate the server machine's actual IP address (instead of the `localhost` as default) and TCP port, in both `demo_server.py` and `demo_client.py`.
......@@ -83,6 +83,23 @@ class DataGenerator(object):
self._rng = random.Random(random_seed)
self._epoch = 0
def process_utterance(self, filename, transcript):
"""Load, augment, featurize and normalize for speech data.
:param filename: Audio filepath
:type filename: basestring
:param transcript: Transcription text.
:type transcript: basestring
:return: Tuple of audio feature tensor and list of token ids for
transcription.
:rtype: tuple of (2darray, list)
"""
speech_segment = SpeechSegment.from_file(filename, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, text_ids = self._speech_featurizer.featurize(speech_segment)
specgram = self._normalizer.apply(specgram)
return specgram, text_ids
def batch_reader_creator(self,
manifest_path,
batch_size,
......@@ -198,14 +215,6 @@ class DataGenerator(object):
"""
return self._speech_featurizer.vocab_list
def _process_utterance(self, filename, transcript):
"""Load, augment, featurize and normalize for speech data."""
speech_segment = SpeechSegment.from_file(filename, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, text_ids = self._speech_featurizer.featurize(speech_segment)
specgram = self._normalizer.apply(specgram)
return specgram, text_ids
def _instance_reader_creator(self, manifest):
"""
Instance reader creator. Create a callable function to produce
......@@ -220,7 +229,7 @@ class DataGenerator(object):
yield instance
def mapper(instance):
return self._process_utterance(instance["audio_filepath"],
return self.process_utterance(instance["audio_filepath"],
instance["text"])
return paddle.reader.xmap_readers(
......
"""Client-end for the ASR demo."""
from pynput import keyboard
import struct
import socket
import sys
import argparse
import pyaudio
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--host_ip",
default="localhost",
type=str,
help="Server IP address. (default: %(default)s)")
parser.add_argument(
"--host_port",
default=8086,
type=int,
help="Server Port. (default: %(default)s)")
args = parser.parse_args()
is_recording = False
enable_trigger_record = True
def on_press(key):
"""On-press keyboard callback function."""
global is_recording, enable_trigger_record
if key == keyboard.Key.space:
if (not is_recording) and enable_trigger_record:
sys.stdout.write("Start Recording ... ")
sys.stdout.flush()
is_recording = True
def on_release(key):
"""On-release keyboard callback function."""
global is_recording, enable_trigger_record
if key == keyboard.Key.esc:
return False
elif key == keyboard.Key.space:
if is_recording == True:
is_recording = False
data_list = []
def callback(in_data, frame_count, time_info, status):
"""Audio recorder's stream callback function."""
global data_list, is_recording, enable_trigger_record
if is_recording:
data_list.append(in_data)
enable_trigger_record = False
elif len(data_list) > 0:
# Connect to server and send data
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((args.host_ip, args.host_port))
sent = ''.join(data_list)
sock.sendall(struct.pack('>i', len(sent)) + sent)
print('Speech[length=%d] Sent.' % len(sent))
# Receive data from the server and shut down
received = sock.recv(1024)
print "Recognition Results: {}".format(received)
sock.close()
data_list = []
enable_trigger_record = True
return (in_data, pyaudio.paContinue)
def main():
# prepare audio recorder
p = pyaudio.PyAudio()
stream = p.open(
format=pyaudio.paInt32,
channels=1,
rate=16000,
input=True,
stream_callback=callback)
stream.start_stream()
# prepare keyboard listener
with keyboard.Listener(
on_press=on_press, on_release=on_release) as listener:
listener.join()
# close up
stream.stop_stream()
stream.close()
p.terminate()
if __name__ == "__main__":
main()
"""Server-end for the ASR demo."""
import os
import time
import random
import argparse
import distutils.util
from time import gmtime, strftime
import SocketServer
import struct
import wave
import paddle.v2 as paddle
from utils import print_arguments
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
from data_utils.utils import read_manifest
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--host_ip",
default="localhost",
type=str,
help="Server IP address. (default: %(default)s)")
parser.add_argument(
"--host_port",
default=8086,
type=int,
help="Server Port. (default: %(default)s)")
parser.add_argument(
"--speech_save_dir",
default="demo_cache",
type=str,
help="Directory for saving demo speech. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='datasets/vocab/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--mean_std_filepath",
default='mean_std.npz',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--warmup_manifest_path",
default='datasets/manifest.test',
type=str,
help="Manifest path for warmup test. (default: %(default)s)")
parser.add_argument(
"--specgram_type",
default='linear',
type=str,
help="Feature type of audio data: 'linear' (power spectrum)"
" or 'mfcc'. (default: %(default)s)")
parser.add_argument(
"--num_conv_layers",
default=2,
type=int,
help="Convolution layer number. (default: %(default)s)")
parser.add_argument(
"--num_rnn_layers",
default=3,
type=int,
help="RNN layer number. (default: %(default)s)")
parser.add_argument(
"--rnn_layer_size",
default=512,
type=int,
help="RNN layer cell number. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='checkpoints/params.latest.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search',
type=str,
help="Method for ctc decoding: best_path or beam_search. "
"(default: %(default)s)")
parser.add_argument(
"--beam_size",
default=100,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--language_model_path",
default="lm/data/common_crawl_00.prune01111.trie.klm",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
"--alpha",
default=0.36,
type=float,
help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
"--beta",
default=0.25,
type=float,
help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=0.99,
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
args = parser.parse_args()
class AsrTCPServer(SocketServer.TCPServer):
"""The ASR TCP Server."""
def __init__(self,
server_address,
RequestHandlerClass,
speech_save_dir,
audio_process_handler,
bind_and_activate=True):
self.speech_save_dir = speech_save_dir
self.audio_process_handler = audio_process_handler
SocketServer.TCPServer.__init__(
self, server_address, RequestHandlerClass, bind_and_activate=True)
class AsrRequestHandler(SocketServer.BaseRequestHandler):
"""The ASR request handler."""
def handle(self):
# receive data through TCP socket
chunk = self.request.recv(1024)
target_len = struct.unpack('>i', chunk[:4])[0]
data = chunk[4:]
while len(data) < target_len:
chunk = self.request.recv(1024)
data += chunk
# write to file
filename = self._write_to_file(data)
print("Received utterance[length=%d] from %s, saved to %s." %
(len(data), self.client_address[0], filename))
start_time = time.time()
transcript = self.server.audio_process_handler(filename)
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
self.request.sendall(transcript)
def _write_to_file(self, data):
# prepare save dir and filename
if not os.path.exists(self.server.speech_save_dir):
os.mkdir(self.server.speech_save_dir)
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
out_filename = os.path.join(
self.server.speech_save_dir,
timestamp + "_" + self.client_address[0] + ".wav")
# write to wav file
file = wave.open(out_filename, 'wb')
file.setnchannels(1)
file.setsampwidth(4)
file.setframerate(16000)
file.writeframes(data)
file.close()
return out_filename
def warm_up_test(audio_process_handler,
manifest_path,
num_test_cases,
random_seed=0):
"""Warming-up test."""
manifest = read_manifest(manifest_path)
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
def start_server():
"""Start the ASR server"""
# prepare data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1)
# prepare ASR model
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
pretrained_model_path=args.model_filepath)
# prepare ASR inference handler
def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "")
result_transcript = ds2_model.infer_batch(
infer_data=[feature],
decode_method=args.decode_method,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list,
language_model_path=args.language_model_path,
num_processes=1)
return result_transcript[0]
# warming up with utterrances sampled from Librispeech
print('-----------------------------------------------------------')
print('Warming up ...')
warm_up_test(
audio_process_handler=file_to_transcript,
manifest_path=args.warmup_manifest_path,
num_test_cases=3)
print('-----------------------------------------------------------')
# start the server
server = AsrTCPServer(
server_address=(args.host_ip, args.host_port),
RequestHandlerClass=AsrRequestHandler,
speech_save_dir=args.speech_save_dir,
audio_process_handler=file_to_transcript)
print("ASR Server Started.")
server.serve_forever()
def main():
print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
start_server()
if __name__ == "__main__":
main()
......@@ -83,18 +83,13 @@ parser.add_argument(
"--decode_method",
default='beam_search',
type=str,
help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)"
)
help="Method for ctc decoding: best_path or beam_search. "
"(default: %(default)s)")
parser.add_argument(
"--beam_size",
default=500,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--num_results_per_sample",
default=1,
type=int,
help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument(
"--language_model_path",
default="lm/data/common_crawl_00.prune01111.trie.klm",
......@@ -102,12 +97,12 @@ parser.add_argument(
help="Path for language model. (default: %(default)s)")
parser.add_argument(
"--alpha",
default=0.26,
default=0.36,
type=float,
help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
"--beta",
default=0.1,
default=0.25,
type=float,
help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
......
......@@ -35,6 +35,7 @@ class DeepSpeech2Model(object):
rnn_layer_size)
self._create_parameters(pretrained_model_path)
self._inferer = None
self._loss_inferer = None
self._ext_scorer = None
def train(self,
......@@ -118,15 +119,33 @@ class DeepSpeech2Model(object):
num_passes=num_passes,
feeding=feeding_dict)
def infer_loss_batch(self, infer_data):
"""Model inference. Infer the ctc loss for a batch of speech
utterances.
:param infer_data: List of utterances to infer, with each utterance a
tuple of audio features and transcription text (empty
string).
:type infer_data: list
:return: List of ctc loss.
:rtype: List of float
"""
# define inferer
if self._loss_inferer == None:
self._loss_inferer = paddle.inference.Inference(
output_layer=self._loss, parameters=self._parameters)
# run inference
return self._loss_inferer.infer(input=infer_data)
def infer_batch(self, infer_data, decode_method, beam_alpha, beam_beta,
beam_size, cutoff_prob, vocab_list, language_model_path,
num_processes):
"""Model inference. Infer the transcription for a batch of speech
utterances.
:param infer_data: List of utterances to infer, with each utterance a
tuple of audio features and transcription text (empty
string).
:param infer_data: List of utterances to infer, with each utterance
consisting of a tuple of audio features and
transcription text (empty string).
:type infer_data: list
:param decode_method: Decoding method name, 'best_path' or
'beam search'.
......@@ -187,6 +206,7 @@ class DeepSpeech2Model(object):
num_processes=num_processes,
ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob)
results = [result[0][1] for result in beam_search_results]
else:
raise ValueError("Decoding method [%s] is not supported." %
......
文件模式从 100755 更改为 100644
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册