提交 d14ee800 编写于 作者: S SmileGoat

add decodable & ctc_beam_search_deocder

上级 e57efcb3
#include "base/basic_types.h"
struct DecoderResult {
BaseFloat acoustic_score;
std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp;
};
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
#include "decoder/ctc_decoders/decoder_utils.h"
namespace ppspeech {
using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts) :
opts_(opts),
vocabulary_(nullptr),
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
root(nullptr) {
LOG(INFO) << "dict path: " << _opts.dict_file;
vocabulary_ = std::make_shared<vector<string>>();
if (!basr::ReadDictToVector(_opts.dict_file, *vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size();
LOG(INFO) << "language model path: " << _opts.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(_opts.alpha,
_opts.beta,
_opts.lm_path,
*vocabulary_);
}
void CTCBeamSearch::InitDecoder() {
blank_id = 0;
auto it = std::find(vocabulary_->begin(), vocabulary_->end(), " ");
space_id = it - vocabulary_->begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary_->size()) {
space_id = -2;
}
clear_prefixes();
root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get());
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
}
void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) {
delete prefixes[i];
prefixes[i] = nullptr;
}
}
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
vector<string>& nbest_words) {
std::thread::id this_id = std::this_thread::get_id();
Timer timer;
vector<vector<double>> double_probs(probs.size(), vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
double_probs[i][j] = static_cast<double>(probs[i][j]);
}
}
timer.Reset();
vector<std::pair<double, string>> results = AdvanceDecoding(double_probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
for (const auto& item : results) {
nbest_words.push_back(item.second);
}
return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
size_t num_time_steps = probs_seq.size();
size_t beam_size = _opts.beam_size;
double cutoff_prob = _opts.cutoff_prob;
size_t cutoff_top_n = _opts.cutoff_top_n;
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
}
vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} // if
} // for probs_seq
// score the last word of each prefix that doesn't end with space
LMRescore();
CalculateApproxScore();
return get_beam_search_result(prefixes, *vocabulary_, beam_size);
}
int CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, float>& log_prob_idx,
const float& min_cutoff) {
size_t beam_size = _opts.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur,
log_prob_c +
prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur,
log_prob_c +
prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur,
log_p);
}
} // end of loop over prefix
return 0;
}
void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = _opts.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (init_ext_scorer_ != nullptr) {
vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = init_ext_scorer_->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight:
approx_ctc -=
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
}
void CTCBeamSearch::LMRescore() {
size_t beam_size = _opts.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha;
score += init_ext_scorer_->beta;
prefix->score += score;
}
}
}
}
} // namespace ppspeech
\ No newline at end of file
#include "base/basic_types.h"
#pragma once
namespace ppspeech {
struct CTCBeamSearchOptions {
std::string dict_file;
std::string lm_path;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions() :
dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(0) {
}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register("beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch {
public:
CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts);
~CTCBeamSearch() {
}
bool InitDecoder();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
std::vector<std::string>& nbest_words);
std::vector<DecodeResult>& GetDecodeResult() {
return decoder_results_;
}
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
std::vector<std::pair<double, std::string>>
AdvanceDecoding(const std::vector<std::vector<double>>& probs_seq);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<DecodeResult> decoder_results_;
std::vector<std::vector<std::string>> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
};
} // namespace basr
\ No newline at end of file
../../../third_party/ctc_decoders
\ No newline at end of file
此差异已折叠。
此差异已折叠。
// decoder/lattice-faster-online-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.cc, about how to maintain this
// file in sync with lattice-faster-decoder.cc
#include "decoder/lattice-faster-online-decoder.h"
#include "lat/lattice-functions.h"
namespace kaldi {
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::TestGetBestPath(
bool use_final_probs) const {
Lattice lat1;
{
Lattice raw_lat;
this->GetRawLattice(&raw_lat, use_final_probs);
ShortestPath(raw_lat, &lat1);
}
Lattice lat2;
GetBestPath(&lat2, use_final_probs);
BaseFloat delta = 0.1;
int32 num_paths = 1;
if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
KALDI_WARN << "Best-path test failed";
return false;
} else {
return true;
}
}
// Outputs an FST corresponding to the single best path through the lattice.
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetBestPath(Lattice *olat,
bool use_final_probs) const {
olat->DeleteStates();
BaseFloat final_graph_cost;
BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
if (iter.Done())
return false; // would have printed warning.
StateId state = olat->AddState();
olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
while (!iter.Done()) {
LatticeArc arc;
iter = TraceBackBestPath(iter, &arc);
arc.nextstate = state;
StateId new_state = olat->AddState();
olat->AddArc(new_state, arc);
state = new_state;
}
olat->SetStart(state);
return true;
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::BestPathEnd(
bool use_final_probs,
BaseFloat *final_cost_out) const {
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "BestPathEnd() with use_final_probs == false";
KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
"You cannot call BestPathEnd if no frames were decoded.");
unordered_map<Token*, BaseFloat> final_costs_local;
const unordered_map<Token*, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ :final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
// Singly linked list of tokens on last frame (access list through "next"
// pointer).
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_final_cost = 0;
Token *best_tok = NULL;
for (Token *tok = this->active_toks_.back().toks;
tok != NULL; tok = tok->next) {
BaseFloat cost = tok->tot_cost, final_cost = 0.0;
if (use_final_probs && !final_costs.empty()) {
// if we are instructed to use final-probs, and any final tokens were
// active on final frame, include the final-prob in the cost of the token.
typename unordered_map<Token*, BaseFloat>::const_iterator
iter = final_costs.find(tok);
if (iter != final_costs.end()) {
final_cost = iter->second;
cost += final_cost;
} else {
cost = std::numeric_limits<BaseFloat>::infinity();
}
}
if (cost < best_cost) {
best_cost = cost;
best_tok = tok;
best_final_cost = final_cost;
}
}
if (best_tok == NULL) { // this should not happen, and is likely a code error or
// caused by infinities in likelihoods, but I'm not making
// it a fatal error for now.
KALDI_WARN << "No final token found.";
}
if (final_cost_out)
*final_cost_out = best_final_cost;
return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::TraceBackBestPath(
BestPathIterator iter, LatticeArc *oarc) const {
KALDI_ASSERT(!iter.Done() && oarc != NULL);
Token *tok = static_cast<Token*>(iter.tok);
int32 cur_t = iter.frame, step_t = 0;
if (tok->backpointer != NULL) {
// retrieve the correct forward link(with the best link cost)
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
ForwardLinkT *link;
for (link = tok->backpointer->links;
link != NULL; link = link->next) {
if (link->next_tok == tok) { // this is a link to "tok"
BaseFloat graph_cost = link->graph_cost,
acoustic_cost = link->acoustic_cost;
BaseFloat cost = graph_cost + acoustic_cost;
if (cost < best_cost) {
oarc->ilabel = link->ilabel;
oarc->olabel = link->olabel;
if (link->ilabel != 0) {
KALDI_ASSERT(static_cast<size_t>(cur_t) < this->cost_offsets_.size());
acoustic_cost -= this->cost_offsets_[cur_t];
step_t = -1;
} else {
step_t = 0;
}
oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
best_cost = cost;
}
}
}
if (link == NULL &&
best_cost == std::numeric_limits<BaseFloat>::infinity()) { // Did not find correct link.
KALDI_ERR << "Error tracing best-path back (likely "
<< "bug in token-pruning algorithm)";
}
} else {
oarc->ilabel = 0;
oarc->olabel = 0;
oarc->weight = LatticeWeight::One(); // zero costs.
}
return BestPathIterator(tok->backpointer, cur_t + step_t);
}
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
Lattice *ofst,
bool use_final_probs,
BaseFloat beam) const {
typedef LatticeArc Arc;
typedef Arc::StateId StateId;
typedef Arc::Weight Weight;
typedef Arc::Label Label;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "GetRawLattice() with use_final_probs == false";
unordered_map<Token*, BaseFloat> final_costs_local;
const unordered_map<Token*, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
ofst->DeleteStates();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32 num_frames = this->active_toks_.size() - 1;
KALDI_ASSERT(num_frames > 0);
for (int32 f = 0; f <= num_frames; f++) {
if (this->active_toks_[f].toks == NULL) {
KALDI_WARN << "No tokens active on frame " << f
<< ": not producing lattice.\n";
return false;
}
}
unordered_map<Token*, StateId> tok_map;
std::queue<std::pair<Token*, int32> > tok_queue;
// First initialize the queue and states. Put the initial state on the queue;
// this is the last token in the list active_toks_[0].toks.
for (Token *tok = this->active_toks_[0].toks;
tok != NULL; tok = tok->next) {
if (tok->next == NULL) {
tok_map[tok] = ofst->AddState();
ofst->SetStart(tok_map[tok]);
std::pair<Token*, int32> tok_pair(tok, 0); // #frame = 0
tok_queue.push(tok_pair);
}
}
// Next create states for "good" tokens
while (!tok_queue.empty()) {
std::pair<Token*, int32> cur_tok_pair = tok_queue.front();
tok_queue.pop();
Token *cur_tok = cur_tok_pair.first;
int32 cur_frame = cur_tok_pair.second;
KALDI_ASSERT(cur_frame >= 0 &&
cur_frame <= this->cost_offsets_.size());
typename unordered_map<Token*, StateId>::const_iterator iter =
tok_map.find(cur_tok);
KALDI_ASSERT(iter != tok_map.end());
StateId cur_state = iter->second;
for (ForwardLinkT *l = cur_tok->links;
l != NULL;
l = l->next) {
Token *next_tok = l->next_tok;
if (next_tok->extra_cost < beam) {
// so both the current and the next token are good; create the arc
int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
StateId nextstate;
if (tok_map.find(next_tok) == tok_map.end()) {
nextstate = tok_map[next_tok] = ofst->AddState();
tok_queue.push(std::pair<Token*, int32>(next_tok, next_frame));
} else {
nextstate = tok_map[next_tok];
}
BaseFloat cost_offset = (l->ilabel != 0 ?
this->cost_offsets_[cur_frame] : 0);
Arc arc(l->ilabel, l->olabel,
Weight(l->graph_cost, l->acoustic_cost - cost_offset),
nextstate);
ofst->AddArc(cur_state, arc);
}
}
if (cur_frame == num_frames) {
if (use_final_probs && !final_costs.empty()) {
typename unordered_map<Token*, BaseFloat>::const_iterator iter =
final_costs.find(cur_tok);
if (iter != final_costs.end())
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
} else {
ofst->SetFinal(cur_state, LatticeWeight::One());
}
}
}
return (ofst->NumStates() != 0);
}
// Instantiate the template for the FST types that we'll need.
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;
} // end namespace kaldi.
// decoder/lattice-faster-online-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.h, about how to maintain this
// file in sync with lattice-faster-decoder.h
#ifndef KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#include "util/stl-utils.h"
#include "util/hash-list.h"
#include "fst/fstlib.h"
#include "itf/decodable-itf.h"
#include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "decoder/lattice-faster-decoder.h"
namespace kaldi {
/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also
supports an efficient way to get the best path (see the function
BestPathEnd()), which is useful in endpointing and in situations where you
might want to frequently access the best path.
This is only templated on the FST type, since the Token type is required to
be BackpointerToken. Actually it only makes sense to instantiate
LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via
this child class.
*/
template <typename FST>
class LatticeFasterOnlineDecoderTpl:
public LatticeFasterDecoderTpl<FST, decoder::BackpointerToken> {
public:
using Arc = typename FST::Arc;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Token = decoder::BackpointerToken;
using ForwardLinkT = decoder::ForwardLink<Token>;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterOnlineDecoderTpl(const FST &fst,
const LatticeFasterDecoderConfig &config):
LatticeFasterDecoderTpl<FST, Token>(fst, config) { }
// This version of the initializer takes ownership of 'fst', and will delete
// it when this object is destroyed.
LatticeFasterOnlineDecoderTpl(const LatticeFasterDecoderConfig &config,
FST *fst):
LatticeFasterDecoderTpl<FST, Token>(config, fst) { }
struct BestPathIterator {
void *tok;
int32 frame;
// note, "frame" is the frame-index of the frame you'll get the
// transition-id for next time, if you call TraceBackBestPath on this
// iterator (assuming it's not an epsilon transition). Note that this
// is one less than you might reasonably expect, e.g. it's -1 for
// the nonemitting transitions before the first frame.
BestPathIterator(void *t, int32 f): tok(t), frame(f) { }
bool Done() const { return tok == NULL; }
};
/// Outputs an FST corresponding to the single best path through the lattice.
/// This is quite efficient because it doesn't get the entire raw lattice and find
/// the best path through it; instead, it uses the BestPathEnd and BestPathIterator
/// so it basically traces it back through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
bool GetBestPath(Lattice *ofst,
bool use_final_probs = true) const;
/// This function does a self-test of GetBestPath(). Returns true on
/// success; returns false and prints a warning on failure.
bool TestGetBestPath(bool use_final_probs = true) const;
/// This function returns an iterator that can be used to trace back
/// the best path. If use_final_probs == true and at least one final state
/// survived till the end, it will use the final-probs in working out the best
/// final Token, and will output the final cost to *final_cost (if non-NULL),
/// else it will use only the forward likelihood, and will put zero in
/// *final_cost (if non-NULL).
/// Requires that NumFramesDecoded() > 0.
BestPathIterator BestPathEnd(bool use_final_probs,
BaseFloat *final_cost = NULL) const;
/// This function can be used in conjunction with BestPathEnd() to trace back
/// the best path one link at a time (e.g. this can be useful in endpoint
/// detection). By "link" we mean a link in the graph; not all links cross
/// frame boundaries, but each time you see a nonzero ilabel you can interpret
/// that as a frame. The return value is the updated iterator. It outputs
/// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer,
/// while leaving its "nextstate" variable unchanged.
BestPathIterator TraceBackBestPath(
BestPathIterator iter, LatticeArc *arc) const;
/// Behaves the same as GetRawLattice but only processes tokens whose
/// extra_cost is smaller than the best-cost plus the specified beam.
/// It is only worthwhile to call this function if beam is less than
/// the lattice_beam specified in the config; otherwise, it would
/// return essentially the same thing as GetRawLattice, but more slowly.
bool GetRawLatticePruned(Lattice *ofst,
bool use_final_probs,
BaseFloat beam) const;
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoderTpl);
};
typedef LatticeFasterOnlineDecoderTpl<fst::StdFst> LatticeFasterOnlineDecoder;
} // end namespace kaldi.
#endif
// lat/determinize-lattice-pruned-test.cc
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/determinize-lattice-pruned.h"
#include "fstext/lattice-utils.h"
#include "fstext/fst-test-utils.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
namespace fst {
// Caution: these tests are not as generic as you might think from all the
// templates in the code. They are basically only valid for LatticeArc.
// This is partly due to the fact that certain templates need to be instantiated
// in other .cc files in this directory.
// test that determinization proceeds correctly on general
// FSTs (not guaranteed determinzable, but we use the
// max-states option to stop it getting out of control).
template<class Arc> void TestDeterminizeLatticePruned() {
typedef kaldi::int32 Int;
typedef typename Arc::Weight Weight;
typedef ArcTpl<CompactLatticeWeightTpl<Weight, Int> > CompactArc;
for(int i = 0; i < 100; i++) {
RandFstOptions opts;
opts.n_states = 4;
opts.n_arcs = 10;
opts.n_final = 2;
opts.allow_empty = false;
opts.weight_multiplier = 0.5; // impt for the randomly generated weights
opts.acyclic = true;
// to be exactly representable in float,
// or this test fails because numerical differences can cause symmetry in
// weights to be broken, which causes the wrong path to be chosen as far
// as the string part is concerned.
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
bool sorted = TopSort(fst);
KALDI_ASSERT(sorted);
ILabelCompare<Arc> ilabel_comp;
if (kaldi::Rand() % 2 == 0)
ArcSort(fst, ilabel_comp);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> det_fst;
try {
DeterminizeLatticePrunedOptions lat_opts;
lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000);
lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20);
lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30);
bool ans = DeterminizeLatticePruned<Weight>(*fst, 10.0, &det_fst, lat_opts);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic);
// OK, now determinize it a different way and check equivalence.
// [note: it's not normal determinization, it's taking the best path
// for any input-symbol sequence....
VectorFst<Arc> pruned_fst(*fst);
if (pruned_fst.NumStates() != 0)
kaldi::PruneLattice(10.0, &pruned_fst);
VectorFst<CompactArc> compact_pruned_fst, compact_pruned_det_fst;
ConvertLattice<Weight, Int>(pruned_fst, &compact_pruned_fst, false);
std::cout << "Compact pruned FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
ConvertLattice<Weight, Int>(det_fst, &compact_pruned_det_fst, false);
std::cout << "Compact version of determinized FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
if (ans)
KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/));
} catch (...) {
std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n";
}
delete fst;
}
}
// test that determinization proceeds without crash on acyclic FSTs
// (guaranteed determinizable in this sense).
template<class Arc> void TestDeterminizeLatticePruned2() {
typedef typename Arc::Weight Weight;
RandFstOptions opts;
opts.acyclic = true;
for(int i = 0; i < 100; i++) {
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> ofst;
DeterminizeLatticePruned<Weight>(*fst, 10.0, &ofst);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(ofst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
delete fst;
}
}
} // end namespace fst
int main() {
using namespace fst;
TestDeterminizeLatticePruned<kaldi::LatticeArc>();
TestDeterminizeLatticePruned2<kaldi::LatticeArc>();
std::cout << "Tests succeeded\n";
}
此差异已折叠。
// lat/determinize-lattice-pruned.h
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#define KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#include <fst/fstlib.h>
#include <fst/fst-decl.h>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "fstext/lattice-weight.h"
#include "itf/transition-information.h"
#include "itf/options-itf.h"
#include "lat/kaldi-lattice.h"
namespace fst {
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-lattice-pruned.cc
/*
DeterminizeLatticePruned implements a special form of determinization with
epsilon removal, optimized for a phase of lattice generation. This algorithm
also does pruning at the same time-- the combination is more efficient as it
somtimes prevents us from creating a lot of states that would later be pruned
away. This allows us to increase the lattice-beam and not have the algorithm
blow up. Also, because our algorithm processes states in order from those
that appear on high-scoring paths down to those that appear on low-scoring
paths, we can easily terminate the algorithm after a certain specified number
of states or arcs.
The input is an FST with weight-type BaseWeightType (usually a pair of floats,
with a lexicographical type of order, such as LatticeWeightTpl<float>).
Typically this would be a state-level lattice, with input symbols equal to
words, and output-symbols equal to p.d.f's (so like the inverse of HCLG). Imagine representing this as an
acceptor of type CompactLatticeWeightTpl<float>, in which the input/output
symbols are words, and the weights contain the original weights together with
strings (with zero or one symbol in them) containing the original output labels
(the p.d.f.'s). We determinize this using acceptor determinization with
epsilon removal. Remember (from lattice-weight.h) that
CompactLatticeWeightTpl has a special kind of semiring where we always take
the string corresponding to the best cost (of type BaseWeightType), and
discard the other. This corresponds to taking the best output-label sequence
(of p.d.f.'s) for each input-label sequence (of words). We couldn't use the
Gallic weight for this, or it would die as soon as it detected that the input
FST was non-functional. In our case, any acyclic FST (and many cyclic ones)
can be determinized.
We assume that there is a function
Compare(const BaseWeightType &a, const BaseWeightType &b)
that returns (-1, 0, 1) according to whether (a < b, a == b, a > b) in the
total order on the BaseWeightType... this information should be the
same as NaturalLess would give, but it's more efficient to do it this way.
You can define this for things like TropicalWeight if you need to instantiate
this class for that weight type.
We implement this determinization in a special way to make it efficient for
the types of FSTs that we will apply it to. One issue is that if we
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
type vector<IntType>, the algorithm takes time quadratic in the length of
words (in states), because propagating each arc involves copying a whole
vector (of integers representing p.d.f.'s). Instead we use a hash structure
where each string is a pointer (Entry*), and uses a hash from (Entry*,
IntType), to the successor string (and a way to get the latest IntType and the
ancestor Entry*). [this is the class LatticeStringRepository].
Another issue is that rather than representing a determinized-state as a
collection of (state, weight), we represent it in a couple of reduced forms.
Suppose a determinized-state is a collection of (state, weight) pairs; call
this the "canonical representation". Note: these collections are always
normalized to remove any common weight and string part. Define end-states as
the subset of states that have an arc out of them with a label on, or are
final. If we represent a determinized-state a the set of just its (end-state,
weight) pairs, this will be a valid and more compact representation, and will
lead to a smaller set of determinized states (like early minimization). Call
this collection of (end-state, weight) pairs the "minimal representation". As
a mechanism to reduce compute, we can also consider another representation.
In the determinization algorithm, we start off with a set of (begin-state,
weight) pairs (where the "begin-states" are initial or have a label on the
transition into them), and the "canonical representation" consists of the
epsilon-closure of this set (i.e. follow epsilons). Call this set of
(begin-state, weight) pairs, appropriately normalized, the "initial
representation". If two initial representations are the same, the "canonical
representation" and hence the "minimal representation" will be the same. We
can use this to reduce compute. Note that if two initial representations are
different, this does not preclude the other representations from being the same.
*/
struct DeterminizeLatticePrunedOptions {
float delta; // A small offset used to measure equality of weights.
int max_mem; // If >0, determinization will fail and return false
// when the algorithm's (approximate) memory consumption crosses this threshold.
int max_loop; // If >0, can be used to detect non-determinizable input
// (a case that wouldn't be caught by max_mem).
int max_states;
int max_arcs;
float retry_cutoff;
DeterminizeLatticePrunedOptions(): delta(kDelta),
max_mem(-1),
max_loop(-1),
max_states(-1),
max_arcs(-1),
retry_cutoff(0.5) { }
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this)");
opts->Register("max-arcs", &max_arcs, "Maximum number of arcs in "
"output FST (total, not per state");
opts->Register("max-states", &max_states, "Maximum number of arcs in output "
"FST (total, not per state");
opts->Register("max-loop", &max_loop, "Option used to detect a particular "
"type of determinization failure, typically due to invalid input "
"(e.g., negative-cost loops)");
opts->Register("retry-cutoff", &retry_cutoff, "Controls pruning un-determinized "
"lattice and retrying determinization: if effective-beam < "
"retry-cutoff * beam, we prune the raw lattice and retry. Avoids "
"ever getting empty output for long segments.");
}
};
struct DeterminizeLatticePhonePrunedOptions {
// delta: a small offset used to measure equality of weights.
float delta;
// max_mem: if > 0, determinization will fail and return false when the
// algorithm's (approximate) memory consumption crosses this threshold.
int max_mem;
// phone_determinize: if true, do a first pass determinization on both phones
// and words.
bool phone_determinize;
// word_determinize: if true, do a second pass determinization on words only.
bool word_determinize;
// minimize: if true, push and minimize after determinization.
bool minimize;
DeterminizeLatticePhonePrunedOptions(): delta(kDelta),
max_mem(50000000),
phone_determinize(true),
word_determinize(true),
minimize(false) {}
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this).");
opts->Register("phone-determinize", &phone_determinize, "If true, do an "
"initial pass of determinization on both phones and words (see"
" also --word-determinize)");
opts->Register("word-determinize", &word_determinize, "If true, do a second "
"pass of determinization on words only (see also "
"--phone-determinize)");
opts->Register("minimize", &minimize, "If true, push and minimize after "
"determinization.");
}
};
/**
This function implements the normal version of DeterminizeLattice, in which the
output strings are represented using sequences of arcs, where all but the
first one has an epsilon on the input side. It also prunes using the beam
in the "prune" parameter. The input FST must be topologically sorted in order
for the algorithm to work. For efficiency it is recommended to sort ilabel as well.
Returns true on success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: you may want to use the version below which outputs to CompactLattice.
*/
template<class Weight>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune,
MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/* This is a version of DeterminizeLattice with a slightly more "natural" output format,
where the output sequences are encoded using the CompactLatticeArcTpl template
(i.e. the sequences of output symbols are represented directly as strings The input
FST must be topologically sorted in order for the algorithm to work. For efficiency
it is recommended to sort the ilabel for the input FST as well.
Returns true on normal success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: if Lattice is the input, you need to Invert() before calling this,
so words are on the input side.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> >&ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/** This function takes in lattices and inserts phones at phone boundaries. It
uses the transition model to work out the transition_id to phone map. The
returning value is the starting index of the phone label. Typically we pick
(maximum_output_label_index + 1) as this value. The inserted phones are then
mapped to (returning_value + original_phone_label) in the new lattice. The
returning value will be used by DeterminizeLatticeDeletePhones() where it
works out the phones according to this value.
*/
template<class Weight>
typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *fst);
/** This function takes in lattices and deletes "phones" from them. The "phones"
here are actually any label that is larger than first_phone_label because
when we insert phones into the lattice, we map the original phone label to
(first_phone_label + original_phone_label). It is supposed to be used
together with DeterminizeLatticeInsertPhones()
*/
template<class Weight>
void DeterminizeLatticeDeletePhones(
typename ArcTpl<Weight>::Label first_phone_label,
MutableFst<ArcTpl<Weight> > *fst);
/** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and
DeterminizeLatticePruned(). If --phone-determinize is set to true, it first
calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of
determinization on the phone + word lattices. If --word-determinize is set
true, it then does a second pass of determinization on the word lattices by
calling DeterminizeLatticePruned(). If both are set to false, then it gives
a warning and copying the lattices without determinization.
Note: the point of doing first a phone-level determinization pass and then
a word-level determinization pass is that it allows us to determinize
deeper lattices without "failing early" and returning a too-small lattice
due to the max-mem constraint. The result should be the same as word-level
determinization in general, but for deeper lattices it is a bit faster,
despite the fact that we now have two passes of determinization by default.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/** "Destructive" version of DeterminizeLatticePhonePruned() where the input
lattice might be changed.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/** This function is a wrapper of DeterminizeLatticePhonePruned() that works for
Lattice type FSTs. It simplifies the calling process by calling
TopSort() Invert() and ArcSort() for you.
Unlike other determinization routines, the function
requires "ifst" to have transition-id's on the input side and words on the
output side.
This function can be used as the top-level interface to all the determinization
code.
*/
bool DeterminizeLatticePhonePrunedWrapper(
const kaldi::TransitionInformation &trans_model,
MutableFst<kaldi::LatticeArc> *ifst,
double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/// @} end "addtogroup fst_extensions"
} // end namespace fst
#endif
// lat/kaldi-lattice.cc
// Copyright 2009-2011 Microsoft Corporation
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/kaldi-lattice.h"
#include "fst/script/print-impl.h"
namespace kaldi {
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
CompactLattice* ConvertToCompactLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
CompactLattice *ofst = new CompactLattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
CompactLattice* ConvertToCompactLattice(CompactLattice *ifst) {
return ifst;
}
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
Lattice* ConvertToLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
Lattice *ofst = new Lattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
Lattice* ConvertToLattice(Lattice *ifst) {
return ifst;
}
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = true, write_one = false;
fst::FstPrinter<CompactLatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
/// LatticeReader provides (static) functions for reading both Lattice
/// and CompactLattice, in text form.
class LatticeReader {
typedef LatticeArc Arc;
typedef LatticeWeight Weight;
typedef CompactLatticeArc CArc;
typedef CompactLatticeWeight CWeight;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
public:
// everything is static in this class.
/** This function reads from the FST text format; it does not know in advance
whether it's a Lattice or CompactLattice in the stream so it tries to
read both formats until it becomes clear which is the correct one.
*/
static std::pair<Lattice*, CompactLattice*> ReadText(
std::istream &is) {
typedef std::pair<Lattice*, CompactLattice*> PairT;
using std::string;
using std::vector;
Lattice *fst = new Lattice();
CompactLattice *cfst = new CompactLattice();
string line;
size_t nline = 0;
string separator = FLAGS_fst_field_separator + "\r\n";
while (std::getline(is, line)) {
nline++;
vector<string> col;
// on Windows we'll write in text and read in binary mode.
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.size() == 0) break; // Empty line is a signal to stop, in our
// archive format.
if (col.size() > 5) {
KALDI_WARN << "Reading lattice: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
StateId s;
if (!ConvertStringToInteger(col[0], &s)) {
KALDI_WARN << "FstCompiler: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
if (fst)
while (s >= fst->NumStates())
fst->AddState();
if (cfst)
while (s >= cfst->NumStates())
cfst->AddState();
if (nline == 1) {
if (fst) fst->SetStart(s);
if (cfst) cfst->SetStart(s);
}
if (fst) { // we still have fst; try to read that arc.
bool ok = true;
Arc arc;
Weight w;
StateId d = s;
switch (col.size()) {
case 1 :
fst->SetFinal(s, Weight::One());
break;
case 2:
if (!StrToWeight(col[1], true, &w)) ok = false;
else fst->SetFinal(s, w);
break;
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
ok = false;
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel);
if (ok) {
d = arc.nextstate;
arc.weight = Weight::One();
fst->AddArc(s, arc);
}
break;
case 5:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel) &&
StrToWeight(col[4], false, &arc.weight);
if (ok) {
d = arc.nextstate;
fst->AddArc(s, arc);
}
break;
default:
ok = false;
}
while (d >= fst->NumStates())
fst->AddState();
if (!ok) {
delete fst;
fst = NULL;
}
}
if (cfst) {
bool ok = true;
CArc arc;
CWeight w;
StateId d = s;
switch (col.size()) {
case 1 :
cfst->SetFinal(s, CWeight::One());
break;
case 2:
if (!StrToCWeight(col[1], true, &w)) ok = false;
else cfst->SetFinal(s, w);
break;
case 3: // compact-lattice is acceptor format: state, next-state, label.
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
arc.weight = CWeight::One();
cfst->AddArc(s, arc);
}
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
StrToCWeight(col[3], false, &arc.weight);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
cfst->AddArc(s, arc);
}
break;
case 5: default:
ok = false;
}
while (d >= cfst->NumStates())
cfst->AddState();
if (!ok) {
delete cfst;
cfst = NULL;
}
}
if (!fst && !cfst) {
KALDI_WARN << "Bad line in lattice text format: " << line;
// read until we get an empty line, so at least we
// have a chance to read the next one (although this might
// be a bit futile since the calling code will get unhappy
// about failing to read this one.
while (std::getline(is, line)) {
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.empty()) break;
}
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
}
return PairT(fst, cfst);
}
static bool StrToWeight(const std::string &s, bool allow_zero, Weight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == Weight::Zero())) {
return false;
}
return true;
}
static bool StrToCWeight(const std::string &s, bool allow_zero, CWeight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == CWeight::Zero())) {
return false;
}
return true;
}
};
CompactLattice *ReadCompactLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.second != NULL) {
delete lat_pair.first;
return lat_pair.second;
} else if (lat_pair.first != NULL) {
// note: ConvertToCompactLattice frees its input.
return ConvertToCompactLattice(lat_pair.first);
} else {
return NULL;
}
}
Lattice *ReadLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.first != NULL) {
delete lat_pair.second;
return lat_pair.first;
} else if (lat_pair.second != NULL) {
// note: ConvertToLattice frees its input.
return ConvertToLattice(lat_pair.second);
} else {
return NULL;
}
}
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat) {
KALDI_ASSERT(*clat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading compact lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading compact lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
CompactLattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToCompactLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToCompactLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToCompactLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToCompactLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to CompactLattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading compact lattice (after reading header).";
return false;
}
*clat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*clat = ReadCompactLatticeText(is); // that routine will warn on error.
return (*clat != NULL);
}
}
bool CompactLatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading CompactLattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadCompactLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadCompactLattice(is, true, &t_);
}
}
bool WriteLattice(std::ostream &os, bool binary, const Lattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to do if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = false, write_one = false;
fst::FstPrinter<LatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat) {
KALDI_ASSERT(*lat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
Lattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to Lattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading lattice (after reading header).";
return false;
}
*lat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*lat = ReadLatticeText(is); // that routine will warn on error.
return (*lat != NULL);
}
}
/* Since we don't write the binary headers for this type of holder,
we use a different method to work out whether we're in binary mode.
*/
bool LatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading Lattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadLattice(is, true, &t_);
}
}
} // end namespace kaldi
// lat/kaldi-lattice.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_KALDI_LATTICE_H_
#define KALDI_LAT_KALDI_LATTICE_H_
#include "fstext/fstext-lib.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
namespace kaldi {
// will import some things above...
typedef fst::LatticeWeightTpl<BaseFloat> LatticeWeight;
// careful: kaldi::int32 is not always the same C type as fst::int32
typedef fst::CompactLatticeWeightTpl<LatticeWeight, int32> CompactLatticeWeight;
typedef fst::CompactLatticeWeightCommonDivisorTpl<LatticeWeight, int32>
CompactLatticeWeightCommonDivisor;
typedef fst::ArcTpl<LatticeWeight> LatticeArc;
typedef fst::ArcTpl<CompactLatticeWeight> CompactLatticeArc;
typedef fst::VectorFst<LatticeArc> Lattice;
typedef fst::VectorFst<CompactLatticeArc> CompactLattice;
// The following functions for writing and reading lattices in binary or text
// form are provided here in case you need to include lattices in larger,
// Kaldi-type objects with their own Read and Write functions. Caution: these
// functions return false on stream failure rather than throwing an exception as
// most similar Kaldi functions would do.
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &clat);
bool WriteLattice(std::ostream &os, bool binary,
const Lattice &lat);
// the following function requires that *clat be
// NULL when called.
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat);
// the following function requires that *lat be
// NULL when called.
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat);
class CompactLatticeHolder {
public:
typedef CompactLattice T;
CompactLatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteCompactLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty CompactLatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(CompactLatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const CompactLatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~CompactLatticeHolder() { Clear(); }
private:
T *t_;
};
class LatticeHolder {
public:
typedef Lattice T;
LatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty LatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(LatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const LatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~LatticeHolder() { Clear(); }
private:
T *t_;
};
typedef TableWriter<LatticeHolder> LatticeWriter;
typedef SequentialTableReader<LatticeHolder> SequentialLatticeReader;
typedef RandomAccessTableReader<LatticeHolder> RandomAccessLatticeReader;
typedef TableWriter<CompactLatticeHolder> CompactLatticeWriter;
typedef SequentialTableReader<CompactLatticeHolder> SequentialCompactLatticeReader;
typedef RandomAccessTableReader<CompactLatticeHolder> RandomAccessCompactLatticeReader;
} // namespace kaldi
#endif // KALDI_LAT_KALDI_LATTICE_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册