demo_server.py 7.6 KB
Newer Older
1
"""Server-end for the ASR demo."""
2 3
import os
import time
4
import random
5 6 7 8 9 10 11 12 13
import argparse
import distutils.util
from time import gmtime, strftime
import SocketServer
import struct
import wave
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
14
from data_utils.utils import read_manifest
15 16

parser = argparse.ArgumentParser(description=__doc__)
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34


def add_arg(argname, type, default, help, **kwargs):
    type = distutils.util.strtobool if type == bool else type
    parser.add_argument(
        "--" + argname,
        default=default,
        type=type,
        help=help + ' Default: %(default)s.',
        **kwargs)


# yapf: disable
add_arg('host_port',        int,    8086,    "Server's IP port.")
add_arg('beam_size',        int,    500,    "Beam search width.")
add_arg('num_conv_layers',  int,    2,      "# of convolution layers.")
add_arg('num_rnn_layers',   int,    3,      "# of recurrent layers.")
add_arg('rnn_layer_size',   int,    2048,   "# of recurrent cells per layer.")
35 36 37
add_arg('alpha',            float,  0.36,   "Coef of LM for beam search.")
add_arg('beta',             float,  0.25,   "Coef of WC for beam search.")
add_arg('cutoff_prob',      float,  0.99,   "Cutoff probability for pruning.")
38
add_arg('use_gru',          bool,   False,  "Use GRUs instead of Simple RNNs.")
39
add_arg('use_gpu',          bool,   True,   "Use GPU or not.")
40 41
add_arg('share_rnn_weights',bool,   True,   "Share input-hidden weights across "
                                            "bi-directional RNNs. Not for GRU.")
42 43 44 45 46 47 48
add_arg('host_ip',          str,
        'localhost',
        "Server's IP address.")
add_arg('speech_save_dir',  str,
        'demo_cache',
        "Directory to save demo audios.")
add_arg('warmup_manifest',  str,
49 50 51 52 53 54 55 56 57 58 59 60
        'datasets/manifest.test',
        "Filepath of manifest to warm up.")
add_arg('mean_std_path',    str,
        'mean_std.npz',
        "Filepath of normalizer's mean & std.")
add_arg('vocab_path',       str,
        'datasets/vocab/eng_vocab.txt',
        "Filepath of vocabulary.")
add_arg('model_path',       str,
        './checkpoints/params.latest.tar.gz',
        "If None, the training starts from scratch, "
        "otherwise, it resumes from the pre-trained model.")
61 62 63 64 65 66 67 68 69 70 71
add_arg('lang_model_path',  str,
        'lm/data/common_crawl_00.prune01111.trie.klm',
        "Filepath for language model.")
add_arg('decoder_method',   str,
        'ctc_beam_search',
        "Decoder method. Options: ctc_beam_search, ctc_greedy",
        choices = ['ctc_beam_search', 'ctc_greedy'])
add_arg('specgram_type',    str,
        'linear',
        "Audio feature type. Options: linear, mfcc.",
        choices=['linear', 'mfcc'])
72
args = parser.parse_args()
73
# yapf: disable
74 75 76


class AsrTCPServer(SocketServer.TCPServer):
77 78
    """The ASR TCP Server."""

79 80 81 82 83 84 85 86 87 88 89 90 91
    def __init__(self,
                 server_address,
                 RequestHandlerClass,
                 speech_save_dir,
                 audio_process_handler,
                 bind_and_activate=True):
        self.speech_save_dir = speech_save_dir
        self.audio_process_handler = audio_process_handler
        SocketServer.TCPServer.__init__(
            self, server_address, RequestHandlerClass, bind_and_activate=True)


