提交 c753b9dd 编写于 作者: H Haoxin Ma

fix runtime.py and server.py

上级 d55e6b5a
......@@ -81,15 +81,15 @@ def inference(config, args):
def start_server(config, args):
"""Start the ASR server"""
config.defrost()
config.data.manfiest = config.data.test_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
config.data.manifest = config.data.test_manifest
dataset = ManifestDataset.from_config(config)
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
config.collator.batch_size=1
config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0)
test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path)
......@@ -97,15 +97,15 @@ def start_server(config, args):
# prepare ASR inference handler
def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
audio_len = feature[0].shape[1]
feature = collate_fn.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, T, D]
audio_len = feature[0].shape[0]
audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode(
paddle.to_tensor(audio),
paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list,
vocab_list=test_loader.collate_fn.vocab_list,
decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha,
......@@ -146,7 +146,7 @@ if __name__ == "__main__":
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('host_port', int, 8089, "Server's IP port.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
......
......@@ -34,15 +34,15 @@ from deepspeech.io.collator import SpeechCollator
def start_server(config, args):
"""Start the ASR server"""
config.defrost()
config.data.manfiest = config.data.test_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
config.data.manifest = config.data.test_manifest
dataset = ManifestDataset.from_config(config)
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
config.collator.batch_size=1
config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0)
test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path)
......@@ -50,15 +50,19 @@ def start_server(config, args):
# prepare ASR inference handler
def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
audio_len = feature[0].shape[1]
feature = test_loader.collate_fn.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, T, D]
# audio = audio.swapaxes(1,2)
print('---file_to_transcript feature----')
print(audio.shape)
audio_len = feature[0].shape[0]
print(audio_len)
audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode(
paddle.to_tensor(audio),
paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list,
vocab_list=test_loader.collate_fn.vocab_list,
decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha,
......@@ -99,7 +103,7 @@ if __name__ == "__main__":
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('host_port', int, 8088, "Server's IP port.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
......
......@@ -242,6 +242,7 @@ class SpeechCollator():
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram=specgram.transpose([1,0])
return specgram, transcript_part
def __call__(self, batch):
......@@ -269,7 +270,7 @@ class SpeechCollator():
#utt
utts.append(utt)
# audio
audios.append(audio.T) # [T, D]
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[1])
# text
# for training, text is token ids
......
......@@ -48,9 +48,9 @@ def warm_up_test(audio_process_handler,
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
print("Warm-up Test Case %d: %s"%(idx, sample['feat']))
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
transcript = audio_process_handler(sample['feat'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册