demo_server.py 7.2 KB
Newer Older
1
"""Server-end for the ASR demo."""
2 3
import os
import time
4
import random
5
import argparse
X
Xinghai Sun 已提交
6
import functools
7 8 9 10 11 12 13
from time import gmtime, strftime
import SocketServer
import struct
import wave
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
14
from data_utils.utils import read_manifest
X
Xinghai Sun 已提交
15
from utils import add_arguments, print_arguments
16

17
parser = argparse.ArgumentParser(description=__doc__)
X
Xinghai Sun 已提交
18 19
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
20 21 22 23 24
add_arg('host_port',        int,    8086,    "Server's IP port.")
add_arg('beam_size',        int,    500,    "Beam search width.")
add_arg('num_conv_layers',  int,    2,      "# of convolution layers.")
add_arg('num_rnn_layers',   int,    3,      "# of recurrent layers.")
add_arg('rnn_layer_size',   int,    2048,   "# of recurrent cells per layer.")
25 26 27
add_arg('alpha',            float,  0.36,   "Coef of LM for beam search.")
add_arg('beta',             float,  0.25,   "Coef of WC for beam search.")
add_arg('cutoff_prob',      float,  0.99,   "Cutoff probability for pruning.")
28
add_arg('use_gru',          bool,   False,  "Use GRUs instead of simple RNNs.")
29
add_arg('use_gpu',          bool,   True,   "Use GPU or not.")
30 31
add_arg('share_rnn_weights',bool,   True,   "Share input-hidden weights across "
                                            "bi-directional RNNs. Not for GRU.")
32 33 34 35 36 37 38
add_arg('host_ip',          str,
        'localhost',
        "Server's IP address.")
add_arg('speech_save_dir',  str,
        'demo_cache',
        "Directory to save demo audios.")
add_arg('warmup_manifest',  str,
39 40 41 42 43 44 45 46 47 48 49 50
        'datasets/manifest.test',
        "Filepath of manifest to warm up.")
add_arg('mean_std_path',    str,
        'mean_std.npz',
        "Filepath of normalizer's mean & std.")
add_arg('vocab_path',       str,
        'datasets/vocab/eng_vocab.txt',
        "Filepath of vocabulary.")
add_arg('model_path',       str,
        './checkpoints/params.latest.tar.gz',
        "If None, the training starts from scratch, "
        "otherwise, it resumes from the pre-trained model.")
51 52 53
add_arg('lang_model_path',  str,
        'lm/data/common_crawl_00.prune01111.trie.klm',
        "Filepath for language model.")
54
add_arg('decoding_method',  str,
55
        'ctc_beam_search',
56
        "Decoding method. Options: ctc_beam_search, ctc_greedy",
57 58 59 60 61
        choices = ['ctc_beam_search', 'ctc_greedy'])
add_arg('specgram_type',    str,
        'linear',
        "Audio feature type. Options: linear, mfcc.",
        choices=['linear', 'mfcc'])
62
# yapf: disable
X
Xinghai Sun 已提交
63
args = parser.parse_args()
64 65 66


class AsrTCPServer(SocketServer.TCPServer):
67 68
    """The ASR TCP Server."""

69 70 71 72 73 74 75 76 77 78 79 80 81
    def __init__(self,
                 server_address,
                 RequestHandlerClass,
                 speech_save_dir,
                 audio_process_handler,
                 bind_and_activate=True):
        self.speech_save_dir = speech_save_dir
        self.audio_process_handler = audio_process_handler
        SocketServer.TCPServer.__init__(
            self, server_address, RequestHandlerClass, bind_and_activate=True)


