提交 5a318e99 编写于 作者: Y Yibing Liu

adapt to the new folder structure of DS2

上级 efc5d9b1
......@@ -12,9 +12,9 @@ python -u infer.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -3,7 +3,7 @@
pushd ../..
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python -u evaluate.py \
python -u test.py \
--batch_size=128 \
--trainer_count=8 \
--beam_size=500 \
......@@ -12,9 +12,9 @@ python -u evaluate.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -84,6 +84,8 @@ def infer():
use_gru=args.use_gru,
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
result_transcripts = ds2_model.infer_batch(
infer_data=infer_data,
decoding_method=args.decoding_method,
......@@ -91,7 +93,7 @@ def infer():
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
......
......@@ -8,8 +8,9 @@ import os
import time
import gzip
import paddle.v2 as paddle
from lm.lm_scorer import LmScorer
from models.decoder import ctc_greedy_decoder, ctc_beam_search_decoder
from models.swig_decoders_wrapper import Scorer
from models.swig_decoders_wrapper import ctc_greedy_decoder
from models.swig_decoders_wrapper import ctc_beam_search_decoder_batch
from models.network import deep_speech_v2_network
......@@ -199,9 +200,12 @@ class DeepSpeech2Model(object):
elif decoding_method == "ctc_beam_search":
# initialize external scorer
if self._ext_scorer == None:
self._ext_scorer = LmScorer(beam_alpha, beam_beta,
self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path)
self._loaded_lm_path = language_model_path
self._ext_scorer.set_char_map(vocab_list)
if (not self._ext_scorer.is_character_based()):
self._ext_scorer.fill_dictionary(True)
else:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
assert self._loaded_lm_path == language_model_path
......
......@@ -10,7 +10,7 @@
#include "fst/fstlib.h"
#include "path_trie.h"
std::string ctc_best_path_decoder(std::vector<std::vector<double>> probs_seq,
std::string ctc_greedy_decoder(std::vector<std::vector<double>> probs_seq,
std::vector<std::string> vocabulary) {
// dimension check
int num_time_steps = probs_seq.size();
......
......@@ -16,7 +16,7 @@
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::string ctc_best_path_decoder(std::vector<std::vector<double>> probs_seq,
std::string ctc_greedy_decoder(std::vector<std::vector<double>> probs_seq,
std::vector<std::string> vocabulary);
/* CTC Beam Search Decoder
......
......@@ -23,7 +23,7 @@ class Scorer(swig_decoders.Scorer):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path)
def ctc_best_path_decoder(probs_seq, vocabulary):
def ctc_greedy_decoder(probs_seq, vocabulary):
"""Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time
......@@ -35,7 +35,7 @@ def ctc_best_path_decoder(probs_seq, vocabulary):
:return: Decoding result string.
:rtype: basestring
"""
return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary)
return swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
def ctc_beam_search_decoder(probs_seq,
......
......@@ -85,6 +85,7 @@ def evaluate():
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_sum, num_ins = 0.0, 0
for infer_data in batch_reader():
......@@ -95,7 +96,7 @@ def evaluate():
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
target_transcripts = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册