demo_server.py 7.2 KB
Newer Older
1
"""Server-end for the ASR demo."""
2 3
import os
import time
4
import random
5
import argparse
X
Xinghai Sun 已提交
6
import functools
7 8 9 10 11
from time import gmtime, strftime
import SocketServer
import struct
import wave
import paddle.v2 as paddle
12
import _init_paths
13
from data_utils.data import DataGenerator
14
from models.model import DeepSpeech2Model
15
from data_utils.utils import read_manifest
16
from utils.utility import add_arguments, print_arguments
17

18
parser = argparse.ArgumentParser(description=__doc__)
X
Xinghai Sun 已提交
19 20
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
21 22 23 24 25
add_arg('host_port',        int,    8086,    "Server's IP port.")
add_arg('beam_size',        int,    500,    "Beam search width.")
add_arg('num_conv_layers',  int,    2,      "# of convolution layers.")
add_arg('num_rnn_layers',   int,    3,      "# of recurrent layers.")
add_arg('rnn_layer_size',   int,    2048,   "# of recurrent cells per layer.")
26 27 28
add_arg('alpha',            float,  0.36,   "Coef of LM for beam search.")
add_arg('beta',             float,  0.25,   "Coef of WC for beam search.")
add_arg('cutoff_prob',      float,  0.99,   "Cutoff probability for pruning.")
29
add_arg('use_gru',          bool,   False,  "Use GRUs instead of simple RNNs.")
30
add_arg('use_gpu',          bool,   True,   "Use GPU or not.")
31 32
add_arg('share_rnn_weights',bool,   True,   "Share input-hidden weights across "
                                            "bi-directional RNNs. Not for GRU.")
33 34 35 36 37 38 39
add_arg('host_ip',          str,
        'localhost',
        "Server's IP address.")
add_arg('speech_save_dir',  str,
        'demo_cache',
        "Directory to save demo audios.")
add_arg('warmup_manifest',  str,
40
        'data/librispeech/manifest.test-clean',
41 42
        "Filepath of manifest to warm up.")
add_arg('mean_std_path',    str,
43
        'data/librispeech/mean_std.npz',
44 45
        "Filepath of normalizer's mean & std.")
add_arg('vocab_path',       str,
46
        'data/librispeech/eng_vocab.txt',
47 48 49 50 51
        "Filepath of vocabulary.")
add_arg('model_path',       str,
        './checkpoints/params.latest.tar.gz',
        "If None, the training starts from scratch, "
        "otherwise, it resumes from the pre-trained model.")
52 53 54
add_arg('lang_model_path',  str,
        'lm/data/common_crawl_00.prune01111.trie.klm',
        "Filepath for language model.")
55
add_arg('decoding_method',  str,
56
        'ctc_beam_search',
57
        "Decoding method. Options: ctc_beam_search, ctc_greedy",
58 59 60 61 62
        choices = ['ctc_beam_search', 'ctc_greedy'])
add_arg('specgram_type',    str,
        'linear',
        "Audio feature type. Options: linear, mfcc.",
        choices=['linear', 'mfcc'])
63
# yapf: disable
X
Xinghai Sun 已提交
64
args = parser.parse_args()
65 66 67


class AsrTCPServer(SocketServer.TCPServer):
68 69
    """The ASR TCP Server."""

70 71 72 73 74 75 76 77 78 79 80 81 82
    def __init__(self,
                 server_address,
                 RequestHandlerClass,
                 speech_save_dir,
                 audio_process_handler,
                 bind_and_activate=True):
        self.speech_save_dir = speech_save_dir
        self.audio_process_handler = audio_process_handler
        SocketServer.TCPServer.__init__(
            self, server_address, RequestHandlerClass, bind_and_activate=True)


