From 0ebf36b98fb8484b44bb512e76f6b94a1799a1c2 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Wed, 2 Aug 2017 21:45:08 +0800 Subject: [PATCH] Add a realtime ASR demo for users to test their own voice with mic. --- data_utils/audio.py | 2 + data_utils/data.py | 29 +++--- demo_client.py | 75 ++++++++++++++++ demo_server.py | 208 ++++++++++++++++++++++++++++++++++++++++++++ infer.py | 13 +-- model.py | 10 +++ requirements.txt | 2 + 7 files changed, 320 insertions(+), 19 deletions(-) create mode 100644 demo_client.py create mode 100644 demo_server.py mode change 100755 => 100644 requirements.txt diff --git a/data_utils/audio.py b/data_utils/audio.py index 3891f5b9..29fdd0bd 100644 --- a/data_utils/audio.py +++ b/data_utils/audio.py @@ -64,6 +64,8 @@ class AudioSegment(object): :rtype: AudioSegment """ samples, sample_rate = soundfile.read(file, dtype='float32') + print(samples) + print(sample_rate) return cls(samples, sample_rate) @classmethod diff --git a/data_utils/data.py b/data_utils/data.py index d01ca8cc..fe064b80 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -83,6 +83,23 @@ class DataGenerator(object): self._rng = random.Random(random_seed) self._epoch = 0 + def process_utterance(self, filename, transcript): + """Load, augment, featurize and normalize for speech data. + + :param filename: Audio filepath + :type filename: basestring + :param transcript: Transcription text. + :type transcript: basestring + :return: Tuple of audio feature tensor and list of token ids for + transcription. + :rtype: tuple of (2darray, list) + """ + speech_segment = SpeechSegment.from_file(filename, transcript) + self._augmentation_pipeline.transform_audio(speech_segment) + specgram, text_ids = self._speech_featurizer.featurize(speech_segment) + specgram = self._normalizer.apply(specgram) + return specgram, text_ids + def batch_reader_creator(self, manifest_path, batch_size, @@ -198,14 +215,6 @@ class DataGenerator(object): """ return self._speech_featurizer.vocab_list - def _process_utterance(self, filename, transcript): - """Load, augment, featurize and normalize for speech data.""" - speech_segment = SpeechSegment.from_file(filename, transcript) - self._augmentation_pipeline.transform_audio(speech_segment) - specgram, text_ids = self._speech_featurizer.featurize(speech_segment) - specgram = self._normalizer.apply(specgram) - return specgram, text_ids - def _instance_reader_creator(self, manifest): """ Instance reader creator. Create a callable function to produce @@ -220,8 +229,8 @@ class DataGenerator(object): yield instance def mapper(instance): - return self._process_utterance(instance["audio_filepath"], - instance["text"]) + return self.process_utterance(instance["audio_filepath"], + instance["text"]) return paddle.reader.xmap_readers( mapper, reader, self._num_threads, 1024, order=True) diff --git a/demo_client.py b/demo_client.py new file mode 100644 index 00000000..97649fd4 --- /dev/null +++ b/demo_client.py @@ -0,0 +1,75 @@ +from pynput import keyboard +import struct +import socket +import sys +import pyaudio + +HOST, PORT = "10.104.18.14", 8086 + +is_recording = False +enable_trigger_record = True + + +def on_press(key): + global is_recording, enable_trigger_record + if key == keyboard.Key.space: + if (not is_recording) and enable_trigger_record: + sys.stdout.write("Start Recording ... ") + sys.stdout.flush() + is_recording = True + + +def on_release(key): + global is_recording, enable_trigger_record + if key == keyboard.Key.esc: + return False + elif key == keyboard.Key.space: + if is_recording == True: + is_recording = False + + +data_list = [] + + +def callback(in_data, frame_count, time_info, status): + global data_list, is_recording, enable_trigger_record + if is_recording: + data_list.append(in_data) + enable_trigger_record = False + elif len(data_list) > 0: + # Connect to server and send data + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((HOST, PORT)) + sent = ''.join(data_list) + sock.sendall(struct.pack('>i', len(sent)) + sent) + print('Speech[length=%d] Sent.' % len(sent)) + # Receive data from the server and shut down + received = sock.recv(1024) + print "Recognition Results: {}".format(received) + sock.close() + data_list = [] + enable_trigger_record = True + return (in_data, pyaudio.paContinue) + + +def main(): + p = pyaudio.PyAudio() + stream = p.open( + format=pyaudio.paInt32, + channels=1, + rate=16000, + input=True, + stream_callback=callback) + stream.start_stream() + + with keyboard.Listener( + on_press=on_press, on_release=on_release) as listener: + listener.join() + + stream.stop_stream() + stream.close() + p.terminate() + + +if __name__ == "__main__": + main() diff --git a/demo_server.py b/demo_server.py new file mode 100644 index 00000000..4a3feb13 --- /dev/null +++ b/demo_server.py @@ -0,0 +1,208 @@ +import os +import time +import argparse +import distutils.util +from time import gmtime, strftime +import SocketServer +import struct +import wave +import pyaudio +import paddle.v2 as paddle +from data_utils.data import DataGenerator +from model import DeepSpeech2Model +import utils + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--host_ip", + default="10.104.18.14", + type=str, + help="Server IP address. (default: %(default)s)") +parser.add_argument( + "--host_port", + default=8086, + type=int, + help="Server Port. (default: %(default)s)") +parser.add_argument( + "--speech_save_dir", + default="demo_cache", + type=str, + help="Directory for saving demo speech. (default: %(default)s)") +parser.add_argument( + "--vocab_filepath", + default='datasets/vocab/eng_vocab.txt', + type=str, + help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--mean_std_filepath", + default='mean_std.npz', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--model_filepath", + default='checkpoints/params.latest.tar.gz', + type=str, + help="Model filepath. (default: %(default)s)") +parser.add_argument( + "--decode_method", + default='beam_search', + type=str, + help="Method for ctc decoding: best_path or beam_search. " + "(default: %(default)s)") +parser.add_argument( + "--beam_size", + default=500, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--language_model_path", + default="lm/data/common_crawl_00.prune01111.trie.klm", + type=str, + help="Path for language model. (default: %(default)s)") +parser.add_argument( + "--alpha", + default=0.36, + type=float, + help="Parameter associated with language model. (default: %(default)f)") +parser.add_argument( + "--beta", + default=0.25, + type=float, + help="Parameter associated with word count. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") +args = parser.parse_args() + + +class AsrTCPServer(SocketServer.TCPServer): + def __init__(self, + server_address, + RequestHandlerClass, + speech_save_dir, + audio_process_handler, + bind_and_activate=True): + self.speech_save_dir = speech_save_dir + self.audio_process_handler = audio_process_handler + SocketServer.TCPServer.__init__( + self, server_address, RequestHandlerClass, bind_and_activate=True) + + +class AsrRequestHandler(SocketServer.BaseRequestHandler): + """The ASR request handler. + """ + + def handle(self): + # receive data through TCP socket + chunk = self.request.recv(1024) + target_len = struct.unpack('>i', chunk[:4])[0] + data = chunk[4:] + while len(data) < target_len: + chunk = self.request.recv(1024) + data += chunk + # write to file + filename = self._write_to_file(data) + + print("Received utterance[length=%d] from %s, saved to %s." % + (len(data), self.client_address[0], filename)) + #filename = "/home/work/.cache/paddle/dataset/speech/Libri/train-other-500/LibriSpeech/train-other-500/811/130143/811-130143-0025.flac" + start_time = time.time() + transcript = self.server.audio_process_handler(filename) + finish_time = time.time() + print("Response Time: %f, Transcript: %s" % + (finish_time - start_time, transcript)) + self.request.sendall(transcript) + + def _write_to_file(self, data): + # prepare save dir and filename + if not os.path.exists(self.server.speech_save_dir): + os.mkdir(self.server.speech_save_dir) + timestamp = strftime("%Y%m%d%H%M%S", gmtime()) + out_filename = os.path.join( + self.server.speech_save_dir, + timestamp + "_" + self.client_address[0] + "_" + ".wav") + # write to wav file + file = wave.open(out_filename, 'wb') + file.setnchannels(1) + file.setsampwidth(4) + file.setframerate(16000) + file.writeframes(data) + file.close() + return out_filename + + +def start_server(): + data_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + mean_std_filepath=args.mean_std_filepath, + augmentation_config='{}', + specgram_type=args.specgram_type, + num_threads=1) + ds2_model = DeepSpeech2Model( + vocab_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_layer_size=args.rnn_layer_size, + pretrained_model_path=args.model_filepath) + + def file_to_transcript(filename): + feature = data_generator.process_utterance(filename, "") + result_transcript = ds2_model.infer_batch( + infer_data=[feature], + decode_method=args.decode_method, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + vocab_list=data_generator.vocab_list, + language_model_path=args.language_model_path, + num_processes=1) + return result_transcript[0] + + server = AsrTCPServer( + server_address=(args.host_ip, args.host_port), + RequestHandlerClass=AsrRequestHandler, + speech_save_dir=args.speech_save_dir, + audio_process_handler=file_to_transcript) + + print("ASR Server Started.") + server.serve_forever() + + +def main(): + utils.print_arguments(args) + paddle.init(use_gpu=args.use_gpu, trainer_count=1) + start_server() + + +if __name__ == "__main__": + main() diff --git a/infer.py b/infer.py index bc77dab7..8fd27dce 100644 --- a/infer.py +++ b/infer.py @@ -83,18 +83,13 @@ parser.add_argument( "--decode_method", default='beam_search', type=str, - help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)" -) + help="Method for ctc decoding: best_path or beam_search. " + "(default: %(default)s)") parser.add_argument( "--beam_size", default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--num_results_per_sample", - default=1, - type=int, - help="Number of output per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", default="lm/data/common_crawl_00.prune01111.trie.klm", @@ -102,12 +97,12 @@ parser.add_argument( help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", - default=0.26, + default=0.36, type=float, help="Parameter associated with language model. (default: %(default)f)") parser.add_argument( "--beta", - default=0.1, + default=0.25, type=float, help="Parameter associated with word count. (default: %(default)f)") parser.add_argument( diff --git a/model.py b/model.py index f5333f17..c8766deb 100644 --- a/model.py +++ b/model.py @@ -35,6 +35,7 @@ class DeepSpeech2Model(object): rnn_layer_size) self._create_parameters(pretrained_model_path) self._inferer = None + self._loss_inferer = None self._ext_scorer = None def train(self, @@ -118,6 +119,14 @@ class DeepSpeech2Model(object): num_passes=num_passes, feeding=feeding_dict) + def infer_loss_batch(self, infer_data): + # define inferer + if self._loss_inferer == None: + self._loss_inferer = paddle.inference.Inference( + output_layer=self._loss, parameters=self._parameters) + # run inference + return self._loss_inferer.infer(input=infer_data) + def infer_batch(self, infer_data, decode_method, beam_alpha, beam_beta, beam_size, cutoff_prob, vocab_list, language_model_path, num_processes): @@ -187,6 +196,7 @@ class DeepSpeech2Model(object): num_processes=num_processes, ext_scoring_func=self._ext_scorer, cutoff_prob=cutoff_prob) + results = [result[0][1] for result in beam_search_results] else: raise ValueError("Decoding method [%s] is not supported." % diff --git a/requirements.txt b/requirements.txt old mode 100755 new mode 100644 index 131f75ff..9297f659 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ resampy==0.1.5 SoundFile==0.9.0.post1 python_speech_features https://github.com/luotao1/kenlm/archive/master.zip +pyaudio +pynput -- GitLab