提交 a8c3de0c 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #341 from pkuyym/fix-340

Fix bugs for demo_server.py.
"""Set up paths for DS2"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import sys
def add_path(path):
if path not in sys.path:
sys.path.insert(0, path)
this_dir = os.path.dirname(__file__)
# Add project path to PYTHONPATH
proj_path = os.path.join(this_dir, '..')
add_path(proj_path)
...@@ -12,7 +12,7 @@ import paddle.v2 as paddle ...@@ -12,7 +12,7 @@ import paddle.v2 as paddle
import _init_paths import _init_paths
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model from model_utils.model import DeepSpeech2Model
from data_utils.utils import read_manifest from data_utils.utility import read_manifest
from utils.utility import add_arguments, print_arguments from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -23,9 +23,10 @@ add_arg('beam_size', int, 500, "Beam search width.") ...@@ -23,9 +23,10 @@ add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.") add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 0.36, "Coef of LM for beam search.") add_arg('alpha', float, 2.15, "Coef of LM for beam search.")
add_arg('beta', float, 0.25, "Coef of WC for beam search.") add_arg('beta', float, 0.35, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
...@@ -156,6 +157,8 @@ def start_server(): ...@@ -156,6 +157,8 @@ def start_server():
pretrained_model_path=args.model_path, pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights) share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
# prepare ASR inference handler # prepare ASR inference handler
def file_to_transcript(filename): def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "") feature = data_generator.process_utterance(filename, "")
...@@ -166,7 +169,8 @@ def start_server(): ...@@ -166,7 +169,8 @@ def start_server():
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list, cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=1) num_processes=1)
return result_transcript[0] return result_transcript[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册