提交 7d0458c7 编写于 作者: Y Yibing Liu

adapt to the new folder structure of DS2

上级 11ede80a
...@@ -12,9 +12,9 @@ python -u infer.py \ ...@@ -12,9 +12,9 @@ python -u infer.py \
--num_conv_layers=2 \ --num_conv_layers=2 \
--num_rnn_layers=3 \ --num_rnn_layers=3 \
--rnn_layer_size=2048 \ --rnn_layer_size=2048 \
--alpha=0.36 \ --alpha=2.15 \
--beta=0.25 \ --beta=0.35 \
--cutoff_prob=0.99 \ --cutoff_prob=1.0 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
pushd ../.. pushd ../..
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python -u evaluate.py \ python -u test.py \
--batch_size=128 \ --batch_size=128 \
--trainer_count=8 \ --trainer_count=8 \
--beam_size=500 \ --beam_size=500 \
...@@ -12,9 +12,9 @@ python -u evaluate.py \ ...@@ -12,9 +12,9 @@ python -u evaluate.py \
--num_conv_layers=2 \ --num_conv_layers=2 \
--num_rnn_layers=3 \ --num_rnn_layers=3 \
--rnn_layer_size=2048 \ --rnn_layer_size=2048 \
--alpha=0.36 \ --alpha=2.15 \
--beta=0.25 \ --beta=0.35 \
--cutoff_prob=0.99 \ --cutoff_prob=1.0 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \
......
...@@ -84,6 +84,8 @@ def infer(): ...@@ -84,6 +84,8 @@ def infer():
use_gru=args.use_gru, use_gru=args.use_gru,
pretrained_model_path=args.model_path, pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights) share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
result_transcripts = ds2_model.infer_batch( result_transcripts = ds2_model.infer_batch(
infer_data=infer_data, infer_data=infer_data,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
...@@ -91,7 +93,7 @@ def infer(): ...@@ -91,7 +93,7 @@ def infer():
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list, vocab_list=vocab_list,
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch) num_processes=args.num_proc_bsearch)
......
...@@ -8,8 +8,9 @@ import os ...@@ -8,8 +8,9 @@ import os
import time import time
import gzip import gzip
import paddle.v2 as paddle import paddle.v2 as paddle
from lm.lm_scorer import LmScorer from models.swig_decoders_wrapper import Scorer
from models.decoder import ctc_greedy_decoder, ctc_beam_search_decoder from models.swig_decoders_wrapper import ctc_greedy_decoder
from models.swig_decoders_wrapper import ctc_beam_search_decoder_batch
from models.network import deep_speech_v2_network from models.network import deep_speech_v2_network
...@@ -199,9 +200,12 @@ class DeepSpeech2Model(object): ...@@ -199,9 +200,12 @@ class DeepSpeech2Model(object):
elif decoding_method == "ctc_beam_search": elif decoding_method == "ctc_beam_search":
# initialize external scorer # initialize external scorer
if self._ext_scorer == None: if self._ext_scorer == None:
self._ext_scorer = LmScorer(beam_alpha, beam_beta, self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path) language_model_path)
self._loaded_lm_path = language_model_path self._loaded_lm_path = language_model_path
self._ext_scorer.set_char_map(vocab_list)
if (not self._ext_scorer.is_character_based()):
self._ext_scorer.fill_dictionary(True)
else: else:
self._ext_scorer.reset_params(beam_alpha, beam_beta) self._ext_scorer.reset_params(beam_alpha, beam_beta)
assert self._loaded_lm_path == language_model_path assert self._loaded_lm_path == language_model_path
......
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
#include "fst/fstlib.h" #include "fst/fstlib.h"
#include "path_trie.h" #include "path_trie.h"
std::string ctc_best_path_decoder(std::vector<std::vector<double>> probs_seq, std::string ctc_greedy_decoder(std::vector<std::vector<double>> probs_seq,
std::vector<std::string> vocabulary) { std::vector<std::string> vocabulary) {
// dimension check // dimension check
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i = 0; i < num_time_steps; i++) { for (int i = 0; i < num_time_steps; i++) {
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
* A vector that each element is a pair of score and decoding result, * A vector that each element is a pair of score and decoding result,
* in desending order. * in desending order.
*/ */
std::string ctc_best_path_decoder(std::vector<std::vector<double>> probs_seq, std::string ctc_greedy_decoder(std::vector<std::vector<double>> probs_seq,
std::vector<std::string> vocabulary); std::vector<std::string> vocabulary);
/* CTC Beam Search Decoder /* CTC Beam Search Decoder
......
...@@ -23,7 +23,7 @@ class Scorer(swig_decoders.Scorer): ...@@ -23,7 +23,7 @@ class Scorer(swig_decoders.Scorer):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path) swig_decoders.Scorer.__init__(self, alpha, beta, model_path)
def ctc_best_path_decoder(probs_seq, vocabulary): def ctc_greedy_decoder(probs_seq, vocabulary):
"""Wrapper for ctc best path decoder in swig. """Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
...@@ -35,7 +35,7 @@ def ctc_best_path_decoder(probs_seq, vocabulary): ...@@ -35,7 +35,7 @@ def ctc_best_path_decoder(probs_seq, vocabulary):
:return: Decoding result string. :return: Decoding result string.
:rtype: basestring :rtype: basestring
""" """
return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary) return swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,
......
...@@ -85,6 +85,7 @@ def evaluate(): ...@@ -85,6 +85,7 @@ def evaluate():
pretrained_model_path=args.model_path, pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights) share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
error_rate_func = cer if args.error_rate_type == 'cer' else wer error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_sum, num_ins = 0.0, 0 error_sum, num_ins = 0.0, 0
for infer_data in batch_reader(): for infer_data in batch_reader():
...@@ -95,7 +96,7 @@ def evaluate(): ...@@ -95,7 +96,7 @@ def evaluate():
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list, vocab_list=vocab_list,
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch) num_processes=args.num_proc_bsearch)
target_transcripts = [ target_transcripts = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册