demo_server.py 7.8 KB
Newer Older
1
"""Server-end for the ASR demo."""
2 3
import os
import time
4
import random
5 6 7 8 9 10 11
import argparse
import distutils.util
from time import gmtime, strftime
import SocketServer
import struct
import wave
import paddle.v2 as paddle
12
from utils import print_arguments
13 14
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
15
from data_utils.utils import read_manifest
16 17 18 19

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
    "--host_ip",
20
    default="localhost",
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
    type=str,
    help="Server IP address. (default: %(default)s)")
parser.add_argument(
    "--host_port",
    default=8086,
    type=int,
    help="Server Port. (default: %(default)s)")
parser.add_argument(
    "--speech_save_dir",
    default="demo_cache",
    type=str,
    help="Directory for saving demo speech. (default: %(default)s)")
parser.add_argument(
    "--vocab_filepath",
    default='datasets/vocab/eng_vocab.txt',
    type=str,
    help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
    "--mean_std_filepath",
    default='mean_std.npz',
    type=str,
    help="Manifest path for normalizer. (default: %(default)s)")
43 44 45 46 47
parser.add_argument(
    "--warmup_manifest_path",
    default='datasets/manifest.test',
    type=str,
    help="Manifest path for warmup test. (default: %(default)s)")
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
parser.add_argument(
    "--specgram_type",
    default='linear',
    type=str,
    help="Feature type of audio data: 'linear' (power spectrum)"
    " or 'mfcc'. (default: %(default)s)")
parser.add_argument(
    "--num_conv_layers",
    default=2,
    type=int,
    help="Convolution layer number. (default: %(default)s)")
parser.add_argument(
    "--num_rnn_layers",
    default=3,
    type=int,
    help="RNN layer number. (default: %(default)s)")
parser.add_argument(
    "--rnn_layer_size",
    default=512,
    type=int,
    help="RNN layer cell number. (default: %(default)s)")
X
Xinghai Sun 已提交
69 70 71 72 73
parser.add_argument(
    "--use_gru",
    default=True,
    type=bool,
    help="Use GRU or simple RNN. (default: %(default)s)")
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
parser.add_argument(
    "--use_gpu",
    default=True,
    type=distutils.util.strtobool,
    help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
    "--model_filepath",
    default='checkpoints/params.latest.tar.gz',
    type=str,
    help="Model filepath. (default: %(default)s)")
parser.add_argument(
    "--decode_method",
    default='beam_search',
    type=str,
    help="Method for ctc decoding: best_path or beam_search. "
    "(default: %(default)s)")
parser.add_argument(
    "--beam_size",
92
    default=100,
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    type=int,
    help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
    "--language_model_path",
    default="lm/data/common_crawl_00.prune01111.trie.klm",
    type=str,
    help="Path for language model. (default: %(default)s)")
parser.add_argument(
    "--alpha",
    default=0.36,
    type=float,
    help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
    "--beta",
    default=0.25,
    type=float,
    help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
    "--cutoff_prob",
    default=0.99,
    type=float,
    help="The cutoff probability of pruning"
    "in beam search. (default: %(default)f)")
args = parser.parse_args()


class AsrTCPServer(SocketServer.TCPServer):
120 121
    """The ASR TCP Server."""

122 123 124 125 126 127 128 129 130 131 132 133 134
    def __init__(self,
                 server_address,
                 RequestHandlerClass,
                 speech_save_dir,
                 audio_process_handler,
                 bind_and_activate=True):
        self.speech_save_dir = speech_save_dir
        self.audio_process_handler = audio_process_handler
        SocketServer.TCPServer.__init__(
            self, server_address, RequestHandlerClass, bind_and_activate=True)


