提交 aee3e11f 编写于 作者: X Xinghai Sun

Add warming-up to demo_server.py for DS2 and clean codes.

上级 ae84c6f6
......@@ -64,8 +64,6 @@ class AudioSegment(object):
:rtype: AudioSegment
"""
samples, sample_rate = soundfile.read(file, dtype='float32')
print(samples)
print(sample_rate)
return cls(samples, sample_rate)
@classmethod
......
import os
import time
import random
import argparse
import distutils.util
from time import gmtime, strftime
......@@ -8,9 +9,10 @@ import struct
import wave
import pyaudio
import paddle.v2 as paddle
from utils import print_arguments
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
import utils
from data_utils.utils import read_manifest
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
......@@ -38,6 +40,11 @@ parser.add_argument(
default='mean_std.npz',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--warmup_manifest_path",
default='datasets/manifest.test',
type=str,
help="Manifest path for warmup test. (default: %(default)s)")
parser.add_argument(
"--specgram_type",
default='linear',
......@@ -77,7 +84,7 @@ parser.add_argument(
"(default: %(default)s)")
parser.add_argument(
"--beam_size",
default=500,
default=100,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
......@@ -134,7 +141,6 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler):
print("Received utterance[length=%d] from %s, saved to %s." %
(len(data), self.client_address[0], filename))
#filename = "/home/work/.cache/paddle/dataset/speech/Libri/train-other-500/LibriSpeech/train-other-500/811/130143/811-130143-0025.flac"
start_time = time.time()
transcript = self.server.audio_process_handler(filename)
finish_time = time.time()
......@@ -149,7 +155,7 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler):
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
out_filename = os.path.join(
self.server.speech_save_dir,
timestamp + "_" + self.client_address[0] + "_" + ".wav")
timestamp + "_" + self.client_address[0] + ".wav")
# write to wav file
file = wave.open(out_filename, 'wb')
file.setnchannels(1)
......@@ -160,6 +166,22 @@ class AsrRequestHandler(SocketServer.BaseRequestHandler):
return out_filename
def warm_up_test(audio_process_handler,
manifest_path,
num_test_cases,
random_seed=0):
manifest = read_manifest(manifest_path)
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
def start_server():
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
......@@ -188,6 +210,14 @@ def start_server():
num_processes=1)
return result_transcript[0]
print('-----------------------------------------------------------')
print('Warming up ...')
warm_up_test(
audio_process_handler=file_to_transcript,
manifest_path=args.warmup_manifest_path,
num_test_cases=3)
print('-----------------------------------------------------------')
server = AsrTCPServer(
server_address=(args.host_ip, args.host_port),
RequestHandlerClass=AsrRequestHandler,
......@@ -199,7 +229,7 @@ def start_server():
def main():
utils.print_arguments(args)
print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
start_server()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册