提交 17ebb40a 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #139 from kuke/ctc_decoder_deploy

Add optimized decoders for DS2
#!/usr/bin/env bash
set -e
readonly VERSION="3.8"
readonly VERSION="3.9"
version=$(clang-format -version)
......
......@@ -42,8 +42,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
blank_id,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
nproc=False):
"""CTC Beam search decoder.
......@@ -66,8 +66,6 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
......@@ -87,9 +85,8 @@ def ctc_beam_search_decoder(probs_seq,
raise ValueError("The shape of prob_seq does not match with the "
"shape of the vocabulary.")
# blank_id check
if not blank_id < len(probs_seq[0]):
raise ValueError("blank_id shouldn't be greater than probs dimension")
# blank_id assign
blank_id = len(vocabulary)
# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch().
......@@ -114,7 +111,7 @@ def ctc_beam_search_decoder(probs_seq,
prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx)
#If pruning is enabled
if cutoff_prob < 1.0:
if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0
for i in xrange(len(prob_idx)):
......@@ -122,6 +119,7 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
cutoff_len = min(cutoff_len, cutoff_top_n)
prob_idx = prob_idx[0:cutoff_len]
for l in prefix_set_prev:
......@@ -191,9 +189,9 @@ def ctc_beam_search_decoder(probs_seq,
def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
blank_id,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
"""CTC beam search decoder using multiple processes.
......@@ -204,8 +202,6 @@ def ctc_beam_search_decoder_batch(probs_split,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
......@@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
nproc)
args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
None, nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close()
......
......@@ -8,7 +8,7 @@ import kenlm
import numpy as np
class LmScorer(object):
class Scorer(object):
"""External scorer to evaluate a prefix or whole sentence in
beam search decoding, including the score from n-gram language
model and word count.
......
#include "ctc_beam_search_decoder.h"
#include <algorithm>
#include <cmath>
#include <iostream>
#include <limits>
#include <map>
#include <utility>
#include "ThreadPool.h"
#include "fst/fstlib.h"
#include "decoder_utils.h"
#include "path_trie.h"
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1,
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// assign blank id
size_t blank_id = vocabulary.size();
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
std::vector<PathTrie *> prefixes;
prefixes.push_back(&root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
// prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
auto &prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}
std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_toscore = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_toscore = prefix_new;
} else {
prefix_toscore = prefix;
}
double score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_toscore);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
}
} // end of loop over time
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
return get_beam_search_result(prefixes, vocabulary, beam_size);
}
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(num_processes);
// number of samples
size_t batch_size = probs_split.size();
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i],
vocabulary,
beam_size,
cutoff_prob,
cutoff_top_n,
ext_scorer));
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#include <string>
#include <utility>
#include <vector>
#include "scorer.h"
/* CTC Beam Search Decoder
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
/* CTC Beam Search Decoder for batch data
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary,
size_t beam_size,
size_t num_processes,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
#endif // CTC_BEAM_SEARCH_DECODER_H_
#include "ctc_greedy_decoder.h"
#include "decoder_utils.h"
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1,
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec;
for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0;
size_t max_idx = 0;
const std::vector<double> &probs_step = probs_seq[i];
for (size_t j = 0; j < probs_step.size(); ++j) {
if (max_prob < probs_step[j]) {
max_idx = j;
max_prob = probs_step[j];
}
}
// id with maximum probability in current time step
max_idx_vec[i] = max_idx;
// deduplicate
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]);
}
}
std::string best_path_result;
for (size_t i = 0; i < idx_vec.size(); ++i) {
if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]];
}
}
return best_path_result;
}
#ifndef CTC_GREEDY_DECODER_H
#define CTC_GREEDY_DECODER_H
#include <string>
#include <vector>
/* CTC Greedy (Best Path) Decoder
*
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* The decoding result in string
*/
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary);
#endif // CTC_GREEDY_DECODER_H
#include "decoder_utils.h"
#include <algorithm>
#include <cmath>
#include <limits>
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step,
double cutoff_prob,
size_t cutoff_top_n) {
std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob_step.size(); ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
}
// pruning of vacobulary
size_t cutoff_len = prob_step.size();
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) {
double cum_prob = 0.0;
cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break;
}
}
prob_idx = std::vector<std::pair<int, double>>(
prob_idx.begin(), prob_idx.begin() + cutoff_len);
}
std::vector<std::pair<size_t, float>> log_prob_idx;
for (size_t i = 0; i < cutoff_len; ++i) {
log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
}
return log_prob_idx;
}
std::vector<std::pair<double, std::string>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size) {
// allow for the post processing
std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]);
}
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::string>> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
std::vector<int> output;
space_prefixes[i]->get_path_vec(output);
// convert index to string
std::string output_str;
for (size_t j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]];
}
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
output_str);
output_vecs.emplace_back(output_pair);
}
return output_vecs;
}
size_t get_utf8_str_len(const std::string &str) {
size_t str_len = 0;
for (char c : str) {
str_len += ((c & 0xc0) != 0x80);
}
return str_len;
}
std::vector<std::string> split_utf8_str(const std::string &str) {
std::vector<std::string> result;
std::string out_str;
for (char c : str) {
if ((c & 0xc0) != 0x80) // new UTF-8 character
{
if (!out_str.empty()) {
result.push_back(out_str);
out_str.clear();
}
}
out_str.append(1, c);
}
result.push_back(out_str);
return result;
}
std::vector<std::string> split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size();
while (true) {
std::size_t end = s.find(delim, start);
if (end == std::string::npos) {
if (start < s.size()) {
result.push_back(s.substr(start));
}
break;
}
if (end > start) {
result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
}
return result;
}
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
if (x->score == y->score) {
if (x->character == y->character) {
return false;
} else {
return (x->character < y->character);
}
} else {
return x->score > y->score;
}
}
void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary) {
if (dictionary->NumStates() == 0) {
fst::StdVectorFst::StateId start = dictionary->AddState();
assert(start == 0);
dictionary->SetStart(start);
}
fst::StdVectorFst::StateId src = dictionary->Start();
fst::StdVectorFst::StateId dst;
for (auto c : word) {
dst = dictionary->AddState();
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
src = dst;
}
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
}
bool add_word_to_dictionary(
const std::string &word,
const std::unordered_map<std::string, int> &char_map,
bool add_space,
int SPACE_ID,
fst::StdVectorFst *dictionary) {
auto characters = split_utf8_str(word);
std::vector<int> int_word;
for (auto &c : characters) {
if (c == " ") {
int_word.push_back(SPACE_ID);
} else {
auto int_c = char_map.find(c);
if (int_c != char_map.end()) {
int_word.push_back(int_c->second);
} else {
return false; // return without adding
}
}
}
if (add_space) {
int_word.push_back(SPACE_ID);
}
add_word_to_fst(int_word, dictionary);
return true; // return with successful adding
}
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_
#include <utility>
#include "fst/log.h"
#include "path_trie.h"
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// inline function for validation check
inline void check(
bool x, const char *expr, const char *file, int line, const char *err) {
if (!x) {
std::cout << "[" << file << ":" << line << "] ";
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
}
}
#define VALID_CHECK(x, info) \
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) {
return a.first > b.first;
}
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) {
return a.second > b.second;
}
// Return the sum of two probabilities in log scale
template <typename T>
T log_sum_exp(const T &x, const T &y) {
static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y;
if (y <= num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
}
// Get pruned probability vector for each time step's beam search
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step,
double cutoff_prob,
size_t cutoff_top_n);
// Get beam search result from prefixes in trie tree
std::vector<std::pair<double, std::string>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size);
// Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y);
/* Get length of utf8 encoding string
* See: http://stackoverflow.com/a/4063229
*/
size_t get_utf8_str_len(const std::string &str);
/* Split a string into a list of strings on a given string
* delimiter. NB: delimiters on beginning / end of string are
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
*/
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
/* Splits string into vector of strings representing
* UTF-8 characters (not same as chars)
*/
std::vector<std::string> split_utf8_str(const std::string &str);
// Add a word in index to the dicionary of fst
void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary);
// Add a word in string to dictionary
bool add_word_to_dictionary(
const std::string &word,
const std::unordered_map<std::string, int> &char_map,
bool add_space,
int SPACE_ID,
fst::StdVectorFst *dictionary);
#endif // DECODER_UTILS_H
%module swig_decoders
%{
#include "scorer.h"
#include "ctc_greedy_decoder.h"
#include "ctc_beam_search_decoder.h"
#include "decoder_utils.h"
%}
%include "std_vector.i"
%include "std_pair.i"
%include "std_string.i"
%import "decoder_utils.h"
namespace std {
%template(DoubleVector) std::vector<double>;
%template(IntVector) std::vector<int>;
%template(StringVector) std::vector<std::string>;
%template(VectorOfStructVector) std::vector<std::vector<double> >;
%template(FloatVector) std::vector<float>;
%template(Pair) std::pair<float, std::string>;
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
%template(PairDoubleStringVector2) std::vector<std::vector<std::pair<double, std::string> > >;
%template(DoubleVector3) std::vector<std::vector<std::vector<double> > >;
}
%template(IntDoublePairCompSecondRev) pair_comp_second_rev<int, double>;
%template(StringDoublePairCompSecondRev) pair_comp_second_rev<std::string, double>;
%template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>;
%include "scorer.h"
%include "ctc_greedy_decoder.h"
%include "ctc_beam_search_decoder.h"
#include "path_trie.h"
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "decoder_utils.h"
PathTrie::PathTrie() {
log_prob_b_prev = -NUM_FLT_INF;
log_prob_nb_prev = -NUM_FLT_INF;
log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF;
score = -NUM_FLT_INF;
ROOT_ = -1;
character = ROOT_;
exists_ = true;
parent = nullptr;
dictionary_ = nullptr;
dictionary_state_ = 0;
has_dictionary_ = false;
matcher_ = nullptr;
}
PathTrie::~PathTrie() {
for (auto child : children_) {
delete child.second;
}
}
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) {
break;
}
}
if (child != children_.end()) {
if (!child->second->exists_) {
child->second->exists_ = true;
child->second->log_prob_b_prev = -NUM_FLT_INF;
child->second->log_prob_nb_prev = -NUM_FLT_INF;
child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->log_prob_nb_cur = -NUM_FLT_INF;
}
return (child->second);
} else {
if (has_dictionary_) {
matcher_->SetState(dictionary_state_);
bool found = matcher_->Find(new_char);
if (!found) {
// Adding this character causes word outside dictionary
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = dictionary_->Final(dictionary_state_);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
dictionary_state_ = dictionary_->Start();
}
return nullptr;
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
new_path->dictionary_ = dictionary_;
new_path->dictionary_state_ = matcher_->Value().nextstate;
new_path->has_dictionary_ = true;
new_path->matcher_ = matcher_;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
}
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
return get_path_vec(output, ROOT_);
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
int stop,
size_t max_steps) {
if (character == stop || character == ROOT_ || output.size() == max_steps) {
std::reverse(output.begin(), output.end());
return this;
} else {
output.push_back(character);
return parent->get_path_vec(output, stop, max_steps);
}
}
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
if (exists_) {
log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur;
log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF;
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
output.push_back(this);
}
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
}
void PathTrie::remove() {
exists_ = false;
if (children_.size() == 0) {
auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end();
++child) {
if (child->first == character) {
parent->children_.erase(child);
break;
}
}
if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove();
}
delete this;
}
}
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
dictionary_ = dictionary;
dictionary_state_ = dictionary->Start();
has_dictionary_ = true;
}
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
matcher_ = matcher;
}
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "fst/fstlib.h"
/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
*/
class PathTrie {
public:
PathTrie();
~PathTrie();
// get new prefix after appending new char
PathTrie* get_path_trie(int new_char, bool reset = true);
// get the prefix in index from root to current node
PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output,
int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
// update log probs
void iterate_to_vec(std::vector<PathTrie*>& output);
// set dictionary for FST
void set_dictionary(fst::StdVectorFst* dictionary);
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
bool is_empty() { return ROOT_ == character; }
// remove current path from root
void remove();
float log_prob_b_prev;
float log_prob_nb_prev;
float log_prob_b_cur;
float log_prob_nb_cur;
float score;
float approx_ctc;
int character;
PathTrie* parent;
private:
int ROOT_;
bool exists_;
bool has_dictionary_;
std::vector<std::pair<int, PathTrie*>> children_;
// pointer to dictionary of FST
fst::StdVectorFst* dictionary_;
fst::StdVectorFst::StateId dictionary_state_;
// true if finding ars in FST
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
};
#endif // PATH_TRIE_H
#include "scorer.h"
#include <unistd.h>
#include <iostream>
#include "lm/config.hh"
#include "lm/model.hh"
#include "lm/state.hh"
#include "util/string_piece.hh"
#include "util/tokenize_piece.hh"
#include "decoder_utils.h"
using namespace lm::ngram;
Scorer::Scorer(double alpha,
double beta,
const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
this->alpha = alpha;
this->beta = beta;
dictionary = nullptr;
is_character_based_ = true;
language_model_ = nullptr;
max_order_ = 0;
dict_size_ = 0;
SPACE_ID_ = -1;
setup(lm_path, vocab_list);
}
Scorer::~Scorer() {
if (language_model_ != nullptr) {
delete static_cast<lm::base::Model*>(language_model_);
}
if (dictionary != nullptr) {
delete static_cast<fst::StdVectorFst*>(dictionary);
}
}
void Scorer::setup(const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
// load language model
load_lm(lm_path);
// set char map for scorer
set_char_map(vocab_list);
// fill the dictionary for FST
if (!is_character_based()) {
fill_dictionary(true);
}
}
void Scorer::load_lm(const std::string& lm_path) {
const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
RetriveStrEnumerateVocab enumerate;
lm::ngram::Config config;
config.enumerate_vocab = &enumerate;
language_model_ = lm::ngram::LoadVirtual(filename, config);
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
vocabulary_ = enumerate.vocabulary;
for (size_t i = 0; i < vocabulary_.size(); ++i) {
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
is_character_based_ = false;
}
}
}
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
double cond_prob;
lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin
model->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV
if (word_index == 0) {
return OOV_SCORE;
}
cond_prob = model->BaseScore(&state, word_index, &out_state);
tmp_state = state;
state = out_state;
out_state = tmp_state;
}
// return log10 prob
return cond_prob;
}
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
std::vector<std::string> sentence;
if (words.size() == 0) {
for (size_t i = 0; i < max_order_; ++i) {
sentence.push_back(START_TOKEN);
}
} else {
for (size_t i = 0; i < max_order_ - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
}
sentence.push_back(END_TOKEN);
return get_log_prob(sentence);
}
double Scorer::get_log_prob(const std::vector<std::string>& words) {
assert(words.size() > max_order_);
double score = 0.0;
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + max_order_);
score += get_log_cond_prob(ngram);
}
return score;
}
void Scorer::reset_params(float alpha, float beta) {
this->alpha = alpha;
this->beta = beta;
}
std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word;
for (auto ind : input) {
word += char_list_[ind];
}
return word;
}
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
if (labels.empty()) return {};
std::string s = vec2str(labels);
std::vector<std::string> words;
if (is_character_based_) {
words = split_utf8_str(s);
} else {
words = split_str(s, " ");
}
return words;
}
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
char_list_ = char_list;
char_map_.clear();
for (size_t i = 0; i < char_list_.size(); i++) {
if (char_list_[i] == " ") {
SPACE_ID_ = i;
char_map_[' '] = i;
} else if (char_list_[i].size() == 1) {
char_map_[char_list_[i][0]] = i;
}
}
}
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<std::string> ngram;
PathTrie* current_node = prefix;
PathTrie* new_node = nullptr;
for (int order = 0; order < max_order_; order++) {
std::vector<int> prefix_vec;
if (is_character_based_) {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
current_node = new_node;
} else {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
current_node = new_node->parent; // Skipping spaces
}
// reconstruct word
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
return ngram;
}
void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary;
// First reverse char_list so ints can be accessed by chars
std::unordered_map<std::string, int> char_map;
for (size_t i = 0; i < char_list_.size(); i++) {
char_map[char_list_[i]] = i;
}
// For each unigram convert to ints and put in trie
int dict_size = 0;
for (const auto& word : vocabulary_) {
bool added = add_word_to_dictionary(
word, char_map, add_space, SPACE_ID_, &dictionary);
dict_size += added ? 1 : 0;
}
dict_size_ = dict_size;
/* Simplify FST
* This gets rid of "epsilon" transitions in the FST.
* These are transitions that don't require a string input to be taken.
* Getting rid of them is necessary to make the FST determinisitc, but
* can greatly increase the size of the FST
*/
fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
/* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state)
*/
fst::Determinize(dictionary, new_dict);
/* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary
*/
fst::Minimize(new_dict);
this->dictionary = new_dict;
}
#ifndef SCORER_H_
#define SCORER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
#include "lm/word_index.hh"
#include "util/string_piece.hh"
#include "path_trie.h"
const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
// Implement a callback to retrive the dictionary of language model.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public:
RetriveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece &str) {
vocabulary.push_back(std::string(str.data(), str.length()));
}
std::vector<std::string> vocabulary;
};
/* External scorer to query score for n-gram or sentence, including language
* model scoring and word insertion.
*
* Example:
* Scorer scorer(alpha, beta, "path_of_language_model");
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/
class Scorer {
public:
Scorer(double alpha,
double beta,
const std::string &lm_path,
const std::vector<std::string> &vocabulary);
~Scorer();
double get_log_cond_prob(const std::vector<std::string> &words);
double get_sent_log_prob(const std::vector<std::string> &words);
// return the max order
size_t get_max_order() const { return max_order_; }
// return the dictionary size of language model
size_t get_dict_size() const { return dict_size_; }
// retrun true if the language model is character based
bool is_character_based() const { return is_character_based_; }
// reset params alpha & beta
void reset_params(float alpha, float beta);
// make ngram for a given prefix
std::vector<std::string> make_ngram(PathTrie *prefix);
// trransform the labels in index to the vector of words (word based lm) or
// the vector of characters (character based lm)
std::vector<std::string> split_labels(const std::vector<int> &labels);
// language model weight
double alpha;
// word insertion weight
double beta;
// pointer to the dictionary of FST
void *dictionary;
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void setup(const std::string &lm_path,
const std::vector<std::string> &vocab_list);
// load language model from given path
void load_lm(const std::string &lm_path);
// fill dictionary for FST
void fill_dictionary(bool add_space);
// set char map
void set_char_map(const std::vector<std::string> &char_list);
double get_log_prob(const std::vector<std::string> &words);
// translate the vector in index to string
std::string vec2str(const std::vector<int> &input);
private:
void *language_model_;
bool is_character_based_;
size_t max_order_;
size_t dict_size_;
int SPACE_ID_;
std::vector<std::string> char_list_;
std::unordered_map<char, int> char_map_;
std::vector<std::string> vocabulary_;
};
#endif // SCORER_H_
"""Script to build and install decoder package."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from setuptools import setup, Extension, distutils
import glob
import platform
import os, sys
import multiprocessing.pool
import argparse
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num_processes",
default=1,
type=int,
help="Number of cpu processes to build package. (default: %(default)d)")
args = parser.parse_known_args()
# reconstruct sys.argv to pass to setup below
sys.argv = [sys.argv[0]] + args[1]
# monkey-patch for parallel compilation
# See: https://stackoverflow.com/a/13176803
def parallelCCompile(self,
sources,
output_dir=None,
macros=None,
include_dirs=None,
debug=0,
extra_preargs=None,
extra_postargs=None,
depends=None):
# those lines are copied from distutils.ccompiler.CCompiler directly
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
output_dir, macros, include_dirs, sources, depends, extra_postargs)
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
# parallel code
def _single_compile(obj):
try:
src, ext = build[obj]
except KeyError:
return
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
# convert to list, imap is evaluated on-demand
thread_pool = multiprocessing.pool.ThreadPool(args[0].num_processes)
list(thread_pool.imap(_single_compile, objects))
return objects
def compile_test(header, library):
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
command = "bash -c \"g++ -include " + header \
+ " -l" + library + " -x c++ - <<<'int main() {}' -o " \
+ dummy_path + " >/dev/null 2>/dev/null && rm " \
+ dummy_path + " 2>/dev/null\""
return os.system(command) == 0
# hack compile to support parallel compiling
distutils.ccompiler.CCompiler.compile = parallelCCompile
FILES = glob.glob('kenlm/util/*.cc') \
+ glob.glob('kenlm/lm/*.cc') \
+ glob.glob('kenlm/util/double-conversion/*.cc')
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# FILES + glob.glob('glog/src/*.cc')
FILES = [
fn for fn in FILES
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
'unittest.cc'))
]
LIBS = ['stdc++']
if platform.system() != 'Darwin':
LIBS.append('rt')
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11']
if compile_test('zlib.h', 'z'):
ARGS.append('-DHAVE_ZLIB')
LIBS.append('z')
if compile_test('bzlib.h', 'bz2'):
ARGS.append('-DHAVE_BZLIB')
LIBS.append('bz2')
if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma')
os.system('swig -python -c++ ./decoders.i')
decoders_module = [
Extension(
name='_swig_decoders',
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
language='c++',
include_dirs=[
'.',
'kenlm',
'openfst-1.6.3/src/include',
'ThreadPool',
#'glog/src'
],
libraries=LIBS,
extra_compile_args=ARGS)
]
setup(
name='swig_decoders',
version='0.1',
description="""CTC decoders""",
ext_modules=decoders_module,
py_modules=['swig_decoders'], )
#!/usr/bin/env bash
if [ ! -d kenlm ]; then
git clone https://github.com/luotao1/kenlm.git
echo -e "\n"
fi
if [ ! -d openfst-1.6.3 ]; then
echo "Download and extract openfst ..."
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
echo -e "\n"
fi
if [ ! -d ThreadPool ]; then
git clone https://github.com/progschj/ThreadPool.git
echo -e "\n"
fi
echo "Install decoders ..."
python setup.py install --num_processes 4
"""Wrapper for various CTC decoders in SWIG."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import swig_decoders
class Scorer(swig_decoders.Scorer):
"""Wrapper for Scorer.
:param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float
:model_path: Path to load language model.
:type model_path: basestring
"""
def __init__(self, alpha, beta, model_path, vocabulary):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary):
"""Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: basestring
"""
return swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
def ctc_beam_search_decoder(probs_seq,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
"""Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), vocabulary,
beam_size, cutoff_prob,
cutoff_top_n, ext_scoring_func)
def ctc_beam_search_decoder_batch(probs_split,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
"""Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
return swig_decoders.ctc_beam_search_decoder_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func)
......@@ -4,7 +4,7 @@ from __future__ import division
from __future__ import print_function
import unittest
from model_utils import decoder
from decoders import decoders_deprecated as decoder
class TestDecoders(unittest.TestCase):
......@@ -66,16 +66,14 @@ class TestDecoders(unittest.TestCase):
beam_result = decoder.ctc_beam_search_decoder(
probs_seq=self.probs_seq1,
beam_size=self.beam_size,
vocabulary=self.vocab_list,
blank_id=len(self.vocab_list))
vocabulary=self.vocab_list)
self.assertEqual(beam_result[0][1], self.beam_search_result[0])
def test_beam_search_decoder_2(self):
beam_result = decoder.ctc_beam_search_decoder(
probs_seq=self.probs_seq2,
beam_size=self.beam_size,
vocabulary=self.vocab_list,
blank_id=len(self.vocab_list))
vocabulary=self.vocab_list)
self.assertEqual(beam_result[0][1], self.beam_search_result[1])
def test_beam_search_decoder_batch(self):
......@@ -83,7 +81,6 @@ class TestDecoders(unittest.TestCase):
probs_split=[self.probs_seq1, self.probs_seq2],
beam_size=self.beam_size,
vocabulary=self.vocab_list,
blank_id=len(self.vocab_list),
num_processes=24)
self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])
......
......@@ -21,9 +21,10 @@ python -u infer.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -30,9 +30,10 @@ python -u infer.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -22,9 +22,9 @@ python -u test.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -31,9 +31,10 @@ python -u test.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -21,9 +21,9 @@ python -u infer.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -30,9 +30,9 @@ python -u infer.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -22,9 +22,9 @@ python -u test.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -31,9 +31,9 @@ python -u test.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
--beta=0.25 \
--cutoff_prob=0.99 \
--alpha=2.15 \
--beta=0.35 \
--cutoff_prob=1.0 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
......
......@@ -21,9 +21,10 @@ add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 0.36, "Coef of LM for beam search.")
add_arg('beta', float, 0.25, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.")
add_arg('alpha', float, 2.15, "Coef of LM for beam search.")
add_arg('beta', float, 0.35, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
......@@ -84,6 +85,10 @@ def infer():
use_gru=args.use_gru,
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)
# decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
result_transcripts = ds2_model.infer_batch(
infer_data=infer_data,
decoding_method=args.decoding_method,
......@@ -91,7 +96,8 @@ def infer():
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
......@@ -106,6 +112,7 @@ def infer():
print("Current error rate [%s] = %f" %
(args.error_rate_type, error_rate_func(target, result)))
ds2_model.logger.info("finish inference")
def main():
print_arguments(args)
......
......@@ -6,14 +6,18 @@ from __future__ import print_function
import sys
import os
import time
import logging
import gzip
from distutils.dir_util import mkpath
import paddle.v2 as paddle
from model_utils.lm_scorer import LmScorer
from model_utils.decoder import ctc_greedy_decoder, ctc_beam_search_decoder
from model_utils.decoder import ctc_beam_search_decoder_batch
from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_greedy_decoder
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from model_utils.network import deep_speech_v2_network
logging.basicConfig(
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
class DeepSpeech2Model(object):
"""DeepSpeech2Model class.
......@@ -44,6 +48,8 @@ class DeepSpeech2Model(object):
self._inferer = None
self._loss_inferer = None
self._ext_scorer = None
self.logger = logging.getLogger("")
self.logger.setLevel(level=logging.INFO)
def train(self,
train_batch_reader,
......@@ -157,8 +163,8 @@ class DeepSpeech2Model(object):
return self._loss_inferer.infer(input=infer_data)
def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
beam_size, cutoff_prob, vocab_list, language_model_path,
num_processes):
beam_size, cutoff_prob, cutoff_top_n, vocab_list,
language_model_path, num_processes):
"""Model inference. Infer the transcription for a batch of speech
utterances.
......@@ -178,6 +184,10 @@ class DeepSpeech2Model(object):
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param language_model_path: Filepath for language model.
......@@ -209,21 +219,33 @@ class DeepSpeech2Model(object):
elif decoding_method == "ctc_beam_search":
# initialize external scorer
if self._ext_scorer == None:
self._ext_scorer = LmScorer(beam_alpha, beam_beta,
language_model_path)
self._loaded_lm_path = language_model_path
self.logger.info("begin to initialize the external scorer "
"for decoding")
self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path, vocab_list)
lm_char_based = self._ext_scorer.is_character_based()
lm_max_order = self._ext_scorer.get_max_order()
lm_dict_size = self._ext_scorer.get_dict_size()
self.logger.info("language model: "
"is_character_based = %d," % lm_char_based +
" max_order = %d," % lm_max_order +
" dict_size = %d" % lm_dict_size)
self.logger.info("end initializing scorer. Start decoding ...")
else:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
assert self._loaded_lm_path == language_model_path
# beam search decode
num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=beam_size,
blank_id=len(vocab_list),
num_processes=num_processes,
ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob)
cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n)
results = [result[0][1] for result in beam_search_results]
else:
......
......@@ -2,4 +2,3 @@ scipy==0.13.1
resampy==0.1.5
SoundFile==0.9.0.post1
python_speech_features
https://github.com/luotao1/kenlm/archive/master.zip
#! /usr/bin/env bash
#! /usr/bin/env bash
# install python dependencies
if [ -f "requirements.txt" ]; then
......@@ -20,10 +20,19 @@ if [ $? != 0 ]; then
fi
tar -zxvf libsndfile-1.0.28.tar.gz
cd libsndfile-1.0.28
./configure && make && make install
./configure > /dev/null && make > /dev/null && make install > /dev/null
cd ..
rm -rf libsndfile-1.0.28
rm libsndfile-1.0.28.tar.gz
fi
# install decoders
python -c "import swig_decoders"
if [ $? != 0 ]; then
cd decoders/swig > /dev/null
sh setup.sh
cd - > /dev/null
fi
echo "Install all dependencies successfully."
......@@ -22,9 +22,10 @@ add_arg('num_proc_data', int, 12, "# of CPUs for data preprocessing.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 0.36, "Coef of LM for beam search.")
add_arg('beta', float, 0.25, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.")
add_arg('alpha', float, 2.15, "Coef of LM for beam search.")
add_arg('beta', float, 0.35, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
......@@ -85,6 +86,9 @@ def evaluate():
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)
# decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_sum, num_ins = 0.0, 0
for infer_data in batch_reader():
......@@ -95,7 +99,8 @@ def evaluate():
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
target_transcripts = [
......@@ -110,6 +115,7 @@ def evaluate():
print("Final error rate [%s] (%d/%d) = %f" %
(args.error_rate_type, num_ins, num_ins, error_sum / num_ins))
ds2_model.logger.info("finish evaluation")
def main():
print_arguments(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册