class AsrRequestHandler(SocketServer.BaseRequestHandler):
135
    """The ASR request handler."""
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163

    def handle(self):
        # receive data through TCP socket
        chunk = self.request.recv(1024)
        target_len = struct.unpack('>i', chunk[:4])[0]
        data = chunk[4:]
        while len(data) < target_len:
            chunk = self.request.recv(1024)
            data += chunk
        # write to file
        filename = self._write_to_file(data)

        print("Received utterance[length=%d] from %s, saved to %s." %
              (len(data), self.client_address[0], filename))
        start_time = time.time()
        transcript = self.server.audio_process_handler(filename)
        finish_time = time.time()
        print("Response Time: %f, Transcript: %s" %
              (finish_time - start_time, transcript))
        self.request.sendall(transcript)

    def _write_to_file(self, data):
        # prepare save dir and filename
        if not os.path.exists(self.server.speech_save_dir):
            os.mkdir(self.server.speech_save_dir)
        timestamp = strftime("%Y%m%d%H%M%S", gmtime())
        out_filename = os.path.join(
            self.server.speech_save_dir,
164
            timestamp + "_" + self.client_address[0] + ".wav")
165 166 167 168 169 170 171 172 173 174
        # write to wav file
        file = wave.open(out_filename, 'wb')
        file.setnchannels(1)
        file.setsampwidth(4)
        file.setframerate(16000)
        file.writeframes(data)
        file.close()
        return out_filename


175 176 177 178
def warm_up_test(audio_process_handler,
                 manifest_path,
                 num_test_cases,
                 random_seed=0):
179
    """Warming-up test."""
180 181 182 183 184 185 186 187 188 189 190 191
    manifest = read_manifest(manifest_path)
    rng = random.Random(random_seed)
    samples = rng.sample(manifest, num_test_cases)
    for idx, sample in enumerate(samples):
        print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
        start_time = time.time()
        transcript = audio_process_handler(sample['audio_filepath'])
        finish_time = time.time()
        print("Response Time: %f, Transcript: %s" %
              (finish_time - start_time, transcript))


192
def start_server():
193 194
    """Start the ASR server"""
    # prepare data generator
195 196 197 198 199 200
    data_generator = DataGenerator(
        vocab_filepath=args.vocab_filepath,
        mean_std_filepath=args.mean_std_filepath,
        augmentation_config='{}',
        specgram_type=args.specgram_type,
        num_threads=1)
201
    # prepare ASR model
202 203 204 205 206
    ds2_model = DeepSpeech2Model(
        vocab_size=data_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
X
Xinghai Sun 已提交
207
        use_gru=args.use_gru,
208 209
        pretrained_model_path=args.model_filepath)

210
    # prepare ASR inference handler
211 212 213 214 215 216 217 218 219 220 221 222 223 224
    def file_to_transcript(filename):
        feature = data_generator.process_utterance(filename, "")
        result_transcript = ds2_model.infer_batch(
            infer_data=[feature],
            decode_method=args.decode_method,
            beam_alpha=args.alpha,
            beam_beta=args.beta,
            beam_size=args.beam_size,
            cutoff_prob=args.cutoff_prob,
            vocab_list=data_generator.vocab_list,
            language_model_path=args.language_model_path,
            num_processes=1)
        return result_transcript[0]

225
    # warming up with utterrances sampled from Librispeech
226 227 228 229 230 231 232 233
    print('-----------------------------------------------------------')
    print('Warming up ...')
    warm_up_test(
        audio_process_handler=file_to_transcript,
        manifest_path=args.warmup_manifest_path,
        num_test_cases=3)
    print('-----------------------------------------------------------')

234
    # start the server
235 236 237 238 239 240 241 242 243 244
    server = AsrTCPServer(
        server_address=(args.host_ip, args.host_port),
        RequestHandlerClass=AsrRequestHandler,
        speech_save_dir=args.speech_save_dir,
        audio_process_handler=file_to_transcript)
    print("ASR Server Started.")
    server.serve_forever()


def main():
245
    print_arguments(args)
246 247 248 249 250 251
    paddle.init(use_gpu=args.use_gpu, trainer_count=1)
    start_server()


if __name__ == "__main__":
    main()