提交 cb9370f3 编写于 作者: X Xinghai Sun

Add warming-up to demo_server.py for DS2 and clean codes.

上级 0ebf36b9
...@@ -64,8 +64,6 @@ class AudioSegment(object): ...@@ -64,8 +64,6 @@ class AudioSegment(object):
:rtype: AudioSegment :rtype: AudioSegment
""" """
samples, sample_rate = soundfile.read(file, dtype='float32') samples, sample_rate = soundfile.read(file, dtype='float32')
print(samples)
print(sample_rate)
return cls(samples, sample_rate) return cls(samples, sample_rate)
@classmethod @classmethod
......
import os import os
import time import time
import random
import argparse import argparse
import distutils.util import distutils.util
from time import gmtime, strftime from time import gmtime, strftime
...@@ -8,9 +9,10 @@ import struct ...@@ -8,9 +9,10 @@ import struct
import wave import wave
import pyaudio import pyaudio
import paddle.v2 as paddle import paddle.v2 as paddle
from utils import print_arguments
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import DeepSpeech2Model from model import DeepSpeech2Model
import utils from data_utils.utils import read_manifest
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
...@@ -38,6 +40,11 @@ parser.add_argument( ...@@ -38,6 +40,11 @@ parser.add_argument(
default='mean_std.npz', default='mean_std.npz',
type=str, type=str,
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--warmup_manifest_path",
default='datasets/manifest.test',
type=str,
help="Manifest path for warmup test. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--specgram_type", "--specgram_type",
default='linear', default='linear',
...@@ -77,7 +84,7 @@ parser.add_argument( ...@@ -77,7 +84,7 @@ parser.add_argument(
"(default: %(default)s)") "(default: %(default)s)")
parser.add_argument( parser.add_argument(
"--beam_size", "--beam_size",
default=500, default=100,
type=int, type=int,
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument( parser.add_argument(
...@@ -134,7 +141,6 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler): ...@@ -134,7 +141,6 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler):
print("Received utterance[length=%d] from %s, saved to %s." % print("Received utterance[length=%d] from %s, saved to %s." %
(len(data), self.client_address[0], filename)) (len(data), self.client_address[0], filename))
#filename = "/home/work/.cache/paddle/dataset/speech/Libri/train-other-500/LibriSpeech/train-other-500/811/130143/811-130143-0025.flac"
start_time = time.time() start_time = time.time()
transcript = self.server.audio_process_handler(filename) transcript = self.server.audio_process_handler(filename)
finish_time = time.time() finish_time = time.time()
...@@ -149,7 +155,7 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler): ...@@ -149,7 +155,7 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler):
timestamp = strftime("%Y%m%d%H%M%S", gmtime()) timestamp = strftime("%Y%m%d%H%M%S", gmtime())
out_filename = os.path.join( out_filename = os.path.join(
self.server.speech_save_dir, self.server.speech_save_dir,
timestamp + "_" + self.client_address[0] + "_" + ".wav") timestamp + "_" + self.client_address[0] + ".wav")
# write to wav file # write to wav file
file = wave.open(out_filename, 'wb') file = wave.open(out_filename, 'wb')
file.setnchannels(1) file.setnchannels(1)
...@@ -160,6 +166,22 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler): ...@@ -160,6 +166,22 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler):
return out_filename return out_filename
def warm_up_test(audio_process_handler,
manifest_path,
num_test_cases,
random_seed=0):
manifest = read_manifest(manifest_path)
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
def start_server(): def start_server():
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
...@@ -188,6 +210,14 @@ def start_server(): ...@@ -188,6 +210,14 @@ def start_server():
num_processes=1) num_processes=1)
return result_transcript[0] return result_transcript[0]
print('-----------------------------------------------------------')
print('Warming up ...')
warm_up_test(
audio_process_handler=file_to_transcript,
manifest_path=args.warmup_manifest_path,
num_test_cases=3)
print('-----------------------------------------------------------')
server = AsrTCPServer( server = AsrTCPServer(
server_address=(args.host_ip, args.host_port), server_address=(args.host_ip, args.host_port),
RequestHandlerClass=AsrRequestHandler, RequestHandlerClass=AsrRequestHandler,
...@@ -199,7 +229,7 @@ def start_server(): ...@@ -199,7 +229,7 @@ def start_server():
def main(): def main():
utils.print_arguments(args) print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=1) paddle.init(use_gpu=args.use_gpu, trainer_count=1)
start_server() start_server()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册