提交 12f540cd 编写于 作者: H Hui Zhang

ctc decoder with blankid

上级 e411e0bd
......@@ -35,7 +35,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
Scorer *ext_scorer,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
......@@ -45,19 +46,13 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// assign blank id
// size_t blank_id = vocabulary.size();
size_t blank_id = 0;
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
......@@ -218,7 +213,8 @@ ctc_beam_search_decoder_batch(
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
Scorer *ext_scorer,
size_t blank_id) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(num_processes);
......@@ -234,7 +230,8 @@ ctc_beam_search_decoder_batch(
beam_size,
cutoff_prob,
cutoff_top_n,
ext_scorer));
ext_scorer,
blank_id));
}
// get decoding results
......
......@@ -43,7 +43,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
/* CTC Beam Search Decoder for batch data
......@@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch(
size_t num_processes,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
#endif // CTC_BEAM_SEARCH_DECODER_H_
......@@ -17,17 +17,18 @@
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) {
const std::vector<std::string> &vocabulary,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1,
vocabulary.size(),
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
size_t blank_id = vocabulary.size();
// size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec;
......
......@@ -29,6 +29,7 @@
*/
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary);
const std::vector<std::string>& vocabulary,
size_t blank_id);
#endif // CTC_GREEDY_DECODER_H
......@@ -15,10 +15,12 @@
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_
#include <string>
#include <utility>
#include "fst/log.h"
#include "path_trie.h"
const std::string kSPACE = "<space>";
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
......
......@@ -165,7 +165,7 @@ void Scorer::set_char_map(const std::vector<std::string>& char_list) {
// Set the char map for the FST for spelling correction
for (size_t i = 0; i < char_list_.size(); i++) {
if (char_list_[i] == " ") {
if (char_list_[i] == kSPACE) {
SPACE_ID_ = i;
}
// The initial state of FST is state 0, hence the index of chars in
......
......@@ -83,10 +83,12 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# yapf: disable
FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')
or fn.endswith('unittest.cc'))
]
# yapf: enable
LIBS = ['stdc++']
if platform.system() != 'Darwin':
......
......@@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary):
def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time
......@@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
:return: Decoding result string.
:rtype: str
"""
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
blank_id)
return result
......@@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
......@@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq,
"""
beam_results = swig_decoders.ctc_beam_search_decoder(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func)
ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results
......@@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list
......@@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split,
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func)
cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
return batch_beam_results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册