提交 04d9db19 编写于 作者: H huangyuxin

add blank_id parameter

上级 48438066
......@@ -35,7 +35,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
Scorer *ext_scorer,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
......@@ -48,7 +49,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// assign blank id
// size_t blank_id = vocabulary.size();
size_t blank_id = 0;
// size_t blank_id = 0;
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
......@@ -57,7 +58,6 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
......@@ -218,7 +218,8 @@ ctc_beam_search_decoder_batch(
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
Scorer *ext_scorer,
size_t blank_id) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(num_processes);
......@@ -234,7 +235,8 @@ ctc_beam_search_decoder_batch(
beam_size,
cutoff_prob,
cutoff_top_n,
ext_scorer));
ext_scorer,
blank_id));
}
// get decoding results
......
......@@ -43,7 +43,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
/* CTC Beam Search Decoder for batch data
......@@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch(
size_t num_processes,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
#endif // CTC_BEAM_SEARCH_DECODER_H_
......@@ -17,17 +17,18 @@
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) {
const std::vector<std::string> &vocabulary,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1,
vocabulary.size(),
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
size_t blank_id = vocabulary.size();
// size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec;
......
......@@ -29,6 +29,7 @@
*/
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary);
const std::vector<std::string>& vocabulary,
size_t blank_id);
#endif // CTC_GREEDY_DECODER_H
......@@ -85,9 +85,8 @@ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# yapf: disable
FILES = [
fn for fn in FILES
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
'unittest.cc'))
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')
or fn.endswith('unittest.cc'))
]
# yapf: enable
......
......@@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary):
def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time
......@@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
:return: Decoding result string.
:rtype: str
"""
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
blank_id)
return result
......@@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
......@@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq,
"""
beam_results = swig_decoders.ctc_beam_search_decoder(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func)
ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results
......@@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list
......@@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split,
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func)
cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
return batch_beam_results
......@@ -141,7 +141,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True):
share_rnn_weights=True,
blank_id=0):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
......@@ -156,7 +157,7 @@ class DeepSpeech2Model(nn.Layer):
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank>
blank_id=blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
......@@ -221,7 +222,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
share_rnn_weights=config.model.share_rnn_weights,
blank_id=config.model.blank_id)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
......@@ -246,7 +248,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights)
share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id)
return model
......@@ -258,7 +261,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True):
share_rnn_weights=True,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
......@@ -266,7 +270,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
share_rnn_weights=share_rnn_weights,
blank_id=blank_id)
def forward(self, audio, audio_len):
"""export model function
......
......@@ -254,6 +254,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
blank_id=0, # index of blank in vocob.txt
))
if config is not None:
config.merge_from_other_cfg(default)
......@@ -268,7 +269,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
use_gru=False,
blank_id=0):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
......@@ -284,7 +286,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank>
blank_id=blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
......@@ -353,7 +355,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru)
use_gru=config.model.use_gru,
blank_id=config.model.blank_id)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
......@@ -380,7 +383,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru)
use_gru=config.use_gru,
blank_id=config.blank_id)
return model
......@@ -394,7 +398,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
use_gru=False,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
......@@ -404,7 +409,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru)
use_gru=use_gru,
blank_id=blank_id)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box):
......
......@@ -136,7 +136,7 @@ class CTCDecoder(nn.Layer):
results = []
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list)
probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id)
results.append(output_transcription)
return results
......@@ -216,7 +216,8 @@ class CTCDecoder(nn.Layer):
num_processes=num_processes,
ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n)
cutoff_top_n=cutoff_top_n,
blank_id=self.blank_id)
results = [result[0][1] for result in beam_search_results]
return results
......
......@@ -40,6 +40,7 @@ model:
rnn_layer_size: 1024
use_gru: True
share_rnn_weights: False
blank_id: 0
training:
n_epoch: 80
......
......@@ -36,17 +36,18 @@ collator:
model:
num_conv_layers: 2
num_rnn_layers: 3
num_rnn_layers: 5
rnn_layer_size: 1024
rnn_direction: forward # [forward, bidirect]
num_fc_layers: 1
fc_layers_size_list: 512,
num_fc_layers: 0
fc_layers_size_list: -1,
use_gru: False
blank_id: 0
training:
n_epoch: 50
lr: 2e-3
lr_decay: 0.91 # 0.83
lr_decay: 0.9 # 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
......@@ -59,7 +60,7 @@ decoding:
error_rate_type: cer
decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 1.9
alpha: 2.2 #1.9
beta: 5.0
beam_size: 300
cutoff_prob: 0.99
......
......@@ -40,6 +40,7 @@ model:
rnn_layer_size: 2048
use_gru: False
share_rnn_weights: True
blank_id: 0
training:
n_epoch: 50
......
......@@ -42,6 +42,7 @@ model:
num_fc_layers: 2
fc_layers_size_list: 512, 256
use_gru: False
blank_id: 0
training:
n_epoch: 50
......
......@@ -41,6 +41,7 @@ model:
rnn_layer_size: 2048
use_gru: False
share_rnn_weights: True
blank_id: 0
training:
n_epoch: 10
......
......@@ -43,6 +43,7 @@ model:
num_fc_layers: 2
fc_layers_size_list: 512, 256
use_gru: True
blank_id: 0
training:
n_epoch: 10
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册