提交 12f540cd 编写于 作者: H Hui Zhang

ctc decoder with blankid

上级 e411e0bd
...@@ -35,7 +35,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -35,7 +35,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer,
size_t blank_id) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) { for (size_t i = 0; i < num_time_steps; ++i) {
...@@ -45,19 +46,13 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -45,19 +46,13 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
"the shape of the vocabulary"); "the shape of the vocabulary");
} }
// assign blank id
// size_t blank_id = vocabulary.size();
size_t blank_id = 0;
// assign space id // assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin(); int space_id = it - vocabulary.begin();
// if no space in vocabulary // if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) { if ((size_t)space_id >= vocabulary.size()) {
space_id = -2; space_id = -2;
} }
// init prefixes' root // init prefixes' root
PathTrie root; PathTrie root;
root.score = root.log_prob_b_prev = 0.0; root.score = root.log_prob_b_prev = 0.0;
...@@ -218,7 +213,8 @@ ctc_beam_search_decoder_batch( ...@@ -218,7 +213,8 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer,
size_t blank_id) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool // thread pool
ThreadPool pool(num_processes); ThreadPool pool(num_processes);
...@@ -234,7 +230,8 @@ ctc_beam_search_decoder_batch( ...@@ -234,7 +230,8 @@ ctc_beam_search_decoder_batch(
beam_size, beam_size,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
ext_scorer)); ext_scorer,
blank_id));
} }
// get decoding results // get decoding results
......
...@@ -43,7 +43,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -43,7 +43,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr); Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
...@@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch( ...@@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr); Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_
...@@ -17,17 +17,18 @@ ...@@ -17,17 +17,18 @@
std::string ctc_greedy_decoder( std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) { const std::vector<std::string> &vocabulary,
size_t blank_id) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) { for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(), VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1, vocabulary.size(),
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
"the shape of the vocabulary"); "the shape of the vocabulary");
} }
size_t blank_id = vocabulary.size(); // size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0); std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec; std::vector<size_t> idx_vec;
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
*/ */
std::string ctc_greedy_decoder( std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq, const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary); const std::vector<std::string>& vocabulary,
size_t blank_id);
#endif // CTC_GREEDY_DECODER_H #endif // CTC_GREEDY_DECODER_H
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
#ifndef DECODER_UTILS_H_ #ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_ #define DECODER_UTILS_H_
#include <string>
#include <utility> #include <utility>
#include "fst/log.h" #include "fst/log.h"
#include "path_trie.h" #include "path_trie.h"
const std::string kSPACE = "<space>";
const float NUM_FLT_INF = std::numeric_limits<float>::max(); const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min(); const float NUM_FLT_MIN = std::numeric_limits<float>::min();
......
...@@ -165,7 +165,7 @@ void Scorer::set_char_map(const std::vector<std::string>& char_list) { ...@@ -165,7 +165,7 @@ void Scorer::set_char_map(const std::vector<std::string>& char_list) {
// Set the char map for the FST for spelling correction // Set the char map for the FST for spelling correction
for (size_t i = 0; i < char_list_.size(); i++) { for (size_t i = 0; i < char_list_.size(); i++) {
if (char_list_[i] == " ") { if (char_list_[i] == kSPACE) {
SPACE_ID_ = i; SPACE_ID_ = i;
} }
// The initial state of FST is state 0, hence the index of chars in // The initial state of FST is state 0, hence the index of chars in
......
...@@ -83,10 +83,12 @@ FILES = glob.glob('kenlm/util/*.cc') \ ...@@ -83,10 +83,12 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# yapf: disable
FILES = [ FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc') fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')
or fn.endswith('unittest.cc')) or fn.endswith('unittest.cc'))
] ]
# yapf: enable
LIBS = ['stdc++'] LIBS = ['stdc++']
if platform.system() != 'Darwin': if platform.system() != 'Darwin':
......
...@@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer): ...@@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary): def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decoder in swig. """Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
...@@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): ...@@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
:return: Decoding result string. :return: Decoding result string.
:rtype: str :rtype: str
""" """
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
blank_id)
return result return result
...@@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None): ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoder. """Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
...@@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq,
""" """
beam_results = swig_decoders.ctc_beam_search_decoder( beam_results = swig_decoders.ctc_beam_search_decoder(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func) ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results return beam_results
...@@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split, ...@@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split,
num_processes, num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None): ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decoder. """Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list :param probs_seq: 3-D list with each element as an instance of 2-D list
...@@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split, ...@@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split,
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob, probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func) cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results] batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results] for beam_results in batch_beam_results]
return batch_beam_results return batch_beam_results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册