未验证 提交 b584b969 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1400 from SmileGoat/feature_dev

[speechx]add linear spectrogram feature extractor
......@@ -35,3 +35,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks
* [librosa](https://github.com/librosa/librosa/blob/main/LICENSE.md)
- ISC License
- Audio feature
* [ThreadPool](https://github.com/progschj/ThreadPool/blob/master/COPYING)
- zlib License
- ThreadPool
......@@ -39,15 +39,40 @@ FetchContent_Declare(
GIT_TAG "20210324.1"
)
FetchContent_MakeAvailable(absl)
include_directories(${absl_SOURCE_DIR})
# libsndfile
#include(FetchContent)
#FetchContent_Declare(
# libsndfile
# GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
# GIT_TAG "1.0.31"
#)
#FetchContent_MakeAvailable(libsndfile)
# todo boost build
#include(FetchContent)
#FetchContent_Declare(
# Boost
# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.zip
# URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
#)
#FetchContent_MakeAvailable(Boost)
#include_directories(${Boost_SOURCE_DIR})
set(BOOST_ROOT ${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
include_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
link_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0/stage/lib)
include(FetchContent)
FetchContent_Declare(
libsndfile
GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
GIT_TAG "1.0.31"
kenlm
GIT_REPOSITORY "https://github.com/kpu/kenlm.git"
GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a"
)
FetchContent_MakeAvailable(libsndfile)
FetchContent_MakeAvailable(kenlm)
add_dependencies(kenlm Boost)
include_directories(${kenlm_SOURCE_DIR})
# gflags
FetchContent_Declare(
......@@ -65,7 +90,7 @@ FetchContent_Declare(
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
)
FetchContent_MakeAvailable(glog)
include_directories(${glog_BINARY_DIR})
include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
# gtest
FetchContent_Declare(googletest
......@@ -93,6 +118,22 @@ add_dependencies(openfst gflags glog)
link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include)
set(PADDLE_LIB ${fc_patch}/paddle-lib/paddle_inference)
include_directories("${PADDLE_LIB}/paddle/include")
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include")
#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib")
#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
link_directories("${PADDLE_LIB}/paddle/lib")
add_subdirectory(speechx)
#openblas
......@@ -121,4 +162,4 @@ add_subdirectory(speechx)
# if dir do not have CmakeLists.txt
#add_library(lib_name STATIC file.cc)
#target_link_libraries(lib_name item0 item1)
#add_dependencies(lib_name depend-target)
\ No newline at end of file
#add_dependencies(lib_name depend-target)
......@@ -4,11 +4,43 @@ project(speechx LANGUAGES CXX)
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/kaldi
)
add_subdirectory(kaldi)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/utils
)
add_subdirectory(utils)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/frontend
)
add_subdirectory(frontend)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/nnet
)
add_subdirectory(nnet)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
#add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
#target_link_libraries(offline_decoder_main nnet decoder gflags glog)
......@@ -16,7 +16,7 @@
#include "kaldi/base/kaldi-types.h"
#include <limits.h>
#include <limits>
typedef float BaseFloat;
typedef double double64;
......@@ -35,7 +35,7 @@ typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef unsigned long uint64;
#else
typedef unsigned long long uint64;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <deque>
#include <iostream>
#include <istream>
#include <fstream>
#include <map>
#include <memory>
#include <ostream>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <mutex>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h"
#include "base/macros.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "gflags/gflags.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "glog/logging.h"
......@@ -16,8 +16,10 @@
namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
#endif
} // namespace pp_speech
\ No newline at end of file
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
// This software is provided 'as-is', without any express or implied
// warranty. In no event will the authors be held liable for any damages
// arising from the use of this software.
// Permission is granted to anyone to use this software for any purpose,
// including commercial applications, and to alter it and redistribute it
// freely, subject to the following restrictions:
// 1. The origin of this software must not be misrepresented; you must not
// claim that you wrote the original software. If you use this software
// in a product, an acknowledgment in the product documentation would be
// appreciated but is not required.
// 2. Altered source versions must be plainly marked as such, and must not be
// misrepresented as being the original software.
// 3. This notice may not be removed or altered from any source
// distribution.
// this code is from https://github.com/progschj/ThreadPool
#ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
}
#endif
// todo refactor, repalce with gtest
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
int32 chunk_size,
std::vector<kaldi::Matrix<BaseFloat>> feature_chunks) {
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
CTCBeamSearchOptions opts;
CTCBeamSearch decoder(opts);
ModelOptions model_opts;
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(model_opts));
Decodable decodable();
decodable.SetNnet(nnet);
int32 chunk_size = 0;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
vector<Matrix<BaseFloat>> feature_chunks;
SplitFeature(feature, chunk_size, &feature_chunks);
for (auto feature_chunk : feature_chunks) {
decodable.FeedFeatures(feature_chunk);
decoder.InitDecoder();
decoder.AdvanceDecode(decodable, chunk_size);
}
decodable.InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
\ No newline at end of file
// todo refactor, repalce with gtest
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "kaldi/feat/wave-reader.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData &wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(waveform);
linear_spectrogram.ReadFeats(&features);
feat_writer.Write(utt, features);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
\ No newline at end of file
aux_source_directory(. DIR_LIB_SRCS)
add_library(decoder STATIC ${DIR_LIB_SRCS})
project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder
ctc_beam_search_decoder.cc
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
)
target_link_libraries(decoder kenlm)
\ No newline at end of file
#include "base/basic_types.h"
struct DecoderResult {
BaseFloat acoustic_score;
std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp;
};
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
#include "decoder/ctc_decoders/decoder_utils.h"
#include "utils/file_utils.h"
namespace ppspeech {
using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) :
opts_(opts),
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
num_frame_decoded_(0),
root(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
opts_.beta,
opts_.lm_path,
vocabulary_);
}
void CTCBeamSearch::Reset() {
num_frame_decoded_ = 0;
ResetPrefixes();
}
void CTCBeamSearch::InitDecoder() {
blank_id = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id = it - vocabulary_.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary_.size()) {
space_id = -2;
}
ResetPrefixes();
root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get());
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
}
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
}
int32 CTCBeamSearch::NumFrameDecoded() {
return num_frame_decoded_;
}
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break;
}
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(likelihood);
max_frames--;
}
}
void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) {
delete prefixes[i];
prefixes[i] = nullptr;
}
}
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
vector<string>& nbest_words) {
kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
}
string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return result[0].second;
}
string CTCBeamSearch::GetFinalBestPath() {
CalculateApproxScore();
LMRescore();
return GetBestPath();
}
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
probs_seq[i][j] = static_cast<double>(probs[i][j]);
}
}
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
}
vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} // if
num_frame_decoded_++;
} // for probs_seq
}
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur,
log_prob_c +
prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur,
log_prob_c +
prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur,
log_p);
}
} // end of loop over prefix
return 0;
}
void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (init_ext_scorer_ != nullptr) {
vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = init_ext_scorer_->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight:
approx_ctc -=
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
}
void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha;
score += init_ext_scorer_->beta;
prefix->score += score;
}
}
}
}
} // namespace ppspeech
\ No newline at end of file
#include "base/common.h"
#include "nnet/decodable-itf.h"
#include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once
namespace ppspeech {
struct CTCBeamSearchOptions {
std::string dict_file;
std::string lm_path;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions() :
dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(0) {
}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register("beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch {
public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {}
void InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void Reset();
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
//std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
};
} // namespace basr
\ No newline at end of file
../../../third_party/ctc_decoders
\ No newline at end of file
project(frontend)
add_library(frontend
normalizer.cc
linear_spectrogram.cc
)
target_link_libraries(frontend kaldi-matrix)
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// wrap the fbank feat of kaldi, todo (SmileGoat)
#include "kaldi/feat/feature-mfcc.h"
#incldue "kaldi/matrix/kaldi-vector.h"
namespace ppspeech {
class FbankExtractor : FeatureExtractorInterface {
public:
explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor);
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0;
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
kaldi::Vector<kaldi::BaseFloat>* feat) const;
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/basic_types.h"
#include "kaldi/matrix/kaldi-vector.h"
namespace ppspeech {
class FeatureExtractorInterface {
public:
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0;
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/linear_spectrogram.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/matrix/matrix-functions.h"
namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat;
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector;
//todo remove later
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
}
void CopyStdVector2Vector_(const vector<BaseFloat>& input,
Vector<BaseFloat>* output) {
if (input.empty()) return;
output->Resize(input.size());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
}
LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift();
fft_points_ = window_size;
hanning_window_.resize(window_size);
double a = M_2PI / (window_size - 1);
hanning_window_energy_ = 0;
for (int i = 0; i < window_size; ++i) {
hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
}
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
}
void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) {
base_extractor_->AcceptWaveform(input);
}
void LinearSpectrogram::Hanning(vector<float>* data) const {
CHECK_GE(data->size(), hanning_window_.size());
for (size_t i = 0; i < hanning_window_.size(); ++i) {
data->at(i) *= hanning_window_[i];
}
}
bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
vector<BaseFloat>* real,
vector<BaseFloat>* img) const {
Vector<BaseFloat> v_tmp;
CopyStdVector2Vector_(*v, &v_tmp);
RealFft(&v_tmp, true);
CopyVector2StdVector_(v_tmp, v);
real->push_back(v->at(0));
img->push_back(0);
for (int i = 1; i < v->size() / 2; i++) {
real->push_back(v->at(2 * i));
img->push_back(v->at(2 * i + 1));
}
real->push_back(v->at(1));
img->push_back(0);
return true;
}
// todo remove later
void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
Vector<BaseFloat> tmp;
waveform_.Resize(base_extractor_->Dim());
Compute(tmp, &waveform_);
vector<vector<BaseFloat>> result;
vector<BaseFloat> feats_vec;
CopyVector2StdVector_(waveform_, &feats_vec);
Compute(feats_vec, result);
feats->Resize(result.size(), result[0].size());
for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
for (int col_idx = 0; col_idx < result.size(); ++col_idx) {
(*feats)(row_idx, col_idx) = result[row_idx][col_idx];
}
}
waveform_.Resize(0);
}
void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) {
// todo
return;
}
// only for test, remove later
// todo: compute the feature frame by frame.
void LinearSpectrogram::Compute(const VectorBase<kaldi::BaseFloat>& input,
VectorBase<kaldi::BaseFloat>* feature) {
base_extractor_->Read(feature);
}
// Compute spectrogram feat, only for test, remove later
// todo: refactor later (SmileGoat)
bool LinearSpectrogram::Compute(const vector<float>& wave,
vector<vector<float>>& feat) {
int num_samples = wave.size();
const int& frame_length = opts_.frame_opts.WindowSize();
const int& sample_rate = opts_.frame_opts.samp_freq;
const int& frame_shift = opts_.frame_opts.WindowShift();
const int& fft_points = fft_points_;
const float scale = hanning_window_energy_ * frame_shift;
if (num_samples < frame_length) {
return true;
}
int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
feat.resize(num_frames);
vector<float> fft_real((fft_points_ / 2 + 1), 0);
vector<float> fft_img((fft_points_ / 2 + 1), 0);
vector<float> v(frame_length, 0);
vector<float> power((fft_points / 2 + 1));
for (int i = 0; i < num_frames; ++i) {
vector<float> data(wave.data() + i * frame_shift,
wave.data() + i * frame_shift + frame_length);
Hanning(&data);
fft_img.clear();
fft_real.clear();
v.assign(data.begin(), data.end());
if (NumpyFft(&v, &fft_real, &fft_img)) {
LOG(ERROR)<< i << " fft compute occurs error, please checkout the input data";
return false;
}
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
for (int j = 0; j < (fft_points / 2 + 1); ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
feat[i][j] = power[j];
if (j == 0 || j == feat[0].size() - 1) {
feat[i][j] /= scale;
} else {
feat[i][j] *= (2.0 / scale);
}
// log added eps=1e-14
feat[i][j] = std::log(feat[i][j] + 1e-14);
}
}
return true;
}
} // namespace ppspeech
\ No newline at end of file
#pragma once
#include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h"
#include "base/common.h"
namespace ppspeech {
struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts;
LinearSpectrogramOptions():
frame_opts() {}
void Register(kaldi::OptionsItf* opts) {
frame_opts.Register(opts);
}
};
class LinearSpectrogram : public FeatureExtractorInterface {
public:
explicit LinearSpectrogram(const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; }
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);
private:
void Hanning(std::vector<kaldi::BaseFloat>* data) const;
bool Compute(const std::vector<kaldi::BaseFloat>& wave,
std::vector<std::vector<kaldi::BaseFloat>>& feat);
void Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feature);
bool NumpyFft(std::vector<kaldi::BaseFloat>* v,
std::vector<kaldi::BaseFloat>* real,
std::vector<kaldi::BaseFloat>* img) const;
kaldi::int32 fft_points_;
size_t dim_;
std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat)
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// wrap the mfcc feat of kaldi, todo (SmileGoat)
#include "kaldi/feat/feature-mfcc.h"
\ No newline at end of file
#include "frontend/normalizer.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
opts_ = opts;
dim_ = 0;
}
void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase<BaseFloat>& input) {
dim_ = input.Dim();
waveform_.Resize(input.Dim());
waveform_.CopyFromVec(input);
}
void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
if (waveform_.Dim() == 0) return;
Compute(waveform_, feat);
}
//todo remove later
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
}
void CopyStdVector2Vector(const vector<BaseFloat>& input,
VectorBase<BaseFloat>* output) {
if (input.empty()) return;
assert(input.size() == output->Dim());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
}
bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const {
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples;
samples.resize(input.Dim());
for (int32 i = 0; i < samples.size(); ++i) {
samples[i] = input(i);
}
// square
for (auto &d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
}
mean_square += d * d;
}
// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
// Note that this is an in-place transformation.
for (auto &item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
CopyStdVector2Vector(samples, feat);
return true;
}
/*
PPNormalizer::PPNormalizer(
const PPNormalizerOptions& opts,
const std::unique_ptr<FeatureExtractorInterface>& pre_extractor) {
}
void PPNormalizer::AcceptWavefrom(const Vector<BaseFloat>& input) {
}
void PPNormalizer::Read(Vector<BaseFloat>* feat) {
}
bool PPNormalizer::Compute(const Vector<BaseFloat>& input,
Vector<BaseFloat>>* feat) {
if ((input.Dim() % mean_.Dim()) == 0) {
LOG(ERROR) << "CMVN dimension is wrong!";
return false;
}
try {
int32 size = mean_.Dim();
feat->Resize(input.Dim());
for (int32 row_idx = 0; row_idx < j; ++row_idx) {
int32 base_idx = row_idx * size;
for (int32 idx = 0; idx < mean_.Dim(); ++idx) {
(*feat)(base_idx + idx) = (input(base_dix + idx) - mean_(idx))* variance_(idx);
}
}
} catch(const std::exception& e) {
std::cerr << e.what() << '\n';
return false;
}
return true;
}*/
} // namespace ppspeech
\ No newline at end of file
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech {
struct DecibelNormalizerOptions {
float target_db;
float max_gain_db;
bool convert_int_float;
DecibelNormalizerOptions() :
target_db(-20),
max_gain_db(300.0),
convert_int_float(false) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("target-db", &target_db, "target db for db normalization");
opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float");
}
};
class DecibelNormalizer : public FeatureExtractorInterface {
public:
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return 0; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
private:
DecibelNormalizerOptions opts_;
size_t dim_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> waveform_;
};
/*
struct NormalizerOptions {
std::string mean_std_path;
NormalizerOptions() :
mean_std_path("") {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("mean-std", &mean_std_path, "mean std file");
}
};
// todo refactor later (SmileGoat)
class PPNormalizer : public FeatureExtractorInterface {
public:
explicit PPNormalizer(const NormalizerOptions& opts,
const std::unique_ptr<FeatureExtractorInterface>& pre_extractor);
~PPNormalizer() {}
virtual void AcceptWavefrom(const kaldi::Vector<kaldi::BaseFloat>& input);
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat);
virtual size_t Dim() const;
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& input,
kaldi::Vector<kaldi::BaseFloat>>& feat);
private:
bool _initialized;
kaldi::Vector<float> mean_;
kaldi::Vector<float> variance_;
NormalizerOptions _opts;
};
*/
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// extract the window of kaldi feat.
// decoder/lattice-faster-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2018 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "decoder/lattice-faster-decoder.h"
#include "lat/lattice-functions.h"
namespace kaldi {
// instantiate this class once for each thing you have to decode.
template <typename FST, typename Token>
LatticeFasterDecoderTpl<FST, Token>::LatticeFasterDecoderTpl(
const FST &fst, const LatticeFasterDecoderConfig &config)
: fst_(&fst),
delete_fst_(false),
config_(config),
num_toks_(0),
token_pool_(config.memory_pool_tokens_block_size),
forward_link_pool_(config.memory_pool_links_block_size) {
config.Check();
toks_.SetSize(1000); // just so on the first frame we do something reasonable.
}
template <typename FST, typename Token>
LatticeFasterDecoderTpl<FST, Token>::LatticeFasterDecoderTpl(
const LatticeFasterDecoderConfig &config, FST *fst)
: fst_(fst),
delete_fst_(true),
config_(config),
num_toks_(0),
token_pool_(config.memory_pool_tokens_block_size),
forward_link_pool_(config.memory_pool_links_block_size) {
config.Check();
toks_.SetSize(1000); // just so on the first frame we do something reasonable.
}
template <typename FST, typename Token>
LatticeFasterDecoderTpl<FST, Token>::~LatticeFasterDecoderTpl() {
DeleteElems(toks_.Clear());
ClearActiveTokens();
if (delete_fst_) delete fst_;
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::InitDecoding() {
// clean up from last time:
DeleteElems(toks_.Clear());
cost_offsets_.clear();
ClearActiveTokens();
warned_ = false;
num_toks_ = 0;
decoding_finalized_ = false;
final_costs_.clear();
StateId start_state = fst_->Start();
KALDI_ASSERT(start_state != fst::kNoStateId);
active_toks_.resize(1);
Token *start_tok =
new (token_pool_.Allocate()) Token(0.0, 0.0, NULL, NULL, NULL);
active_toks_[0].toks = start_tok;
toks_.Insert(start_state, start_tok);
num_toks_++;
ProcessNonemitting(config_.beam);
}
// Returns true if any kind of traceback is available (not necessarily from
// a final state). It should only very rarely return false; this indicates
// an unusual search error.
template <typename FST, typename Token>
bool LatticeFasterDecoderTpl<FST, Token>::Decode(DecodableInterface *decodable) {
InitDecoding();
// We use 1-based indexing for frames in this decoder (if you view it in
// terms of features), but note that the decodable object uses zero-based
// numbering, which we have to correct for when we call it.
AdvanceDecoding(decodable);
FinalizeDecoding();
// Returns true if we have any kind of traceback available (not necessarily
// to the end state; query ReachedFinal() for that).
return !active_toks_.empty() && active_toks_.back().toks != NULL;
}
// Outputs an FST corresponding to the single best path through the lattice.
template <typename FST, typename Token>
bool LatticeFasterDecoderTpl<FST, Token>::GetBestPath(Lattice *olat,
bool use_final_probs) const {
Lattice raw_lat;
GetRawLattice(&raw_lat, use_final_probs);
ShortestPath(raw_lat, olat);
return (olat->NumStates() != 0);
}
// Outputs an FST corresponding to the raw, state-level lattice
template <typename FST, typename Token>
bool LatticeFasterDecoderTpl<FST, Token>::GetRawLattice(
Lattice *ofst,
bool use_final_probs) const {
typedef LatticeArc Arc;
typedef Arc::StateId StateId;
typedef Arc::Weight Weight;
typedef Arc::Label Label;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if (decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "GetRawLattice() with use_final_probs == false";
unordered_map<Token*, BaseFloat> final_costs_local;
const unordered_map<Token*, BaseFloat> &final_costs =
(decoding_finalized_ ? final_costs_ : final_costs_local);
if (!decoding_finalized_ && use_final_probs)
ComputeFinalCosts(&final_costs_local, NULL, NULL);
ofst->DeleteStates();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32 num_frames = active_toks_.size() - 1;
KALDI_ASSERT(num_frames > 0);
const int32 bucket_count = num_toks_/2 + 3;
unordered_map<Token*, StateId> tok_map(bucket_count);
// First create all states.
std::vector<Token*> token_list;
for (int32 f = 0; f <= num_frames; f++) {
if (active_toks_[f].toks == NULL) {
KALDI_WARN << "GetRawLattice: no tokens active on frame " << f
<< ": not producing lattice.\n";
return false;
}
TopSortTokens(active_toks_[f].toks, &token_list);
for (size_t i = 0; i < token_list.size(); i++)
if (token_list[i] != NULL)
tok_map[token_list[i]] = ofst->AddState();
}
// The next statement sets the start state of the output FST. Because we
// topologically sorted the tokens, state zero must be the start-state.
ofst->SetStart(0);
KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:"
<< tok_map.bucket_count() << " load:" << tok_map.load_factor()
<< " max:" << tok_map.max_load_factor();
// Now create all arcs.
for (int32 f = 0; f <= num_frames; f++) {
for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) {
StateId cur_state = tok_map[tok];
for (ForwardLinkT *l = tok->links;
l != NULL;
l = l->next) {
typename unordered_map<Token*, StateId>::const_iterator
iter = tok_map.find(l->next_tok);
StateId nextstate = iter->second;
KALDI_ASSERT(iter != tok_map.end());
BaseFloat cost_offset = 0.0;
if (l->ilabel != 0) { // emitting..
KALDI_ASSERT(f >= 0 && f < cost_offsets_.size());
cost_offset = cost_offsets_[f];
}
Arc arc(l->ilabel, l->olabel,
Weight(l->graph_cost, l->acoustic_cost - cost_offset),
nextstate);
ofst->AddArc(cur_state, arc);
}
if (f == num_frames) {
if (use_final_probs && !final_costs.empty()) {
typename unordered_map<Token*, BaseFloat>::const_iterator
iter = final_costs.find(tok);
if (iter != final_costs.end())
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
} else {
ofst->SetFinal(cur_state, LatticeWeight::One());
}
}
}
}
return (ofst->NumStates() > 0);
}
// This function is now deprecated, since now we do determinization from outside
// the LatticeFasterDecoder class. Outputs an FST corresponding to the
// lattice-determinized lattice (one path per word sequence).
template <typename FST, typename Token>
bool LatticeFasterDecoderTpl<FST, Token>::GetLattice(CompactLattice *ofst,
bool use_final_probs) const {
Lattice raw_fst;
GetRawLattice(&raw_fst, use_final_probs);
Invert(&raw_fst); // make it so word labels are on the input.
// (in phase where we get backward-costs).
fst::ILabelCompare<LatticeArc> ilabel_comp;
ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes
// lattice-determinization more efficient.
fst::DeterminizeLatticePrunedOptions lat_opts;
lat_opts.max_mem = config_.det_opts.max_mem;
DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts);
raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed.
Connect(ofst); // Remove unreachable states... there might be
// a small number of these, in some cases.
// Note: if something went wrong and the raw lattice was empty,
// we should still get to this point in the code without warnings or failures.
return (ofst->NumStates() != 0);
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PossiblyResizeHash(size_t num_toks) {
size_t new_sz = static_cast<size_t>(static_cast<BaseFloat>(num_toks)
* config_.hash_ratio);
if (new_sz > toks_.Size()) {
toks_.SetSize(new_sz);
}
}
/*
A note on the definition of extra_cost.
extra_cost is used in pruning tokens, to save memory.
extra_cost can be thought of as a beta (backward) cost assuming
we had set the betas on currently-active tokens to all be the negative
of the alphas for those tokens. (So all currently active tokens would
be on (tied) best paths).
We can use the extra_cost to accurately prune away tokens that we know will
never appear in the lattice. If the extra_cost is greater than the desired
lattice beam, the token would provably never appear in the lattice, so we can
prune away the token.
(Note: we don't update all the extra_costs every time we update a frame; we
only do it every 'config_.prune_interval' frames).
*/
// FindOrAddToken either locates a token in hash of toks_,
// or if necessary inserts a new, empty token (i.e. with no forward links)
// for the current frame. [note: it's inserted if necessary into hash toks_
// and also into the singly linked list of tokens active on this frame
// (whose head is at active_toks_[frame]).
template <typename FST, typename Token>
inline typename LatticeFasterDecoderTpl<FST, Token>::Elem*
LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(
StateId state, int32 frame_plus_one, BaseFloat tot_cost,
Token *backpointer, bool *changed) {
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
KALDI_ASSERT(frame_plus_one < active_toks_.size());
Token *&toks = active_toks_[frame_plus_one].toks;
Elem *e_found = toks_.Insert(state, NULL);
if (e_found->val == NULL) { // no such token presently.
const BaseFloat extra_cost = 0.0;
// tokens on the currently final frame have zero extra_cost
// as any of them could end up
// on the winning path.
Token *new_tok = new (token_pool_.Allocate())
Token(tot_cost, extra_cost, NULL, toks, backpointer);
// NULL: no forward links yet
toks = new_tok;
num_toks_++;
e_found->val = new_tok;
if (changed) *changed = true;
return e_found;
} else {
Token *tok = e_found->val; // There is an existing Token for this state.
if (tok->tot_cost > tot_cost) { // replace old token
tok->tot_cost = tot_cost;
// SetBackpointer() just does tok->backpointer = backpointer in
// the case where Token == BackpointerToken, else nothing.
tok->SetBackpointer(backpointer);
// we don't allocate a new token, the old stays linked in active_toks_
// we only replace the tot_cost
// in the current frame, there are no forward links (and no extra_cost)
// only in ProcessNonemitting we have to delete forward links
// in case we visit a state for the second time
// those forward links, that lead to this replaced token before:
// they remain and will hopefully be pruned later (PruneForwardLinks...)
if (changed) *changed = true;
} else {
if (changed) *changed = false;
}
return e_found;
}
}
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PruneForwardLinks(
int32 frame_plus_one, bool *extra_costs_changed,
bool *links_pruned, BaseFloat delta) {
// delta is the amount by which the extra_costs must change
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
*extra_costs_changed = false;
*links_pruned = false;
KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen.
if (!warned_) {
KALDI_WARN << "No tokens alive [doing pruning].. warning first "
"time only for each utterance\n";
warned_ = true;
}
}
// We have to iterate until there is no more change, because the links
// are not guaranteed to be in topological order.
bool changed = true; // difference new minus old extra cost >= delta ?
while (changed) {
changed = false;
for (Token *tok = active_toks_[frame_plus_one].toks;
tok != NULL; tok = tok->next) {
ForwardLinkT *link, *prev_link = NULL;
// will recompute tok_extra_cost for tok.
BaseFloat tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
// tok_extra_cost is the best (min) of link_extra_cost of outgoing links
for (link = tok->links; link != NULL; ) {
// See if we need to excise this link...
Token *next_tok = link->next_tok;
BaseFloat link_extra_cost = next_tok->extra_cost +
((tok->tot_cost + link->acoustic_cost + link->graph_cost)
- next_tok->tot_cost); // difference in brackets is >= 0
// link_exta_cost is the difference in score between the best paths
// through link source state and through link destination state
KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN
if (link_extra_cost > config_.lattice_beam) { // excise link
ForwardLinkT *next_link = link->next;
if (prev_link != NULL) prev_link->next = next_link;
else tok->links = next_link;
forward_link_pool_.Free(link);
link = next_link; // advance link but leave prev_link the same.
*links_pruned = true;
} else { // keep the link and update the tok_extra_cost if needed.
if (link_extra_cost < 0.0) { // this is just a precaution.
if (link_extra_cost < -0.01)
KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
link_extra_cost = 0.0;
}
if (link_extra_cost < tok_extra_cost)
tok_extra_cost = link_extra_cost;
prev_link = link; // move to next link
link = link->next;
}
} // for all outgoing links
if (fabs(tok_extra_cost - tok->extra_cost) > delta)
changed = true; // difference new minus old is bigger than delta
tok->extra_cost = tok_extra_cost;
// will be +infinity or <= lattice_beam_.
// infinity indicates, that no forward link survived pruning
} // for all Token on active_toks_[frame]
if (changed) *extra_costs_changed = true;
// Note: it's theoretically possible that aggressive compiler
// optimizations could cause an infinite loop here for small delta and
// high-dynamic-range scores.
} // while changed
}
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PruneForwardLinksFinal() {
KALDI_ASSERT(!active_toks_.empty());
int32 frame_plus_one = active_toks_.size() - 1;
if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen.
KALDI_WARN << "No tokens alive at end of file";
typedef typename unordered_map<Token*, BaseFloat>::const_iterator IterType;
ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
decoding_finalized_ = true;
// We call DeleteElems() as a nicety, not because it's really necessary;
// otherwise there would be a time, after calling PruneTokensForFrame() on the
// final frame, when toks_.GetList() or toks_.Clear() would contain pointers
// to nonexistent tokens.
DeleteElems(toks_.Clear());
// Now go through tokens on this frame, pruning forward links... may have to
// iterate a few times until there is no more change, because the list is not
// in topological order. This is a modified version of the code in
// PruneForwardLinks, but here we also take account of the final-probs.
bool changed = true;
BaseFloat delta = 1.0e-05;
while (changed) {
changed = false;
for (Token *tok = active_toks_[frame_plus_one].toks;
tok != NULL; tok = tok->next) {
ForwardLinkT *link, *prev_link = NULL;
// will recompute tok_extra_cost. It has a term in it that corresponds
// to the "final-prob", so instead of initializing tok_extra_cost to infinity
// below we set it to the difference between the (score+final_prob) of this token,
// and the best such (score+final_prob).
BaseFloat final_cost;
if (final_costs_.empty()) {
final_cost = 0.0;
} else {
IterType iter = final_costs_.find(tok);
if (iter != final_costs_.end())
final_cost = iter->second;
else
final_cost = std::numeric_limits<BaseFloat>::infinity();
}
BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_;
// tok_extra_cost will be a "min" over either directly being final, or
// being indirectly final through other links, and the loop below may
// decrease its value:
for (link = tok->links; link != NULL; ) {
// See if we need to excise this link...
Token *next_tok = link->next_tok;
BaseFloat link_extra_cost = next_tok->extra_cost +
((tok->tot_cost + link->acoustic_cost + link->graph_cost)
- next_tok->tot_cost);
if (link_extra_cost > config_.lattice_beam) { // excise link
ForwardLinkT *next_link = link->next;
if (prev_link != NULL) prev_link->next = next_link;
else tok->links = next_link;
forward_link_pool_.Free(link);
link = next_link; // advance link but leave prev_link the same.
} else { // keep the link and update the tok_extra_cost if needed.
if (link_extra_cost < 0.0) { // this is just a precaution.
if (link_extra_cost < -0.01)
KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
link_extra_cost = 0.0;
}
if (link_extra_cost < tok_extra_cost)
tok_extra_cost = link_extra_cost;
prev_link = link;
link = link->next;
}
}
// prune away tokens worse than lattice_beam above best path. This step
// was not necessary in the non-final case because then, this case
// showed up as having no forward links. Here, the tok_extra_cost has
// an extra component relating to the final-prob.
if (tok_extra_cost > config_.lattice_beam)
tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
// to be pruned in PruneTokensForFrame
if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta))
changed = true;
tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_.
}
} // while changed
}
template <typename FST, typename Token>
BaseFloat LatticeFasterDecoderTpl<FST, Token>::FinalRelativeCost() const {
if (!decoding_finalized_) {
BaseFloat relative_cost;
ComputeFinalCosts(NULL, &relative_cost, NULL);
return relative_cost;
} else {
// we're not allowed to call that function if FinalizeDecoding() has
// been called; return a cached value.
return final_relative_cost_;
}
}
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PruneTokensForFrame(int32 frame_plus_one) {
KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
Token *&toks = active_toks_[frame_plus_one].toks;
if (toks == NULL)
KALDI_WARN << "No tokens alive [doing pruning]";
Token *tok, *next_tok, *prev_tok = NULL;
for (tok = toks; tok != NULL; tok = next_tok) {
next_tok = tok->next;
if (tok->extra_cost == std::numeric_limits<BaseFloat>::infinity()) {
// token is unreachable from end of graph; (no forward links survived)
// excise tok from list and delete tok.
if (prev_tok != NULL) prev_tok->next = tok->next;
else toks = tok->next;
token_pool_.Free(tok);
num_toks_--;
} else { // fetch next Token
prev_tok = tok;
}
}
}
// Go backwards through still-alive tokens, pruning them, starting not from
// the current frame (where we want to keep all tokens) but from the frame before
// that. We go backwards through the frames and stop when we reach a point
// where the delta-costs are not changing (and the delta controls when we consider
// a cost to have "not changed").
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PruneActiveTokens(BaseFloat delta) {
int32 cur_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
// The index "f" below represents a "frame plus one", i.e. you'd have to subtract
// one to get the corresponding index for the decodable object.
for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
// Reason why we need to prune forward links in this situation:
// (1) we have never pruned them (new TokenList)
// (2) we have not yet pruned the forward links to the next f,
// after any of those tokens have changed their extra_cost.
if (active_toks_[f].must_prune_forward_links) {
bool extra_costs_changed = false, links_pruned = false;
PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
if (extra_costs_changed && f > 0) // any token has changed extra_cost
active_toks_[f-1].must_prune_forward_links = true;
if (links_pruned) // any link was pruned
active_toks_[f].must_prune_tokens = true;
active_toks_[f].must_prune_forward_links = false; // job done
}
if (f+1 < cur_frame_plus_one && // except for last f (no forward links)
active_toks_[f+1].must_prune_tokens) {
PruneTokensForFrame(f+1);
active_toks_[f+1].must_prune_tokens = false;
}
}
KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin
<< " to " << num_toks_;
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::ComputeFinalCosts(
unordered_map<Token*, BaseFloat> *final_costs,
BaseFloat *final_relative_cost,
BaseFloat *final_best_cost) const {
KALDI_ASSERT(!decoding_finalized_);
if (final_costs != NULL)
final_costs->clear();
const Elem *final_toks = toks_.GetList();
BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_cost = infinity,
best_cost_with_final = infinity;
while (final_toks != NULL) {
StateId state = final_toks->key;
Token *tok = final_toks->val;
const Elem *next = final_toks->tail;
BaseFloat final_cost = fst_->Final(state).Value();
BaseFloat cost = tok->tot_cost,
cost_with_final = cost + final_cost;
best_cost = std::min(cost, best_cost);
best_cost_with_final = std::min(cost_with_final, best_cost_with_final);
if (final_costs != NULL && final_cost != infinity)
(*final_costs)[tok] = final_cost;
final_toks = next;
}
if (final_relative_cost != NULL) {
if (best_cost == infinity && best_cost_with_final == infinity) {
// Likely this will only happen if there are no tokens surviving.
// This seems the least bad way to handle it.
*final_relative_cost = infinity;
} else {
*final_relative_cost = best_cost_with_final - best_cost;
}
}
if (final_best_cost != NULL) {
if (best_cost_with_final != infinity) { // final-state exists.
*final_best_cost = best_cost_with_final;
} else { // no final-state exists.
*final_best_cost = best_cost;
}
}
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::AdvanceDecoding(DecodableInterface *decodable,
int32 max_num_frames) {
if (std::is_same<FST, fst::Fst<fst::StdArc> >::value) {
// if the type 'FST' is the FST base-class, then see if the FST type of fst_
// is actually VectorFst or ConstFst. If so, call the AdvanceDecoding()
// function after casting *this to the more specific type.
if (fst_->Type() == "const") {
LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, Token> *this_cast =
reinterpret_cast<LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, Token>* >(this);
this_cast->AdvanceDecoding(decodable, max_num_frames);
return;
} else if (fst_->Type() == "vector") {
LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, Token> *this_cast =
reinterpret_cast<LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, Token>* >(this);
this_cast->AdvanceDecoding(decodable, max_num_frames);
return;
}
}
KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ &&
"You must call InitDecoding() before AdvanceDecoding");
int32 num_frames_ready = decodable->NumFramesReady();
// num_frames_ready must be >= num_frames_decoded, or else
// the number of frames ready must have decreased (which doesn't
// make sense) or the decodable object changed between calls
// (which isn't allowed).
KALDI_ASSERT(num_frames_ready >= NumFramesDecoded());
int32 target_frames_decoded = num_frames_ready;
if (max_num_frames >= 0)
target_frames_decoded = std::min(target_frames_decoded,
NumFramesDecoded() + max_num_frames);
while (NumFramesDecoded() < target_frames_decoded) {
if (NumFramesDecoded() % config_.prune_interval == 0) {
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
}
BaseFloat cost_cutoff = ProcessEmitting(decodable);
ProcessNonemitting(cost_cutoff);
}
}
// FinalizeDecoding() is a version of PruneActiveTokens that we call
// (optionally) on the final frame. Takes into account the final-prob of
// tokens. This function used to be called PruneActiveTokensFinal().
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::FinalizeDecoding() {
int32 final_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
// PruneForwardLinksFinal() prunes final frame (with final-probs), and
// sets decoding_finalized_.
PruneForwardLinksFinal();
for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
bool b1, b2; // values not used.
BaseFloat dontcare = 0.0; // delta of zero means we must always update
PruneForwardLinks(f, &b1, &b2, dontcare);
PruneTokensForFrame(f + 1);
}
PruneTokensForFrame(0);
KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin
<< " to " << num_toks_;
}
/// Gets the weight cutoff. Also counts the active tokens.
template <typename FST, typename Token>
BaseFloat LatticeFasterDecoderTpl<FST, Token>::GetCutoff(Elem *list_head, size_t *tok_count,
BaseFloat *adaptive_beam, Elem **best_elem) {
BaseFloat best_weight = std::numeric_limits<BaseFloat>::infinity();
// positive == high cost == bad.
size_t count = 0;
if (config_.max_active == std::numeric_limits<int32>::max() &&
config_.min_active == 0) {
for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
BaseFloat w = static_cast<BaseFloat>(e->val->tot_cost);
if (w < best_weight) {
best_weight = w;
if (best_elem) *best_elem = e;
}
}
if (tok_count != NULL) *tok_count = count;
if (adaptive_beam != NULL) *adaptive_beam = config_.beam;
return best_weight + config_.beam;
} else {
tmp_array_.clear();
for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
BaseFloat w = e->val->tot_cost;
tmp_array_.push_back(w);
if (w < best_weight) {
best_weight = w;
if (best_elem) *best_elem = e;
}
}
if (tok_count != NULL) *tok_count = count;
BaseFloat beam_cutoff = best_weight + config_.beam,
min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded()
<< " is " << tmp_array_.size();
if (tmp_array_.size() > static_cast<size_t>(config_.max_active)) {
std::nth_element(tmp_array_.begin(),
tmp_array_.begin() + config_.max_active,
tmp_array_.end());
max_active_cutoff = tmp_array_[config_.max_active];
}
if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam.
if (adaptive_beam)
*adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta;
return max_active_cutoff;
}
if (tmp_array_.size() > static_cast<size_t>(config_.min_active)) {
if (config_.min_active == 0) min_active_cutoff = best_weight;
else {
std::nth_element(tmp_array_.begin(),
tmp_array_.begin() + config_.min_active,
tmp_array_.size() > static_cast<size_t>(config_.max_active) ?
tmp_array_.begin() + config_.max_active :
tmp_array_.end());
min_active_cutoff = tmp_array_[config_.min_active];
}
}
if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam.
if (adaptive_beam)
*adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta;
return min_active_cutoff;
} else {
*adaptive_beam = config_.beam;
return beam_cutoff;
}
}
}
template <typename FST, typename Token>
BaseFloat LatticeFasterDecoderTpl<FST, Token>::ProcessEmitting(
DecodableInterface *decodable) {
KALDI_ASSERT(active_toks_.size() > 0);
int32 frame = active_toks_.size() - 1; // frame is the frame-index
// (zero-based) used to get likelihoods
// from the decodable object.
active_toks_.resize(active_toks_.size() + 1);
Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_
// in simple-decoder.h. Removes the Elems from
// being indexed in the hash in toks_.
Elem *best_elem = NULL;
BaseFloat adaptive_beam;
size_t tok_cnt;
BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is "
<< adaptive_beam;
PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
// pruning "online" before having seen all tokens
BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good
// dynamic range.
// First process the best token to get a hopefully
// reasonably tight bound on the next cutoff. The only
// products of the next block are "next_cutoff" and "cost_offset".
if (best_elem) {
StateId state = best_elem->key;
Token *tok = best_elem->val;
cost_offset = - tok->tot_cost;
for (fst::ArcIterator<FST> aiter(*fst_, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // propagate..
BaseFloat new_weight = arc.weight.Value() + cost_offset -
decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost;
if (new_weight + adaptive_beam < next_cutoff)
next_cutoff = new_weight + adaptive_beam;
}
}
}
// Store the offset on the acoustic likelihoods that we're applying.
// Could just do cost_offsets_.push_back(cost_offset), but we
// do it this way as it's more robust to future code changes.
cost_offsets_.resize(frame + 1, 0.0);
cost_offsets_[frame] = cost_offset;
// the tokens are now owned here, in final_toks, and the hash is empty.
// 'owned' is a complex thing here; the point is we need to call DeleteElem
// on each elem 'e' to let toks_ know we're done with them.
for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) {
// loop this way because we delete "e" as we go.
StateId state = e->key;
Token *tok = e->val;
if (tok->tot_cost <= cur_cutoff) {
for (fst::ArcIterator<FST> aiter(*fst_, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // propagate..
BaseFloat ac_cost = cost_offset -
decodable->LogLikelihood(frame, arc.ilabel),
graph_cost = arc.weight.Value(),
cur_cost = tok->tot_cost,
tot_cost = cur_cost + ac_cost + graph_cost;
if (tot_cost >= next_cutoff) continue;
else if (tot_cost + adaptive_beam < next_cutoff)
next_cutoff = tot_cost + adaptive_beam; // prune by best current token
// Note: the frame indexes into active_toks_ are one-based,
// hence the + 1.
Elem *e_next = FindOrAddToken(arc.nextstate,
frame + 1, tot_cost, tok, NULL);
// NULL: no change indicator needed
// Add ForwardLink from tok to next_tok (put on head of list tok->links)
tok->links = new (forward_link_pool_.Allocate())
ForwardLinkT(e_next->val, arc.ilabel, arc.olabel, graph_cost,
ac_cost, tok->links);
}
} // for all arcs
}
e_tail = e->tail;
toks_.Delete(e); // delete Elem
}
return next_cutoff;
}
// static inline
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::DeleteForwardLinks(Token *tok) {
ForwardLinkT *l = tok->links, *m;
while (l != NULL) {
m = l->next;
forward_link_pool_.Free(l);
l = m;
}
tok->links = NULL;
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::ProcessNonemitting(BaseFloat cutoff) {
KALDI_ASSERT(!active_toks_.empty());
int32 frame = static_cast<int32>(active_toks_.size()) - 2;
// Note: "frame" is the time-index we just processed, or -1 if
// we are processing the nonemitting transitions before the
// first frame (called from InitDecoding()).
// Processes nonemitting arcs for one frame. Propagates within toks_.
// Note-- this queue structure is not very optimal as
// it may cause us to process states unnecessarily (e.g. more than once),
// but in the baseline code, turning this vector into a set to fix this
// problem did not improve overall speed.
KALDI_ASSERT(queue_.empty());
if (toks_.GetList() == NULL) {
if (!warned_) {
KALDI_WARN << "Error, no surviving tokens: frame is " << frame;
warned_ = true;
}
}
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
StateId state = e->key;
if (fst_->NumInputEpsilons(state) != 0)
queue_.push_back(e);
}
while (!queue_.empty()) {
const Elem *e = queue_.back();
queue_.pop_back();
StateId state = e->key;
Token *tok = e->val; // would segfault if e is a NULL pointer but this can't happen.
BaseFloat cur_cost = tok->tot_cost;
if (cur_cost >= cutoff) // Don't bother processing successors.
continue;
// If "tok" has any existing forward links, delete them,
// because we're about to regenerate them. This is a kind
// of non-optimality (remember, this is the simple decoder),
// but since most states are emitting it's not a huge issue.
DeleteForwardLinks(tok); // necessary when re-visiting
tok->links = NULL;
for (fst::ArcIterator<FST> aiter(*fst_, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel == 0) { // propagate nonemitting only...
BaseFloat graph_cost = arc.weight.Value(),
tot_cost = cur_cost + graph_cost;
if (tot_cost < cutoff) {
bool changed;
Elem *e_new = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
tok, &changed);
tok->links = new (forward_link_pool_.Allocate()) ForwardLinkT(
e_new->val, 0, arc.olabel, graph_cost, 0, tok->links);
// "changed" tells us whether the new token has a different
// cost from before, or is new [if so, add into queue].
if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0)
queue_.push_back(e_new);
}
}
} // for all arcs
} // while queue not empty
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::DeleteElems(Elem *list) {
for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
e_tail = e->tail;
toks_.Delete(e);
}
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::ClearActiveTokens() { // a cleanup routine, at utt end/begin
for (size_t i = 0; i < active_toks_.size(); i++) {
// Delete all tokens alive on this frame, and any forward
// links they may have.
for (Token *tok = active_toks_[i].toks; tok != NULL; ) {
DeleteForwardLinks(tok);
Token *next_tok = tok->next;
token_pool_.Free(tok);
num_toks_--;
tok = next_tok;
}
}
active_toks_.clear();
KALDI_ASSERT(num_toks_ == 0);
}
// static
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::TopSortTokens(
Token *tok_list, std::vector<Token*> *topsorted_list) {
unordered_map<Token*, int32> token2pos;
typedef typename unordered_map<Token*, int32>::iterator IterType;
int32 num_toks = 0;
for (Token *tok = tok_list; tok != NULL; tok = tok->next)
num_toks++;
int32 cur_pos = 0;
// We assign the tokens numbers num_toks - 1, ... , 2, 1, 0.
// This is likely to be in closer to topological order than
// if we had given them ascending order, because of the way
// new tokens are put at the front of the list.
for (Token *tok = tok_list; tok != NULL; tok = tok->next)
token2pos[tok] = num_toks - ++cur_pos;
unordered_set<Token*> reprocess;
for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) {
Token *tok = iter->first;
int32 pos = iter->second;
for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) {
if (link->ilabel == 0) {
// We only need to consider epsilon links, since non-epsilon links
// transition between frames and this function only needs to sort a list
// of tokens from a single frame.
IterType following_iter = token2pos.find(link->next_tok);
if (following_iter != token2pos.end()) { // another token on this frame,
// so must consider it.
int32 next_pos = following_iter->second;
if (next_pos < pos) { // reassign the position of the next Token.
following_iter->second = cur_pos++;
reprocess.insert(link->next_tok);
}
}
}
}
// In case we had previously assigned this token to be reprocessed, we can
// erase it from that set because it's "happy now" (we just processed it).
reprocess.erase(tok);
}
size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles.
for (loop_count = 0;
!reprocess.empty() && loop_count < max_loop; ++loop_count) {
std::vector<Token*> reprocess_vec;
for (typename unordered_set<Token*>::iterator iter = reprocess.begin();
iter != reprocess.end(); ++iter)
reprocess_vec.push_back(*iter);
reprocess.clear();
for (typename std::vector<Token*>::iterator iter = reprocess_vec.begin();
iter != reprocess_vec.end(); ++iter) {
Token *tok = *iter;
int32 pos = token2pos[tok];
// Repeat the processing we did above (for comments, see above).
for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) {
if (link->ilabel == 0) {
IterType following_iter = token2pos.find(link->next_tok);
if (following_iter != token2pos.end()) {
int32 next_pos = following_iter->second;
if (next_pos < pos) {
following_iter->second = cur_pos++;
reprocess.insert(link->next_tok);
}
}
}
}
}
}
KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding "
"graph (this is not allowed!)");
topsorted_list->clear();
topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between.
for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter)
(*topsorted_list)[iter->second] = iter->first;
}
// Instantiate the template for the combination of token types and FST types
// that we'll need.
template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc>, decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::StdToken >;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::StdToken >;
template class LatticeFasterDecoderTpl<fst::ConstGrammarFst, decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc> , decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::BackpointerToken >;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::BackpointerToken >;
template class LatticeFasterDecoderTpl<fst::ConstGrammarFst, decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::BackpointerToken>;
} // end namespace kaldi.
// decoder/lattice-faster-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#include "decoder/grammar-fst.h"
#include "fst/fstlib.h"
#include "fst/memory.h"
#include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "util/hash-list.h"
#include "util/stl-utils.h"
namespace kaldi {
struct LatticeFasterDecoderConfig {
BaseFloat beam;
int32 max_active;
int32 min_active;
BaseFloat lattice_beam;
int32 prune_interval;
bool determinize_lattice; // not inspected by this class... used in
// command-line program.
BaseFloat beam_delta;
BaseFloat hash_ratio;
// Note: we don't make prune_scale configurable on the command line, it's not
// a very important parameter. It affects the algorithm that prunes the
// tokens as we go.
BaseFloat prune_scale;
// Number of elements in the block for Token and ForwardLink memory
// pool allocation.
int32 memory_pool_tokens_block_size;
int32 memory_pool_links_block_size;
// Most of the options inside det_opts are not actually queried by the
// LatticeFasterDecoder class itself, but by the code that calls it, for
// example in the function DecodeUtteranceLatticeFaster.
fst::DeterminizeLatticePhonePrunedOptions det_opts;
LatticeFasterDecoderConfig()
: beam(16.0),
max_active(std::numeric_limits<int32>::max()),
min_active(200),
lattice_beam(10.0),
prune_interval(25),
determinize_lattice(true),
beam_delta(0.5),
hash_ratio(2.0),
prune_scale(0.1),
memory_pool_tokens_block_size(1 << 8),
memory_pool_links_block_size(1 << 8) {}
void Register(OptionsItf *opts) {
det_opts.Register(opts);
opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate.");
opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; "
"more accurate");
opts->Register("min-active", &min_active, "Decoder minimum #active states.");
opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, "
"and deeper lattices");
opts->Register("prune-interval", &prune_interval, "Interval (in frames) at "
"which to prune tokens");
opts->Register("determinize-lattice", &determinize_lattice, "If true, "
"determinize the lattice (lattice-determinization, keeping only "
"best pdf-sequence for each word-sequence).");
opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this "
"parameter is obscure and relates to a speedup in the way the "
"max-active constraint is applied. Larger is more accurate.");
opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to "
"control hash behavior");
opts->Register("memory-pool-tokens-block-size", &memory_pool_tokens_block_size,
"Memory pool block size suggestion for storing tokens (in elements). "
"Smaller uses less memory but increases cache misses.");
opts->Register("memory-pool-links-block-size", &memory_pool_links_block_size,
"Memory pool block size suggestion for storing links (in elements). "
"Smaller uses less memory but increases cache misses.");
}
void Check() const {
KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0
&& min_active <= max_active
&& prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0
&& prune_scale > 0.0 && prune_scale < 1.0);
}
};
namespace decoder {
// We will template the decoder on the token type as well as the FST type; this
// is a mechanism so that we can use the same underlying decoder code for
// versions of the decoder that support quickly getting the best path
// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also
// those that do not (LatticeFasterDecoder).
// ForwardLinks are the links from a token to a token on the next frame.
// or sometimes on the current frame (for input-epsilon links).
template <typename Token>
struct ForwardLink {
using Label = fst::StdArc::Label;
Token *next_tok; // the next token [or NULL if represents final-state]
Label ilabel; // ilabel on arc
Label olabel; // olabel on arc
BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.)
BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc
ForwardLink *next; // next in singly-linked list of forward arcs (arcs
// in the state-level lattice) from a token.
inline ForwardLink(Token *next_tok, Label ilabel, Label olabel,
BaseFloat graph_cost, BaseFloat acoustic_cost,
ForwardLink *next):
next_tok(next_tok), ilabel(ilabel), olabel(olabel),
graph_cost(graph_cost), acoustic_cost(acoustic_cost),
next(next) { }
};
struct StdToken {
using ForwardLinkT = ForwardLink<StdToken>;
using Token = StdToken;
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat tot_cost;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals the
// minimum difference between the cost of the best path that this link is a
// part of, and the cost of the absolute best path, under the assumption that
// any of the currently active states at the decoding front may eventually
// succeed (e.g. if you were to take the currently active states one by one
// and compute this difference, and then take the minimum).
BaseFloat extra_cost;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT *links;
//'next' is the next in the singly-linked list of tokens for this frame.
Token *next;
// This function does nothing and should be optimized out; it's needed
// so we can share the regular LatticeFasterDecoderTpl code and the code
// for LatticeFasterOnlineDecoder that supports fast traceback.
inline void SetBackpointer (Token *backpointer) { }
// This constructor just ignores the 'backpointer' argument. That argument is
// needed so that we can use the same decoder code for LatticeFasterDecoderTpl
// and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a
// fast way to obtain the best path).
inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
Token *next, Token *backpointer):
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { }
};
struct BackpointerToken {
using ForwardLinkT = ForwardLink<BackpointerToken>;
using Token = BackpointerToken;
// BackpointerToken is like Token but also
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat tot_cost;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals
// the minimum difference between the cost of the best path, and the cost of
// this is on, and the cost of the absolute best path, under the assumption
// that any of the currently active states at the decoding front may
// eventually succeed (e.g. if you were to take the currently active states
// one by one and compute this difference, and then take the minimum).
BaseFloat extra_cost;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT *links;
//'next' is the next in the singly-linked list of tokens for this frame.
BackpointerToken *next;
// Best preceding BackpointerToken (could be a on this frame, connected to
// this via an epsilon transition, or on a previous frame). This is only
// required for an efficient GetBestPath function in
// LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation
// (the "links" list is what stores the forward links, for that).
Token *backpointer;
inline void SetBackpointer (Token *backpointer) {
this->backpointer = backpointer;
}
inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
Token *next, Token *backpointer):
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next),
backpointer(backpointer) { }
};
} // namespace decoder
/** This is the "normal" lattice-generating decoder.
See \ref lattices_generation \ref decoders_faster and \ref decoders_simple
for more information.
The decoder is templated on the FST type and the token type. The token type
will normally be StdToken, but also may be BackpointerToken which is to support
quick lookup of the current best path (see lattice-faster-online-decoder.h)
The FST you invoke this decoder which is expected to equal
Fst::Fst<fst::StdArc>, a.k.a. StdFst, or GrammarFst. If you invoke it with
FST == StdFst and it notices that the actual FST type is
fst::VectorFst<fst::StdArc> or fst::ConstFst<fst::StdArc>, the decoder object
will internally cast itself to one that is templated on those more specific
types; this is an optimization for speed.
*/
template <typename FST, typename Token = decoder::StdToken>
class LatticeFasterDecoderTpl {
public:
using Arc = typename FST::Arc;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using ForwardLinkT = decoder::ForwardLink<Token>;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterDecoderTpl(const FST &fst,
const LatticeFasterDecoderConfig &config);
// This version of the constructor takes ownership of the fst, and will delete
// it when this object is destroyed.
LatticeFasterDecoderTpl(const LatticeFasterDecoderConfig &config,
FST *fst);
void SetOptions(const LatticeFasterDecoderConfig &config) {
config_ = config;
}
const LatticeFasterDecoderConfig &GetOptions() const {
return config_;
}
~LatticeFasterDecoderTpl();
/// Decodes until there are no more frames left in the "decodable" object..
/// note, this may block waiting for input if the "decodable" object blocks.
/// Returns true if any kind of traceback is available (not necessarily from a
/// final state).
bool Decode(DecodableInterface *decodable);
/// says whether a final-state was active on the last frame. If it was not, the
/// lattice (or traceback) will end with states that are not final-states.
bool ReachedFinal() const {
return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity();
}
/// Outputs an FST corresponding to the single best path through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one. Note: this just calls GetRawLattice()
/// and figures out the shortest path.
bool GetBestPath(Lattice *ofst,
bool use_final_probs = true) const;
/// Outputs an FST corresponding to the raw, state-level
/// tracebacks. Returns true if result is nonempty.
/// If "use_final_probs" is true AND we reached the final-state
/// of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
/// The raw lattice will be topologically sorted.
///
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
/// which also supports a pruning beam, in case for some reason
/// you want it pruned tighter than the regular lattice beam.
/// We could put that here in future needed.
bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const;
/// [Deprecated, users should now use GetRawLattice and determinize it
/// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper].
/// Outputs an FST corresponding to the lattice-determinized
/// lattice (one path per word sequence). Returns true if result is nonempty.
/// If "use_final_probs" is true AND we reached the final-state of the graph
/// then it will include those as final-probs, else it will treat all
/// final-probs as one.
bool GetLattice(CompactLattice *ofst,
bool use_final_probs = true) const;
/// InitDecoding initializes the decoding, and should only be used if you
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
/// call this. You can also call InitDecoding if you have already decoded an
/// utterance and want to start with a new utterance.
void InitDecoding();
/// This will decode until there are no more frames ready in the decodable
/// object. You can keep calling it each time more frames become available.
/// If max_num_frames is specified, it specifies the maximum number of frames
/// the function will decode before returning.
void AdvanceDecoding(DecodableInterface *decodable,
int32 max_num_frames = -1);
/// This function may be optionally called after AdvanceDecoding(), when you
/// do not plan to decode any further. It does an extra pruning step that
/// will help to prune the lattices output by GetLattice and (particularly)
/// GetRawLattice more completely, particularly toward the end of the
/// utterance. If you call this, you cannot call AdvanceDecoding again (it
/// will fail), and you cannot call GetLattice() and related functions with
/// use_final_probs = false. Used to be called PruneActiveTokensFinal().
void FinalizeDecoding();
/// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives
/// more information. It returns the difference between the best (final-cost
/// plus cost) of any token on the final frame, and the best cost of any token
/// on the final frame. If it is infinity it means no final-states were
/// present on the final frame. It will usually be nonnegative. If it not
/// too positive (e.g. < 5 is my first guess, but this is not tested) you can
/// take it as a good indication that we reached the final-state with
/// reasonable likelihood.
BaseFloat FinalRelativeCost() const;
// Returns the number of frames decoded so far. The value returned changes
// whenever we call ProcessEmitting().
inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
protected:
// we make things protected instead of private, as code in
// LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the
// internals.
// Deletes the elements of the singly linked list tok->links.
void DeleteForwardLinks(Token *tok);
// head of per-frame list of Tokens (list is in topological order),
// and something saying whether we ever pruned it using PruneForwardLinks.
struct TokenList {
Token *toks;
bool must_prune_forward_links;
bool must_prune_tokens;
TokenList(): toks(NULL), must_prune_forward_links(true),
must_prune_tokens(true) { }
};
using Elem = typename HashList<StateId, Token*>::Elem;
// Equivalent to:
// struct Elem {
// StateId key;
// Token *val;
// Elem *tail;
// };
void PossiblyResizeHash(size_t num_toks);
// FindOrAddToken either locates a token in hash of toks_, or if necessary
// inserts a new, empty token (i.e. with no forward links) for the current
// frame. [note: it's inserted if necessary into hash toks_ and also into the
// singly linked list of tokens active on this frame (whose head is at
// active_toks_[frame]). The frame_plus_one argument is the acoustic frame
// index plus one, which is used to index into the active_toks_ array.
// Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
// token was newly created or the cost changed.
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
// hopefully be optimized out).
inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one,
BaseFloat tot_cost, Token *backpointer,
bool *changed);
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
// delta is the amount by which the extra_costs must change
// before we set *extra_costs_changed = true.
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed,
bool *links_pruned,
BaseFloat delta);
// This function computes the final-costs for tokens active on the final
// frame. It outputs to final-costs, if non-NULL, a map from the Token*
// pointer to the final-prob of the corresponding state, for all Tokens
// that correspond to states that have final-probs. This map will be
// empty if there were no final-probs. It outputs to
// final_relative_cost, if non-NULL, the difference between the best
// forward-cost including the final-prob cost, and the best forward-cost
// without including the final-prob cost (this will usually be positive), or
// infinity if there were no final-probs. [c.f. FinalRelativeCost(), which
// outputs this quanitity]. It outputs to final_best_cost, if
// non-NULL, the lowest for any token t active on the final frame, of
// forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in
// the graph of the state corresponding to token t, or the best of
// forward-cost[t] if there were no final-probs active on the final frame.
// You cannot call this after FinalizeDecoding() has been called; in that
// case you should get the answer from class-member variables.
void ComputeFinalCosts(unordered_map<Token*, BaseFloat> *final_costs,
BaseFloat *final_relative_cost,
BaseFloat *final_best_cost) const;
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
void PruneForwardLinksFinal();
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
void PruneTokensForFrame(int32 frame_plus_one);
// Go backwards through still-alive tokens, pruning them if the
// forward+backward cost is more than lat_beam away from the best path. It's
// possible to prove that this is "correct" in the sense that we won't lose
// anything outside of lat_beam, regardless of what happens in the future.
// delta controls when it considers a cost to have changed enough to continue
// going backward and propagating the change. larger delta -> will recurse
// less far.
void PruneActiveTokens(BaseFloat delta);
/// Gets the weight cutoff. Also counts the active tokens.
BaseFloat GetCutoff(Elem *list_head, size_t *tok_count,
BaseFloat *adaptive_beam, Elem **best_elem);
/// Processes emitting arcs for one frame. Propagates from prev_toks_ to
/// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to
/// use.
BaseFloat ProcessEmitting(DecodableInterface *decodable);
/// Processes nonemitting (epsilon) arcs for one frame. Called after
/// ProcessEmitting() on each frame. The cost cutoff is computed by the
/// preceding ProcessEmitting().
void ProcessNonemitting(BaseFloat cost_cutoff);
// HashList defined in ../util/hash-list.h. It actually allows us to maintain
// more than one list (e.g. for current and previous frames), but only one of
// them at a time can be indexed by StateId. It is indexed by frame-index
// plus one, where the frame-index is zero-based, as used in decodable object.
// That is, the emitting probs of frame t are accounted for in tokens at
// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of
// the graph.
HashList<StateId, Token*> toks_;
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
// frame (members of TokenList are toks, must_prune_forward_links,
// must_prune_tokens).
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
// fst_ is a pointer to the FST we are decoding from.
const FST *fst_;
// delete_fst_ is true if the pointer fst_ needs to be deleted when this
// object is destroyed.
bool delete_fst_;
std::vector<BaseFloat> cost_offsets_; // This contains, for each
// frame, an offset that was added to the acoustic log-likelihoods on that
// frame in order to keep everything in a nice dynamic range i.e. close to
// zero, to reduce roundoff errors.
LatticeFasterDecoderConfig config_;
int32 num_toks_; // current total #toks allocated...
bool warned_;
/// decoding_finalized_ is true if someone called FinalizeDecoding(). [note,
/// calling this is optional]. If true, it's forbidden to decode more. Also,
/// if this is set, then the output of ComputeFinalCosts() is in the next
/// three variables. The reason we need to do this is that after
/// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some
/// of the tokens on the last frame are freed, so we free the list from toks_
/// to avoid having dangling pointers hanging around.
bool decoding_finalized_;
/// For the meaning of the next 3 variables, see the comment for
/// decoding_finalized_ above., and ComputeFinalCosts().
unordered_map<Token*, BaseFloat> final_costs_;
BaseFloat final_relative_cost_;
BaseFloat final_best_cost_;
// Memory pools for storing tokens and forward links.
// We use it to decrease the work put on allocator and to move some of data
// together. Too small block sizes will result in more work to allocator but
// bigger ones increase the memory usage.
fst::MemoryPool<Token> token_pool_;
fst::MemoryPool<ForwardLinkT> forward_link_pool_;
// There are various cleanup tasks... the toks_ structure contains
// singly linked lists of Token pointers, where Elem is the list type.
// It also indexes them in a hash, indexed by state (this hash is only
// maintained for the most recent frame). toks_.Clear()
// deletes them from the hash and returns the list of Elems. The
// function DeleteElems calls toks_.Delete(elem) for each elem in
// the list, which returns ownership of the Elem to the toks_ structure
// for reuse, but does not delete the Token pointer. The Token pointers
// are reference-counted and are ultimately deleted in PruneTokensForFrame,
// but are also linked together on each frame by their own linked-list,
// using the "next" pointer. We delete them manually.
void DeleteElems(Elem *list);
// This function takes a singly linked list of tokens for a single frame, and
// outputs a list of them in topological order (it will crash if no such order
// can be found, which will typically be due to decoding graphs with epsilon
// cycles, which are not allowed). Note: the output list may contain NULLs,
// which the caller should pass over; it just happens to be more efficient for
// the algorithm to output a list that contains NULLs.
static void TopSortTokens(Token *tok_list,
std::vector<Token*> *topsorted_list);
void ClearActiveTokens();
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderTpl);
};
typedef LatticeFasterDecoderTpl<fst::StdFst, decoder::StdToken> LatticeFasterDecoder;
} // end namespace kaldi.
#endif
// decoder/lattice-faster-online-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.cc, about how to maintain this
// file in sync with lattice-faster-decoder.cc
#include "decoder/lattice-faster-online-decoder.h"
#include "lat/lattice-functions.h"
namespace kaldi {
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::TestGetBestPath(
bool use_final_probs) const {
Lattice lat1;
{
Lattice raw_lat;
this->GetRawLattice(&raw_lat, use_final_probs);
ShortestPath(raw_lat, &lat1);
}
Lattice lat2;
GetBestPath(&lat2, use_final_probs);
BaseFloat delta = 0.1;
int32 num_paths = 1;
if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
KALDI_WARN << "Best-path test failed";
return false;
} else {
return true;
}
}
// Outputs an FST corresponding to the single best path through the lattice.
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetBestPath(Lattice *olat,
bool use_final_probs) const {
olat->DeleteStates();
BaseFloat final_graph_cost;
BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
if (iter.Done())
return false; // would have printed warning.
StateId state = olat->AddState();
olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
while (!iter.Done()) {
LatticeArc arc;
iter = TraceBackBestPath(iter, &arc);
arc.nextstate = state;
StateId new_state = olat->AddState();
olat->AddArc(new_state, arc);
state = new_state;
}
olat->SetStart(state);
return true;
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::BestPathEnd(
bool use_final_probs,
BaseFloat *final_cost_out) const {
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "BestPathEnd() with use_final_probs == false";
KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
"You cannot call BestPathEnd if no frames were decoded.");
unordered_map<Token*, BaseFloat> final_costs_local;
const unordered_map<Token*, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ :final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
// Singly linked list of tokens on last frame (access list through "next"
// pointer).
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_final_cost = 0;
Token *best_tok = NULL;
for (Token *tok = this->active_toks_.back().toks;
tok != NULL; tok = tok->next) {
BaseFloat cost = tok->tot_cost, final_cost = 0.0;
if (use_final_probs && !final_costs.empty()) {
// if we are instructed to use final-probs, and any final tokens were
// active on final frame, include the final-prob in the cost of the token.
typename unordered_map<Token*, BaseFloat>::const_iterator
iter = final_costs.find(tok);
if (iter != final_costs.end()) {
final_cost = iter->second;
cost += final_cost;
} else {
cost = std::numeric_limits<BaseFloat>::infinity();
}
}
if (cost < best_cost) {
best_cost = cost;
best_tok = tok;
best_final_cost = final_cost;
}
}
if (best_tok == NULL) { // this should not happen, and is likely a code error or
// caused by infinities in likelihoods, but I'm not making
// it a fatal error for now.
KALDI_WARN << "No final token found.";
}
if (final_cost_out)
*final_cost_out = best_final_cost;
return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::TraceBackBestPath(
BestPathIterator iter, LatticeArc *oarc) const {
KALDI_ASSERT(!iter.Done() && oarc != NULL);
Token *tok = static_cast<Token*>(iter.tok);
int32 cur_t = iter.frame, step_t = 0;
if (tok->backpointer != NULL) {
// retrieve the correct forward link(with the best link cost)
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
ForwardLinkT *link;
for (link = tok->backpointer->links;
link != NULL; link = link->next) {
if (link->next_tok == tok) { // this is a link to "tok"
BaseFloat graph_cost = link->graph_cost,
acoustic_cost = link->acoustic_cost;
BaseFloat cost = graph_cost + acoustic_cost;
if (cost < best_cost) {
oarc->ilabel = link->ilabel;
oarc->olabel = link->olabel;
if (link->ilabel != 0) {
KALDI_ASSERT(static_cast<size_t>(cur_t) < this->cost_offsets_.size());
acoustic_cost -= this->cost_offsets_[cur_t];
step_t = -1;
} else {
step_t = 0;
}
oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
best_cost = cost;
}
}
}
if (link == NULL &&
best_cost == std::numeric_limits<BaseFloat>::infinity()) { // Did not find correct link.
KALDI_ERR << "Error tracing best-path back (likely "
<< "bug in token-pruning algorithm)";
}
} else {
oarc->ilabel = 0;
oarc->olabel = 0;
oarc->weight = LatticeWeight::One(); // zero costs.
}
return BestPathIterator(tok->backpointer, cur_t + step_t);
}
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
Lattice *ofst,
bool use_final_probs,
BaseFloat beam) const {
typedef LatticeArc Arc;
typedef Arc::StateId StateId;
typedef Arc::Weight Weight;
typedef Arc::Label Label;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "GetRawLattice() with use_final_probs == false";
unordered_map<Token*, BaseFloat> final_costs_local;
const unordered_map<Token*, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
ofst->DeleteStates();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32 num_frames = this->active_toks_.size() - 1;
KALDI_ASSERT(num_frames > 0);
for (int32 f = 0; f <= num_frames; f++) {
if (this->active_toks_[f].toks == NULL) {
KALDI_WARN << "No tokens active on frame " << f
<< ": not producing lattice.\n";
return false;
}
}
unordered_map<Token*, StateId> tok_map;
std::queue<std::pair<Token*, int32> > tok_queue;
// First initialize the queue and states. Put the initial state on the queue;
// this is the last token in the list active_toks_[0].toks.
for (Token *tok = this->active_toks_[0].toks;
tok != NULL; tok = tok->next) {
if (tok->next == NULL) {
tok_map[tok] = ofst->AddState();
ofst->SetStart(tok_map[tok]);
std::pair<Token*, int32> tok_pair(tok, 0); // #frame = 0
tok_queue.push(tok_pair);
}
}
// Next create states for "good" tokens
while (!tok_queue.empty()) {
std::pair<Token*, int32> cur_tok_pair = tok_queue.front();
tok_queue.pop();
Token *cur_tok = cur_tok_pair.first;
int32 cur_frame = cur_tok_pair.second;
KALDI_ASSERT(cur_frame >= 0 &&
cur_frame <= this->cost_offsets_.size());
typename unordered_map<Token*, StateId>::const_iterator iter =
tok_map.find(cur_tok);
KALDI_ASSERT(iter != tok_map.end());
StateId cur_state = iter->second;
for (ForwardLinkT *l = cur_tok->links;
l != NULL;
l = l->next) {
Token *next_tok = l->next_tok;
if (next_tok->extra_cost < beam) {
// so both the current and the next token are good; create the arc
int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
StateId nextstate;
if (tok_map.find(next_tok) == tok_map.end()) {
nextstate = tok_map[next_tok] = ofst->AddState();
tok_queue.push(std::pair<Token*, int32>(next_tok, next_frame));
} else {
nextstate = tok_map[next_tok];
}
BaseFloat cost_offset = (l->ilabel != 0 ?
this->cost_offsets_[cur_frame] : 0);
Arc arc(l->ilabel, l->olabel,
Weight(l->graph_cost, l->acoustic_cost - cost_offset),
nextstate);
ofst->AddArc(cur_state, arc);
}
}
if (cur_frame == num_frames) {
if (use_final_probs && !final_costs.empty()) {
typename unordered_map<Token*, BaseFloat>::const_iterator iter =
final_costs.find(cur_tok);
if (iter != final_costs.end())
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
} else {
ofst->SetFinal(cur_state, LatticeWeight::One());
}
}
}
return (ofst->NumStates() != 0);
}
// Instantiate the template for the FST types that we'll need.
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;
} // end namespace kaldi.
// decoder/lattice-faster-online-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.h, about how to maintain this
// file in sync with lattice-faster-decoder.h
#ifndef KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#include "util/stl-utils.h"
#include "util/hash-list.h"
#include "fst/fstlib.h"
#include "itf/decodable-itf.h"
#include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "decoder/lattice-faster-decoder.h"
namespace kaldi {
/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also
supports an efficient way to get the best path (see the function
BestPathEnd()), which is useful in endpointing and in situations where you
might want to frequently access the best path.
This is only templated on the FST type, since the Token type is required to
be BackpointerToken. Actually it only makes sense to instantiate
LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via
this child class.
*/
template <typename FST>
class LatticeFasterOnlineDecoderTpl:
public LatticeFasterDecoderTpl<FST, decoder::BackpointerToken> {
public:
using Arc = typename FST::Arc;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Token = decoder::BackpointerToken;
using ForwardLinkT = decoder::ForwardLink<Token>;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterOnlineDecoderTpl(const FST &fst,
const LatticeFasterDecoderConfig &config):
LatticeFasterDecoderTpl<FST, Token>(fst, config) { }
// This version of the initializer takes ownership of 'fst', and will delete
// it when this object is destroyed.
LatticeFasterOnlineDecoderTpl(const LatticeFasterDecoderConfig &config,
FST *fst):
LatticeFasterDecoderTpl<FST, Token>(config, fst) { }
struct BestPathIterator {
void *tok;
int32 frame;
// note, "frame" is the frame-index of the frame you'll get the
// transition-id for next time, if you call TraceBackBestPath on this
// iterator (assuming it's not an epsilon transition). Note that this
// is one less than you might reasonably expect, e.g. it's -1 for
// the nonemitting transitions before the first frame.
BestPathIterator(void *t, int32 f): tok(t), frame(f) { }
bool Done() const { return tok == NULL; }
};
/// Outputs an FST corresponding to the single best path through the lattice.
/// This is quite efficient because it doesn't get the entire raw lattice and find
/// the best path through it; instead, it uses the BestPathEnd and BestPathIterator
/// so it basically traces it back through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
bool GetBestPath(Lattice *ofst,
bool use_final_probs = true) const;
/// This function does a self-test of GetBestPath(). Returns true on
/// success; returns false and prints a warning on failure.
bool TestGetBestPath(bool use_final_probs = true) const;
/// This function returns an iterator that can be used to trace back
/// the best path. If use_final_probs == true and at least one final state
/// survived till the end, it will use the final-probs in working out the best
/// final Token, and will output the final cost to *final_cost (if non-NULL),
/// else it will use only the forward likelihood, and will put zero in
/// *final_cost (if non-NULL).
/// Requires that NumFramesDecoded() > 0.
BestPathIterator BestPathEnd(bool use_final_probs,
BaseFloat *final_cost = NULL) const;
/// This function can be used in conjunction with BestPathEnd() to trace back
/// the best path one link at a time (e.g. this can be useful in endpoint
/// detection). By "link" we mean a link in the graph; not all links cross
/// frame boundaries, but each time you see a nonzero ilabel you can interpret
/// that as a frame. The return value is the updated iterator. It outputs
/// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer,
/// while leaving its "nextstate" variable unchanged.
BestPathIterator TraceBackBestPath(
BestPathIterator iter, LatticeArc *arc) const;
/// Behaves the same as GetRawLattice but only processes tokens whose
/// extra_cost is smaller than the best-cost plus the specified beam.
/// It is only worthwhile to call this function if beam is less than
/// the lattice_beam specified in the config; otherwise, it would
/// return essentially the same thing as GetRawLattice, but more slowly.
bool GetRawLatticePruned(Lattice *ofst,
bool use_final_probs,
BaseFloat beam) const;
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoderTpl);
};
typedef LatticeFasterOnlineDecoderTpl<fst::StdFst> LatticeFasterOnlineDecoder;
} // end namespace kaldi.
#endif
// lat/determinize-lattice-pruned-test.cc
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/determinize-lattice-pruned.h"
#include "fstext/lattice-utils.h"
#include "fstext/fst-test-utils.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
namespace fst {
// Caution: these tests are not as generic as you might think from all the
// templates in the code. They are basically only valid for LatticeArc.
// This is partly due to the fact that certain templates need to be instantiated
// in other .cc files in this directory.
// test that determinization proceeds correctly on general
// FSTs (not guaranteed determinzable, but we use the
// max-states option to stop it getting out of control).
template<class Arc> void TestDeterminizeLatticePruned() {
typedef kaldi::int32 Int;
typedef typename Arc::Weight Weight;
typedef ArcTpl<CompactLatticeWeightTpl<Weight, Int> > CompactArc;
for(int i = 0; i < 100; i++) {
RandFstOptions opts;
opts.n_states = 4;
opts.n_arcs = 10;
opts.n_final = 2;
opts.allow_empty = false;
opts.weight_multiplier = 0.5; // impt for the randomly generated weights
opts.acyclic = true;
// to be exactly representable in float,
// or this test fails because numerical differences can cause symmetry in
// weights to be broken, which causes the wrong path to be chosen as far
// as the string part is concerned.
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
bool sorted = TopSort(fst);
KALDI_ASSERT(sorted);
ILabelCompare<Arc> ilabel_comp;
if (kaldi::Rand() % 2 == 0)
ArcSort(fst, ilabel_comp);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> det_fst;
try {
DeterminizeLatticePrunedOptions lat_opts;
lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000);
lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20);
lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30);
bool ans = DeterminizeLatticePruned<Weight>(*fst, 10.0, &det_fst, lat_opts);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic);
// OK, now determinize it a different way and check equivalence.
// [note: it's not normal determinization, it's taking the best path
// for any input-symbol sequence....
VectorFst<Arc> pruned_fst(*fst);
if (pruned_fst.NumStates() != 0)
kaldi::PruneLattice(10.0, &pruned_fst);
VectorFst<CompactArc> compact_pruned_fst, compact_pruned_det_fst;
ConvertLattice<Weight, Int>(pruned_fst, &compact_pruned_fst, false);
std::cout << "Compact pruned FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
ConvertLattice<Weight, Int>(det_fst, &compact_pruned_det_fst, false);
std::cout << "Compact version of determinized FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
if (ans)
KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/));
} catch (...) {
std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n";
}
delete fst;
}
}
// test that determinization proceeds without crash on acyclic FSTs
// (guaranteed determinizable in this sense).
template<class Arc> void TestDeterminizeLatticePruned2() {
typedef typename Arc::Weight Weight;
RandFstOptions opts;
opts.acyclic = true;
for(int i = 0; i < 100; i++) {
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> ofst;
DeterminizeLatticePruned<Weight>(*fst, 10.0, &ofst);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(ofst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
delete fst;
}
}
} // end namespace fst
int main() {
using namespace fst;
TestDeterminizeLatticePruned<kaldi::LatticeArc>();
TestDeterminizeLatticePruned2<kaldi::LatticeArc>();
std::cout << "Tests succeeded\n";
}
// lat/determinize-lattice-pruned.cc
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <climits>
#include "fstext/determinize-lattice.h" // for LatticeStringRepository
#include "fstext/fstext-utils.h"
#include "lat/lattice-functions.h" // for PruneLattice
#include "lat/minimize-lattice.h" // for minimization
#include "lat/push-lattice.h" // for minimization
#include "lat/determinize-lattice-pruned.h"
namespace fst {
using std::vector;
using std::pair;
using std::greater;
// class LatticeDeterminizerPruned is templated on the same types that
// CompactLatticeWeight is templated on: the base weight (Weight), typically
// LatticeWeightTpl<float> etc. but could also be e.g. TropicalWeight, and the
// IntType, typically int32, used for the output symbols in the compact
// representation of strings [note: the output symbols would usually be
// p.d.f. id's in the anticipated use of this code] It has a special requirement
// on the Weight type: that there should be a Compare function on the weights
// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 >
// w2. This requires that there be a total order on the weights.
template<class Weight, class IntType> class LatticeDeterminizerPruned {
public:
// Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 correspondence
// between our states and the states in ofst. If destroy == true, release memory as we go
// (but we cannot output again).
typedef CompactLatticeWeightTpl<Weight, IntType> CompactWeight;
typedef ArcTpl<CompactWeight> CompactArc; // arc in compact, acceptor form of lattice
typedef ArcTpl<Weight> Arc; // arc in non-compact version of lattice
// Output to standard FST with CompactWeightTpl<Weight> as its weight type (the
// weight stores the original output-symbol strings). If destroy == true,
// release memory as we go (but we cannot output again).
void Output(MutableFst<CompactArc> *ofst, bool destroy = true) {
KALDI_ASSERT(determinized_);
typedef typename Arc::StateId StateId;
StateId nStates = static_cast<StateId>(output_states_.size());
if (destroy)
FreeMostMemory();
ofst->DeleteStates();
ofst->SetStart(kNoStateId);
if (nStates == 0) {
return;
}
for (StateId s = 0;s < nStates;s++) {
OutputStateId news = ofst->AddState();
KALDI_ASSERT(news == s);
}
ofst->SetStart(0);
// now process transitions.
for (StateId this_state_id = 0; this_state_id < nStates; this_state_id++) {
OutputState &this_state = *(output_states_[this_state_id]);
vector<TempArc> &this_vec(this_state.arcs);
typename vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end();
for (;iter != end; ++iter) {
const TempArc &temp_arc(*iter);
CompactArc new_arc;
vector<Label> olabel_seq;
repository_.ConvertToVector(temp_arc.string, &olabel_seq);
CompactWeight weight(temp_arc.weight, olabel_seq);
if (temp_arc.nextstate == kNoStateId) { // is really final weight.
ofst->SetFinal(this_state_id, weight);
} else { // is really an arc.
new_arc.nextstate = temp_arc.nextstate;
new_arc.ilabel = temp_arc.ilabel;
new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
new_arc.weight = weight; // includes string and weight.
ofst->AddArc(this_state_id, new_arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating memory,
// and we want to reduce the maximum amount ever allocated.
if (destroy) { vector<TempArc> temp; temp.swap(this_vec); }
}
if (destroy) {
FreeOutputStates();
repository_.Destroy();
}
}
// Output to standard FST with Weight as its weight type. We will create extra
// states to handle sequences of symbols on the output. If destroy == true,
// release memory as we go (but we cannot output again).
void Output(MutableFst<Arc> *ofst, bool destroy = true) {
// Outputs to standard fst.
OutputStateId nStates = static_cast<OutputStateId>(output_states_.size());
ofst->DeleteStates();
if (nStates == 0) {
ofst->SetStart(kNoStateId);
return;
}
if (destroy)
FreeMostMemory();
// Add basic states-- but we will add extra ones to account for strings on output.
for (OutputStateId s = 0; s< nStates;s++) {
OutputStateId news = ofst->AddState();
KALDI_ASSERT(news == s);
}
ofst->SetStart(0);
for (OutputStateId this_state_id = 0; this_state_id < nStates; this_state_id++) {
OutputState &this_state = *(output_states_[this_state_id]);
vector<TempArc> &this_vec(this_state.arcs);
typename vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end();
for (; iter != end; ++iter) {
const TempArc &temp_arc(*iter);
vector<Label> seq;
repository_.ConvertToVector(temp_arc.string, &seq);
if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
// Make a sequence of states going to a final state, with the strings
// as labels. Put the weight on the first arc.
OutputStateId cur_state = this_state_id;
for (size_t i = 0; i < seq.size(); i++) {
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = 0; // epsilon.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
ofst->SetFinal(cur_state, (seq.size() == 0 ? temp_arc.weight : Weight::One()));
} else { // Really an arc.
OutputStateId cur_state = this_state_id;
// Have to be careful with this integer comparison (i+1 < seq.size()) because unsigned.
// i < seq.size()-1 could fail for zero-length sequences.
for (size_t i = 0; i+1 < seq.size();i++) {
// for all but the last element of seq, create new state.
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = (i == 0 ? temp_arc.ilabel : 0); // put ilabel on first element of seq.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
// Add the final arc in the sequence.
Arc arc;
arc.nextstate = temp_arc.nextstate;
arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
arc.olabel = (seq.size() > 0 ? seq.back() : 0);
ofst->AddArc(cur_state, arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating memory
if (destroy) { vector<TempArc> temp; temp.swap(this_vec); }
}
if (destroy) {
FreeOutputStates();
repository_.Destroy();
}
}
// Initializer. After initializing the object you will typically
// call Determinize() and then call one of the Output functions.
// Note: ifst.Copy() will generally do a
// shallow copy. We do it like this for memory safety, rather than
// keeping a reference or pointer to ifst_.
LatticeDeterminizerPruned(const ExpandedFst<Arc> &ifst,
double beam,
DeterminizeLatticePrunedOptions opts):
num_arcs_(0), num_elems_(0), ifst_(ifst.Copy()), beam_(beam), opts_(opts),
equal_(opts_.delta), determinized_(false),
minimal_hash_(3, hasher_, equal_), initial_hash_(3, hasher_, equal_) {
KALDI_ASSERT(Weight::Properties() & kIdempotent); // this algorithm won't
// work correctly otherwise.
}
void FreeOutputStates() {
for (size_t i = 0; i < output_states_.size(); i++)
delete output_states_[i];
vector<OutputState*> temp;
temp.swap(output_states_);
}
// frees all memory except the info (in output_states_[ ]->arcs)
// that we need to output the FST.
void FreeMostMemory() {
if (ifst_) {
delete ifst_;
ifst_ = NULL;
}
{ MinimalSubsetHash tmp; tmp.swap(minimal_hash_); }
for (size_t i = 0; i < output_states_.size(); i++) {
vector<Element> empty_subset;
empty_subset.swap(output_states_[i]->minimal_subset);
}
for (typename InitialSubsetHash::iterator iter = initial_hash_.begin();
iter != initial_hash_.end(); ++iter)
delete iter->first;
{ InitialSubsetHash tmp; tmp.swap(initial_hash_); }
{ vector<char> tmp; tmp.swap(isymbol_or_final_); }
{ // Free up the queue. I'm not sure how to make sure all
// the memory is really freed (no swap() function)... doesn't really
// matter much though.
while (!queue_.empty()) {
Task *t = queue_.top();
delete t;
queue_.pop();
}
}
{ vector<pair<Label, Element> > tmp; tmp.swap(all_elems_tmp_); }
}
~LatticeDeterminizerPruned() {
FreeMostMemory();
FreeOutputStates();
// rest is deleted by destructors.
}
void RebuildRepository() { // rebuild the string repository,
// freeing stuff we don't need.. we call this when memory usage
// passes a supplied threshold. We need to accumulate all the
// strings we need the repository to "remember", then tell it
// to clean the repository.
std::vector<StringId> needed_strings;
for (size_t i = 0; i < output_states_.size(); i++) {
AddStrings(output_states_[i]->minimal_subset, &needed_strings);
for (size_t j = 0; j < output_states_[i]->arcs.size(); j++)
needed_strings.push_back(output_states_[i]->arcs[j].string);
}
{ // the queue doesn't allow us access to the underlying vector,
// so we have to resort to a temporary collection.
std::vector<Task*> tasks;
while (!queue_.empty()) {
Task *task = queue_.top();
queue_.pop();
tasks.push_back(task);
AddStrings(task->subset, &needed_strings);
}
for (size_t i = 0; i < tasks.size(); i++)
queue_.push(tasks[i]);
}
// the following loop covers strings present in initial_hash_.
for (typename InitialSubsetHash::const_iterator
iter = initial_hash_.begin();
iter != initial_hash_.end(); ++iter) {
const vector<Element> &vec = *(iter->first);
Element elem = iter->second;
AddStrings(vec, &needed_strings);
needed_strings.push_back(elem.string);
}
std::sort(needed_strings.begin(), needed_strings.end());
needed_strings.erase(std::unique(needed_strings.begin(),
needed_strings.end()),
needed_strings.end()); // uniq the strings.
KALDI_LOG << "Rebuilding repository.";
repository_.Rebuild(needed_strings);
}
bool CheckMemoryUsage() {
int32 repo_size = repository_.MemSize(),
arcs_size = num_arcs_ * sizeof(TempArc),
elems_size = num_elems_ * sizeof(Element),
total_size = repo_size + arcs_size + elems_size;
if (opts_.max_mem > 0 && total_size > opts_.max_mem) { // We passed the memory threshold.
// This is usually due to the repository getting large, so we
// clean this out.
RebuildRepository();
int32 new_repo_size = repository_.MemSize(),
new_total_size = new_repo_size + arcs_size + elems_size;
KALDI_VLOG(2) << "Rebuilt repository in determinize-lattice: repository shrank from "
<< repo_size << " to " << new_repo_size << " bytes (approximately)";
if (new_total_size > static_cast<int32>(opts_.max_mem * 0.8)) {
// Rebuilding didn't help enough-- we need a margin to stop
// having to rebuild too often. We'll just return to the user at
// this point, with a partial lattice that's pruned tighter than
// the specified beam. Here we figure out what the effective
// beam was.
double effective_beam = beam_;
if (!queue_.empty()) { // Note: queue should probably not be empty; we're
// just being paranoid here.
Task *task = queue_.top();
double total_weight = backward_costs_[ifst_->Start()]; // best weight of FST.
effective_beam = task->priority_cost - total_weight;
}
KALDI_WARN << "Did not reach requested beam in determinize-lattice: "
<< "size exceeds maximum " << opts_.max_mem
<< " bytes; (repo,arcs,elems) = (" << repo_size << ","
<< arcs_size << "," << elems_size
<< "), after rebuilding, repo size was " << new_repo_size
<< ", effective beam was " << effective_beam
<< " vs. requested beam " << beam_;
return false;
}
}
return true;
}
bool Determinize(double *effective_beam) {
KALDI_ASSERT(!determinized_);
// This determinizes the input fst but leaves it in the "special format"
// in "output_arcs_". Must be called after Initialize(). To get the
// output, call one of the Output routines.
InitializeDeterminization(); // some start-up tasks.
while (!queue_.empty()) {
Task *task = queue_.top();
// Note: the queue contains only tasks that are "within the beam".
// We also have to check whether we have reached one of the user-specified
// maximums, of estimated memory, arcs, or states. The condition for
// ending is:
// num-states is more than user specified, OR
// num-arcs is more than user specified, OR
// memory passed a user-specified threshold and cleanup failed
// to get it below that threshold.
size_t num_states = output_states_.size();
if ((opts_.max_states > 0 && num_states > opts_.max_states) ||
(opts_.max_arcs > 0 && num_arcs_ > opts_.max_arcs) ||
(num_states % 10 == 0 && !CheckMemoryUsage())) { // note: at some point
// it was num_states % 100, not num_states % 10, but I encountered an example
// where memory was exhausted before we reached state #100.
KALDI_VLOG(1) << "Lattice determinization terminated but not "
<< " because of lattice-beam. (#states, #arcs) is ( "
<< output_states_.size() << ", " << num_arcs_
<< " ), versus limits ( " << opts_.max_states << ", "
<< opts_.max_arcs << " ) (else, may be memory limit).";
break;
// we terminate the determinization here-- whatever we already expanded is
// what we'll return... because we expanded stuff in order of total
// (forward-backward) weight, the stuff we returned first is the most
// important.
}
queue_.pop();
ProcessTransition(task->state, task->label, &(task->subset));
delete task;
}
determinized_ = true;
if (effective_beam != NULL) {
if (queue_.empty()) *effective_beam = beam_;
else
*effective_beam = queue_.top()->priority_cost -
backward_costs_[ifst_->Start()];
}
return (queue_.empty()); // return success if queue was empty, i.e. we processed
// all tasks and did not break out of the loop early due to reaching a memory,
// arc or state limit.
}
private:
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId; // use this when we don't know if it's input or output.
typedef typename Arc::StateId InputStateId; // state in the input FST.
typedef typename Arc::StateId OutputStateId; // same as above but distinguish
// states in output Fst.
typedef LatticeStringRepository<IntType> StringRepositoryType;
typedef const typename StringRepositoryType::Entry* StringId;
// Element of a subset [of original states]
struct Element {
StateId state; // use StateId as this is usually InputStateId but in one case
// OutputStateId.
StringId string;
Weight weight;
bool operator != (const Element &other) const {
return (state != other.state || string != other.string ||
weight != other.weight);
}
// This operator is only intended for the priority_queue in the function
// EpsilonClosure().
bool operator > (const Element &other) const {
return state > other.state;
}
// This operator is only intended to support sorting in EpsilonClosure()
bool operator < (const Element &other) const {
return state < other.state;
}
};
// Arcs in the format we temporarily create in this class (a representation, essentially of
// a Gallic Fst).
struct TempArc {
Label ilabel;
StringId string; // Look it up in the StringRepository, it's a sequence of Labels.
OutputStateId nextstate; // or kNoState for final weights.
Weight weight;
};
// Hashing function used in hash of subsets.
// A subset is a pointer to vector<Element>.
// The Elements are in sorted order on state id, and without repeated states.
// Because the order of Elements is fixed, we can use a hashing function that is
// order-dependent. However the weights are not included in the hashing function--
// we hash subsets that differ only in weight to the same key. This is not optimal
// in terms of the O(N) performance but typically if we have a lot of determinized
// states that differ only in weight then the input probably was pathological in some way,
// or even non-determinizable.
// We don't quantize the weights, in order to avoid inexactness in simple cases.
// Instead we apply the delta when comparing subsets for equality, and allow a small
// difference.
class SubsetKey {
public:
size_t operator ()(const vector<Element> * subset) const { // hashes only the state and string.
size_t hash = 0, factor = 1;
for (typename vector<Element>::const_iterator iter= subset->begin(); iter != subset->end(); ++iter) {
hash *= factor;
hash += iter->state + reinterpret_cast<size_t>(iter->string);
factor *= 23531; // these numbers are primes.
}
return hash;
}
};
// This is the equality operator on subsets. It checks for exact match on state-id
// and string, and approximate match on weights.
class SubsetEqual {
public:
bool operator ()(const vector<Element> * s1, const vector<Element> * s2) const {
size_t sz = s1->size();
KALDI_ASSERT(sz>=0);
if (sz != s2->size()) return false;
typename vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(), iter2=s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state ||
iter1->string != iter2->string ||
! ApproxEqual(iter1->weight, iter2->weight, delta_)) return false;
}
return true;
}
float delta_;
SubsetEqual(float delta): delta_(delta) {}
SubsetEqual(): delta_(kDelta) {}
};
// Operator that says whether two Elements have the same states.
// Used only for debug.
class SubsetEqualStates {
public:
bool operator ()(const vector<Element> * s1, const vector<Element> * s2) const {
size_t sz = s1->size();
KALDI_ASSERT(sz>=0);
if (sz != s2->size()) return false;
typename vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(), iter2=s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state) return false;
}
return true;
}
};
// Define the hash type we use to map subsets (in minimal
// representation) to OutputStateId.
typedef unordered_map<const vector<Element>*, OutputStateId,
SubsetKey, SubsetEqual> MinimalSubsetHash;
// Define the hash type we use to map subsets (in initial
// representation) to OutputStateId, together with an
// extra weight. [note: we interpret the Element.state in here
// as an OutputStateId even though it's declared as InputStateId;
// these types are the same anyway].
typedef unordered_map<const vector<Element>*, Element,
SubsetKey, SubsetEqual> InitialSubsetHash;
// converts the representation of the subset from canonical (all states) to
// minimal (only states with output symbols on arcs leaving them, and final
// states). Output is not necessarily normalized, even if input_subset was.
void ConvertToMinimal(vector<Element> *subset) {
KALDI_ASSERT(!subset->empty());
typename vector<Element>::iterator cur_in = subset->begin(),
cur_out = subset->begin(), end = subset->end();
while (cur_in != end) {
if(IsIsymbolOrFinal(cur_in->state)) { // keep it...
*cur_out = *cur_in;
cur_out++;
}
cur_in++;
}
subset->resize(cur_out - subset->begin());
}
// Takes a minimal, normalized subset, and converts it to an OutputStateId.
// Involves a hash lookup, and possibly adding a new OutputStateId.
// If it creates a new OutputStateId, it creates a new record for it, works
// out its final-weight, and puts stuff on the queue relating to its
// transitions.
OutputStateId MinimalToStateId(const vector<Element> &subset,
const double forward_cost) {
typename MinimalSubsetHash::const_iterator iter
= minimal_hash_.find(&subset);
if (iter != minimal_hash_.end()) { // Found a matching subset.
OutputStateId state_id = iter->second;
const OutputState &state = *(output_states_[state_id]);
// Below is just a check that the algorithm is working...
if (forward_cost < state.forward_cost - 0.1) {
// for large weights, this check could fail due to roundoff.
KALDI_WARN << "New cost is less (check the difference is small) "
<< forward_cost << ", "
<< state.forward_cost;
}
return state_id;
}
OutputStateId state_id = static_cast<OutputStateId>(output_states_.size());
OutputState *new_state = new OutputState(subset, forward_cost);
minimal_hash_[&(new_state->minimal_subset)] = state_id;
output_states_.push_back(new_state);
num_elems_ += subset.size();
// Note: in the previous algorithm, we pushed the new state-id onto the queue
// at this point. Here, the queue happens elsewhere, and we directly process
// the state (which result in stuff getting added to the queue).
ProcessFinal(state_id); // will work out the final-prob.
ProcessTransitions(state_id); // will process transitions and add stuff to the queue.
return state_id;
}
// Given a normalized initial subset of elements (i.e. before epsilon closure),
// compute the corresponding output-state.
OutputStateId InitialToStateId(const vector<Element> &subset_in,
double forward_cost,
Weight *remaining_weight,
StringId *common_prefix) {
typename InitialSubsetHash::const_iterator iter
= initial_hash_.find(&subset_in);
if (iter != initial_hash_.end()) { // Found a matching subset.
const Element &elem = iter->second;
*remaining_weight = elem.weight;
*common_prefix = elem.string;
if (elem.weight == Weight::Zero())
KALDI_WARN << "Zero weight!";
return elem.state;
}
// else no matching subset-- have to work it out.
vector<Element> subset(subset_in);
// Follow through epsilons. Will add no duplicate states. note: after
// EpsilonClosure, it is the same as "canonical" subset, except not
// normalized (actually we never compute the normalized canonical subset,
// only the normalized minimal one).
EpsilonClosure(&subset); // follow epsilons.
ConvertToMinimal(&subset); // remove all but emitting and final states.
Element elem; // will be used to store remaining weight and string, and
// OutputStateId, in initial_hash_;
NormalizeSubset(&subset, &elem.weight, &elem.string); // normalize subset; put
// common string and weight in "elem". The subset is now a minimal,
// normalized subset.
forward_cost += ConvertToCost(elem.weight);
OutputStateId ans = MinimalToStateId(subset, forward_cost);
*remaining_weight = elem.weight;
*common_prefix = elem.string;
if (elem.weight == Weight::Zero())
KALDI_WARN << "Zero weight!";
// Before returning "ans", add the initial subset to the hash,
// so that we can bypass the epsilon-closure etc., next time
// we process the same initial subset.
vector<Element> *initial_subset_ptr = new vector<Element>(subset_in);
elem.state = ans;
initial_hash_[initial_subset_ptr] = elem;
num_elems_ += initial_subset_ptr->size(); // keep track of memory usage.
return ans;
}
// returns the Compare value (-1 if a < b, 0 if a == b, 1 if a > b) according
// to the ordering we defined on strings for the CompactLatticeWeightTpl.
// see function
// inline int Compare (const CompactLatticeWeightTpl<WeightType,IntType> &w1,
// const CompactLatticeWeightTpl<WeightType,IntType> &w2)
// in lattice-weight.h.
// this is the same as that, but optimized for our data structures.
inline int Compare(const Weight &a_w, StringId a_str,
const Weight &b_w, StringId b_str) const {
int weight_comp = fst::Compare(a_w, b_w);
if (weight_comp != 0) return weight_comp;
// now comparing strings.
if (a_str == b_str) return 0;
vector<IntType> a_vec, b_vec;
repository_.ConvertToVector(a_str, &a_vec);
repository_.ConvertToVector(b_str, &b_vec);
// First compare their lengths.
int a_len = a_vec.size(), b_len = b_vec.size();
// use opposite order on the string lengths (c.f. Compare in
// lattice-weight.h)
if (a_len > b_len) return -1;
else if (a_len < b_len) return 1;
for(int i = 0; i < a_len; i++) {
if (a_vec[i] < b_vec[i]) return -1;
else if (a_vec[i] > b_vec[i]) return 1;
}
KALDI_ASSERT(0); // because we checked if a_str == b_str above, shouldn't reach here
return 0;
}
// This function computes epsilon closure of subset of states by following epsilon links.
// Called by InitialToStateId and Initialize.
// Has no side effects except on the string repository. The "output_subset" is not
// necessarily normalized (in the sense of there being no common substring), unless
// input_subset was.
void EpsilonClosure(vector<Element> *subset) {
// at input, subset must have only one example of each StateId. [will still
// be so at output]. This function follows input-epsilons, and augments the
// subset accordingly.
std::priority_queue<Element, vector<Element>, greater<Element> > queue;
unordered_map<InputStateId, Element> cur_subset;
typedef typename unordered_map<InputStateId, Element>::iterator MapIter;
typedef typename vector<Element>::const_iterator VecIter;
for (VecIter iter = subset->begin(); iter != subset->end(); ++iter) {
queue.push(*iter);
cur_subset[iter->state] = *iter;
}
// find whether input fst is known to be sorted on input label.
bool sorted = ((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
bool replaced_elems = false; // relates to an optimization, see below.
int counter = 0; // stops infinite loops here for non-lattice-determinizable input
// (e.g. input with negative-cost epsilon loops); useful in testing.
while (queue.size() != 0) {
Element elem = queue.top();
queue.pop();
// The next if-statement is a kind of optimization. It's to prevent us
// unnecessarily repeating the processing of a state. "cur_subset" always
// contains only one Element with a particular state. The issue is that
// whenever we modify the Element corresponding to that state in "cur_subset",
// both the new (optimal) and old (less-optimal) Element will still be in
// "queue". The next if-statement stops us from wasting compute by
// processing the old Element.
if (replaced_elems && cur_subset[elem.state] != elem)
continue;
if (opts_.max_loop > 0 && counter++ > opts_.max_loop) {
KALDI_ERR << "Lattice determinization aborted since looped more than "
<< opts_.max_loop << " times during epsilon closure.";
}
for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, elem.state); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (sorted && arc.ilabel != 0) break; // Break from the loop: due to sorting there will be no
// more transitions with epsilons as input labels.
if (arc.ilabel == 0
&& arc.weight != Weight::Zero()) { // Epsilon transition.
Element next_elem;
next_elem.state = arc.nextstate;
next_elem.weight = Times(elem.weight, arc.weight);
// next_elem.string is not set up yet... create it only
// when we know we need it (this is an optimization)
MapIter iter = cur_subset.find(next_elem.state);
if (iter == cur_subset.end()) {
// was no such StateId: insert and add to queue.
next_elem.string = (arc.olabel == 0 ? elem.string :
repository_.Successor(elem.string, arc.olabel));
cur_subset[next_elem.state] = next_elem;
queue.push(next_elem);
} else {
// was not inserted because one already there. In normal
// determinization we'd add the weights. Here, we find which one
// has the better weight, and keep its corresponding string.
int comp = fst::Compare(next_elem.weight, iter->second.weight);
if (comp == 0) { // A tie on weights. This should be a rare case;
// we don't optimize for it.
next_elem.string = (arc.olabel == 0 ? elem.string :
repository_.Successor(elem.string,
arc.olabel));
comp = Compare(next_elem.weight, next_elem.string,
iter->second.weight, iter->second.string);
}
if(comp == 1) { // next_elem is better, so use its (weight, string)
next_elem.string = (arc.olabel == 0 ? elem.string :
repository_.Successor(elem.string, arc.olabel));
iter->second.string = next_elem.string;
iter->second.weight = next_elem.weight;
queue.push(next_elem);
replaced_elems = true;
}
// else it is the same or worse, so use original one.
}
}
}
}
{ // copy cur_subset to subset.
subset->clear();
subset->reserve(cur_subset.size());
MapIter iter = cur_subset.begin(), end = cur_subset.end();
for (; iter != end; ++iter) subset->push_back(iter->second);
// sort by state ID, because the subset hash function is order-dependent(see SubsetKey)
std::sort(subset->begin(), subset->end());
}
}
// This function works out the final-weight of the determinized state.
// called by ProcessSubset.
// Has no side effects except on the variable repository_, and
// output_states_[output_state_id].arcs
void ProcessFinal(OutputStateId output_state_id) {
OutputState &state = *(output_states_[output_state_id]);
const vector<Element> &minimal_subset = state.minimal_subset;
// processes final-weights for this subset. state.minimal_subset_ may be
// empty if the graphs is not connected/trimmed, I think, do don't check
// that it's nonempty.
StringId final_string = repository_.EmptyString(); // set it to keep the
// compiler happy; if it doesn't get set in the loop, we won't use the value anyway.
Weight final_weight = Weight::Zero();
bool is_final = false;
typename vector<Element>::const_iterator iter = minimal_subset.begin(), end = minimal_subset.end();
for (; iter != end; ++iter) {
const Element &elem = *iter;
Weight this_final_weight = Times(elem.weight, ifst_->Final(elem.state));
StringId this_final_string = elem.string;
if (this_final_weight != Weight::Zero() &&
(!is_final || Compare(this_final_weight, this_final_string,
final_weight, final_string) == 1)) { // the new
// (weight, string) pair is more in semiring than our current
// one.
is_final = true;
final_weight = this_final_weight;
final_string = this_final_string;
}
}
if (is_final &&
ConvertToCost(final_weight) + state.forward_cost <= cutoff_) {
// store final weights in TempArc structure, just like a transition.
// Note: we only store the final-weight if it's inside the pruning beam, hence
// the stuff with Compare.
TempArc temp_arc;
temp_arc.ilabel = 0;
temp_arc.nextstate = kNoStateId; // special marker meaning "final weight".
temp_arc.string = final_string;
temp_arc.weight = final_weight;
state.arcs.push_back(temp_arc);
num_arcs_++;
}
}
// NormalizeSubset normalizes the subset "elems" by
// removing any common string prefix (putting it in common_str),
// and dividing by the total weight (putting it in tot_weight).
void NormalizeSubset(vector<Element> *elems,
Weight *tot_weight,
StringId *common_str) {
if(elems->empty()) { // just set common_str, tot_weight
// to defaults and return...
KALDI_WARN << "empty subset";
*common_str = repository_.EmptyString();
*tot_weight = Weight::Zero();
return;
}
size_t size = elems->size();
vector<IntType> common_prefix;
repository_.ConvertToVector((*elems)[0].string, &common_prefix);
Weight weight = (*elems)[0].weight;
for(size_t i = 1; i < size; i++) {
weight = Plus(weight, (*elems)[i].weight);
repository_.ReduceToCommonPrefix((*elems)[i].string, &common_prefix);
}
KALDI_ASSERT(weight != Weight::Zero()); // we made sure to ignore arcs with zero
// weights on them, so we shouldn't have zero here.
size_t prefix_len = common_prefix.size();
for(size_t i = 0; i < size; i++) {
(*elems)[i].weight = Divide((*elems)[i].weight, weight, DIVIDE_LEFT);
(*elems)[i].string =
repository_.RemovePrefix((*elems)[i].string, prefix_len);
}
*common_str = repository_.ConvertFromVector(common_prefix);
*tot_weight = weight;
}
// Take a subset of Elements that is sorted on state, and
// merge any Elements that have the same state (taking the best
// (weight, string) pair in the semiring).
void MakeSubsetUnique(vector<Element> *subset) {
typedef typename vector<Element>::iterator IterType;
// This KALDI_ASSERT is designed to fail (usually) if the subset is not sorted on
// state.
KALDI_ASSERT(subset->size() < 2 || (*subset)[0].state <= (*subset)[1].state);
IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
size_t num_out = 0;
// Merge elements with same state-id
while (cur_in != end) { // while we have more elements to process.
// At this point, cur_out points to location of next place we want to put an element,
// cur_in points to location of next element we want to process.
if (cur_in != cur_out) *cur_out = *cur_in;
cur_in++;
while (cur_in != end && cur_in->state == cur_out->state) {
if (Compare(cur_in->weight, cur_in->string,
cur_out->weight, cur_out->string) == 1) {
// if *cur_in > *cur_out in semiring, then take *cur_in.
cur_out->string = cur_in->string;
cur_out->weight = cur_in->weight;
}
cur_in++;
}
cur_out++;
num_out++;
}
subset->resize(num_out);
}
// ProcessTransition was called from "ProcessTransitions" in the non-pruned
// code, but now we in effect put the calls to ProcessTransition on a priority
// queue, and it now gets called directly from Determinize(). This function
// processes a transition from state "ostate_id". The set "subset" of Elements
// represents a set of next-states with associated weights and strings, each
// one arising from an arc from some state in a determinized-state; the
// next-states are unique (there is only one Entry assocated with each)
void ProcessTransition(OutputStateId ostate_id, Label ilabel, vector<Element> *subset) {
double forward_cost = output_states_[ostate_id]->forward_cost;
StringId common_str;
Weight tot_weight;
NormalizeSubset(subset, &tot_weight, &common_str);
forward_cost += ConvertToCost(tot_weight);
OutputStateId nextstate;
{
Weight next_tot_weight;
StringId next_common_str;
nextstate = InitialToStateId(*subset,
forward_cost,
&next_tot_weight,
&next_common_str);
common_str = repository_.Concatenate(common_str, next_common_str);
tot_weight = Times(tot_weight, next_tot_weight);
}
// Now add an arc to the next state (would have been created if necessary by
// InitialToStateId).
TempArc temp_arc;
temp_arc.ilabel = ilabel;
temp_arc.nextstate = nextstate;
temp_arc.string = common_str;
temp_arc.weight = tot_weight;
output_states_[ostate_id]->arcs.push_back(temp_arc); // record the arc.
num_arcs_++;
}
// "less than" operator for pair<Label, Element>. Used in ProcessTransitions.
// Lexicographical order, which only compares the state when ordering the
// "Element" member of the pair.
class PairComparator {
public:
inline bool operator () (const pair<Label, Element> &p1, const pair<Label, Element> &p2) {
if (p1.first < p2.first) return true;
else if (p1.first > p2.first) return false;
else {
return p1.second.state < p2.second.state;
}
}
};
// ProcessTransitions processes emitting transitions (transitions with
// ilabels) out of this subset of states. It actualy only creates records
// ("Task") that get added to the queue. The transitions will be processed in
// priority order from Determinize(). This function soes not consider final
// states. Partitions the emitting transitions up by ilabel (by sorting on
// ilabel), and for each unique ilabel, it creates a Task record that contains
// the information we need to process the transition.
void ProcessTransitions(OutputStateId output_state_id) {
const vector<Element> &minimal_subset = output_states_[output_state_id]->minimal_subset;
// it's possible that minimal_subset could be empty if there are
// unreachable parts of the graph, so don't check that it's nonempty.
vector<pair<Label, Element> > &all_elems(all_elems_tmp_); // use class member
// to avoid memory allocation/deallocation.
{
// Push back into "all_elems", elements corresponding to all
// non-epsilon-input transitions out of all states in "minimal_subset".
typename vector<Element>::const_iterator iter = minimal_subset.begin(), end = minimal_subset.end();
for (;iter != end; ++iter) {
const Element &elem = *iter;
for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, elem.state); ! aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0
&& arc.weight != Weight::Zero()) { // Non-epsilon transition -- ignore epsilons here.
pair<Label, Element> this_pr;
this_pr.first = arc.ilabel;
Element &next_elem(this_pr.second);
next_elem.state = arc.nextstate;
next_elem.weight = Times(elem.weight, arc.weight);
if (arc.olabel == 0) // output epsilon
next_elem.string = elem.string;
else
next_elem.string = repository_.Successor(elem.string, arc.olabel);
all_elems.push_back(this_pr);
}
}
}
}
PairComparator pc;
std::sort(all_elems.begin(), all_elems.end(), pc);
// now sorted first on input label, then on state.
typedef typename vector<pair<Label, Element> >::const_iterator PairIter;
PairIter cur = all_elems.begin(), end = all_elems.end();
while (cur != end) {
// The old code (non-pruned) called ProcessTransition; here, instead,
// we'll put the calls into a priority queue.
Task *task = new Task;
// Process ranges that share the same input symbol.
Label ilabel = cur->first;
task->state = output_state_id;
task->priority_cost = std::numeric_limits<double>::infinity();
task->label = ilabel;
while (cur != end && cur->first == ilabel) {
task->subset.push_back(cur->second);
const Element &element = cur->second;
// Note: we'll later include the term "forward_cost" in the
// priority_cost.
task->priority_cost = std::min(task->priority_cost,
ConvertToCost(element.weight) +
backward_costs_[element.state]);
cur++;
}
// After the command below, the "priority_cost" is a value comparable to
// the total-weight of the input FST, like a total-path weight... of
// course, it will typically be less (in the semiring) than that.
// note: we represent it just as a double.
task->priority_cost += output_states_[output_state_id]->forward_cost;
if (task->priority_cost > cutoff_) {
// This task would never get done as it's past the pruning cutoff.
delete task;
} else {
MakeSubsetUnique(&(task->subset)); // remove duplicate Elements with the same state.
queue_.push(task); // Push the task onto the queue. The queue keeps it
// in prioritized order, so we always process the one with the "best"
// weight (highest in the semiring).
{ // this is a check.
double best_cost = backward_costs_[ifst_->Start()],
tolerance = 0.01 + 1.0e-04 * std::abs(best_cost);
if (task->priority_cost < best_cost - tolerance) {
KALDI_WARN << "Cost below best cost was encountered:"
<< task->priority_cost << " < " << best_cost;
}
}
}
}
all_elems.clear(); // as it's a reference to a class variable; we want it to stay
// empty.
}
bool IsIsymbolOrFinal(InputStateId state) { // returns true if this state
// of the input FST either is final or has an osymbol on an arc out of it.
// Uses the vector isymbol_or_final_ as a cache for this info.
KALDI_ASSERT(state >= 0);
if (isymbol_or_final_.size() <= state)
isymbol_or_final_.resize(state+1, static_cast<char>(OSF_UNKNOWN));
if (isymbol_or_final_[state] == static_cast<char>(OSF_NO))
return false;
else if (isymbol_or_final_[state] == static_cast<char>(OSF_YES))
return true;
// else work it out...
isymbol_or_final_[state] = static_cast<char>(OSF_NO);
if (ifst_->Final(state) != Weight::Zero())
isymbol_or_final_[state] = static_cast<char>(OSF_YES);
for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0 && arc.weight != Weight::Zero()) {
isymbol_or_final_[state] = static_cast<char>(OSF_YES);
return true;
}
}
return IsIsymbolOrFinal(state); // will only recurse once.
}
void ComputeBackwardWeight() {
// Sets up the backward_costs_ array, and the cutoff_ variable.
KALDI_ASSERT(beam_ > 0);
// Only handle the toplogically sorted case.
backward_costs_.resize(ifst_->NumStates());
for (StateId s = ifst_->NumStates() - 1; s >= 0; s--) {
double &cost = backward_costs_[s];
cost = ConvertToCost(ifst_->Final(s));
for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, s);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
cost = std::min(cost,
ConvertToCost(arc.weight) + backward_costs_[arc.nextstate]);
}
}
if (ifst_->Start() == kNoStateId) return; // we'll be returning
// an empty FST.
double best_cost = backward_costs_[ifst_->Start()];
if (best_cost == std::numeric_limits<double>::infinity())
KALDI_WARN << "Total weight of input lattice is zero.";
cutoff_ = best_cost + beam_;
}
void InitializeDeterminization() {
// We insist that the input lattice be topologically sorted. This is not a
// fundamental limitation of the algorithm (which in principle should be
// applicable to even cyclic FSTs), but it helps us more efficiently
// compute the backward_costs_ array. There may be some other reason we
// require this, that escapes me at the moment.
KALDI_ASSERT(ifst_->Properties(kTopSorted, true) != 0);
ComputeBackwardWeight();
#if !(__GNUC__ == 4 && __GNUC_MINOR__ == 0)
if(ifst_->Properties(kExpanded, false) != 0) { // if we know the number of
// states in ifst_, it might be a bit more efficient
// to pre-size the hashes so we're not constantly rebuilding them.
StateId num_states =
down_cast<const ExpandedFst<Arc>*, const Fst<Arc> >(ifst_)->NumStates();
minimal_hash_.rehash(num_states/2 + 3);
initial_hash_.rehash(num_states/2 + 3);
}
#endif
InputStateId start_id = ifst_->Start();
if (start_id != kNoStateId) {
/* Create determinized-state corresponding to the start state....
Unlike all the other states, we don't "normalize" the representation
of this determinized-state before we put it into minimal_hash_. This is actually
what we want, as otherwise we'd have problems dealing with any extra weight
and string and might have to create a "super-initial" state which would make
the output nondeterministic. Normalization is only needed to make the
determinized output more minimal anyway, it's not needed for correctness.
Note, we don't put anything in the initial_hash_. The initial_hash_ is only
a lookaside buffer anyway, so this isn't a problem-- it will get populated
later if it needs to be.
*/
vector<Element> subset(1);
subset[0].state = start_id;
subset[0].weight = Weight::One();
subset[0].string = repository_.EmptyString(); // Id of empty sequence.
EpsilonClosure(&subset); // follow through epsilon-input links
ConvertToMinimal(&subset); // remove all but final states and
// states with input-labels on arcs out of them.
// Weight::One() is the "forward-weight" of this determinized state...
// i.e. the minimal cost from the start of the determinized FST to this
// state [One() because it's the start state].
OutputState *initial_state = new OutputState(subset, 0);
KALDI_ASSERT(output_states_.empty());
output_states_.push_back(initial_state);
num_elems_ += subset.size();
OutputStateId initial_state_id = 0;
minimal_hash_[&(initial_state->minimal_subset)] = initial_state_id;
ProcessFinal(initial_state_id);
ProcessTransitions(initial_state_id); // this will add tasks to
// the queue, which we'll start processing in Determinize().
}
}
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeDeterminizerPruned);
struct OutputState {
vector<Element> minimal_subset;
vector<TempArc> arcs; // arcs out of the state-- those that have been processed.
// Note: the final-weight is included here with kNoStateId as the state id. We
// always process the final-weight regardless of the beam; when producing the
// output we may have to ignore some of these.
double forward_cost; // Represents minimal cost from start-state
// to this state. Used in prioritization of tasks, and pruning.
// Note: we know this minimal cost from when we first create the OutputState;
// this is because of the priority-queue we use, that ensures that the
// "best" path into the state will be expanded first.
OutputState(const vector<Element> &minimal_subset,
double forward_cost): minimal_subset(minimal_subset),
forward_cost(forward_cost) { }
};
vector<OutputState*> output_states_; // All the info about the output states.
int num_arcs_; // keep track of memory usage: number of arcs in output_states_[ ]->arcs
int num_elems_; // keep track of memory usage: number of elems in output_states_ and
// the keys of initial_hash_
const ExpandedFst<Arc> *ifst_;
std::vector<double> backward_costs_; // This vector stores, for every state in ifst_,
// the minimal cost to the end-state (i.e. the sum of weights; they are guaranteed to
// have "take-the-minimum" semantics). We get the double from the ConvertToCost()
// function on the lattice weights.
double beam_;
double cutoff_; // beam plus total-weight of input (and note, the weight is
// guaranteed to be "tropical-like" so the sum does represent a min-cost.
DeterminizeLatticePrunedOptions opts_;
SubsetKey hasher_; // object that computes keys-- has no data members.
SubsetEqual equal_; // object that compares subsets-- only data member is delta_.
bool determinized_; // set to true when user called Determinize(); used to make
// sure this object is used correctly.
MinimalSubsetHash minimal_hash_; // hash from Subset to OutputStateId. Subset is "minimal
// representation" (only include final and states and states with
// nonzero ilabel on arc out of them. Owns the pointers
// in its keys.
InitialSubsetHash initial_hash_; // hash from Subset to Element, which
// represents the OutputStateId together
// with an extra weight and string. Subset
// is "initial representation". The extra
// weight and string is needed because after
// we convert to minimal representation and
// normalize, there may be an extra weight
// and string. Owns the pointers
// in its keys.
struct Task {
OutputStateId state; // State from which we're processing the transition.
Label label; // Label on the transition we're processing out of this state.
vector<Element> subset; // Weighted subset of states (with strings)-- not normalized.
double priority_cost; // Cost used in deciding priority of tasks. Note:
// we assume there is a ConvertToCost() function that converts the semiring to double.
};
struct TaskCompare {
inline int operator() (const Task *t1, const Task *t2) {
// view this like operator <, which is the default template parameter
// to std::priority_queue.
// returns true if t1 is worse than t2.
return (t1->priority_cost > t2->priority_cost);
}
};
// This priority queue contains "Task"s to be processed; these correspond
// to transitions out of determinized states. We process these in priority
// order according to the best weight of any path passing through these
// determinized states... it's possible to work this out.
std::priority_queue<Task*, vector<Task*>, TaskCompare> queue_;
vector<pair<Label, Element> > all_elems_tmp_; // temporary vector used in ProcessTransitions.
enum IsymbolOrFinal { OSF_UNKNOWN = 0, OSF_NO = 1, OSF_YES = 2 };
vector<char> isymbol_or_final_; // A kind of cache; it says whether
// each state is (emitting or final) where emitting means it has at least one
// non-epsilon output arc. Only accessed by IsIsymbolOrFinal()
LatticeStringRepository<IntType> repository_; // defines a compact and fast way of
// storing sequences of labels.
void AddStrings(const vector<Element> &vec,
vector<StringId> *needed_strings) {
for (typename std::vector<Element>::const_iterator iter = vec.begin();
iter != vec.end(); ++iter)
needed_strings->push_back(iter->string);
}
};
// normally Weight would be LatticeWeight<float> (which has two floats),
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
// Caution: there are two versions of the function DeterminizeLatticePruned,
// with identical code but different output FST types.
template<class Weight, class IntType>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> >&ifst,
double beam,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > >*ofst,
DeterminizeLatticePrunedOptions opts) {
ofst->SetInputSymbols(ifst.InputSymbols());
ofst->SetOutputSymbols(ifst.OutputSymbols());
if (ifst.NumStates() == 0) {
ofst->DeleteStates();
return true;
}
KALDI_ASSERT(opts.retry_cutoff >= 0.0 && opts.retry_cutoff < 1.0);
int32 max_num_iters = 10; // avoid the potential for infinite loops if
// retrying.
VectorFst<ArcTpl<Weight> > temp_fst;
for (int32 iter = 0; iter < max_num_iters; iter++) {
LatticeDeterminizerPruned<Weight, IntType> det(iter == 0 ? ifst : temp_fst,
beam, opts);
double effective_beam;
bool ans = det.Determinize(&effective_beam);
// if it returns false it will typically still produce reasonable output,
// just with a narrower beam than "beam". If the user specifies an infinite
// beam we don't do this beam-narrowing.
if (effective_beam >= beam * opts.retry_cutoff ||
beam == std::numeric_limits<double>::infinity() ||
iter + 1 == max_num_iters) {
det.Output(ofst);
return ans;
} else {
// The code below to set "beam" is a heuristic.
// If effective_beam is very small, we want to reduce by a lot.
// But never change the beam by more than a factor of two.
if (effective_beam < 0.0) effective_beam = 0.0;
double new_beam = beam * sqrt(effective_beam / beam);
if (new_beam < 0.5 * beam) new_beam = 0.5 * beam;
beam = new_beam;
if (iter == 0) temp_fst = ifst;
kaldi::PruneLattice(beam, &temp_fst);
KALDI_LOG << "Pruned state-level lattice with beam " << beam
<< " and retrying determinization with that beam.";
}
}
return false; // Suppress compiler warning; this code is unreachable.
}
// normally Weight would be LatticeWeight<float> (which has two floats),
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
// Caution: there are two versions of the function DeterminizeLatticePruned,
// with identical code but different output FST types.
template<class Weight>
bool DeterminizeLatticePruned(const ExpandedFst<ArcTpl<Weight> > &ifst,
double beam,
MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticePrunedOptions opts) {
typedef int32 IntType;
ofst->SetInputSymbols(ifst.InputSymbols());
ofst->SetOutputSymbols(ifst.OutputSymbols());
KALDI_ASSERT(opts.retry_cutoff >= 0.0 && opts.retry_cutoff < 1.0);
if (ifst.NumStates() == 0) {
ofst->DeleteStates();
return true;
}
int32 max_num_iters = 10; // avoid the potential for infinite loops if
// retrying.
VectorFst<ArcTpl<Weight> > temp_fst;
for (int32 iter = 0; iter < max_num_iters; iter++) {
LatticeDeterminizerPruned<Weight, IntType> det(iter == 0 ? ifst : temp_fst,
beam, opts);
double effective_beam;
bool ans = det.Determinize(&effective_beam);
// if it returns false it will typically still
// produce reasonable output, just with a
// narrower beam than "beam".
if (effective_beam >= beam * opts.retry_cutoff ||
iter + 1 == max_num_iters) {
det.Output(ofst);
return ans;
} else {
// The code below to set "beam" is a heuristic.
// If effective_beam is very small, we want to reduce by a lot.
// But never change the beam by more than a factor of two.
if (effective_beam < 0)
effective_beam = 0;
double new_beam = beam * sqrt(effective_beam / beam);
if (new_beam < 0.5 * beam) new_beam = 0.5 * beam;
KALDI_WARN << "Effective beam " << effective_beam << " was less than beam "
<< beam << " * cutoff " << opts.retry_cutoff << ", pruning raw "
<< "lattice with new beam " << new_beam << " and retrying.";
beam = new_beam;
if (iter == 0) temp_fst = ifst;
kaldi::PruneLattice(beam, &temp_fst);
}
}
return false; // Suppress compiler warning; this code is unreachable.
}
template<class Weight>
typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *fst) {
// Define some types.
typedef ArcTpl<Weight> Arc;
typedef typename Arc::StateId StateId;
typedef typename Arc::Label Label;
// Work out the first phone symbol. This is more related to the phone
// insertion function, so we put it here and make it the returning value of
// DeterminizeLatticeInsertPhones().
Label first_phone_label = HighestNumberedInputSymbol(*fst) + 1;
// Insert phones here.
for (StateIterator<MutableFst<Arc> > siter(*fst);
!siter.Done(); siter.Next()) {
StateId state = siter.Value();
if (state == fst->Start())
continue;
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
!aiter.Done(); aiter.Next()) {
Arc arc = aiter.Value();
// Note: the words are on the input symbol side and transition-id's are on
// the output symbol side.
if ((arc.olabel != 0)
&& (trans_model.TransitionIdIsStartOfPhone(arc.olabel))
&& (!trans_model.IsSelfLoop(arc.olabel))) {
Label phone =
static_cast<Label>(trans_model.TransitionIdToPhone(arc.olabel));
// Skips <eps>.
KALDI_ASSERT(phone != 0);
if (arc.ilabel == 0) {
// If there is no word on the arc, insert the phone directly.
arc.ilabel = first_phone_label + phone;
} else {
// Otherwise, add an additional arc.
StateId additional_state = fst->AddState();
StateId next_state = arc.nextstate;
arc.nextstate = additional_state;
fst->AddArc(additional_state,
Arc(first_phone_label + phone, 0,
Weight::One(), next_state));
}
}
aiter.SetValue(arc);
}
}
return first_phone_label;
}
template<class Weight>
void DeterminizeLatticeDeletePhones(
typename ArcTpl<Weight>::Label first_phone_label,
MutableFst<ArcTpl<Weight> > *fst) {
// Define some types.
typedef ArcTpl<Weight> Arc;
typedef typename Arc::StateId StateId;
typedef typename Arc::Label Label;
// Delete phones here.
for (StateIterator<MutableFst<Arc> > siter(*fst);
!siter.Done(); siter.Next()) {
StateId state = siter.Value();
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
!aiter.Done(); aiter.Next()) {
Arc arc = aiter.Value();
if (arc.ilabel >= first_phone_label)
arc.ilabel = 0;
aiter.SetValue(arc);
}
}
}
// instantiate for type LatticeWeight
template
void DeterminizeLatticeDeletePhones(
ArcTpl<kaldi::LatticeWeight>::Label first_phone_label,
MutableFst<ArcTpl<kaldi::LatticeWeight> > *fst);
/** This function does a first pass determinization with phone symbols inserted
at phone boundary. It uses a transition model to work out the transition-id
to phone map. First, phones will be inserted into the word level lattice.
Second, determinization will be applied on top of the phone + word lattice.
Finally, the inserted phones will be removed, converting the lattice back to
a word level lattice. The output lattice of this pass is not deterministic,
since we remove the phone symbols as a last step. It is supposed to be
followed by another pass of determinization at the word level. It could also
be useful for some other applications such as fMLLR estimation, confidence
estimation, discriminative training, etc.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePhonePrunedFirstPass(
const kaldi::TransitionInformation &trans_model,
double beam,
MutableFst<ArcTpl<Weight> > *fst,
const DeterminizeLatticePrunedOptions &opts) {
// First, insert the phones.
typename ArcTpl<Weight>::Label first_phone_label =
DeterminizeLatticeInsertPhones(trans_model, fst);
TopSort(fst);
// Second, do determinization with phone inserted.
bool ans = DeterminizeLatticePruned<Weight>(*fst, beam, fst, opts);
// Finally, remove the inserted phones.
DeterminizeLatticeDeletePhones(first_phone_label, fst);
TopSort(fst);
return ans;
}
// "Destructive" version of DeterminizeLatticePhonePruned() where the input
// lattice might be modified.
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *ifst,
double beam,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts) {
// Returning status.
bool ans = true;
// Make sure at least one of opts.phone_determinize and opts.word_determinize
// is not false, otherwise calling this function doesn't make any sense.
if ((opts.phone_determinize || opts.word_determinize) == false) {
KALDI_WARN << "Both --phone-determinize and --word-determinize are set to "
<< "false, copying lattice without determinization.";
// We are expecting the words on the input side.
ConvertLattice<Weight, IntType>(*ifst, ofst, false);
return ans;
}
// Determinization options.
DeterminizeLatticePrunedOptions det_opts;
det_opts.delta = opts.delta;
det_opts.max_mem = opts.max_mem;
// If --phone-determinize is true, do the determinization on phone + word
// lattices.
if (opts.phone_determinize) {
KALDI_VLOG(3) << "Doing first pass of determinization on phone + word "
<< "lattices.";
ans = DeterminizeLatticePhonePrunedFirstPass<Weight, IntType>(
trans_model, beam, ifst, det_opts) && ans;
// If --word-determinize is false, we've finished the job and return here.
if (!opts.word_determinize) {
// We are expecting the words on the input side.
ConvertLattice<Weight, IntType>(*ifst, ofst, false);
return ans;
}
}
// If --word-determinize is true, do the determinization on word lattices.
if (opts.word_determinize) {
KALDI_VLOG(3) << "Doing second pass of determinization on word lattices.";
ans = DeterminizeLatticePruned<Weight, IntType>(
*ifst, beam, ofst, det_opts) && ans;
}
// If --minimize is true, push and minimize after determinization.
if (opts.minimize) {
KALDI_VLOG(3) << "Pushing and minimizing on word lattices.";
ans = PushCompactLatticeStrings<Weight, IntType>(ofst) && ans;
ans = PushCompactLatticeWeights<Weight, IntType>(ofst) && ans;
ans = MinimizeCompactLattice<Weight, IntType>(ofst) && ans;
}
return ans;
}
// Normal verson of DeterminizeLatticePhonePruned(), where the input lattice
// will be kept as unchanged.
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
const ExpandedFst<ArcTpl<Weight> > &ifst,
double beam,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts) {
VectorFst<ArcTpl<Weight> > temp_fst(ifst);
return DeterminizeLatticePhonePruned(trans_model, &temp_fst,
beam, ofst, opts);
}
bool DeterminizeLatticePhonePrunedWrapper(
const kaldi::TransitionInformation &trans_model,
MutableFst<kaldi::LatticeArc> *ifst,
double beam,
MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts) {
bool ans = true;
Invert(ifst);
if (ifst->Properties(fst::kTopSorted, true) == 0) {
if (!TopSort(ifst)) {
// Cannot topologically sort the lattice -- determinization will fail.
KALDI_ERR << "Topological sorting of state-level lattice failed (probably"
<< " your lexicon has empty words or your LM has epsilon cycles"
<< ").";
}
}
ILabelCompare<kaldi::LatticeArc> ilabel_comp;
ArcSort(ifst, ilabel_comp);
ans = DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
trans_model, ifst, beam, ofst, opts);
Connect(ofst);
return ans;
}
// Instantiate the templates for the types we might need.
// Note: there are actually four templates, each of which
// we instantiate for a single type.
template
bool DeterminizeLatticePruned<kaldi::LatticeWeight>(
const ExpandedFst<kaldi::LatticeArc> &ifst,
double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePrunedOptions opts);
template
bool DeterminizeLatticePruned<kaldi::LatticeWeight>(
const ExpandedFst<kaldi::LatticeArc> &ifst,
double prune,
MutableFst<kaldi::LatticeArc> *ofst,
DeterminizeLatticePrunedOptions opts);
template
bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
const kaldi::TransitionInformation &trans_model,
const ExpandedFst<kaldi::LatticeArc> &ifst,
double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts);
template
bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
const kaldi::TransitionInformation &trans_model,
MutableFst<kaldi::LatticeArc> *ifst,
double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts);
}
// lat/determinize-lattice-pruned.h
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#define KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#include <fst/fstlib.h>
#include <fst/fst-decl.h>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "fstext/lattice-weight.h"
#include "itf/transition-information.h"
#include "itf/options-itf.h"
#include "lat/kaldi-lattice.h"
namespace fst {
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-lattice-pruned.cc
/*
DeterminizeLatticePruned implements a special form of determinization with
epsilon removal, optimized for a phase of lattice generation. This algorithm
also does pruning at the same time-- the combination is more efficient as it
somtimes prevents us from creating a lot of states that would later be pruned
away. This allows us to increase the lattice-beam and not have the algorithm
blow up. Also, because our algorithm processes states in order from those
that appear on high-scoring paths down to those that appear on low-scoring
paths, we can easily terminate the algorithm after a certain specified number
of states or arcs.
The input is an FST with weight-type BaseWeightType (usually a pair of floats,
with a lexicographical type of order, such as LatticeWeightTpl<float>).
Typically this would be a state-level lattice, with input symbols equal to
words, and output-symbols equal to p.d.f's (so like the inverse of HCLG). Imagine representing this as an
acceptor of type CompactLatticeWeightTpl<float>, in which the input/output
symbols are words, and the weights contain the original weights together with
strings (with zero or one symbol in them) containing the original output labels
(the p.d.f.'s). We determinize this using acceptor determinization with
epsilon removal. Remember (from lattice-weight.h) that
CompactLatticeWeightTpl has a special kind of semiring where we always take
the string corresponding to the best cost (of type BaseWeightType), and
discard the other. This corresponds to taking the best output-label sequence
(of p.d.f.'s) for each input-label sequence (of words). We couldn't use the
Gallic weight for this, or it would die as soon as it detected that the input
FST was non-functional. In our case, any acyclic FST (and many cyclic ones)
can be determinized.
We assume that there is a function
Compare(const BaseWeightType &a, const BaseWeightType &b)
that returns (-1, 0, 1) according to whether (a < b, a == b, a > b) in the
total order on the BaseWeightType... this information should be the
same as NaturalLess would give, but it's more efficient to do it this way.
You can define this for things like TropicalWeight if you need to instantiate
this class for that weight type.
We implement this determinization in a special way to make it efficient for
the types of FSTs that we will apply it to. One issue is that if we
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
type vector<IntType>, the algorithm takes time quadratic in the length of
words (in states), because propagating each arc involves copying a whole
vector (of integers representing p.d.f.'s). Instead we use a hash structure
where each string is a pointer (Entry*), and uses a hash from (Entry*,
IntType), to the successor string (and a way to get the latest IntType and the
ancestor Entry*). [this is the class LatticeStringRepository].
Another issue is that rather than representing a determinized-state as a
collection of (state, weight), we represent it in a couple of reduced forms.
Suppose a determinized-state is a collection of (state, weight) pairs; call
this the "canonical representation". Note: these collections are always
normalized to remove any common weight and string part. Define end-states as
the subset of states that have an arc out of them with a label on, or are
final. If we represent a determinized-state a the set of just its (end-state,
weight) pairs, this will be a valid and more compact representation, and will
lead to a smaller set of determinized states (like early minimization). Call
this collection of (end-state, weight) pairs the "minimal representation". As
a mechanism to reduce compute, we can also consider another representation.
In the determinization algorithm, we start off with a set of (begin-state,
weight) pairs (where the "begin-states" are initial or have a label on the
transition into them), and the "canonical representation" consists of the
epsilon-closure of this set (i.e. follow epsilons). Call this set of
(begin-state, weight) pairs, appropriately normalized, the "initial
representation". If two initial representations are the same, the "canonical
representation" and hence the "minimal representation" will be the same. We
can use this to reduce compute. Note that if two initial representations are
different, this does not preclude the other representations from being the same.
*/
struct DeterminizeLatticePrunedOptions {
float delta; // A small offset used to measure equality of weights.
int max_mem; // If >0, determinization will fail and return false
// when the algorithm's (approximate) memory consumption crosses this threshold.
int max_loop; // If >0, can be used to detect non-determinizable input
// (a case that wouldn't be caught by max_mem).
int max_states;
int max_arcs;
float retry_cutoff;
DeterminizeLatticePrunedOptions(): delta(kDelta),
max_mem(-1),
max_loop(-1),
max_states(-1),
max_arcs(-1),
retry_cutoff(0.5) { }
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this)");
opts->Register("max-arcs", &max_arcs, "Maximum number of arcs in "
"output FST (total, not per state");
opts->Register("max-states", &max_states, "Maximum number of arcs in output "
"FST (total, not per state");
opts->Register("max-loop", &max_loop, "Option used to detect a particular "
"type of determinization failure, typically due to invalid input "
"(e.g., negative-cost loops)");
opts->Register("retry-cutoff", &retry_cutoff, "Controls pruning un-determinized "
"lattice and retrying determinization: if effective-beam < "
"retry-cutoff * beam, we prune the raw lattice and retry. Avoids "
"ever getting empty output for long segments.");
}
};
struct DeterminizeLatticePhonePrunedOptions {
// delta: a small offset used to measure equality of weights.
float delta;
// max_mem: if > 0, determinization will fail and return false when the
// algorithm's (approximate) memory consumption crosses this threshold.
int max_mem;
// phone_determinize: if true, do a first pass determinization on both phones
// and words.
bool phone_determinize;
// word_determinize: if true, do a second pass determinization on words only.
bool word_determinize;
// minimize: if true, push and minimize after determinization.
bool minimize;
DeterminizeLatticePhonePrunedOptions(): delta(kDelta),
max_mem(50000000),
phone_determinize(true),
word_determinize(true),
minimize(false) {}
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this).");
opts->Register("phone-determinize", &phone_determinize, "If true, do an "
"initial pass of determinization on both phones and words (see"
" also --word-determinize)");
opts->Register("word-determinize", &word_determinize, "If true, do a second "
"pass of determinization on words only (see also "
"--phone-determinize)");
opts->Register("minimize", &minimize, "If true, push and minimize after "
"determinization.");
}
};
/**
This function implements the normal version of DeterminizeLattice, in which the
output strings are represented using sequences of arcs, where all but the
first one has an epsilon on the input side. It also prunes using the beam
in the "prune" parameter. The input FST must be topologically sorted in order
for the algorithm to work. For efficiency it is recommended to sort ilabel as well.
Returns true on success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: you may want to use the version below which outputs to CompactLattice.
*/
template<class Weight>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune,
MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/* This is a version of DeterminizeLattice with a slightly more "natural" output format,
where the output sequences are encoded using the CompactLatticeArcTpl template
(i.e. the sequences of output symbols are represented directly as strings The input
FST must be topologically sorted in order for the algorithm to work. For efficiency
it is recommended to sort the ilabel for the input FST as well.
Returns true on normal success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: if Lattice is the input, you need to Invert() before calling this,
so words are on the input side.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> >&ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/** This function takes in lattices and inserts phones at phone boundaries. It
uses the transition model to work out the transition_id to phone map. The
returning value is the starting index of the phone label. Typically we pick
(maximum_output_label_index + 1) as this value. The inserted phones are then
mapped to (returning_value + original_phone_label) in the new lattice. The
returning value will be used by DeterminizeLatticeDeletePhones() where it
works out the phones according to this value.
*/
template<class Weight>
typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *fst);
/** This function takes in lattices and deletes "phones" from them. The "phones"
here are actually any label that is larger than first_phone_label because
when we insert phones into the lattice, we map the original phone label to
(first_phone_label + original_phone_label). It is supposed to be used
together with DeterminizeLatticeInsertPhones()
*/
template<class Weight>
void DeterminizeLatticeDeletePhones(
typename ArcTpl<Weight>::Label first_phone_label,
MutableFst<ArcTpl<Weight> > *fst);
/** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and
DeterminizeLatticePruned(). If --phone-determinize is set to true, it first
calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of
determinization on the phone + word lattices. If --word-determinize is set
true, it then does a second pass of determinization on the word lattices by
calling DeterminizeLatticePruned(). If both are set to false, then it gives
a warning and copying the lattices without determinization.
Note: the point of doing first a phone-level determinization pass and then
a word-level determinization pass is that it allows us to determinize
deeper lattices without "failing early" and returning a too-small lattice
due to the max-mem constraint. The result should be the same as word-level
determinization in general, but for deeper lattices it is a bit faster,
despite the fact that we now have two passes of determinization by default.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/** "Destructive" version of DeterminizeLatticePhonePruned() where the input
lattice might be changed.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/** This function is a wrapper of DeterminizeLatticePhonePruned() that works for
Lattice type FSTs. It simplifies the calling process by calling
TopSort() Invert() and ArcSort() for you.
Unlike other determinization routines, the function
requires "ifst" to have transition-id's on the input side and words on the
output side.
This function can be used as the top-level interface to all the determinization
code.
*/
bool DeterminizeLatticePhonePrunedWrapper(
const kaldi::TransitionInformation &trans_model,
MutableFst<kaldi::LatticeArc> *ifst,
double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/// @} end "addtogroup fst_extensions"
} // end namespace fst
#endif
// lat/kaldi-lattice.cc
// Copyright 2009-2011 Microsoft Corporation
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/kaldi-lattice.h"
#include "fst/script/print-impl.h"
namespace kaldi {
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
CompactLattice* ConvertToCompactLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
CompactLattice *ofst = new CompactLattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
CompactLattice* ConvertToCompactLattice(CompactLattice *ifst) {
return ifst;
}
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
Lattice* ConvertToLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
Lattice *ofst = new Lattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
Lattice* ConvertToLattice(Lattice *ifst) {
return ifst;
}
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = true, write_one = false;
fst::FstPrinter<CompactLatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
/// LatticeReader provides (static) functions for reading both Lattice
/// and CompactLattice, in text form.
class LatticeReader {
typedef LatticeArc Arc;
typedef LatticeWeight Weight;
typedef CompactLatticeArc CArc;
typedef CompactLatticeWeight CWeight;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
public:
// everything is static in this class.
/** This function reads from the FST text format; it does not know in advance
whether it's a Lattice or CompactLattice in the stream so it tries to
read both formats until it becomes clear which is the correct one.
*/
static std::pair<Lattice*, CompactLattice*> ReadText(
std::istream &is) {
typedef std::pair<Lattice*, CompactLattice*> PairT;
using std::string;
using std::vector;
Lattice *fst = new Lattice();
CompactLattice *cfst = new CompactLattice();
string line;
size_t nline = 0;
string separator = FLAGS_fst_field_separator + "\r\n";
while (std::getline(is, line)) {
nline++;
vector<string> col;
// on Windows we'll write in text and read in binary mode.
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.size() == 0) break; // Empty line is a signal to stop, in our
// archive format.
if (col.size() > 5) {
KALDI_WARN << "Reading lattice: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
StateId s;
if (!ConvertStringToInteger(col[0], &s)) {
KALDI_WARN << "FstCompiler: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
if (fst)
while (s >= fst->NumStates())
fst->AddState();
if (cfst)
while (s >= cfst->NumStates())
cfst->AddState();
if (nline == 1) {
if (fst) fst->SetStart(s);
if (cfst) cfst->SetStart(s);
}
if (fst) { // we still have fst; try to read that arc.
bool ok = true;
Arc arc;
Weight w;
StateId d = s;
switch (col.size()) {
case 1 :
fst->SetFinal(s, Weight::One());
break;
case 2:
if (!StrToWeight(col[1], true, &w)) ok = false;
else fst->SetFinal(s, w);
break;
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
ok = false;
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel);
if (ok) {
d = arc.nextstate;
arc.weight = Weight::One();
fst->AddArc(s, arc);
}
break;
case 5:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel) &&
StrToWeight(col[4], false, &arc.weight);
if (ok) {
d = arc.nextstate;
fst->AddArc(s, arc);
}
break;
default:
ok = false;
}
while (d >= fst->NumStates())
fst->AddState();
if (!ok) {
delete fst;
fst = NULL;
}
}
if (cfst) {
bool ok = true;
CArc arc;
CWeight w;
StateId d = s;
switch (col.size()) {
case 1 :
cfst->SetFinal(s, CWeight::One());
break;
case 2:
if (!StrToCWeight(col[1], true, &w)) ok = false;
else cfst->SetFinal(s, w);
break;
case 3: // compact-lattice is acceptor format: state, next-state, label.
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
arc.weight = CWeight::One();
cfst->AddArc(s, arc);
}
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
StrToCWeight(col[3], false, &arc.weight);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
cfst->AddArc(s, arc);
}
break;
case 5: default:
ok = false;
}
while (d >= cfst->NumStates())
cfst->AddState();
if (!ok) {
delete cfst;
cfst = NULL;
}
}
if (!fst && !cfst) {
KALDI_WARN << "Bad line in lattice text format: " << line;
// read until we get an empty line, so at least we
// have a chance to read the next one (although this might
// be a bit futile since the calling code will get unhappy
// about failing to read this one.
while (std::getline(is, line)) {
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.empty()) break;
}
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
}
return PairT(fst, cfst);
}
static bool StrToWeight(const std::string &s, bool allow_zero, Weight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == Weight::Zero())) {
return false;
}
return true;
}
static bool StrToCWeight(const std::string &s, bool allow_zero, CWeight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == CWeight::Zero())) {
return false;
}
return true;
}
};
CompactLattice *ReadCompactLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.second != NULL) {
delete lat_pair.first;
return lat_pair.second;
} else if (lat_pair.first != NULL) {
// note: ConvertToCompactLattice frees its input.
return ConvertToCompactLattice(lat_pair.first);
} else {
return NULL;
}
}
Lattice *ReadLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.first != NULL) {
delete lat_pair.second;
return lat_pair.first;
} else if (lat_pair.second != NULL) {
// note: ConvertToLattice frees its input.
return ConvertToLattice(lat_pair.second);
} else {
return NULL;
}
}
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat) {
KALDI_ASSERT(*clat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading compact lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading compact lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
CompactLattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToCompactLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToCompactLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToCompactLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToCompactLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to CompactLattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading compact lattice (after reading header).";
return false;
}
*clat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*clat = ReadCompactLatticeText(is); // that routine will warn on error.
return (*clat != NULL);
}
}
bool CompactLatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading CompactLattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadCompactLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadCompactLattice(is, true, &t_);
}
}
bool WriteLattice(std::ostream &os, bool binary, const Lattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to do if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = false, write_one = false;
fst::FstPrinter<LatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat) {
KALDI_ASSERT(*lat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
Lattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to Lattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading lattice (after reading header).";
return false;
}
*lat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*lat = ReadLatticeText(is); // that routine will warn on error.
return (*lat != NULL);
}
}
/* Since we don't write the binary headers for this type of holder,
we use a different method to work out whether we're in binary mode.
*/
bool LatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading Lattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadLattice(is, true, &t_);
}
}
} // end namespace kaldi
// lat/kaldi-lattice.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_KALDI_LATTICE_H_
#define KALDI_LAT_KALDI_LATTICE_H_
#include "fstext/fstext-lib.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
namespace kaldi {
// will import some things above...
typedef fst::LatticeWeightTpl<BaseFloat> LatticeWeight;
// careful: kaldi::int32 is not always the same C type as fst::int32
typedef fst::CompactLatticeWeightTpl<LatticeWeight, int32> CompactLatticeWeight;
typedef fst::CompactLatticeWeightCommonDivisorTpl<LatticeWeight, int32>
CompactLatticeWeightCommonDivisor;
typedef fst::ArcTpl<LatticeWeight> LatticeArc;
typedef fst::ArcTpl<CompactLatticeWeight> CompactLatticeArc;
typedef fst::VectorFst<LatticeArc> Lattice;
typedef fst::VectorFst<CompactLatticeArc> CompactLattice;
// The following functions for writing and reading lattices in binary or text
// form are provided here in case you need to include lattices in larger,
// Kaldi-type objects with their own Read and Write functions. Caution: these
// functions return false on stream failure rather than throwing an exception as
// most similar Kaldi functions would do.
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &clat);
bool WriteLattice(std::ostream &os, bool binary,
const Lattice &lat);
// the following function requires that *clat be
// NULL when called.
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat);
// the following function requires that *lat be
// NULL when called.
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat);
class CompactLatticeHolder {
public:
typedef CompactLattice T;
CompactLatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteCompactLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty CompactLatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(CompactLatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const CompactLatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~CompactLatticeHolder() { Clear(); }
private:
T *t_;
};
class LatticeHolder {
public:
typedef Lattice T;
LatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty LatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(LatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const LatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~LatticeHolder() { Clear(); }
private:
T *t_;
};
typedef TableWriter<LatticeHolder> LatticeWriter;
typedef SequentialTableReader<LatticeHolder> SequentialLatticeReader;
typedef RandomAccessTableReader<LatticeHolder> RandomAccessLatticeReader;
typedef TableWriter<CompactLatticeHolder> CompactLatticeWriter;
typedef SequentialTableReader<CompactLatticeHolder> SequentialCompactLatticeReader;
typedef RandomAccessTableReader<CompactLatticeHolder> RandomAccessCompactLatticeReader;
} // namespace kaldi
#endif // KALDI_LAT_KALDI_LATTICE_H_
// lat/lattice-functions.cc
// Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal)
// 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao Weng;
// Bagher BabaAli
// 2013 Cisco Systems (author: Neha Agrawal) [code modified
// from original code in ../gmmbin/gmm-rescore-lattice.cc]
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-math.h"
#include "lat/lattice-functions.h"
namespace kaldi {
using std::map;
using std::vector;
void GetPerFrameAcousticCosts(const Lattice &nbest,
Vector<BaseFloat> *per_frame_loglikes) {
using namespace fst;
typedef Lattice::Arc::Weight Weight;
vector<BaseFloat> loglikes;
int32 cur_state = nbest.Start();
int32 prev_frame = -1;
BaseFloat eps_acwt = 0.0;
while(1) {
Weight w = nbest.Final(cur_state);
if (w != Weight::Zero()) {
KALDI_ASSERT(nbest.NumArcs(cur_state) == 0);
if (per_frame_loglikes != NULL) {
SubVector<BaseFloat> subvec(&(loglikes[0]), loglikes.size());
Vector<BaseFloat> vec(subvec);
*per_frame_loglikes = vec;
}
break;
} else {
KALDI_ASSERT(nbest.NumArcs(cur_state) == 1);
fst::ArcIterator<Lattice> iter(nbest, cur_state);
const Lattice::Arc &arc = iter.Value();
BaseFloat acwt = arc.weight.Value2();
if (arc.ilabel != 0) {
if (eps_acwt > 0) {
acwt += eps_acwt;
eps_acwt = 0.0;
}
loglikes.push_back(acwt);
prev_frame++;
} else if (acwt == acwt){
if (prev_frame > -1) {
loglikes[prev_frame] += acwt;
} else {
eps_acwt += acwt;
}
}
cur_state = arc.nextstate;
}
}
}
int32 LatticeStateTimes(const Lattice &lat, vector<int32> *times) {
if (!lat.Properties(fst::kTopSorted, true))
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates();
times->clear();
times->resize(num_states, -1);
(*times)[0] = 0;
for (int32 state = 0; state < num_states; state++) {
int32 cur_time = (*times)[state];
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
const LatticeArc &arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon input label on arc
// next time instance
if ((*times)[arc.nextstate] == -1) {
(*times)[arc.nextstate] = cur_time + 1;
} else {
KALDI_ASSERT((*times)[arc.nextstate] == cur_time + 1);
}
} else { // epsilon input label on arc
// Same time instance
if ((*times)[arc.nextstate] == -1)
(*times)[arc.nextstate] = cur_time;
else
KALDI_ASSERT((*times)[arc.nextstate] == cur_time);
}
}
}
return (*std::max_element(times->begin(), times->end()));
}
int32 CompactLatticeStateTimes(const CompactLattice &lat,
vector<int32> *times) {
if (!lat.Properties(fst::kTopSorted, true))
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates();
times->clear();
times->resize(num_states, -1);
(*times)[0] = 0;
int32 utt_len = -1;
for (int32 state = 0; state < num_states; state++) {
int32 cur_time = (*times)[state];
for (fst::ArcIterator<CompactLattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value();
int32 arc_len = static_cast<int32>(arc.weight.String().size());
if ((*times)[arc.nextstate] == -1)
(*times)[arc.nextstate] = cur_time + arc_len;
else
KALDI_ASSERT((*times)[arc.nextstate] == cur_time + arc_len);
}
if (lat.Final(state) != CompactLatticeWeight::Zero()) {
int32 this_utt_len = (*times)[state] + lat.Final(state).String().size();
if (utt_len == -1) utt_len = this_utt_len;
else {
if (this_utt_len != utt_len) {
KALDI_WARN << "Utterance does not "
"seem to have a consistent length.";
utt_len = std::max(utt_len, this_utt_len);
}
}
}
}
if (utt_len == -1) {
KALDI_WARN << "Utterance does not have a final-state.";
return 0;
}
return utt_len;
}
bool ComputeCompactLatticeAlphas(const CompactLattice &clat,
vector<double> *alpha) {
using namespace fst;
// typedef the arc, weight types
typedef CompactLattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
//Make sure the lattice is topologically sorted.
if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_WARN << "Input lattice must be topologically sorted.";
return false;
}
if (clat.Start() != 0) {
KALDI_WARN << "Input lattice must start from state 0.";
return false;
}
int32 num_states = clat.NumStates();
(*alpha).resize(0);
(*alpha).resize(num_states, kLogZeroDouble);
// Now propagate alphas forward. Note that we don't acount the weight of the
// final state to alpha[final_state] -- we acount it to beta[final_state];
(*alpha)[0] = 0.0;
for (StateId s = 0; s < num_states; s++) {
double this_alpha = (*alpha)[s];
for (ArcIterator<CompactLattice> aiter(clat, s);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -(arc.weight.Weight().Value1() +
arc.weight.Weight().Value2());
(*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate],
this_alpha + arc_like);
}
}
return true;
}
bool ComputeCompactLatticeBetas(const CompactLattice &clat,
vector<double> *beta) {
using namespace fst;
// typedef the arc, weight types
typedef CompactLattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
// Make sure the lattice is topologically sorted.
if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_WARN << "Input lattice must be topologically sorted.";
return false;
}
if (clat.Start() != 0) {
KALDI_WARN << "Input lattice must start from state 0.";
return false;
}
int32 num_states = clat.NumStates();
(*beta).resize(0);
(*beta).resize(num_states, kLogZeroDouble);
// Now propagate betas backward. Note that beta[final_state] contains the
// weight of the final state in the lattice -- compare that with alpha.
for (StateId s = num_states-1; s >= 0; s--) {
Weight f = clat.Final(s);
double this_beta = -(f.Weight().Value1()+f.Weight().Value2());
for (ArcIterator<CompactLattice> aiter(clat, s);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -(arc.weight.Weight().Value1() +
arc.weight.Weight().Value2());
double arc_beta = (*beta)[arc.nextstate] + arc_like;
this_beta = LogAdd(this_beta, arc_beta);
}
(*beta)[s] = this_beta;
}
return true;
}
template<class LatType> // could be Lattice or CompactLattice
bool PruneLattice(BaseFloat beam, LatType *lat) {
typedef typename LatType::Arc Arc;
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
KALDI_ASSERT(beam > 0.0);
if (!lat->Properties(fst::kTopSorted, true)) {
if (fst::TopSort(lat) == false) {
KALDI_WARN << "Cycles detected in lattice";
return false;
}
}
// We assume states before "start" are not reachable, since
// the lattice is topologically sorted.
int32 start = lat->Start();
int32 num_states = lat->NumStates();
if (num_states == 0) return false;
std::vector<double> forward_cost(num_states,
std::numeric_limits<double>::infinity()); // viterbi forward.
forward_cost[start] = 0.0; // lattice can't have cycles so couldn't be
// less than this.
double best_final_cost = std::numeric_limits<double>::infinity();
// Update the forward probs.
// Thanks to Jing Zheng for finding a bug here.
for (int32 state = 0; state < num_states; state++) {
double this_forward_cost = forward_cost[state];
for (fst::ArcIterator<LatType> aiter(*lat, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc(aiter.Value());
StateId nextstate = arc.nextstate;
KALDI_ASSERT(nextstate > state && nextstate < num_states);
double next_forward_cost = this_forward_cost +
ConvertToCost(arc.weight);
if (forward_cost[nextstate] > next_forward_cost)
forward_cost[nextstate] = next_forward_cost;
}
Weight final_weight = lat->Final(state);
double this_final_cost = this_forward_cost +
ConvertToCost(final_weight);
if (this_final_cost < best_final_cost)
best_final_cost = this_final_cost;
}
int32 bad_state = lat->AddState(); // this state is not final.
double cutoff = best_final_cost + beam;
// Go backwards updating the backward probs (which share memory with the
// forward probs), and pruning arcs and deleting final-probs. We prune arcs
// by making them point to the non-final state "bad_state". We'll then use
// Trim() to remove unnecessary arcs and states. [this is just easier than
// doing it ourselves.]
std::vector<double> &backward_cost(forward_cost);
for (int32 state = num_states - 1; state >= 0; state--) {
double this_forward_cost = forward_cost[state];
double this_backward_cost = ConvertToCost(lat->Final(state));
if (this_backward_cost + this_forward_cost > cutoff
&& this_backward_cost != std::numeric_limits<double>::infinity())
lat->SetFinal(state, Weight::Zero());
for (fst::MutableArcIterator<LatType> aiter(lat, state);
!aiter.Done();
aiter.Next()) {
Arc arc(aiter.Value());
StateId nextstate = arc.nextstate;
KALDI_ASSERT(nextstate > state && nextstate < num_states);
double arc_cost = ConvertToCost(arc.weight),
arc_backward_cost = arc_cost + backward_cost[nextstate],
this_fb_cost = this_forward_cost + arc_backward_cost;
if (arc_backward_cost < this_backward_cost)
this_backward_cost = arc_backward_cost;
if (this_fb_cost > cutoff) { // Prune the arc.
arc.nextstate = bad_state;
aiter.SetValue(arc);
}
}
backward_cost[state] = this_backward_cost;
}
fst::Connect(lat);
return (lat->NumStates() > 0);
}
// instantiate the template for lattice and CompactLattice.
template bool PruneLattice(BaseFloat beam, Lattice *lat);
template bool PruneLattice(BaseFloat beam, CompactLattice *lat);
BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post,
double *acoustic_like_sum) {
// Note, Posterior is defined as follows: Indexed [frame], then a list
// of (transition-id, posterior-probability) pairs.
// typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
using namespace fst;
typedef Lattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
if (acoustic_like_sum) *acoustic_like_sum = 0.0;
// Make sure the lattice is topologically sorted.
if (lat.Properties(fst::kTopSorted, true) == 0)
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates();
vector<int32> state_times;
int32 max_time = LatticeStateTimes(lat, &state_times);
std::vector<double> alpha(num_states, kLogZeroDouble);
std::vector<double> &beta(alpha); // we re-use the same memory for
// this, but it's semantically distinct so we name it differently.
double tot_forward_prob = kLogZeroDouble;
post->clear();
post->resize(max_time);
alpha[0] = 0.0;
// Propagate alphas forward.
for (StateId s = 0; s < num_states; s++) {
double this_alpha = alpha[s];
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight);
alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
}
Weight f = lat.Final(s);
if (f != Weight::Zero()) {
double final_like = this_alpha - (f.Value1() + f.Value2());
tot_forward_prob = LogAdd(tot_forward_prob, final_like);
KALDI_ASSERT(state_times[s] == max_time &&
"Lattice is inconsistent (final-prob not at max_time)");
}
}
for (StateId s = num_states-1; s >= 0; s--) {
Weight f = lat.Final(s);
double this_beta = -(f.Value1() + f.Value2());
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight),
arc_beta = beta[arc.nextstate] + arc_like;
this_beta = LogAdd(this_beta, arc_beta);
int32 transition_id = arc.ilabel;
// The following "if" is an optimization to avoid un-needed exp().
if (transition_id != 0 || acoustic_like_sum != NULL) {
double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
if (transition_id != 0) // Arc has a transition-id on it [not epsilon]
(*post)[state_times[s]].push_back(std::make_pair(transition_id,
static_cast<kaldi::BaseFloat>(posterior)));
if (acoustic_like_sum != NULL)
*acoustic_like_sum -= posterior * arc.weight.Value2();
}
}
if (acoustic_like_sum != NULL && f != Weight::Zero()) {
double final_logprob = - ConvertToCost(f),
posterior = Exp(alpha[s] + final_logprob - tot_forward_prob);
*acoustic_like_sum -= posterior * f.Value2();
}
beta[s] = this_beta;
}
double tot_backward_prob = beta[0];
if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
<< ", while total backward probability = " << tot_backward_prob;
}
// Now combine any posteriors with the same transition-id.
for (int32 t = 0; t < max_time; t++)
MergePairVectorSumming(&((*post)[t]));
return tot_backward_prob;
}
void LatticeActivePhones(const Lattice &lat, const TransitionInformation &trans,
const vector<int32> &silence_phones,
vector< std::set<int32> > *active_phones) {
KALDI_ASSERT(IsSortedAndUniq(silence_phones));
vector<int32> state_times;
int32 num_states = lat.NumStates();
int32 max_time = LatticeStateTimes(lat, &state_times);
active_phones->clear();
active_phones->resize(max_time);
for (int32 state = 0; state < num_states; state++) {
int32 cur_time = state_times[state];
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
const LatticeArc &arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon arc
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
if (!std::binary_search(silence_phones.begin(),
silence_phones.end(), phone))
(*active_phones)[cur_time].insert(phone);
}
} // end looping over arcs
} // end looping over states
}
void ConvertLatticeToPhones(const TransitionInformation &trans,
Lattice *lat) {
typedef LatticeArc Arc;
int32 num_states = lat->NumStates();
for (int32 state = 0; state < num_states; state++) {
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
Arc arc(aiter.Value());
arc.olabel = 0; // remove any word.
if ((arc.ilabel != 0) // has a transition-id on input..
&& (trans.TransitionIdIsStartOfPhone(arc.ilabel))
&& (!trans.IsSelfLoop(arc.ilabel))) {
// && trans.IsFinal(arc.ilabel)) // there is one of these per phone...
arc.olabel = trans.TransitionIdToPhone(arc.ilabel);
}
aiter.SetValue(arc);
} // end looping over arcs
} // end looping over states
}
static inline double LogAddOrMax(bool viterbi, double a, double b) {
if (viterbi)
return std::max(a, b);
else
return LogAdd(a, b);
}
template<typename LatticeType>
double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
bool viterbi,
vector<double> *alpha,
vector<double> *beta) {
typedef typename LatticeType::Arc Arc;
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
StateId num_states = lat.NumStates();
KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted);
KALDI_ASSERT(lat.Start() == 0);
alpha->clear();
beta->clear();
alpha->resize(num_states, kLogZeroDouble);
beta->resize(num_states, kLogZeroDouble);
double tot_forward_prob = kLogZeroDouble;
(*alpha)[0] = 0.0;
// Propagate alphas forward.
for (StateId s = 0; s < num_states; s++) {
double this_alpha = (*alpha)[s];
for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight);
(*alpha)[arc.nextstate] = LogAddOrMax(viterbi, (*alpha)[arc.nextstate],
this_alpha + arc_like);
}
Weight f = lat.Final(s);
if (f != Weight::Zero()) {
double final_like = this_alpha - ConvertToCost(f);
tot_forward_prob = LogAddOrMax(viterbi, tot_forward_prob, final_like);
}
}
for (StateId s = num_states-1; s >= 0; s--) { // it's guaranteed signed.
double this_beta = -ConvertToCost(lat.Final(s));
for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight),
arc_beta = (*beta)[arc.nextstate] + arc_like;
this_beta = LogAddOrMax(viterbi, this_beta, arc_beta);
}
(*beta)[s] = this_beta;
}
double tot_backward_prob = (*beta)[lat.Start()];
if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
<< ", while total backward probability = " << tot_backward_prob;
}
// Split the difference when returning... they should be the same.
return 0.5 * (tot_backward_prob + tot_forward_prob);
}
// instantiate the template for Lattice and CompactLattice
template
double ComputeLatticeAlphasAndBetas(const Lattice &lat,
bool viterbi,
vector<double> *alpha,
vector<double> *beta);
template
double ComputeLatticeAlphasAndBetas(const CompactLattice &lat,
bool viterbi,
vector<double> *alpha,
vector<double> *beta);
/// This is used in CompactLatticeLimitDepth.
struct LatticeArcRecord {
BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc,
// minus the overall best-cost of the lattice.
CompactLatticeArc::StateId state; // state in the lattice.
size_t arc; // arc index within the state.
bool operator < (const LatticeArcRecord &other) const {
return logprob < other.logprob;
}
};
void CompactLatticeLimitDepth(int32 max_depth_per_frame,
CompactLattice *clat) {
typedef CompactLatticeArc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
if (clat->Start() == fst::kNoStateId) {
KALDI_WARN << "Limiting depth of empty lattice.";
return;
}
if (clat->Properties(fst::kTopSorted, true) == 0) {
if (!TopSort(clat))
KALDI_ERR << "Topological sorting of lattice failed.";
}
vector<int32> state_times;
int32 T = CompactLatticeStateTimes(*clat, &state_times);
// The alpha and beta quantities here are "viterbi" alphas and beta.
std::vector<double> alpha;
std::vector<double> beta;
bool viterbi = true;
double best_prob = ComputeLatticeAlphasAndBetas(*clat, viterbi,
&alpha, &beta);
std::vector<std::vector<LatticeArcRecord> > arc_records(T);
StateId num_states = clat->NumStates();
for (StateId s = 0; s < num_states; s++) {
for (fst::ArcIterator<CompactLattice> aiter(*clat, s); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
LatticeArcRecord arc_record;
arc_record.state = s;
arc_record.arc = aiter.Position();
arc_record.logprob =
(alpha[s] + beta[arc.nextstate] - ConvertToCost(arc.weight))
- best_prob;
KALDI_ASSERT(arc_record.logprob < 0.1); // Should be zero or negative.
int32 num_frames = arc.weight.String().size(), start_t = state_times[s];
for (int32 t = start_t; t < start_t + num_frames; t++) {
KALDI_ASSERT(t < T);
arc_records[t].push_back(arc_record);
}
}
}
StateId dead_state = clat->AddState(); // A non-coaccesible state which we use
// to remove arcs (make them end
// there).
size_t max_depth = max_depth_per_frame;
for (int32 t = 0; t < T; t++) {
size_t size = arc_records[t].size();
if (size > max_depth) {
// we sort from worst to best, so we keep the later-numbered ones,
// and delete the lower-numbered ones.
size_t cutoff = size - max_depth;
std::nth_element(arc_records[t].begin(),
arc_records[t].begin() + cutoff,
arc_records[t].end());
for (size_t index = 0; index < cutoff; index++) {
LatticeArcRecord record(arc_records[t][index]);
fst::MutableArcIterator<CompactLattice> aiter(clat, record.state);
aiter.Seek(record.arc);
Arc arc = aiter.Value();
if (arc.nextstate != dead_state) { // not already killed.
arc.nextstate = dead_state;
aiter.SetValue(arc);
}
}
}
}
Connect(clat);
TopSortCompactLatticeIfNeeded(clat);
}
void TopSortCompactLatticeIfNeeded(CompactLattice *clat) {
if (clat->Properties(fst::kTopSorted, true) == 0) {
if (fst::TopSort(clat) == false) {
KALDI_ERR << "Topological sorting failed";
}
}
}
void TopSortLatticeIfNeeded(Lattice *lat) {
if (lat->Properties(fst::kTopSorted, true) == 0) {
if (fst::TopSort(lat) == false) {
KALDI_ERR << "Topological sorting failed";
}
}
}
/// Returns the depth of the lattice, defined as the average number of
/// arcs crossing any given frame. Returns 1 for empty lattices.
/// Requires that input is topologically sorted.
BaseFloat CompactLatticeDepth(const CompactLattice &clat,
int32 *num_frames) {
typedef CompactLattice::Arc::StateId StateId;
if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_ERR << "Lattice input to CompactLatticeDepth was not topologically "
<< "sorted.";
}
if (clat.Start() == fst::kNoStateId) {
*num_frames = 0;
return 1.0;
}
size_t num_arc_frames = 0;
int32 t;
{
vector<int32> state_times;
t = CompactLatticeStateTimes(clat, &state_times);
}
if (num_frames != NULL)
*num_frames = t;
for (StateId s = 0; s < clat.NumStates(); s++) {
for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value();
num_arc_frames += arc.weight.String().size();
}
num_arc_frames += clat.Final(s).String().size();
}
return num_arc_frames / static_cast<BaseFloat>(t);
}
void CompactLatticeDepthPerFrame(const CompactLattice &clat,
std::vector<int32> *depth_per_frame) {
typedef CompactLattice::Arc::StateId StateId;
if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_ERR << "Lattice input to CompactLatticeDepthPerFrame was not "
<< "topologically sorted.";
}
if (clat.Start() == fst::kNoStateId) {
depth_per_frame->clear();
return;
}
vector<int32> state_times;
int32 T = CompactLatticeStateTimes(clat, &state_times);
depth_per_frame->clear();
if (T <= 0) {
return;
} else {
depth_per_frame->resize(T, 0);
for (StateId s = 0; s < clat.NumStates(); s++) {
int32 start_time = state_times[s];
for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value();
int32 len = arc.weight.String().size();
for (int32 t = start_time; t < start_time + len; t++) {
KALDI_ASSERT(t < T);
(*depth_per_frame)[t]++;
}
}
int32 final_len = clat.Final(s).String().size();
for (int32 t = start_time; t < start_time + final_len; t++) {
KALDI_ASSERT(t < T);
(*depth_per_frame)[t]++;
}
}
}
}
void ConvertCompactLatticeToPhones(const TransitionInformation &trans,
CompactLattice *clat) {
typedef CompactLatticeArc Arc;
typedef Arc::Weight Weight;
int32 num_states = clat->NumStates();
for (int32 state = 0; state < num_states; state++) {
for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
!aiter.Done();
aiter.Next()) {
Arc arc(aiter.Value());
std::vector<int32> phone_seq;
const std::vector<int32> &tid_seq = arc.weight.String();
for (std::vector<int32>::const_iterator iter = tid_seq.begin();
iter != tid_seq.end(); ++iter) {
if (trans.IsFinal(*iter))// note: there is one of these per phone...
phone_seq.push_back(trans.TransitionIdToPhone(*iter));
}
arc.weight.SetString(phone_seq);
aiter.SetValue(arc);
} // end looping over arcs
Weight f = clat->Final(state);
if (f != Weight::Zero()) {
std::vector<int32> phone_seq;
const std::vector<int32> &tid_seq = f.String();
for (std::vector<int32>::const_iterator iter = tid_seq.begin();
iter != tid_seq.end(); ++iter) {
if (trans.IsFinal(*iter))// note: there is one of these per phone...
phone_seq.push_back(trans.TransitionIdToPhone(*iter));
}
f.SetString(phone_seq);
clat->SetFinal(state, f);
}
} // end looping over states
}
bool LatticeBoost(const TransitionInformation &trans,
const std::vector<int32> &alignment,
const std::vector<int32> &silence_phones,
BaseFloat b,
BaseFloat max_silence_error,
Lattice *lat) {
TopSortLatticeIfNeeded(lat);
// get all stored properties (test==false means don't test if not known).
uint64 props = lat->Properties(fst::kFstProperties,
false);
KALDI_ASSERT(IsSortedAndUniq(silence_phones));
KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
vector<int32> state_times;
int32 num_states = lat->NumStates();
int32 num_frames = LatticeStateTimes(*lat, &state_times);
KALDI_ASSERT(num_frames == static_cast<int32>(alignment.size()));
for (int32 state = 0; state < num_states; state++) {
int32 cur_time = state_times[state];
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
LatticeArc arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon arc
if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) {
KALDI_WARN << "Lattice has out-of-range transition-ids: "
<< "lattice/model mismatch?";
return false;
}
int32 phone = trans.TransitionIdToPhone(arc.ilabel),
ref_phone = trans.TransitionIdToPhone(alignment[cur_time]);
BaseFloat frame_error;
if (phone == ref_phone) {
frame_error = 0.0;
} else { // an error...
if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone))
frame_error = max_silence_error;
else
frame_error = 1.0;
}
BaseFloat delta_cost = -b * frame_error; // negative cost if
// frame is wrong, to boost likelihood of arcs with errors on them.
// Add this cost to the graph part.
arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
aiter.SetValue(arc);
}
}
}
// All we changed is the weights, so any properties that were
// known before, are still known, except for whether or not the
// lattice was weighted.
lat->SetProperties(props,
~(fst::kWeighted|fst::kUnweighted));
return true;
}
BaseFloat LatticeForwardBackwardMpeVariants(
const TransitionInformation &trans,
const std::vector<int32> &silence_phones,
const Lattice &lat,
const std::vector<int32> &num_ali,
std::string criterion,
bool one_silence_class,
Posterior *post) {
using namespace fst;
typedef Lattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr");
bool is_mpfe = (criterion == "mpfe");
if (lat.Properties(fst::kTopSorted, true) == 0)
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates();
vector<int32> state_times;
int32 max_time = LatticeStateTimes(lat, &state_times);
KALDI_ASSERT(max_time == static_cast<int32>(num_ali.size()));
std::vector<double> alpha(num_states, kLogZeroDouble),
alpha_smbr(num_states, 0), //forward variable for sMBR
beta(num_states, kLogZeroDouble),
beta_smbr(num_states, 0); //backward variable for sMBR
double tot_forward_prob = kLogZeroDouble;
double tot_forward_score = 0;
post->clear();
post->resize(max_time);
alpha[0] = 0.0;
// First Pass Forward,
for (StateId s = 0; s < num_states; s++) {
double this_alpha = alpha[s];
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight);
alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
}
Weight f = lat.Final(s);
if (f != Weight::Zero()) {
double final_like = this_alpha - (f.Value1() + f.Value2());
tot_forward_prob = LogAdd(tot_forward_prob, final_like);
KALDI_ASSERT(state_times[s] == max_time &&
"Lattice is inconsistent (final-prob not at max_time)");
}
}
// First Pass Backward,
for (StateId s = num_states-1; s >= 0; s--) {
Weight f = lat.Final(s);
double this_beta = -(f.Value1() + f.Value2());
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight),
arc_beta = beta[arc.nextstate] + arc_like;
this_beta = LogAdd(this_beta, arc_beta);
}
beta[s] = this_beta;
}
// First Pass Forward-Backward Check
double tot_backward_prob = beta[0];
// may loose the condition somehow here 1e-6 (was 1e-8)
if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-6)) {
KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob
<< ", while total backward probability = " << tot_backward_prob;
}
alpha_smbr[0] = 0.0;
// Second Pass Forward, calculate forward for MPFE/SMBR
for (StateId s = 0; s < num_states; s++) {
double this_alpha = alpha[s];
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight);
double frame_acc = 0.0;
if (arc.ilabel != 0) {
int32 cur_time = state_times[s];
int32 phone = trans.TransitionIdToPhone(arc.ilabel),
ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
bool phone_is_sil = std::binary_search(silence_phones.begin(),
silence_phones.end(),
phone),
ref_phone_is_sil = std::binary_search(silence_phones.begin(),
silence_phones.end(),
ref_phone),
both_sil = phone_is_sil && ref_phone_is_sil;
if (!is_mpfe) { // smbr.
int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
if (!one_silence_class) // old behavior
frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
else
frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
} else {
if (!one_silence_class) // old behavior
frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
else
frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
}
}
double arc_scale = Exp(alpha[s] + arc_like - alpha[arc.nextstate]);
alpha_smbr[arc.nextstate] += arc_scale * (alpha_smbr[s] + frame_acc);
}
Weight f = lat.Final(s);
if (f != Weight::Zero()) {
double final_like = this_alpha - (f.Value1() + f.Value2());
double arc_scale = Exp(final_like - tot_forward_prob);
tot_forward_score += arc_scale * alpha_smbr[s];
KALDI_ASSERT(state_times[s] == max_time &&
"Lattice is inconsistent (final-prob not at max_time)");
}
}
// Second Pass Backward, collect Mpe style posteriors
for (StateId s = num_states-1; s >= 0; s--) {
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight),
arc_beta = beta[arc.nextstate] + arc_like;
double frame_acc = 0.0;
int32 transition_id = arc.ilabel;
if (arc.ilabel != 0) {
int32 cur_time = state_times[s];
int32 phone = trans.TransitionIdToPhone(arc.ilabel),
ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
bool phone_is_sil = std::binary_search(silence_phones.begin(),
silence_phones.end(), phone),
ref_phone_is_sil = std::binary_search(silence_phones.begin(),
silence_phones.end(),
ref_phone),
both_sil = phone_is_sil && ref_phone_is_sil;
if (!is_mpfe) { // smbr.
int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
if (!one_silence_class) // old behavior
frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
else
frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
} else {
if (!one_silence_class) // old behavior
frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
else
frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
}
}
double arc_scale = Exp(beta[arc.nextstate] + arc_like - beta[s]);
// check arc_scale NAN,
// this is to prevent partial paths in Lattices
// i.e., paths don't survive to the final state
if (KALDI_ISNAN(arc_scale)) arc_scale = 0;
beta_smbr[s] += arc_scale * (beta_smbr[arc.nextstate] + frame_acc);
if (transition_id != 0) { // Arc has a transition-id on it [not epsilon]
double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
double acc_diff = alpha_smbr[s] + frame_acc + beta_smbr[arc.nextstate]
- tot_forward_score;
double posterior_smbr = posterior * acc_diff;
(*post)[state_times[s]].push_back(std::make_pair(transition_id,
static_cast<BaseFloat>(posterior_smbr)));
}
}
}
//Second Pass Forward Backward check
double tot_backward_score = beta_smbr[0]; // Initial state id == 0
// may loose the condition somehow here 1e-5/1e-4
if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-4)) {
KALDI_ERR << "Total forward score over lattice = " << tot_forward_score
<< ", while total backward score = " << tot_backward_score;
}
// Output the computed posteriors
for (int32 t = 0; t < max_time; t++)
MergePairVectorSumming(&((*post)[t]));
return tot_forward_score;
}
bool CompactLatticeToWordAlignment(const CompactLattice &clat,
std::vector<int32> *words,
std::vector<int32> *begin_times,
std::vector<int32> *lengths) {
words->clear();
begin_times->clear();
lengths->clear();
typedef CompactLattice::Arc Arc;
typedef Arc::Label Label;
typedef CompactLattice::StateId StateId;
typedef CompactLattice::Weight Weight;
using namespace fst;
StateId state = clat.Start();
int32 cur_time = 0;
if (state == kNoStateId) {
KALDI_WARN << "Empty lattice.";
return false;
}
while (1) {
Weight final = clat.Final(state);
size_t num_arcs = clat.NumArcs(state);
if (final != Weight::Zero()) {
if (num_arcs != 0) {
KALDI_WARN << "Lattice is not linear.";
return false;
}
if (! final.String().empty()) {
KALDI_WARN << "Lattice has alignments on final-weight: probably "
"was not word-aligned (alignments will be approximate)";
}
return true;
} else {
if (num_arcs != 1) {
KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
return false;
}
fst::ArcIterator<CompactLattice> aiter(clat, state);
const Arc &arc = aiter.Value();
Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
// Also note: word_id may be zero; we output it anyway.
int32 length = arc.weight.String().size();
words->push_back(word_id);
begin_times->push_back(cur_time);
lengths->push_back(length);
cur_time += length;
state = arc.nextstate;
}
}
}
void CompactLatticeShortestPath(const CompactLattice &clat,
CompactLattice *shortest_path) {
using namespace fst;
if (clat.Properties(fst::kTopSorted, true) == 0) {
CompactLattice clat_copy(clat);
if (!TopSort(&clat_copy))
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
CompactLatticeShortestPath(clat_copy, shortest_path);
return;
}
// Now we can assume it's topologically sorted.
shortest_path->DeleteStates();
if (clat.Start() == kNoStateId) return;
typedef CompactLatticeArc Arc;
typedef Arc::StateId StateId;
typedef CompactLatticeWeight Weight;
vector<std::pair<double, StateId> > best_cost_and_pred(clat.NumStates() + 1);
StateId superfinal = clat.NumStates();
for (StateId s = 0; s <= clat.NumStates(); s++) {
best_cost_and_pred[s].first = std::numeric_limits<double>::infinity();
best_cost_and_pred[s].second = fst::kNoStateId;
}
best_cost_and_pred[clat.Start()].first = 0;
for (StateId s = 0; s < clat.NumStates(); s++) {
double my_cost = best_cost_and_pred[s].first;
for (ArcIterator<CompactLattice> aiter(clat, s);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_cost = ConvertToCost(arc.weight),
next_cost = my_cost + arc_cost;
if (next_cost < best_cost_and_pred[arc.nextstate].first) {
best_cost_and_pred[arc.nextstate].first = next_cost;
best_cost_and_pred[arc.nextstate].second = s;
}
}
double final_cost = ConvertToCost(clat.Final(s)),
tot_final = my_cost + final_cost;
if (tot_final < best_cost_and_pred[superfinal].first) {
best_cost_and_pred[superfinal].first = tot_final;
best_cost_and_pred[superfinal].second = s;
}
}
std::vector<StateId> states; // states on best path.
StateId cur_state = superfinal, start_state = clat.Start();
while (cur_state != start_state) {
StateId prev_state = best_cost_and_pred[cur_state].second;
if (prev_state == kNoStateId) {
KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)";
return; // return empty best-path.
}
states.push_back(prev_state);
KALDI_ASSERT(cur_state != prev_state && "Lattice with cycles");
cur_state = prev_state;
}
std::reverse(states.begin(), states.end());
for (size_t i = 0; i < states.size(); i++)
shortest_path->AddState();
for (StateId s = 0; static_cast<size_t>(s) < states.size(); s++) {
if (s == 0) shortest_path->SetStart(s);
if (static_cast<size_t>(s + 1) < states.size()) { // transition to next state.
bool have_arc = false;
Arc cur_arc;
for (ArcIterator<CompactLattice> aiter(clat, states[s]);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.nextstate == states[s+1]) {
if (!have_arc ||
ConvertToCost(arc.weight) < ConvertToCost(cur_arc.weight)) {
cur_arc = arc;
have_arc = true;
}
}
}
KALDI_ASSERT(have_arc && "Code error.");
shortest_path->AddArc(s, Arc(cur_arc.ilabel, cur_arc.olabel,
cur_arc.weight, s+1));
} else { // final-prob.
shortest_path->SetFinal(s, clat.Final(states[s]));
}
}
}
void ExpandCompactLattice(const CompactLattice &clat,
double epsilon,
CompactLattice *expand_clat) {
using namespace fst;
typedef CompactLattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
typedef std::pair<StateId, StateId> StatePair;
typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
typedef MapType::iterator IterType;
if (clat.Start() == kNoStateId) return;
// Make sure the input lattice is topologically sorted.
if (clat.Properties(kTopSorted, true) == 0) {
CompactLattice clat_copy(clat);
KALDI_LOG << "Topsort this lattice.";
if (!TopSort(&clat_copy))
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
ExpandCompactLattice(clat_copy, epsilon, expand_clat);
return;
}
// Compute backward logprobs betas for the expanded lattice.
// Note: the backward logprobs in the original lattice <clat> and the
// expanded lattice <expand_clat> are the same.
int32 num_states = clat.NumStates();
std::vector<double> beta(num_states, kLogZeroDouble);
ComputeCompactLatticeBetas(clat, &beta);
double tot_backward_logprob = beta[0];
std::vector<double> alpha;
alpha.push_back(0.0);
expand_clat->DeleteStates();
MapType state_map; // Map from state pair (orig_state, copy_state) to
// copy_state, where orig_state is a state in the original lattice, and
// copy_state is its corresponding one in the expanded lattice.
unordered_map<StateId, StateId> states; // Map from orig_state to its
// copy_state for states with incoming arcs' posteriors <= epsilon.
std::queue<StatePair> state_queue;
// Set start state in the expanded lattice.
StateId start_state = expand_clat->AddState();
expand_clat->SetStart(start_state);
StatePair start_pair(clat.Start(), start_state);
state_queue.push(start_pair);
std::pair<IterType, bool> result =
state_map.insert(std::make_pair(start_pair, start_state));
KALDI_ASSERT(result.second == true);
// Expand <clat> and update forward logprobs alphas in <expand_clat>.
while (!state_queue.empty()) {
StatePair s = state_queue.front();
StateId s1 = s.first,
s2 = s.second;
state_queue.pop();
Weight f = clat.Final(s1);
if (f != Weight::Zero()) {
KALDI_ASSERT(state_map.find(s) != state_map.end());
expand_clat->SetFinal(state_map[s], f);
}
for (ArcIterator<CompactLattice> aiter(clat, s1);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
StateId orig_state = arc.nextstate;
double arc_like = -ConvertToCost(arc.weight),
this_alpha = alpha[s2] + arc_like,
arc_post = Exp(this_alpha + beta[orig_state] -
tot_backward_logprob);
// Generate the expanded lattice.
StateId copy_state;
if (arc_post > epsilon) {
copy_state = expand_clat->AddState();
StatePair next_pair(orig_state, copy_state);
std::pair<IterType, bool> result =
state_map.insert(std::make_pair(next_pair, copy_state));
KALDI_ASSERT(result.second == true);
state_queue.push(next_pair);
} else {
unordered_map<StateId, StateId>::iterator iter = states.find(orig_state);
if (iter == states.end() ) { // The counterpart state of orig_state
// has not been created in <expand_clat> yet.
copy_state = expand_clat->AddState();
StatePair next_pair(orig_state, copy_state);
std::pair<IterType, bool> result =
state_map.insert(std::make_pair(next_pair, copy_state));
KALDI_ASSERT(result.second == true);
state_queue.push(next_pair);
states[orig_state] = copy_state;
} else {
copy_state = iter->second;
}
}
// Create an arc from state_map[s] to copy_state in the expanded lattice.
expand_clat->AddArc(state_map[s], Arc(arc.ilabel, arc.olabel, arc.weight,
copy_state));
// Compute forward logprobs alpha for the expanded lattice.
if ((alpha.size() - 1) < copy_state) { // The first time to compute alpha
// for copy_state in <expand_clat>.
alpha.push_back(this_alpha);
} else { // Accumulate alpha.
alpha[copy_state] = LogAdd(alpha[copy_state], this_alpha);
}
}
} // end while
}
void CompactLatticeBestCostsAndTracebacks(
const CompactLattice &clat,
CostTraceType *forward_best_cost_and_pred,
CostTraceType *backward_best_cost_and_pred) {
// typedef the arc, weight types
typedef CompactLatticeArc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
forward_best_cost_and_pred->clear();
backward_best_cost_and_pred->clear();
forward_best_cost_and_pred->resize(clat.NumStates());
backward_best_cost_and_pred->resize(clat.NumStates());
// Initialize the cost and predecessor state for each state.
for (StateId s = 0; s < clat.NumStates(); s++) {
(*forward_best_cost_and_pred)[s].first =
std::numeric_limits<double>::infinity();
(*backward_best_cost_and_pred)[s].first =
std::numeric_limits<double>::infinity();
(*forward_best_cost_and_pred)[s].second = fst::kNoStateId;
(*backward_best_cost_and_pred)[s].second = fst::kNoStateId;
}
StateId start_state = clat.Start();
(*forward_best_cost_and_pred)[start_state].first = 0;
// Transverse the lattice forwardly to compute the best cost from the start
// state to each state and the best predecessor state of each state.
for (StateId s = 0; s < clat.NumStates(); s++) {
double cur_cost = (*forward_best_cost_and_pred)[s].first;
for (fst::ArcIterator<CompactLattice> aiter(clat, s);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double next_cost = cur_cost + ConvertToCost(arc.weight);
if (next_cost < (*forward_best_cost_and_pred)[arc.nextstate].first) {
(*forward_best_cost_and_pred)[arc.nextstate].first = next_cost;
(*forward_best_cost_and_pred)[arc.nextstate].second = s;
}
}
}
// Transverse the lattice backwardly to compute the best cost from a final
// state to each state and the best predecessor state of each state.
for (StateId s = clat.NumStates() - 1; s >= 0; s--) {
double this_cost = ConvertToCost(clat.Final(s));
for (fst::ArcIterator<CompactLattice> aiter(clat, s);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double next_cost = (*backward_best_cost_and_pred)[arc.nextstate].first +
ConvertToCost(arc.weight);
if (next_cost < this_cost) {
this_cost = next_cost;
(*backward_best_cost_and_pred)[s].second = arc.nextstate;
}
}
(*backward_best_cost_and_pred)[s].first = this_cost;
}
}
void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
CompactLattice *clat) {
if (clat->Start() == fst::kNoStateId) return;
// Make sure the input lattice is topologically sorted.
if (clat->Properties(fst::kTopSorted, true) == 0) {
KALDI_LOG << "Topsort this lattice.";
if (!TopSort(clat))
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
AddNnlmScoreToCompactLattice(nnlm_scores, clat);
return;
}
// typedef the arc, weight types
typedef CompactLatticeArc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
typedef std::pair<int32, int32> StatePair;
int32 num_states = clat->NumStates();
unordered_map<StatePair, bool, PairHasher<int32> > final_state_check;
for (StateId s = 0; s < num_states; s++) {
for (fst::MutableArcIterator<CompactLattice> aiter(clat, s);
!aiter.Done(); aiter.Next()) {
Arc arc(aiter.Value());
StatePair arc_index = std::make_pair(static_cast<int32>(s),
static_cast<int32>(arc.nextstate));
MapT::const_iterator it = nnlm_scores.find(arc_index);
double nnlm_score;
if (it != nnlm_scores.end())
nnlm_score = it->second;
else
KALDI_ERR << "Some arc does not have neural language model score.";
if (arc.ilabel != 0) { // if there is a word on this arc
LatticeWeight weight = arc.weight.Weight();
// Add associated neural LM score to each arc.
weight.SetValue1(weight.Value1() + nnlm_score);
arc.weight.SetWeight(weight);
aiter.SetValue(arc);
}
Weight clat_final = clat->Final(arc.nextstate);
StatePair final_pair = std::make_pair(arc.nextstate, arc.nextstate);
// Add neural LM scores to each final state only once.
if (clat_final != CompactLatticeWeight::Zero() &&
final_state_check.find(final_pair) == final_state_check.end()) {
MapT::const_iterator final_it = nnlm_scores.find(final_pair);
double final_nnlm_score = 0.0;
if (final_it != nnlm_scores.end())
final_nnlm_score = final_it->second;
// Add neural LM scores to the final weight.
Weight final_weight(LatticeWeight(clat_final.Weight().Value1() +
final_nnlm_score,
clat_final.Weight().Value2()),
clat_final.String());
clat->SetFinal(arc.nextstate, final_weight);
final_state_check[final_pair] = true;
}
} // end looping over arcs
} // end looping over states
}
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
CompactLattice *clat) {
typedef CompactLatticeArc Arc;
int32 num_states = clat->NumStates();
//scan the lattice
for (int32 state = 0; state < num_states; state++) {
for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
!aiter.Done(); aiter.Next()) {
Arc arc(aiter.Value());
if (arc.ilabel != 0) { // if there is a word on this arc
LatticeWeight weight = arc.weight.Weight();
// add word insertion penalty to lattice
weight.SetValue1( weight.Value1() + word_ins_penalty);
arc.weight.SetWeight(weight);
aiter.SetValue(arc);
}
} // end looping over arcs
} // end looping over states
}
struct ClatRescoreTuple {
ClatRescoreTuple(int32 state, int32 arc, int32 tid):
state_id(state), arc_id(arc), tid(tid) { }
int32 state_id;
int32 arc_id;
int32 tid;
};
/** RescoreCompactLatticeInternal is the internal code for both
RescoreCompactLattice and RescoreCompatLatticeSpeedup. For
RescoreCompactLattice, "tmodel" will be NULL and speedup_factor will be 1.0.
*/
bool RescoreCompactLatticeInternal(
const TransitionInformation *tmodel,
BaseFloat speedup_factor,
DecodableInterface *decodable,
CompactLattice *clat) {
KALDI_ASSERT(speedup_factor >= 1.0);
if (clat->NumStates() == 0) {
KALDI_WARN << "Rescoring empty lattice";
return false;
}
if (!clat->Properties(fst::kTopSorted, true)) {
if (fst::TopSort(clat) == false) {
KALDI_WARN << "Cycles detected in lattice.";
return false;
}
}
std::vector<int32> state_times;
int32 utt_len = kaldi::CompactLatticeStateTimes(*clat, &state_times);
std::vector<std::vector<ClatRescoreTuple> > time_to_state(utt_len);
int32 num_states = clat->NumStates();
KALDI_ASSERT(num_states == state_times.size());
for (size_t state = 0; state < num_states; state++) {
KALDI_ASSERT(state_times[state] >= 0);
int32 t = state_times[state];
int32 arc_id = 0;
for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
!aiter.Done(); aiter.Next(), arc_id++) {
CompactLatticeArc arc = aiter.Value();
std::vector<int32> arc_string = arc.weight.String();
for (size_t offset = 0; offset < arc_string.size(); offset++) {
if (t < utt_len) { // end state may be past this..
int32 tid = arc_string[offset];
time_to_state[t+offset].push_back(ClatRescoreTuple(state, arc_id, tid));
} else {
if (t != utt_len) {
KALDI_WARN << "There appears to be lattice/feature mismatch, "
<< "aborting.";
return false;
}
}
}
}
if (clat->Final(state) != CompactLatticeWeight::Zero()) {
arc_id = -1;
std::vector<int32> arc_string = clat->Final(state).String();
for (size_t offset = 0; offset < arc_string.size(); offset++) {
KALDI_ASSERT(t + offset < utt_len); // already checked in
// CompactLatticeStateTimes, so would be code error.
time_to_state[t+offset].push_back(
ClatRescoreTuple(state, arc_id, arc_string[offset]));
}
}
}
for (int32 t = 0; t < utt_len; t++) {
if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
KALDI_WARN << "Features are too short for lattice: utt-len is "
<< utt_len << ", " << t << " is last frame";
return false;
}
// frame_scale is the scale we put on the computed acoustic probs for this
// frame. It will always be 1.0 if tmodel == NULL (i.e. if we are not doing
// the "speedup" code). For frames with multiple pdf-ids it will be one.
// For frames with only one pdf-id, it will equal speedup_factor (>=1.0)
// with probability 1.0 / speedup_factor, and zero otherwise. If it is zero,
// we can avoid computing the probabilities.
BaseFloat frame_scale = 1.0;
KALDI_ASSERT(!time_to_state[t].empty());
if (tmodel != NULL) {
int32 pdf_id = tmodel->TransitionIdToPdf(time_to_state[t][0].tid);
bool frame_has_multiple_pdfs = false;
for (size_t i = 1; i < time_to_state[t].size(); i++) {
if (tmodel->TransitionIdToPdf(time_to_state[t][i].tid) != pdf_id) {
frame_has_multiple_pdfs = true;
break;
}
}
if (frame_has_multiple_pdfs) {
frame_scale = 1.0;
} else {
if (WithProb(1.0 / speedup_factor)) {
frame_scale = speedup_factor;
} else {
frame_scale = 0.0;
}
}
if (frame_scale == 0.0)
continue; // the code below would be pointless.
}
for (size_t i = 0; i < time_to_state[t].size(); i++) {
int32 state = time_to_state[t][i].state_id;
int32 arc_id = time_to_state[t][i].arc_id;
int32 tid = time_to_state[t][i].tid;
if (arc_id == -1) { // Final state
// Access the trans_id
CompactLatticeWeight curr_clat_weight = clat->Final(state);
// Calculate likelihood
BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
// update weight
CompactLatticeWeight new_clat_weight = curr_clat_weight;
LatticeWeight new_lat_weight = new_clat_weight.Weight();
new_lat_weight.SetValue2(-log_like + curr_clat_weight.Weight().Value2());
new_clat_weight.SetWeight(new_lat_weight);
clat->SetFinal(state, new_clat_weight);
} else {
fst::MutableArcIterator<CompactLattice> aiter(clat, state);
aiter.Seek(arc_id);
CompactLatticeArc arc = aiter.Value();
// Calculate likelihood
BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
// update weight
LatticeWeight new_weight = arc.weight.Weight();
new_weight.SetValue2(-log_like + arc.weight.Weight().Value2());
arc.weight.SetWeight(new_weight);
aiter.SetValue(arc);
}
}
}
return true;
}
bool RescoreCompactLatticeSpeedup(
const TransitionInformation &tmodel,
BaseFloat speedup_factor,
DecodableInterface *decodable,
CompactLattice *clat) {
return RescoreCompactLatticeInternal(&tmodel, speedup_factor, decodable, clat);
}
bool RescoreCompactLattice(DecodableInterface *decodable,
CompactLattice *clat) {
return RescoreCompactLatticeInternal(NULL, 1.0, decodable, clat);
}
bool RescoreLattice(DecodableInterface *decodable,
Lattice *lat) {
if (lat->NumStates() == 0) {
KALDI_WARN << "Rescoring empty lattice";
return false;
}
if (!lat->Properties(fst::kTopSorted, true)) {
if (fst::TopSort(lat) == false) {
KALDI_WARN << "Cycles detected in lattice.";
return false;
}
}
std::vector<int32> state_times;
int32 utt_len = kaldi::LatticeStateTimes(*lat, &state_times);
std::vector<std::vector<int32> > time_to_state(utt_len );
int32 num_states = lat->NumStates();
KALDI_ASSERT(num_states == state_times.size());
for (size_t state = 0; state < num_states; state++) {
int32 t = state_times[state];
// Don't check t >= 0 because non-accessible states could have t = -1.
KALDI_ASSERT(t <= utt_len);
if (t >= 0 && t < utt_len)
time_to_state[t].push_back(state);
}
for (int32 t = 0; t < utt_len; t++) {
if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
KALDI_WARN << "Features are too short for lattice: utt-len is "
<< utt_len << ", " << t << " is last frame";
return false;
}
for (size_t i = 0; i < time_to_state[t].size(); i++) {
int32 state = time_to_state[t][i];
for (fst::MutableArcIterator<Lattice> aiter(lat, state);
!aiter.Done(); aiter.Next()) {
LatticeArc arc = aiter.Value();
if (arc.ilabel != 0) {
int32 trans_id = arc.ilabel; // Note: it doesn't necessarily
// have to be a transition-id, just whatever the Decodable
// object is expecting, but it's normally a transition-id.
BaseFloat log_like = decodable->LogLikelihood(t, trans_id);
arc.weight.SetValue2(-log_like + arc.weight.Value2());
aiter.SetValue(arc);
}
}
}
}
return true;
}
int32 LongestSentenceLength(const Lattice &lat) {
typedef Lattice::Arc Arc;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
if (lat.Properties(fst::kTopSorted, true) == 0) {
Lattice lat_copy(lat);
if (!TopSort(&lat_copy))
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
return LongestSentenceLength(lat_copy);
}
std::vector<int32> max_length(lat.NumStates(), 0);
int32 lattice_max_length = 0;
for (StateId s = 0; s < lat.NumStates(); s++) {
int32 this_max_length = max_length[s];
for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
bool arc_has_word = (arc.olabel != 0);
StateId nextstate = arc.nextstate;
KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
if (arc_has_word) {
// A lattice should ideally not have cycles anyway; a cycle with a word
// on is something very bad.
KALDI_ASSERT(nextstate > s && "Lattice has cycles with words on.");
max_length[nextstate] = std::max(max_length[nextstate],
this_max_length + 1);
} else {
max_length[nextstate] = std::max(max_length[nextstate],
this_max_length);
}
}
if (lat.Final(s) != LatticeWeight::Zero())
lattice_max_length = std::max(lattice_max_length, max_length[s]);
}
return lattice_max_length;
}
int32 LongestSentenceLength(const CompactLattice &clat) {
typedef CompactLattice::Arc Arc;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
if (clat.Properties(fst::kTopSorted, true) == 0) {
CompactLattice clat_copy(clat);
if (!TopSort(&clat_copy))
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
return LongestSentenceLength(clat_copy);
}
std::vector<int32> max_length(clat.NumStates(), 0);
int32 lattice_max_length = 0;
for (StateId s = 0; s < clat.NumStates(); s++) {
int32 this_max_length = max_length[s];
for (fst::ArcIterator<CompactLattice> aiter(clat, s);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
bool arc_has_word = (arc.ilabel != 0); // note: olabel == ilabel.
// also note: for normal CompactLattice, e.g. as produced by
// determinization, all arcs will have nonzero labels, but the user might
// decide to remplace some of the labels with zero for some reason, and we
// want to support this.
StateId nextstate = arc.nextstate;
KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
KALDI_ASSERT(nextstate > s && "CompactLattice has cycles");
if (arc_has_word)
max_length[nextstate] = std::max(max_length[nextstate],
this_max_length + 1);
else
max_length[nextstate] = std::max(max_length[nextstate],
this_max_length);
}
if (clat.Final(s) != CompactLatticeWeight::Zero())
lattice_max_length = std::max(lattice_max_length, max_length[s]);
}
return lattice_max_length;
}
void ComposeCompactLatticeDeterministic(
const CompactLattice& clat,
fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
CompactLattice* composed_clat) {
// StdFst::Arc and CompactLatticeArc has the same StateId type.
typedef fst::StdArc::StateId StateId;
typedef fst::StdArc::Weight Weight1;
typedef CompactLatticeArc::Weight Weight2;
typedef std::pair<StateId, StateId> StatePair;
typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
typedef MapType::iterator IterType;
// Empties the output FST.
KALDI_ASSERT(composed_clat != NULL);
composed_clat->DeleteStates();
MapType state_map;
std::queue<StatePair> state_queue;
// Sets start state in <composed_clat>.
StateId start_state = composed_clat->AddState();
StatePair start_pair(clat.Start(), det_fst->Start());
composed_clat->SetStart(start_state);
state_queue.push(start_pair);
std::pair<IterType, bool> result =
state_map.insert(std::make_pair(start_pair, start_state));
KALDI_ASSERT(result.second == true);
// Starts composition here.
while (!state_queue.empty()) {
// Gets the first state in the queue.
StatePair s = state_queue.front();
StateId s1 = s.first;
StateId s2 = s.second;
state_queue.pop();
Weight2 clat_final = clat.Final(s1);
if (clat_final.Weight().Value1() !=
std::numeric_limits<BaseFloat>::infinity()) {
// Test for whether the final-prob of state s1 was zero.
Weight1 det_fst_final = det_fst->Final(s2);
if (det_fst_final.Value() !=
std::numeric_limits<BaseFloat>::infinity()) {
// Test for whether the final-prob of state s2 was zero. If neither
// source-state final prob was zero, then we should create final state
// in fst_composed. We compute the product manually since this is more
// efficient.
Weight2 final_weight(LatticeWeight(clat_final.Weight().Value1() +
det_fst_final.Value(),
clat_final.Weight().Value2()),
clat_final.String());
// we can assume final_weight is not Zero(), since neither of
// the sources was zero.
KALDI_ASSERT(state_map.find(s) != state_map.end());
composed_clat->SetFinal(state_map[s], final_weight);
}
}
// Loops over pair of edges at s1 and s2.
for (fst::ArcIterator<CompactLattice> aiter(clat, s1);
!aiter.Done(); aiter.Next()) {
const CompactLatticeArc& arc1 = aiter.Value();
fst::StdArc arc2;
StateId next_state1 = arc1.nextstate, next_state2;
bool matched = false;
if (arc1.olabel == 0) {
// If the symbol on <arc1> is <epsilon>, we transit to the next state
// for <clat>, but keep <det_fst> at the current state.
matched = true;
next_state2 = s2;
} else {
// Otherwise try to find the matched arc in <det_fst>.
matched = det_fst->GetArc(s2, arc1.olabel, &arc2);
if (matched) {
next_state2 = arc2.nextstate;
}
}
// If matched arc is found in <det_fst>, then we have to add new arcs to
// <composed_clat>.
if (matched) {
StatePair next_state_pair(next_state1, next_state2);
IterType siter = state_map.find(next_state_pair);
StateId next_state;
// Adds composed state to <state_map>.
if (siter == state_map.end()) {
// If the composed state has not been created yet, create it.
next_state = composed_clat->AddState();
std::pair<const StatePair, StateId> next_state_map(next_state_pair,
next_state);
std::pair<IterType, bool> result = state_map.insert(next_state_map);
KALDI_ASSERT(result.second);
state_queue.push(next_state_pair);
} else {
// If the composed state is already in <state_map>, we can directly
// use that.
next_state = siter->second;
}
// Adds arc to <composed_clat>.
if (arc1.olabel == 0) {
composed_clat->AddArc(state_map[s],
CompactLatticeArc(arc1.ilabel, 0,
arc1.weight, next_state));
} else {
Weight2 composed_weight(
LatticeWeight(arc1.weight.Weight().Value1() +
arc2.weight.Value(),
arc1.weight.Weight().Value2()),
arc1.weight.String());
composed_clat->AddArc(state_map[s],
CompactLatticeArc(arc1.ilabel, arc2.olabel,
composed_weight, next_state));
}
}
}
}
fst::Connect(composed_clat);
}
void ComputeAcousticScoresMap(
const Lattice &lat,
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
PairHasher<int32> > *acoustic_scores) {
// typedef the arc, weight types
typedef Lattice::Arc Arc;
typedef Arc::Weight LatticeWeight;
typedef Arc::StateId StateId;
acoustic_scores->clear();
std::vector<int32> state_times;
LatticeStateTimes(lat, &state_times); // Assumes the input is top sorted
KALDI_ASSERT(lat.Start() == 0);
for (StateId s = 0; s < lat.NumStates(); s++) {
int32 t = state_times[s];
for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
const LatticeWeight &weight = arc.weight;
int32 tid = arc.ilabel;
if (tid != 0) {
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
PairHasher<int32> >::iterator it = acoustic_scores->find(std::make_pair(t, tid));
if (it == acoustic_scores->end()) {
acoustic_scores->insert(std::make_pair(std::make_pair(t, tid),
std::make_pair(weight.Value2(), 1)));
} else {
if (it->second.second == 2
&& it->second.first / it->second.second != weight.Value2()) {
KALDI_VLOG(2) << "Transitions on the same frame have different "
<< "acoustic costs for tid " << tid << "; "
<< it->second.first / it->second.second
<< " vs " << weight.Value2();
}
it->second.first += weight.Value2();
it->second.second++;
}
} else {
// Arcs with epsilon input label (tid) must have 0 acoustic cost
KALDI_ASSERT(weight.Value2() == 0);
}
}
LatticeWeight f = lat.Final(s);
if (f != LatticeWeight::Zero()) {
// Final acoustic cost must be 0 as we are reading from
// non-determinized, non-compact lattice
KALDI_ASSERT(f.Value2() == 0.0);
}
}
}
void ReplaceAcousticScoresFromMap(
const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
PairHasher<int32> > &acoustic_scores,
Lattice *lat) {
// typedef the arc, weight types
typedef Lattice::Arc Arc;
typedef Arc::Weight LatticeWeight;
typedef Arc::StateId StateId;
TopSortLatticeIfNeeded(lat);
std::vector<int32> state_times;
LatticeStateTimes(*lat, &state_times);
KALDI_ASSERT(lat->Start() == 0);
for (StateId s = 0; s < lat->NumStates(); s++) {
int32 t = state_times[s];
for (fst::MutableArcIterator<Lattice> aiter(lat, s);
!aiter.Done(); aiter.Next()) {
Arc arc(aiter.Value());
int32 tid = arc.ilabel;
if (tid != 0) {
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
PairHasher<int32> >::const_iterator it = acoustic_scores.find(std::make_pair(t, tid));
if (it == acoustic_scores.end()) {
KALDI_ERR << "Could not find tid " << tid << " at time " << t
<< " in the acoustic scores map.";
} else {
arc.weight.SetValue2(it->second.first / it->second.second);
}
} else {
// For epsilon arcs, set acoustic cost to 0.0
arc.weight.SetValue2(0.0);
}
aiter.SetValue(arc);
}
LatticeWeight f = lat->Final(s);
if (f != LatticeWeight::Zero()) {
// Set final acoustic cost to 0.0
f.SetValue2(0.0);
lat->SetFinal(s, f);
}
}
}
} // namespace kaldi
// lat/lattice-functions.h
// Copyright 2009-2012 Saarland University (author: Arnab Ghoshal)
// 2012-2013 Johns Hopkins University (Author: Daniel Povey);
// Bagher BabaAli
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_LATTICE_FUNCTIONS_H_
#define KALDI_LAT_LATTICE_FUNCTIONS_H_
#include <vector>
#include <map>
#include "base/kaldi-common.h"
#include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h"
#include "itf/transition-information.h"
#include "lat/kaldi-lattice.h"
namespace kaldi {
// Redundant with the typedef in hmm/posterior.h. We want functions
// using the Posterior type to be usable without a dependency on the
// hmm library.
typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
/**
This function extracts the per-frame log likelihoods from a linear
lattice (which we refer to as an 'nbest' lattice elsewhere in Kaldi code).
The dimension of *per_frame_loglikes will be set to the
number of input symbols in 'nbest'. The elements of
'*per_frame_loglikes' will be set to the .Value2() elements of the lattice
weights, which represent the acoustic costs; you may want to scale this
vector afterward by -1/acoustic_scale to get the original loglikes.
If there are acoustic costs on input-epsilon arcs or the final-prob in 'nbest'
(and this should not normally be the case in situations where it makes
sense to call this function), they will be included to the cost of the
preceding input symbol, or the following input symbol for input-epsilons
encountered prior to any input symbol. If 'nbest' has no input symbols,
'per_frame_loglikes' will be set to the empty vector.
**/
void GetPerFrameAcousticCosts(const Lattice &nbest,
Vector<BaseFloat> *per_frame_loglikes);
/// This function iterates over the states of a topologically sorted lattice and
/// counts the time instance corresponding to each state. The times are returned
/// in a vector of integers 'times' which is resized to have a size equal to the
/// number of states in the lattice. The function also returns the maximum time
/// in the lattice (this will equal the number of frames in the file).
int32 LatticeStateTimes(const Lattice &lat, std::vector<int32> *times);
/// As LatticeStateTimes, but in the CompactLattice format. Note: must
/// be topologically sorted. Returns length of the utterance in frames, which
/// might not be the same as the maximum time in the lattice, due to frames
/// in the final-prob.
int32 CompactLatticeStateTimes(const CompactLattice &clat,
std::vector<int32> *times);
/// This function does the forward-backward over lattices and computes the
/// posterior probabilities of the arcs. It returns the total log-probability
/// of the lattice. The Posterior quantities contain pairs of (transition-id, weight)
/// on each frame.
/// If the pointer "acoustic_like_sum" is provided, this value is set to
/// the sum over the arcs, of the posterior of the arc times the
/// acoustic likelihood [i.e. negated acoustic score] on that link.
/// This is used in combination with other quantities to work out
/// the objective function in MMI discriminative training.
BaseFloat LatticeForwardBackward(const Lattice &lat,
Posterior *arc_post,
double *acoustic_like_sum = NULL);
// This function is something similar to LatticeForwardBackward(), but it is on
// the CompactLattice lattice format. Also we only need the alpha in the forward
// path, not the posteriors.
bool ComputeCompactLatticeAlphas(const CompactLattice &lat,
std::vector<double> *alpha);
// A sibling of the function CompactLatticeAlphas()... We compute the beta from
// the backward path here.
bool ComputeCompactLatticeBetas(const CompactLattice &lat,
std::vector<double> *beta);
// Computes (normal or Viterbi) alphas and betas; returns (total-prob, or
// best-path negated cost) Note: in either case, the alphas and betas are
// negated costs. Requires that lat be topologically sorted. This code
// will work for either CompactLattice or Lattice.
template<typename LatticeType>
double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
bool viterbi,
std::vector<double> *alpha,
std::vector<double> *beta);
/// Topologically sort the compact lattice if not already topologically sorted.
/// Will crash if the lattice cannot be topologically sorted.
void TopSortCompactLatticeIfNeeded(CompactLattice *clat);
/// Topologically sort the lattice if not already topologically sorted.
/// Will crash if lattice cannot be topologically sorted.
void TopSortLatticeIfNeeded(Lattice *clat);
/// Returns the depth of the lattice, defined as the average number of arcs (or
/// final-prob strings) crossing any given frame. Returns 1 for empty lattices.
/// Requires that clat is topologically sorted!
BaseFloat CompactLatticeDepth(const CompactLattice &clat,
int32 *num_frames = NULL);
/// This function returns, for each frame, the number of arcs crossing that
/// frame.
void CompactLatticeDepthPerFrame(const CompactLattice &clat,
std::vector<int32> *depth_per_frame);
/// This function limits the depth of the lattice, per frame: that means, it
/// does not allow more than a specified number of arcs active on any given
/// frame. This can be used to reduce the size of the "very deep" portions of
/// the lattice.
void CompactLatticeLimitDepth(int32 max_arcs_per_frame,
CompactLattice *clat);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// outputs for each frame the set of phones active on that frame. If
/// sil_phones (which must be sorted and uniq) is nonempty, it excludes
/// phones in this list.
void LatticeActivePhones(const Lattice &lat, const TransitionInformation &trans,
const std::vector<int32> &sil_phones,
std::vector<std::set<int32> > *active_phones);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// replace the output symbols (presumably words), with phones; we
/// use the TransitionModel to work out the phone sequence. Note
/// that the phone labels are not exactly aligned with the phone
/// boundaries. We put a phone label to coincide with any transition
/// to the final, nonemitting state of a phone (this state always exists,
/// we ensure this in HmmTopology::Check()). This would be the last
/// transition-id in the phone if reordering is not done (but typically
/// we do reorder).
/// Also see PhoneAlignLattice, in phone-align-lattice.h.
void ConvertLatticeToPhones(const TransitionInformation &trans_model,
Lattice *lat);
/// Prunes a lattice or compact lattice. Returns true on success, false if
/// there was some kind of failure.
template<class LatticeType>
bool PruneLattice(BaseFloat beam, LatticeType *lat);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// replace the sequences of transition-ids with sequences of phones.
/// Note that this is different from ConvertLatticeToPhones, in that
/// we replace the transition-ids not the words.
void ConvertCompactLatticeToPhones(const TransitionInformation &trans_model,
CompactLattice *clat);
/// Boosts LM probabilities by b * [number of frame errors]; equivalently, adds
/// -b*[number of frame errors] to the graph-component of the cost of each arc/path.
/// There is a frame error if a particular transition-id on a particular frame
/// corresponds to a phone not matching transcription's alignment for that frame.
/// This is used in "margin-inspired" discriminative training, esp. Boosted MMI.
/// The TransitionInformation is used to map transition-ids in the lattice
/// input-side to phones; the phones appearing in
/// "silence_phones" are treated specially in that we replace the frame error f
/// (either zero or 1) for a frame, with the minimum of f or max_silence_error.
/// For the normal recipe, max_silence_error would be zero.
/// Returns true on success, false if there was some kind of mismatch.
/// At input, silence_phones must be sorted and unique.
bool LatticeBoost(const TransitionInformation &trans,
const std::vector<int32> &alignment,
const std::vector<int32> &silence_phones,
BaseFloat b,
BaseFloat max_silence_error,
Lattice *lat);
/**
This function implements either the MPFE (minimum phone frame error) or SMBR
(state-level minimum bayes risk) forward-backward, depending on whether
"criterion" is "mpfe" or "smbr". It returns the MPFE
criterion of SMBR criterion for this utterance, and outputs the posteriors (which
may be positive or negative) into "post".
@param [in] trans The transition model. Used to map the
transition-ids to phones or pdfs.
@param [in] silence_phones A list of integer ids of silence phones. The
silence frames i.e. the frames where num_ali
corresponds to a silence phones are treated specially.
The behavior is determined by 'one_silence_class'
being false (traditional behavior) or true.
Usually in our setup, several phones including
the silence, vocalized noise, non-spoken noise
and unk are treated as "silence phones"
@param [in] lat The denominator lattice
@param [in] num_ali The numerator alignment
@param [in] criterion The objective function. Must be "mpfe" or "smbr"
for MPFE (minimum phone frame error) or sMBR
(state minimum bayes risk) training.
@param [in] one_silence_class Determines how the silence frames are treated.
Setting this to false gives the old traditional behavior,
where the silence frames (according to num_ali) are
treated as incorrect. However, this means that the
insertions are not penalized by the objective.
Setting this to true gives the new behaviour, where we
treat silence as any other phone, except that all pdfs
of silence phones are collapsed into a single class for
the frame-error computation. This can possible reduce
the insertions in the trained model. This is closer to
the WER metric that we actually care about, since WER is
generally computed after filtering out noises, but
does penalize insertions.
@param [out] post The "MBR posteriors" i.e. derivatives w.r.t to the
pseudo log-likelihoods of states at each frame.
*/
BaseFloat LatticeForwardBackwardMpeVariants(
const TransitionInformation &trans,
const std::vector<int32> &silence_phones,
const Lattice &lat,
const std::vector<int32> &num_ali,
std::string criterion,
bool one_silence_class,
Posterior *post);
/// This function takes a CompactLattice that should only contain a single
/// linear sequence (e.g. derived from lattice-1best), and that should have been
/// processed so that the arcs in the CompactLattice align correctly with the
/// word boundaries (e.g. by lattice-align-words). It outputs 3 vectors of the
/// same size, which give, for each word in the lattice (in sequence), the word
/// label and the begin time and length in frames. This is done even for zero
/// (epsilon) words, generally corresponding to optional silence-- if you don't
/// want them, just ignore them in the output.
/// This function will print a warning and return false, if the lattice
/// did not have the correct format (e.g. if it is empty or it is not
/// linear).
bool CompactLatticeToWordAlignment(const CompactLattice &clat,
std::vector<int32> *words,
std::vector<int32> *begin_times,
std::vector<int32> *lengths);
/// A form of the shortest-path/best-path algorithm that's specially coded for
/// CompactLattice. Requires that clat be acyclic.
void CompactLatticeShortestPath(const CompactLattice &clat,
CompactLattice *shortest_path);
/// This function expands a CompactLattice to ensure high-probability paths
/// have unique histories. Arcs with posteriors larger than epsilon get splitted.
void ExpandCompactLattice(const CompactLattice &clat,
double epsilon,
CompactLattice *expand_clat);
/// For each state, compute forward and backward best (viterbi) costs and its
/// traceback states (for generating best paths later). The forward best cost
/// for a state is the cost of the best path from the start state to the state.
/// The traceback state of this state is its predecessor state in the best path.
/// The backward best cost for a state is the cost of the best path from the
/// state to a final one. Its traceback state is the successor state in the best
/// path in the forward direction.
/// Note: final weights of states are in backward_best_cost_and_pred.
/// Requires the input CompactLattice clat be acyclic.
typedef std::vector<std::pair<double,
CompactLatticeArc::StateId> > CostTraceType;
void CompactLatticeBestCostsAndTracebacks(
const CompactLattice &clat,
CostTraceType *forward_best_cost_and_pred,
CostTraceType *backward_best_cost_and_pred);
/// This function adds estimated neural language model scores of words in a
/// minimal list of hypotheses that covers a lattice, to the graph scores on the
/// arcs. The list of hypotheses are generated by latbin/lattice-path-cover.
typedef unordered_map<std::pair<int32, int32>, double, PairHasher<int32> > MapT;
void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
CompactLattice *clat);
/// This function add the word insertion penalty to graph score of each word
/// in the compact lattice
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
CompactLattice *clat);
/// This function *adds* the negated scores obtained from the Decodable object,
/// to the acoustic scores on the arcs. If you want to replace them, you should
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
/// true on success, false on error (typically some kind of mismatched inputs).
bool RescoreCompactLattice(DecodableInterface *decodable,
CompactLattice *clat);
/// This function returns the number of words in the longest sentence in a
/// CompactLattice (i.e. the the maximum of any path, of the count of
/// olabels on that path).
int32 LongestSentenceLength(const Lattice &lat);
/// This function returns the number of words in the longest sentence in a
/// CompactLattice, i.e. the the maximum of any path, of the count of
/// labels on that path... note, in CompactLattice, the ilabels and olabels
/// are identical because it is an acceptor.
int32 LongestSentenceLength(const CompactLattice &lat);
/// This function is like RescoreCompactLattice, but it is modified to avoid
/// computing probabilities on most frames where all the pdf-ids are the same.
/// (it needs the transition-model to work out whether two transition-ids map to
/// the same pdf-id, and it assumes that the lattice has transition-ids on it).
/// The naive thing would be to just set all probabilities to zero on frames
/// where all the pdf-ids are the same (because this value won't affect the
/// lattice posterior). But this would become confusing when we compute
/// corpus-level diagnostics such as the MMI objective function. Instead,
/// imagine speedup_factor = 100 (it must be >= 1.0)... with probability (1.0 /
/// speedup_factor) we compute those likelihoods and multiply them by
/// speedup_factor; otherwise we set them to zero. This gives the right
/// expected probability so our corpus-level diagnostics will be about right.
bool RescoreCompactLatticeSpeedup(
const TransitionInformation &tmodel,
BaseFloat speedup_factor,
DecodableInterface *decodable,
CompactLattice *clat);
/// This function *adds* the negated scores obtained from the Decodable object,
/// to the acoustic scores on the arcs. If you want to replace them, you should
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
/// true on success, false on error (e.g. some kind of mismatched inputs).
/// The input labels, if nonzero, are interpreted as transition-ids or whatever
/// other index the Decodable object expects.
bool RescoreLattice(DecodableInterface *decodable,
Lattice *lat);
/// This function Composes a CompactLattice format lattice with a
/// DeterministicOnDemandFst<fst::StdFst> format fst, and outputs another
/// CompactLattice format lattice. The first element (the one that corresponds
/// to LM weight) in CompactLatticeWeight is used for composition.
///
/// Note that the DeterministicOnDemandFst interface is not "const", therefore
/// we cannot use "const" for <det_fst>.
void ComposeCompactLatticeDeterministic(
const CompactLattice& clat,
fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
CompactLattice* composed_clat);
/// This function computes the mapping from the pair
/// (frame-index, transition-id) to the pair
/// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the
/// transition-id in that frame.
/// frame-index in the lattice.
/// This function is useful for retaining the acoustic scores in a
/// non-compact lattice after a process like determinization where the
/// frame-level acoustic scores are typically lost.
/// The function ReplaceAcousticScoresFromMap is used to restore the
/// acoustic scores computed by this function.
///
/// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the
/// function will crash.
/// @param [out] acoustic_scores
/// Pointer to a map from the pair (frame-index,
/// transition-id) to a pair (sum-of-acoustic-scores,
/// num-of-occurences).
/// Usually the acoustic scores for a pdf-id (and hence
/// transition-id) on a frame will be the same for all the
/// occurences of the pdf-id in that frame.
/// But if not, we will take the average of the acoustic
/// scores. Hence, we store both the sum-of-acoustic-scores
/// and the num-of-occurences of the transition-id in that
/// frame.
void ComputeAcousticScoresMap(
const Lattice &lat,
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
PairHasher<int32> > *acoustic_scores);
/// This function restores acoustic scores computed using the function
/// ComputeAcousticScoresMap into the lattice.
///
/// @param [in] acoustic_scores
/// A map from the pair (frame-index, transition-id) to a
/// pair (sum-of-acoustic-scores, num-of-occurences) of
/// the occurences of the transition-id in that frame.
/// See the comments for ComputeAcousticScoresMap for
/// details.
/// @param [out] lat Pointer to the output lattice.
void ReplaceAcousticScoresFromMap(
const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
PairHasher<int32> > &acoustic_scores,
Lattice *lat);
} // namespace kaldi
#endif // KALDI_LAT_LATTICE_FUNCTIONS_H_
aux_source_directory(. DIR_LIB_SRCS)
add_library(nnet STATIC ${DIR_LIB_SRCS})
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Mirko Hannemann; Go Vivace Inc.;
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_ITF_DECODABLE_ITF_H_
#define KALDI_ITF_DECODABLE_ITF_H_ 1
#include "base/kaldi-common.h"
namespace kaldi {
/// @ingroup Interfaces
/// @{
/**
DecodableInterface provides a link between the (acoustic-modeling and
feature-processing) code and the decoder. The idea is to make this
interface as small as possible, and to make it as agnostic as possible about
the form of the acoustic model (e.g. don't assume the probabilities are a
function of just a vector of floats), and about the decoder (e.g. don't
assume it accesses frames in strict left-to-right order). For normal
models, without on-line operation, the "decodable" sub-class will just be a
wrapper around a matrix of features and an acoustic model, and it will
answer the question 'what is the acoustic likelihood for this index and this
frame?'.
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
code, the decoder would do:
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
// Process this frame
}
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
\code{.cc}
while (num_frames_decoded_ < decodable.NumFramesReady()) {
// Decode one more frame [increments num_frames_decoded_]
}
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
For truly online decoding, the "old" online decodable objects in ../online/
have a "blocking" IsLastFrame() and will crash if you call NumFramesReady().
The "new" online decodable objects in ../online2/ return the number of frames
currently accessible if you call NumFramesReady(). You will likely not need
to call IsLastFrame(), but we implement it to only return true for the last
frame of the file once we've decided to terminate decoding.
*/
class DecodableInterface {
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0;
/// The call NumFramesReady() will return the number of frames currently available
/// for this decodable object. This is for use in setups where you don't want the
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// know when to stop decoding.
virtual int32 NumFramesReady() const {
KALDI_ERR << "NumFramesReady() not implemented for this decodable type.";
return -1;
}
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
virtual ~DecodableInterface() {}
};
/// @}
} // namespace Kaldi
#endif // KALDI_ITF_DECODABLE_ITF_H_
#include "nnet/decodable.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Matrix;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet):
frontend_(NULL),
nnet_(nnet),
finished_(false),
frames_ready_(0) {
}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
}
//Decodable::Init(DecodableConfig config) {
//}
bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1);
}
int32 Decodable::NumIndices() const {
return 0;
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return 0;
}
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows();
return ;
}
void Decodable::Reset() {
// frontend_.Reset();
nnet_->Reset();
}
} // namespace ppspeech
\ No newline at end of file
#include "nnet/decodable-itf.h"
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h"
#include "nnet/nnet_interface.h"
namespace ppspeech {
struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
//void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const;
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later
std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Reset();
void InputFinished() { finished_ = true; }
private:
std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
bool finished_;
int32 frames_ready_;
};
} // namespace ppspeech
#pragma once
#include "base/basic_types.h"
#include "kaldi/base/kaldi-types.h"
#include "kaldi/matrix/kaldi-matrix.h"
namespace ppspeech {
class NnetInterface {
public:
virtual ~NnetInterface() {}
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset();
};
} // namespace ppspeech
\ No newline at end of file
#include "nnet/paddle_nnet.h"
#include "absl/strings/str_split.h"
namespace ppspeech {
using std::vector;
using std::string;
using std::shared_ptr;
using kaldi::Matrix;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ", ");
std::vector<std::string> cache_shapes;
cache_shapes = absl::StrSplit(opts.cache_shape, ", ");
assert(cache_shapes.size() == cache_names.size());
for (size_t i = 0; i < cache_shapes.size(); i++) {
std::vector<std::string> tmp_shape;
tmp_shape = absl::StrSplit(cache_shapes[i], "- ");
std::vector<int> cur_shape;
std::transform(tmp_shape.begin(), tmp_shape.end(),
std::back_inserter(cur_shape),
[](const std::string& s) {
return atoi(s.c_str());
});
cache_names_idx_[cache_names[i]] = i;
std::shared_ptr<Tensor<BaseFloat>> cache_eout = std::make_shared<Tensor<BaseFloat>>(cur_shape);
cache_encouts_.push_back(cache_eout);
}
}
PaddleNnet::PaddleNnet(const ModelOptions& opts) {
paddle_infer::Config config;
config.SetModel(opts.model_path, opts.params_path);
if (opts.use_gpu) {
config.EnableUseGpu(500, 0);
}
config.SwitchIrOptim(opts.switch_ir_optim);
if (opts.enable_fc_padding) {
config.DisableFCPadding();
}
if (opts.enable_profile) {
config.EnableProfile();
}
pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num));
if (pool == nullptr) {
LOG(ERROR) << "create the predictor pool failed";
}
pool_usages.resize(opts.thread_num);
std::fill(pool_usages.begin(), pool_usages.end(), false);
LOG(INFO) << "load paddle model success";
LOG(INFO) << "start to check the predictor input and output names";
LOG(INFO) << "input names: " << opts.input_names;
LOG(INFO) << "output names: " << opts.output_names;
vector<string> input_names_vec = absl::StrSplit(opts.input_names, ", ");
vector<string> output_names_vec = absl::StrSplit(opts.output_names, ", ");
paddle_infer::Predictor* predictor = GetPredictor();
std::vector<std::string> model_input_names = predictor->GetInputNames();
assert(input_names_vec.size() == model_input_names.size());
for (size_t i = 0; i < model_input_names.size(); i++) {
assert(input_names_vec[i] == model_input_names[i]);
}
std::vector<std::string> model_output_names = predictor->GetOutputNames();
assert(output_names_vec.size() == model_output_names.size());
for (size_t i = 0;i < output_names_vec.size(); i++) {
assert(output_names_vec[i] == model_output_names[i]);
}
ReleasePredictor(predictor);
InitCacheEncouts(opts);
}
paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
paddle_infer::Predictor* predictor = nullptr;
std::lock_guard<std::mutex> guard(pool_mutex);
int pred_id = 0;
while (pred_id < pool_usages.size()) {
if (pool_usages[pred_id] == false) {
predictor = pool->Retrive(pred_id);
break;
}
++pred_id;
}
if (predictor) {
pool_usages[pred_id] = true;
predictor_to_thread_id[predictor] = pred_id;
LOG(INFO) << pred_id << " predictor create success";
} else {
LOG(INFO) << "Failed to get predictor from pool !!!";
}
return predictor;
}
int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
LOG(INFO) << "attempt to releae a predictor";
std::lock_guard<std::mutex> guard(pool_mutex);
auto iter = predictor_to_thread_id.find(predictor);
if (iter == predictor_to_thread_id.end()) {
LOG(INFO) << "there is no such predictor";
return 0;
}
LOG(INFO) << iter->second << " predictor will be release";
pool_usages[iter->second] = false;
predictor_to_thread_id.erase(predictor);
LOG(INFO) << "release success";
return 0;
}
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
auto iter = cache_names_idx_.find(name);
if (iter == cache_names_idx_.end()) {
return nullptr;
}
assert(iter->second < cache_encouts_.size());
return cache_encouts_[iter->second];
}
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) {
paddle_infer::Predictor* predictor = GetPredictor();
// 1. 得到所有的 input tensor 的名称
int row = features.NumRows();
int col = features.NumCols();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col=" << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(features.Data());
// 3. 输入每个音频帧数
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
// 输入流式的缓存数据
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data());
bool success = predictor->Run();
if (success == false) {
LOG(INFO) << "predictor run occurs error";
}
LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
assert(c_cache->get_shape() == c_out->shape());
c_out->CopyToCpu(c_cache->get_data().data());
// 5. 得到最后的输出结果
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1];
col = output_shape[2];
inferences->Resize(row, col);
output_tensor->CopyToCpu(inferences->Data());
ReleasePredictor(predictor);
}
} // namespace ppspeech
\ No newline at end of file
#pragma once
#include "nnet/nnet_interface.h"
#include "base/common.h"
#include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
#include <numeric>
namespace ppspeech {
struct ModelOptions {
std::string model_path;
std::string params_path;
int thread_num;
bool use_gpu;
bool switch_ir_optim;
std::string input_names;
std::string output_names;
std::string cache_names;
std::string cache_shape;
bool enable_fc_padding;
bool enable_profile;
ModelOptions() :
model_path("model/final.zip"),
params_path("model/avg_1.jit.pdmodel"),
thread_num(2),
use_gpu(false),
input_names("audio"),
output_names("probs"),
cache_names("enouts"),
cache_shape("1-1-1"),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false) {
}
void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path");
opts->Register("model-params", &params_path, "params model file path");
opts->Register("thread-num", &thread_num, "thread num");
opts->Register("use-gpu", &use_gpu, "if use gpu");
opts->Register("input-names", &input_names, "paddle input names");
opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option");
opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option");
opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option");
}
};
template<typename T>
class Tensor {
public:
Tensor() {
}
Tensor(const std::vector<int>& shape) :
_shape(shape) {
int data_size = std::accumulate(_shape.begin(), _shape.end(),
1, std::multiplies<int>());
LOG(INFO) << "data size: " << data_size;
_data.resize(data_size, 0);
}
void reshape(const std::vector<int>& shape) {
_shape = shape;
int data_size = std::accumulate(_shape.begin(), _shape.end(),
1, std::multiplies<int>());
_data.resize(data_size, 0);
}
const std::vector<int>& get_shape() const {
return _shape;
}
std::vector<T>& get_data() {
return _data;
}
private:
std::vector<int> _shape;
std::vector<T> _data;
};
class PaddleNnet : public NnetInterface {
public:
PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences);
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
void InitCacheEncouts(const ModelOptions& opts);
private:
paddle_infer::Predictor* GetPredictor();
int ReleasePredictor(paddle_infer::Predictor* predictor);
std::unique_ptr<paddle_infer::services::PredictorPool> pool;
std::vector<bool> pool_usages;
std::mutex pool_mutex;
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
std::map<std::string, int> cache_names_idx_;
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
public:
DISALLOW_COPY_AND_ASSIGN(PaddleNnet);
};
} // namespace ppspeech
#include "utils/file_utils.h"
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* vocabulary) {
std::ifstream file_in(filename);
if (!file_in) {
std::cerr << "please input a valid file" << std::endl;
return false;
}
std::string line;
while (std::getline(file_in, line)) {
vocabulary->emplace_back(line);
}
return true;
}
#include "base/common.h"
namespace ppspeech {
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册