class AsrRequestHandler(SocketServer.BaseRequestHandler):
82
    """The ASR request handler."""
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

    def handle(self):
        # receive data through TCP socket
        chunk = self.request.recv(1024)
        target_len = struct.unpack('>i', chunk[:4])[0]
        data = chunk[4:]
        while len(data) < target_len:
            chunk = self.request.recv(1024)
            data += chunk
        # write to file
        filename = self._write_to_file(data)

        print("Received utterance[length=%d] from %s, saved to %s." %
              (len(data), self.client_address[0], filename))
        start_time = time.time()
        transcript = self.server.audio_process_handler(filename)
        finish_time = time.time()
        print("Response Time: %f, Transcript: %s" %
              (finish_time - start_time, transcript))
        self.request.sendall(transcript)

    def _write_to_file(self, data):
        # prepare save dir and filename
        if not os.path.exists(self.server.speech_save_dir):
            os.mkdir(self.server.speech_save_dir)
        timestamp = strftime("%Y%m%d%H%M%S", gmtime())
        out_filename = os.path.join(
            self.server.speech_save_dir,
111
            timestamp + "_" + self.client_address[0] + ".wav")
112 113 114 115 116 117 118 119 120 121
        # write to wav file
        file = wave.open(out_filename, 'wb')
        file.setnchannels(1)
        file.setsampwidth(4)
        file.setframerate(16000)
        file.writeframes(data)
        file.close()
        return out_filename


122 123 124 125
def warm_up_test(audio_process_handler,
                 manifest_path,
                 num_test_cases,
                 random_seed=0):
126
    """Warming-up test."""
127 128 129 130 131 132 133 134 135 136 137 138
    manifest = read_manifest(manifest_path)
    rng = random.Random(random_seed)
    samples = rng.sample(manifest, num_test_cases)
    for idx, sample in enumerate(samples):
        print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
        start_time = time.time()
        transcript = audio_process_handler(sample['audio_filepath'])
        finish_time = time.time()
        print("Response Time: %f, Transcript: %s" %
              (finish_time - start_time, transcript))


139
def start_server():
140 141
    """Start the ASR server"""
    # prepare data generator
142
    data_generator = DataGenerator(
143 144
        vocab_filepath=args.vocab_path,
        mean_std_filepath=args.mean_std_path,
145 146 147
        augmentation_config='{}',
        specgram_type=args.specgram_type,
        num_threads=1)
148
    # prepare ASR model
149 150 151 152 153
    ds2_model = DeepSpeech2Model(
        vocab_size=data_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
X
Xinghai Sun 已提交
154
        use_gru=args.use_gru,
155
        pretrained_model_path=args.model_path,
156
        share_rnn_weights=args.share_rnn_weights)
157

158
    # prepare ASR inference handler
159 160 161 162
    def file_to_transcript(filename):
        feature = data_generator.process_utterance(filename, "")
        result_transcript = ds2_model.infer_batch(
            infer_data=[feature],
163
            decoding_method=args.decoding_method,
164 165 166 167 168
            beam_alpha=args.alpha,
            beam_beta=args.beta,
            beam_size=args.beam_size,
            cutoff_prob=args.cutoff_prob,
            vocab_list=data_generator.vocab_list,
169
            language_model_path=args.lang_model_path,
170 171 172
            num_processes=1)
        return result_transcript[0]

173
    # warming up with utterrances sampled from Librispeech
174 175 176 177
    print('-----------------------------------------------------------')
    print('Warming up ...')
    warm_up_test(
        audio_process_handler=file_to_transcript,
178
        manifest_path=args.warmup_manifest,
179 180 181
        num_test_cases=3)
    print('-----------------------------------------------------------')

182
    # start the server
183 184 185 186 187 188 189 190 191 192
    server = AsrTCPServer(
        server_address=(args.host_ip, args.host_port),
        RequestHandlerClass=AsrRequestHandler,
        speech_save_dir=args.speech_save_dir,
        audio_process_handler=file_to_transcript)
    print("ASR Server Started.")
    server.serve_forever()


def main():
193
    print_arguments(args)
194 195 196 197 198 199
    paddle.init(use_gpu=args.use_gpu, trainer_count=1)
    start_server()


if __name__ == "__main__":
    main()