class AsrRequestHandler(SocketServer.BaseRequestHandler):
83
    """The ASR request handler."""
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111

    def handle(self):
        # receive data through TCP socket
        chunk = self.request.recv(1024)
        target_len = struct.unpack('>i', chunk[:4])[0]
        data = chunk[4:]
        while len(data) < target_len:
            chunk = self.request.recv(1024)
            data += chunk
        # write to file
        filename = self._write_to_file(data)

        print("Received utterance[length=%d] from %s, saved to %s." %
              (len(data), self.client_address[0], filename))
        start_time = time.time()
        transcript = self.server.audio_process_handler(filename)
        finish_time = time.time()
        print("Response Time: %f, Transcript: %s" %
              (finish_time - start_time, transcript))
        self.request.sendall(transcript)

    def _write_to_file(self, data):
        # prepare save dir and filename
        if not os.path.exists(self.server.speech_save_dir):
            os.mkdir(self.server.speech_save_dir)
        timestamp = strftime("%Y%m%d%H%M%S", gmtime())
        out_filename = os.path.join(
            self.server.speech_save_dir,
112
            timestamp + "_" + self.client_address[0] + ".wav")
113 114 115 116 117 118 119 120 121 122
        # write to wav file
        file = wave.open(out_filename, 'wb')
        file.setnchannels(1)
        file.setsampwidth(4)
        file.setframerate(16000)
        file.writeframes(data)
        file.close()
        return out_filename


123 124 125 126
def warm_up_test(audio_process_handler,
                 manifest_path,
                 num_test_cases,
                 random_seed=0):
127
    """Warming-up test."""
128 129 130 131 132 133 134 135 136 137 138 139
    manifest = read_manifest(manifest_path)
    rng = random.Random(random_seed)
    samples = rng.sample(manifest, num_test_cases)
    for idx, sample in enumerate(samples):
        print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
        start_time = time.time()
        transcript = audio_process_handler(sample['audio_filepath'])
        finish_time = time.time()
        print("Response Time: %f, Transcript: %s" %
              (finish_time - start_time, transcript))


140
def start_server():
141 142
    """Start the ASR server"""
    # prepare data generator
143
    data_generator = DataGenerator(
144 145
        vocab_filepath=args.vocab_path,
        mean_std_filepath=args.mean_std_path,
146 147 148
        augmentation_config='{}',
        specgram_type=args.specgram_type,
        num_threads=1)
149
    # prepare ASR model
150 151 152 153 154
    ds2_model = DeepSpeech2Model(
        vocab_size=data_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
X
Xinghai Sun 已提交
155
        use_gru=args.use_gru,
156
        pretrained_model_path=args.model_path,
157
        share_rnn_weights=args.share_rnn_weights)
158

159
    # prepare ASR inference handler
160 161 162 163
    def file_to_transcript(filename):
        feature = data_generator.process_utterance(filename, "")
        result_transcript = ds2_model.infer_batch(
            infer_data=[feature],
164
            decoding_method=args.decoding_method,
165 166 167 168 169
            beam_alpha=args.alpha,
            beam_beta=args.beta,
            beam_size=args.beam_size,
            cutoff_prob=args.cutoff_prob,
            vocab_list=data_generator.vocab_list,
170
            language_model_path=args.lang_model_path,
171 172 173
            num_processes=1)
        return result_transcript[0]

174
    # warming up with utterrances sampled from Librispeech
175 176 177 178
    print('-----------------------------------------------------------')
    print('Warming up ...')
    warm_up_test(
        audio_process_handler=file_to_transcript,
179
        manifest_path=args.warmup_manifest,
180 181 182
        num_test_cases=3)
    print('-----------------------------------------------------------')

183
    # start the server
184 185 186 187 188 189 190 191 192 193
    server = AsrTCPServer(
        server_address=(args.host_ip, args.host_port),
        RequestHandlerClass=AsrRequestHandler,
        speech_save_dir=args.speech_save_dir,
        audio_process_handler=file_to_transcript)
    print("ASR Server Started.")
    server.serve_forever()


def main():
194
    print_arguments(args)
195 196 197 198 199 200
    paddle.init(use_gpu=args.use_gpu, trainer_count=1)
    start_server()


if __name__ == "__main__":
    main()