class AsrRequestHandler(SocketServer.BaseRequestHandler):
92
    """The ASR request handler."""
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

    def handle(self):
        # receive data through TCP socket
        chunk = self.request.recv(1024)
        target_len = struct.unpack('>i', chunk[:4])[0]
        data = chunk[4:]
        while len(data) < target_len:
            chunk = self.request.recv(1024)
            data += chunk
        # write to file
        filename = self._write_to_file(data)

        print("Received utterance[length=%d] from %s, saved to %s." %
              (len(data), self.client_address[0], filename))
        start_time = time.time()
        transcript = self.server.audio_process_handler(filename)
        finish_time = time.time()
        print("Response Time: %f, Transcript: %s" %
              (finish_time - start_time, transcript))
        self.request.sendall(transcript)

    def _write_to_file(self, data):
        # prepare save dir and filename
        if not os.path.exists(self.server.speech_save_dir):
            os.mkdir(self.server.speech_save_dir)
        timestamp = strftime("%Y%m%d%H%M%S", gmtime())
        out_filename = os.path.join(
            self.server.speech_save_dir,
121
            timestamp + "_" + self.client_address[0] + ".wav")
122 123 124 125 126 127 128 129 130 131
        # write to wav file
        file = wave.open(out_filename, 'wb')
        file.setnchannels(1)
        file.setsampwidth(4)
        file.setframerate(16000)
        file.writeframes(data)
        file.close()
        return out_filename


132 133 134 135
def warm_up_test(audio_process_handler,
                 manifest_path,
                 num_test_cases,
                 random_seed=0):
136
    """Warming-up test."""
137 138 139 140 141 142 143 144 145 146 147 148
    manifest = read_manifest(manifest_path)
    rng = random.Random(random_seed)
    samples = rng.sample(manifest, num_test_cases)
    for idx, sample in enumerate(samples):
        print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
        start_time = time.time()
        transcript = audio_process_handler(sample['audio_filepath'])
        finish_time = time.time()
        print("Response Time: %f, Transcript: %s" %
              (finish_time - start_time, transcript))


149
def start_server():
150 151
    """Start the ASR server"""
    # prepare data generator
152
    data_generator = DataGenerator(
153 154
        vocab_filepath=args.vocab_path,
        mean_std_filepath=args.mean_std_path,
155 156 157
        augmentation_config='{}',
        specgram_type=args.specgram_type,
        num_threads=1)
158
    # prepare ASR model
159 160 161 162 163
    ds2_model = DeepSpeech2Model(
        vocab_size=data_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
X
Xinghai Sun 已提交
164
        use_gru=args.use_gru,
165
        pretrained_model_path=args.model_path,
166
        share_rnn_weights=args.share_rnn_weights)
167

168
    # prepare ASR inference handler
169 170 171 172
    def file_to_transcript(filename):
        feature = data_generator.process_utterance(filename, "")
        result_transcript = ds2_model.infer_batch(
            infer_data=[feature],
173
            decoder_method=args.decoder_method,
174 175 176 177 178
            beam_alpha=args.alpha,
            beam_beta=args.beta,
            beam_size=args.beam_size,
            cutoff_prob=args.cutoff_prob,
            vocab_list=data_generator.vocab_list,
179
            language_model_path=args.lang_model_path,
180 181 182
            num_processes=1)
        return result_transcript[0]

183
    # warming up with utterrances sampled from Librispeech
184 185 186 187
    print('-----------------------------------------------------------')
    print('Warming up ...')
    warm_up_test(
        audio_process_handler=file_to_transcript,
188
        manifest_path=args.warmup_manifest,
189 190 191
        num_test_cases=3)
    print('-----------------------------------------------------------')

192
    # start the server
193 194 195 196 197 198 199 200 201
    server = AsrTCPServer(
        server_address=(args.host_ip, args.host_port),
        RequestHandlerClass=AsrRequestHandler,
        speech_save_dir=args.speech_save_dir,
        audio_process_handler=file_to_transcript)
    print("ASR Server Started.")
    server.serve_forever()


202 203 204 205 206 207 208
def print_arguments(args):
    print("-----------  Configuration Arguments -----------")
    for arg, value in sorted(vars(args).iteritems()):
        print("%s: %s" % (arg, value))
    print("------------------------------------------------")


209
def main():
210
    print_arguments(args)
211 212 213 214 215 216
    paddle.init(use_gpu=args.use_gpu, trainer_count=1)
    start_server()


if __name__ == "__main__":
    main()