import os import time import argparse import distutils.util from time import gmtime, strftime import SocketServer import struct import wave import pyaudio import paddle.v2 as paddle from data_utils.data import DataGenerator from model import DeepSpeech2Model import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--host_ip", default="10.104.18.14", type=str, help="Server IP address. (default: %(default)s)") parser.add_argument( "--host_port", default=8086, type=int, help="Server Port. (default: %(default)s)") parser.add_argument( "--speech_save_dir", default="demo_cache", type=str, help="Directory for saving demo speech. (default: %(default)s)") parser.add_argument( "--vocab_filepath", default='datasets/vocab/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', type=str, help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--specgram_type", default='linear', type=str, help="Feature type of audio data: 'linear' (power spectrum)" " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--num_conv_layers", default=2, type=int, help="Convolution layer number. (default: %(default)s)") parser.add_argument( "--num_rnn_layers", default=3, type=int, help="RNN layer number. (default: %(default)s)") parser.add_argument( "--rnn_layer_size", default=512, type=int, help="RNN layer cell number. (default: %(default)s)") parser.add_argument( "--use_gpu", default=True, type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") parser.add_argument( "--model_filepath", default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( "--decode_method", default='beam_search', type=str, help="Method for ctc decoding: best_path or beam_search. " "(default: %(default)s)") parser.add_argument( "--beam_size", default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( "--language_model_path", default="lm/data/common_crawl_00.prune01111.trie.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", default=0.36, type=float, help="Parameter associated with language model. (default: %(default)f)") parser.add_argument( "--beta", default=0.25, type=float, help="Parameter associated with word count. (default: %(default)f)") parser.add_argument( "--cutoff_prob", default=0.99, type=float, help="The cutoff probability of pruning" "in beam search. (default: %(default)f)") args = parser.parse_args() class AsrTCPServer(SocketServer.TCPServer): def __init__(self, server_address, RequestHandlerClass, speech_save_dir, audio_process_handler, bind_and_activate=True): self.speech_save_dir = speech_save_dir self.audio_process_handler = audio_process_handler SocketServer.TCPServer.__init__( self, server_address, RequestHandlerClass, bind_and_activate=True) class AsrRequestHandler(SocketServer.BaseRequestHandler): """The ASR request handler. """ def handle(self): # receive data through TCP socket chunk = self.request.recv(1024) target_len = struct.unpack('>i', chunk[:4])[0] data = chunk[4:] while len(data) < target_len: chunk = self.request.recv(1024) data += chunk # write to file filename = self._write_to_file(data) print("Received utterance[length=%d] from %s, saved to %s." % (len(data), self.client_address[0], filename)) #filename = "/home/work/.cache/paddle/dataset/speech/Libri/train-other-500/LibriSpeech/train-other-500/811/130143/811-130143-0025.flac" start_time = time.time() transcript = self.server.audio_process_handler(filename) finish_time = time.time() print("Response Time: %f, Transcript: %s" % (finish_time - start_time, transcript)) self.request.sendall(transcript) def _write_to_file(self, data): # prepare save dir and filename if not os.path.exists(self.server.speech_save_dir): os.mkdir(self.server.speech_save_dir) timestamp = strftime("%Y%m%d%H%M%S", gmtime()) out_filename = os.path.join( self.server.speech_save_dir, timestamp + "_" + self.client_address[0] + "_" + ".wav") # write to wav file file = wave.open(out_filename, 'wb') file.setnchannels(1) file.setsampwidth(4) file.setframerate(16000) file.writeframes(data) file.close() return out_filename def start_server(): data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', specgram_type=args.specgram_type, num_threads=1) ds2_model = DeepSpeech2Model( vocab_size=data_generator.vocab_size, num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, pretrained_model_path=args.model_filepath) def file_to_transcript(filename): feature = data_generator.process_utterance(filename, "") result_transcript = ds2_model.infer_batch( infer_data=[feature], decode_method=args.decode_method, beam_alpha=args.alpha, beam_beta=args.beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, vocab_list=data_generator.vocab_list, language_model_path=args.language_model_path, num_processes=1) return result_transcript[0] server = AsrTCPServer( server_address=(args.host_ip, args.host_port), RequestHandlerClass=AsrRequestHandler, speech_save_dir=args.speech_save_dir, audio_process_handler=file_to_transcript) print("ASR Server Started.") server.serve_forever() def main(): utils.print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=1) start_server() if __name__ == "__main__": main()