未验证 提交 fdc189a3 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1599 from SmileGoat/add_tlg

[Speechx] add tlg decoder
...@@ -48,7 +48,7 @@ wer=./aishell_wer ...@@ -48,7 +48,7 @@ wer=./aishell_wer
nj=40 nj=40
export GLOG_logtostderr=1 export GLOG_logtostderr=1
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj #./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
data=$PWD/data data=$PWD/data
# 3. gen linear feat # 3. gen linear feat
...@@ -72,10 +72,42 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \ ...@@ -72,10 +72,42 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \ --param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \ --dict_file=$lm_model_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result --result_wspecifier=ark,t:$data/split${nj}/JOB/result
cat $data/split${nj}/*/result > $label_file cat $data/split${nj}/*/result > ${label_file}
local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
# 4. decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \
offline_decoder_sliding_chunk_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
cat $data/split${nj}/*/result_lm > ${label_file}_lm
local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
graph_dir=./aishell_graph
if [ ! -d $ ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip -d aishell_graph.zip
fi
# 5. test TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \
offline_wfst_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
local/compute-wer.py --char=1 --v=1 $label_file $text > $wer cat $data/split${nj}/*/result_tlg > ${label_file}_tlg
tail $wer local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg
\ No newline at end of file
...@@ -8,6 +8,10 @@ add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_ ...@@ -8,6 +8,10 @@ add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc) add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
......
...@@ -27,7 +27,7 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); ...@@ -27,7 +27,7 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model"); DEFINE_string(lm_path, "", "language model");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=5) downsampling module."); "receptive field of two CNN(kernel=5) downsampling module.");
...@@ -45,7 +45,6 @@ using kaldi::BaseFloat; ...@@ -45,7 +45,6 @@ using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
// test ds2 online decoder by feeding speech feature // test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
...@@ -63,7 +62,6 @@ int main(int argc, char* argv[]) { ...@@ -63,7 +62,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "dict path: " << dict_file; LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path; LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts; ppspeech::CTCBeamSearchOptions opts;
...@@ -138,10 +136,16 @@ int main(int argc, char* argv[]) { ...@@ -138,10 +136,16 @@ int main(int argc, char* argv[]) {
} }
std::string result; std::string result;
result = decoder.GetFinalBestPath(); result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
decodable->Reset(); decodable->Reset();
decoder.Reset(); decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done; ++num_done;
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test TLG decoder by feeding speech feature.
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string word_symbol_table = FLAGS_word_symbol_table;
std::string graph_path = FLAGS_graph_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "word symbol path: " << word_symbol_table;
LOG(INFO) << "graph path: " << graph_path;
int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts;
opts.word_symbol_table = word_symbol_table;
opts.fst_path = graph_path;
opts.opts.max_active = FLAGS_max_active;
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_downsampling_rate;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
...@@ -6,5 +6,6 @@ add_library(decoder STATIC ...@@ -6,5 +6,6 @@ add_library(decoder STATIC
ctc_decoders/decoder_utils.cpp ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp ctc_decoders/scorer.cpp
ctc_tlg_decoder.cc
) )
target_link_libraries(decoder PUBLIC kenlm utils fst) target_link_libraries(decoder PUBLIC kenlm utils fst)
...@@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode( ...@@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode(
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob; vector<BaseFloat> frame_prob;
bool flag = bool flag =
decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break; if (flag == false) break;
likelihood.push_back(frame_prob); likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood); AdvanceDecoding(likelihood);
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h" #include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h" #include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h" #include "kaldi/decoder/decodable-itf.h"
#include "util/parse-options.h" #include "util/parse-options.h"
#pragma once #pragma once
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr);
word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding();
frame_decoded_size_ = 0;
}
void TLGDecoder::InitDecoder() {
decoder_->InitDecoding();
frame_decoded_size_ = 0;
}
void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) {
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get());
}
}
void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
decoder_->AdvanceDecoding(decodable, 1);
frame_decoded_size_++;
}
void TLGDecoder::Reset() {
InitDecoder();
return;
}
std::string TLGDecoder::GetFinalBestPath() {
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
std::vector<int> words_id;
decoder_->GetBestPath(&lat, true);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
}
}
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/basic_types.h"
#include "kaldi/decoder/decodable-itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
namespace ppspeech {
struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts;
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {}
};
class TLGDecoder {
public:
explicit TLGDecoder(TLGDecoderOptions opts);
void InitDecoder();
void Decode();
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable);
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
// the frame size which have decoded starts from 0.
int32 frame_decoded_size_;
};
} // namespace ppspeech
\ No newline at end of file
...@@ -4,3 +4,6 @@ add_subdirectory(base) ...@@ -4,3 +4,6 @@ add_subdirectory(base)
add_subdirectory(util) add_subdirectory(util)
add_subdirectory(feat) add_subdirectory(feat)
add_subdirectory(matrix) add_subdirectory(matrix)
add_subdirectory(lat)
add_subdirectory(fstext)
add_subdirectory(decoder)
add_library(kaldi-decoder
lattice-faster-decoder.cc
lattice-faster-online-decoder.cc
)
target_link_libraries(kaldi-decoder PUBLIC kaldi-lat)
...@@ -121,7 +121,7 @@ class DecodableInterface { ...@@ -121,7 +121,7 @@ class DecodableInterface {
/// decoding-from-matrix setting where we want to allow the last delta or /// decoding-from-matrix setting where we want to allow the last delta or
/// LDA /// LDA
/// features to be flushed out for compatibility with the baseline setup. /// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0; virtual bool IsLastFrame(int32 frame) = 0;
/// The call NumFramesReady() will return the number of frames currently /// The call NumFramesReady() will return the number of frames currently
/// available /// available
...@@ -143,7 +143,7 @@ class DecodableInterface { ...@@ -143,7 +143,7 @@ class DecodableInterface {
/// this is for compatibility with OpenFst). /// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0; virtual int32 NumIndices() const = 0;
virtual bool FrameLogLikelihood( virtual bool FrameLikelihood(
int32 frame, std::vector<kaldi::BaseFloat>* likelihood) = 0; int32 frame, std::vector<kaldi::BaseFloat>* likelihood) = 0;
......
...@@ -1007,14 +1007,10 @@ template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc>, decoder::StdToken> ...@@ -1007,14 +1007,10 @@ template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc>, decoder::StdToken>
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::StdToken >; template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::StdToken >;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::StdToken >; template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::StdToken >;
template class LatticeFasterDecoderTpl<fst::ConstGrammarFst, decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc> , decoder::BackpointerToken>; template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc> , decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::BackpointerToken >; template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::BackpointerToken >;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::BackpointerToken >; template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::BackpointerToken >;
template class LatticeFasterDecoderTpl<fst::ConstGrammarFst, decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::BackpointerToken>;
} // end namespace kaldi. } // end namespace kaldi.
...@@ -23,11 +23,10 @@ ...@@ -23,11 +23,10 @@
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_ #ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_ #define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#include "decoder/grammar-fst.h"
#include "fst/fstlib.h" #include "fst/fstlib.h"
#include "fst/memory.h" #include "fst/memory.h"
#include "fstext/fstext-lib.h" #include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h" #include "decoder/decodable-itf.h"
#include "lat/determinize-lattice-pruned.h" #include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h" #include "lat/kaldi-lattice.h"
#include "util/hash-list.h" #include "util/hash-list.h"
......
...@@ -278,8 +278,8 @@ bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned( ...@@ -278,8 +278,8 @@ bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >; template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >; template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >; template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >; //template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >; //template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;
} // end namespace kaldi. } // end namespace kaldi.
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include "util/stl-utils.h" #include "util/stl-utils.h"
#include "util/hash-list.h" #include "util/hash-list.h"
#include "fst/fstlib.h" #include "fst/fstlib.h"
#include "itf/decodable-itf.h" #include "decoder/decodable-itf.h"
#include "fstext/fstext-lib.h" #include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h" #include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h" #include "lat/kaldi-lattice.h"
......
add_library(kaldi-fstext
kaldi-fst-io.cc
)
target_link_libraries(kaldi-fstext PUBLIC kaldi-util)
// fstext/determinize-lattice-inl.h
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
#define KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
// Do not include this file directly. It is included by determinize-lattice.h
#include <algorithm>
#include <climits>
#include <deque>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace fst {
// This class maps back and forth from/to integer id's to sequences of strings.
// used in determinization algorithm. It is constructed in such a way that
// finding the string-id of the successor of (string, next-label) has constant
// time.
// Note: class IntType, typically int32, is the type of the element in the
// string (typically a template argument of the CompactLatticeWeightTpl).
template <class IntType>
class LatticeStringRepository {
public:
struct Entry {
const Entry *parent; // NULL for empty string.
IntType i;
inline bool operator==(const Entry &other) const {
return (parent == other.parent && i == other.i);
}
Entry() {}
Entry(const Entry &e) : parent(e.parent), i(e.i) {}
};
// Note: all Entry* pointers returned in function calls are
// owned by the repository itself, not by the caller!
// Interface guarantees empty string is NULL.
inline const Entry *EmptyString() { return NULL; }
// Returns string of "parent" with i appended. Pointer
// owned by repository
const Entry *Successor(const Entry *parent, IntType i) {
new_entry_->parent = parent;
new_entry_->i = i;
std::pair<typename SetType::iterator, bool> pr = set_.insert(new_entry_);
if (pr.second) { // Was successfully inserted (was not there). We need to
// replace the element we inserted, which resides on the
// stack, with one from the heap.
const Entry *ans = new_entry_;
new_entry_ = new Entry();
return ans;
} else { // Was not inserted because an equivalent Entry already
// existed.
return *pr.first;
}
}
const Entry *Concatenate(const Entry *a, const Entry *b) {
if (a == NULL)
return b;
else if (b == NULL)
return a;
std::vector<IntType> v;
ConvertToVector(b, &v);
const Entry *ans = a;
for (size_t i = 0; i < v.size(); i++) ans = Successor(ans, v[i]);
return ans;
}
const Entry *CommonPrefix(const Entry *a, const Entry *b) {
std::vector<IntType> a_vec, b_vec;
ConvertToVector(a, &a_vec);
ConvertToVector(b, &b_vec);
const Entry *ans = NULL;
for (size_t i = 0;
i < a_vec.size() && i < b_vec.size() && a_vec[i] == b_vec[i]; i++)
ans = Successor(ans, a_vec[i]);
return ans;
}
// removes any elements from b that are not part of
// a common prefix with a.
void ReduceToCommonPrefix(const Entry *a, std::vector<IntType> *b) {
size_t a_size = Size(a), b_size = b->size();
while (a_size > b_size) {
a = a->parent;
a_size--;
}
if (b_size > a_size) b_size = a_size;
typename std::vector<IntType>::iterator b_begin = b->begin();
while (a_size != 0) {
if (a->i != *(b_begin + a_size - 1)) b_size = a_size - 1;
a = a->parent;
a_size--;
}
if (b_size != b->size()) b->resize(b_size);
}
// removes the first n elements of a.
const Entry *RemovePrefix(const Entry *a, size_t n) {
if (n == 0) return a;
std::vector<IntType> a_vec;
ConvertToVector(a, &a_vec);
assert(a_vec.size() >= n);
const Entry *ans = NULL;
for (size_t i = n; i < a_vec.size(); i++) ans = Successor(ans, a_vec[i]);
return ans;
}
// Returns true if a is a prefix of b. If a is prefix of b,
// time taken is |b| - |a|. Else, time taken is |b|.
bool IsPrefixOf(const Entry *a, const Entry *b) const {
if (a == NULL) return true; // empty string prefix of all.
if (a == b) return true;
if (b == NULL) return false;
return IsPrefixOf(a, b->parent);
}
inline size_t Size(const Entry *entry) const {
size_t ans = 0;
while (entry != NULL) {
ans++;
entry = entry->parent;
}
return ans;
}
void ConvertToVector(const Entry *entry, std::vector<IntType> *out) const {
size_t length = Size(entry);
out->resize(length);
if (entry != NULL) {
typename std::vector<IntType>::reverse_iterator iter = out->rbegin();
while (entry != NULL) {
*iter = entry->i;
entry = entry->parent;
++iter;
}
}
}
const Entry *ConvertFromVector(const std::vector<IntType> &vec) {
const Entry *e = NULL;
for (size_t i = 0; i < vec.size(); i++) e = Successor(e, vec[i]);
return e;
}
LatticeStringRepository() { new_entry_ = new Entry; }
void Destroy() {
for (typename SetType::iterator iter = set_.begin(); iter != set_.end();
++iter)
delete *iter;
SetType tmp;
tmp.swap(set_);
if (new_entry_) {
delete new_entry_;
new_entry_ = NULL;
}
}
// Rebuild will rebuild this object, guaranteeing only
// to preserve the Entry values that are in the vector pointed
// to (this list does not have to be unique). The point of
// this is to save memory.
void Rebuild(const std::vector<const Entry *> &to_keep) {
SetType tmp_set;
for (typename std::vector<const Entry *>::const_iterator iter =
to_keep.begin();
iter != to_keep.end(); ++iter)
RebuildHelper(*iter, &tmp_set);
// Now delete all elems not in tmp_set.
for (typename SetType::iterator iter = set_.begin(); iter != set_.end();
++iter) {
if (tmp_set.count(*iter) == 0)
delete (*iter); // delete the Entry; not needed.
}
set_.swap(tmp_set);
}
~LatticeStringRepository() { Destroy(); }
int32 MemSize() const {
return set_.size() * sizeof(Entry) * 2; // this is a lower bound
// on the size this structure might take.
}
private:
class EntryKey { // Hash function object.
public:
inline size_t operator()(const Entry *entry) const {
size_t prime = 49109;
return static_cast<size_t>(entry->i) +
prime * reinterpret_cast<size_t>(entry->parent);
}
};
class EntryEqual {
public:
inline bool operator()(const Entry *e1, const Entry *e2) const {
return (*e1 == *e2);
}
};
typedef std::unordered_set<const Entry *, EntryKey, EntryEqual> SetType;
void RebuildHelper(const Entry *to_add, SetType *tmp_set) {
while (true) {
if (to_add == NULL) return;
typename SetType::iterator iter = tmp_set->find(to_add);
if (iter == tmp_set->end()) { // not in tmp_set.
tmp_set->insert(to_add);
to_add = to_add->parent; // and loop.
} else {
return;
}
}
}
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeStringRepository);
Entry *new_entry_; // We always have a pre-allocated Entry ready to use,
// to avoid unnecessary news and deletes.
SetType set_;
};
// class LatticeDeterminizer is templated on the same types that
// CompactLatticeWeight is templated on: the base weight (Weight), typically
// LatticeWeightTpl<float> etc. but could also be e.g. TropicalWeight, and the
// IntType, typically int32, used for the output symbols in the compact
// representation of strings [note: the output symbols would usually be
// p.d.f. id's in the anticipated use of this code] It has a special requirement
// on the Weight type: that there should be a Compare function on the weights
// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1
// > w2. This requires that there be a total order on the weights.
template <class Weight, class IntType>
class LatticeDeterminizer {
public:
// Output to Gallic acceptor (so the strings go on weights, and there is a 1-1
// correspondence between our states and the states in ofst. If destroy ==
// true, release memory as we go (but we cannot output again).
typedef CompactLatticeWeightTpl<Weight, IntType> CompactWeight;
typedef ArcTpl<CompactWeight>
CompactArc; // arc in compact, acceptor form of lattice
typedef ArcTpl<Weight> Arc; // arc in non-compact version of lattice
// Output to standard FST with CompactWeightTpl<Weight> as its weight type
// (the weight stores the original output-symbol strings). If destroy ==
// true, release memory as we go (but we cannot output again).
void Output(MutableFst<CompactArc> *ofst, bool destroy = true) {
assert(determinized_);
typedef typename Arc::StateId StateId;
StateId nStates = static_cast<StateId>(output_arcs_.size());
if (destroy) FreeMostMemory();
ofst->DeleteStates();
ofst->SetStart(kNoStateId);
if (nStates == 0) {
return;
}
for (StateId s = 0; s < nStates; s++) {
OutputStateId news = ofst->AddState();
assert(news == s);
}
ofst->SetStart(0);
// now process transitions.
for (StateId this_state = 0; this_state < nStates; this_state++) {
std::vector<TempArc> &this_vec(output_arcs_[this_state]);
typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
end = this_vec.end();
for (; iter != end; ++iter) {
const TempArc &temp_arc(*iter);
CompactArc new_arc;
std::vector<Label> seq;
repository_.ConvertToVector(temp_arc.string, &seq);
CompactWeight weight(temp_arc.weight, seq);
if (temp_arc.nextstate == kNoStateId) { // is really final weight.
ofst->SetFinal(this_state, weight);
} else { // is really an arc.
new_arc.nextstate = temp_arc.nextstate;
new_arc.ilabel = temp_arc.ilabel;
new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
new_arc.weight = weight; // includes string and weight.
ofst->AddArc(this_state, new_arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating
// memory
if (destroy) {
std::vector<TempArc> temp;
std::swap(temp, this_vec);
}
}
if (destroy) {
std::vector<std::vector<TempArc> > temp;
std::swap(temp, output_arcs_);
}
}
// Output to standard FST with Weight as its weight type. We will create
// extra states to handle sequences of symbols on the output. If destroy ==
// true, release memory as we go (but we cannot output again).
void Output(MutableFst<Arc> *ofst, bool destroy = true) {
// Outputs to standard fst.
OutputStateId nStates = static_cast<OutputStateId>(output_arcs_.size());
ofst->DeleteStates();
if (nStates == 0) {
ofst->SetStart(kNoStateId);
return;
}
if (destroy) FreeMostMemory();
// Add basic states-- but we will add extra ones to account for strings on
// output.
for (OutputStateId s = 0; s < nStates; s++) {
OutputStateId news = ofst->AddState();
assert(news == s);
}
ofst->SetStart(0);
for (OutputStateId this_state = 0; this_state < nStates; this_state++) {
std::vector<TempArc> &this_vec(output_arcs_[this_state]);
typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
end = this_vec.end();
for (; iter != end; ++iter) {
const TempArc &temp_arc(*iter);
std::vector<Label> seq;
repository_.ConvertToVector(temp_arc.string, &seq);
if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
// Make a sequence of states going to a final state, with the strings
// as labels. Put the weight on the first arc.
OutputStateId cur_state = this_state;
for (size_t i = 0; i < seq.size(); i++) {
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = 0; // epsilon.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
ofst->SetFinal(cur_state,
(seq.size() == 0 ? temp_arc.weight : Weight::One()));
} else { // Really an arc.
OutputStateId cur_state = this_state;
// Have to be careful with this integer comparison (i+1 < seq.size())
// because unsigned. i < seq.size()-1 could fail for zero-length
// sequences.
for (size_t i = 0; i + 1 < seq.size(); i++) {
// for all but the last element of seq, create new state.
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = (i == 0 ? temp_arc.ilabel
: 0); // put ilabel on first element of seq.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
// Add the final arc in the sequence.
Arc arc;
arc.nextstate = temp_arc.nextstate;
arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
arc.olabel = (seq.size() > 0 ? seq.back() : 0);
ofst->AddArc(cur_state, arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating
// memory
if (destroy) {
std::vector<TempArc> temp;
temp.swap(this_vec);
}
}
if (destroy) {
std::vector<std::vector<TempArc> > temp;
temp.swap(output_arcs_);
repository_.Destroy();
}
}
// Initializer. After initializing the object you will typically
// call Determinize() and then call one of the Output functions.
// Note: ifst.Copy() will generally do a
// shallow copy. We do it like this for memory safety, rather than
// keeping a reference or pointer to ifst_.
LatticeDeterminizer(const Fst<Arc> &ifst, DeterminizeLatticeOptions opts)
: num_arcs_(0),
num_elems_(0),
ifst_(ifst.Copy()),
opts_(opts),
equal_(opts_.delta),
determinized_(false),
minimal_hash_(3, hasher_, equal_),
initial_hash_(3, hasher_, equal_) {
KALDI_ASSERT(Weight::Properties() & kIdempotent); // this algorithm won't
// work correctly otherwise.
}
// frees all except output_arcs_, which contains the important info
// we need to output the FST.
void FreeMostMemory() {
if (ifst_) {
delete ifst_;
ifst_ = NULL;
}
for (typename MinimalSubsetHash::iterator iter = minimal_hash_.begin();
iter != minimal_hash_.end(); ++iter)
delete iter->first;
{
MinimalSubsetHash tmp;
tmp.swap(minimal_hash_);
}
for (typename InitialSubsetHash::iterator iter = initial_hash_.begin();
iter != initial_hash_.end(); ++iter)
delete iter->first;
{
InitialSubsetHash tmp;
tmp.swap(initial_hash_);
}
{
std::vector<std::vector<Element> *> output_states_tmp;
output_states_tmp.swap(output_states_);
}
{
std::vector<char> tmp;
tmp.swap(isymbol_or_final_);
}
{
std::vector<OutputStateId> tmp;
tmp.swap(queue_);
}
{
std::vector<std::pair<Label, Element> > tmp;
tmp.swap(all_elems_tmp_);
}
}
~LatticeDeterminizer() {
FreeMostMemory(); // rest is deleted by destructors.
}
void RebuildRepository() { // rebuild the string repository,
// freeing stuff we don't need.. we call this when memory usage
// passes a supplied threshold. We need to accumulate all the
// strings we need the repository to "remember", then tell it
// to clean the repository.
std::vector<StringId> needed_strings;
for (size_t i = 0; i < output_arcs_.size(); i++)
for (size_t j = 0; j < output_arcs_[i].size(); j++)
needed_strings.push_back(output_arcs_[i][j].string);
// the following loop covers strings present in minimal_hash_
// which are also accessible via output_states_.
for (size_t i = 0; i < output_states_.size(); i++)
for (size_t j = 0; j < output_states_[i]->size(); j++)
needed_strings.push_back((*(output_states_[i]))[j].string);
// the following loop covers strings present in initial_hash_.
for (typename InitialSubsetHash::const_iterator iter =
initial_hash_.begin();
iter != initial_hash_.end(); ++iter) {
const std::vector<Element> &vec = *(iter->first);
Element elem = iter->second;
for (size_t i = 0; i < vec.size(); i++)
needed_strings.push_back(vec[i].string);
needed_strings.push_back(elem.string);
}
std::sort(needed_strings.begin(), needed_strings.end());
needed_strings.erase(
std::unique(needed_strings.begin(), needed_strings.end()),
needed_strings.end()); // uniq the strings.
repository_.Rebuild(needed_strings);
}
bool CheckMemoryUsage() {
int32 repo_size = repository_.MemSize(),
arcs_size = num_arcs_ * sizeof(TempArc),
elems_size = num_elems_ * sizeof(Element),
total_size = repo_size + arcs_size + elems_size;
if (opts_.max_mem > 0 &&
total_size > opts_.max_mem) { // We passed the memory threshold.
// This is usually due to the repository getting large, so we
// clean this out.
RebuildRepository();
int32 new_repo_size = repository_.MemSize(),
new_total_size = new_repo_size + arcs_size + elems_size;
KALDI_VLOG(2) << "Rebuilt repository in determinize-lattice: repository "
"shrank from "
<< repo_size << " to " << new_repo_size
<< " bytes (approximately)";
if (new_total_size > static_cast<int32>(opts_.max_mem * 0.8)) {
// Rebuilding didn't help enough-- we need a margin to stop
// having to rebuild too often.
KALDI_WARN << "Failure in determinize-lattice: size exceeds maximum "
<< opts_.max_mem << " bytes; (repo,arcs,elems) = ("
<< repo_size << "," << arcs_size << "," << elems_size
<< "), after rebuilding, repo size was " << new_repo_size;
return false;
}
}
return true;
}
// Returns true on success. Can fail for out-of-memory
// or max-states related reasons.
bool Determinize(bool *debug_ptr) {
assert(!determinized_);
// This determinizes the input fst but leaves it in the "special format"
// in "output_arcs_". Must be called after Initialize(). To get the
// output, call one of the Output routines.
try {
InitializeDeterminization(); // some start-up tasks.
while (!queue_.empty()) {
OutputStateId out_state = queue_.back();
queue_.pop_back();
ProcessState(out_state);
if (debug_ptr && *debug_ptr) Debug(); // will exit.
if (!CheckMemoryUsage()) return false;
}
return (determinized_ = true);
} catch (const std::bad_alloc &) {
int32 repo_size = repository_.MemSize(),
arcs_size = num_arcs_ * sizeof(TempArc),
elems_size = num_elems_ * sizeof(Element),
total_size = repo_size + arcs_size + elems_size;
KALDI_WARN
<< "Memory allocation error doing lattice determinization; using "
<< total_size << " bytes (max = " << opts_.max_mem
<< " (repo,arcs,elems) = (" << repo_size << "," << arcs_size << ","
<< elems_size << ")";
return (determinized_ = false);
} catch (const std::runtime_error &) {
KALDI_WARN << "Caught exception doing lattice determinization";
return (determinized_ = false);
}
}
private:
typedef typename Arc::Label Label;
typedef typename Arc::StateId
StateId; // use this when we don't know if it's input or output.
typedef typename Arc::StateId InputStateId; // state in the input FST.
typedef typename Arc::StateId OutputStateId; // same as above but distinguish
// states in output Fst.
typedef LatticeStringRepository<IntType> StringRepositoryType;
typedef const typename StringRepositoryType::Entry *StringId;
// Element of a subset [of original states]
struct Element {
StateId state; // use StateId as this is usually InputStateId but in one
// case OutputStateId.
StringId string;
Weight weight;
bool operator!=(const Element &other) const {
return (state != other.state || string != other.string ||
weight != other.weight);
}
// This operator is only intended to support sorting in EpsilonClosure()
bool operator<(const Element &other) const { return state < other.state; }
};
// Arcs in the format we temporarily create in this class (a representation,
// essentially of a Gallic Fst).
struct TempArc {
Label ilabel;
StringId string; // Look it up in the StringRepository, it's a sequence of
// Labels.
OutputStateId nextstate; // or kNoState for final weights.
Weight weight;
};
// Hashing function used in hash of subsets.
// A subset is a pointer to vector<Element>.
// The Elements are in sorted order on state id, and without repeated states.
// Because the order of Elements is fixed, we can use a hashing function that
// is order-dependent. However the weights are not included in the hashing
// function-- we hash subsets that differ only in weight to the same key. This
// is not optimal in terms of the O(N) performance but typically if we have a
// lot of determinized states that differ only in weight then the input
// probably was pathological in some way, or even non-determinizable.
// We don't quantize the weights, in order to avoid inexactness in simple
// cases.
// Instead we apply the delta when comparing subsets for equality, and allow a
// small difference.
class SubsetKey {
public:
size_t operator()(const std::vector<Element> *subset)
const { // hashes only the state and string.
size_t hash = 0, factor = 1;
for (typename std::vector<Element>::const_iterator iter = subset->begin();
iter != subset->end(); ++iter) {
hash *= factor;
hash += iter->state + reinterpret_cast<size_t>(iter->string);
factor *= 23531; // these numbers are primes.
}
return hash;
}
};
// This is the equality operator on subsets. It checks for exact match on
// state-id and string, and approximate match on weights.
class SubsetEqual {
public:
bool operator()(const std::vector<Element> *s1,
const std::vector<Element> *s2) const {
size_t sz = s1->size();
assert(sz >= 0);
if (sz != s2->size()) return false;
typename std::vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(),
iter2 = s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state || iter1->string != iter2->string ||
!ApproxEqual(iter1->weight, iter2->weight, delta_))
return false;
}
return true;
}
float delta_;
explicit SubsetEqual(float delta) : delta_(delta) {}
SubsetEqual() : delta_(kDelta) {}
};
// Operator that says whether two Elements have the same states.
// Used only for debug.
class SubsetEqualStates {
public:
bool operator()(const std::vector<Element> *s1,
const std::vector<Element> *s2) const {
size_t sz = s1->size();
assert(sz >= 0);
if (sz != s2->size()) return false;
typename std::vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(),
iter2 = s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state) return false;
}
return true;
}
};
// Define the hash type we use to map subsets (in minimal
// representation) to OutputStateId.
typedef std::unordered_map<const std::vector<Element> *, OutputStateId,
SubsetKey, SubsetEqual>
MinimalSubsetHash;
// Define the hash type we use to map subsets (in initial
// representation) to OutputStateId, together with an
// extra weight. [note: we interpret the Element.state in here
// as an OutputStateId even though it's declared as InputStateId;
// these types are the same anyway].
typedef std::unordered_map<const std::vector<Element> *, Element, SubsetKey,
SubsetEqual>
InitialSubsetHash;
// converts the representation of the subset from canonical (all states) to
// minimal (only states with output symbols on arcs leaving them, and final
// states). Output is not necessarily normalized, even if input_subset was.
void ConvertToMinimal(std::vector<Element> *subset) {
assert(!subset->empty());
typename std::vector<Element>::iterator cur_in = subset->begin(),
cur_out = subset->begin(),
end = subset->end();
while (cur_in != end) {
if (IsIsymbolOrFinal(cur_in->state)) { // keep it...
*cur_out = *cur_in;
cur_out++;
}
cur_in++;
}
subset->resize(cur_out - subset->begin());
}
// Takes a minimal, normalized subset, and converts it to an OutputStateId.
// Involves a hash lookup, and possibly adding a new OutputStateId.
// If it creates a new OutputStateId, it adds it to the queue.
OutputStateId MinimalToStateId(const std::vector<Element> &subset) {
typename MinimalSubsetHash::const_iterator iter =
minimal_hash_.find(&subset);
if (iter != minimal_hash_.end()) // Found a matching subset.
return iter->second;
OutputStateId ans = static_cast<OutputStateId>(output_arcs_.size());
std::vector<Element> *subset_ptr = new std::vector<Element>(subset);
output_states_.push_back(subset_ptr);
num_elems_ += subset_ptr->size();
output_arcs_.push_back(std::vector<TempArc>());
minimal_hash_[subset_ptr] = ans;
queue_.push_back(ans);
return ans;
}
// Given a normalized initial subset of elements (i.e. before epsilon
// closure), compute the corresponding output-state.
OutputStateId InitialToStateId(const std::vector<Element> &subset_in,
Weight *remaining_weight,
StringId *common_prefix) {
typename InitialSubsetHash::const_iterator iter =
initial_hash_.find(&subset_in);
if (iter != initial_hash_.end()) { // Found a matching subset.
const Element &elem = iter->second;
*remaining_weight = elem.weight;
*common_prefix = elem.string;
if (elem.weight == Weight::Zero()) KALDI_WARN << "Zero weight!"; // TEMP
return elem.state;
}
// else no matching subset-- have to work it out.
std::vector<Element> subset(subset_in);
// Follow through epsilons. Will add no duplicate states. note: after
// EpsilonClosure, it is the same as "canonical" subset, except not
// normalized (actually we never compute the normalized canonical subset,
// only the normalized minimal one).
EpsilonClosure(&subset); // follow epsilons.
ConvertToMinimal(&subset); // remove all but emitting and final states.
Element elem; // will be used to store remaining weight and string, and
// OutputStateId, in initial_hash_;
NormalizeSubset(&subset, &elem.weight,
&elem.string); // normalize subset; put
// common string and weight in "elem". The subset is now a minimal,
// normalized subset.
OutputStateId ans = MinimalToStateId(subset);
*remaining_weight = elem.weight;
*common_prefix = elem.string;
if (elem.weight == Weight::Zero()) KALDI_WARN << "Zero weight!"; // TEMP
// Before returning "ans", add the initial subset to the hash,
// so that we can bypass the epsilon-closure etc., next time
// we process the same initial subset.
std::vector<Element> *initial_subset_ptr =
new std::vector<Element>(subset_in);
elem.state = ans;
initial_hash_[initial_subset_ptr] = elem;
num_elems_ += initial_subset_ptr->size(); // keep track of memory usage.
return ans;
}
// returns the Compare value (-1 if a < b, 0 if a == b, 1 if a > b) according
// to the ordering we defined on strings for the CompactLatticeWeightTpl.
// see function
// inline int Compare (const CompactLatticeWeightTpl<WeightType,IntType> &w1,
// const CompactLatticeWeightTpl<WeightType,IntType> &w2)
// in lattice-weight.h.
// this is the same as that, but optimized for our data structures.
inline int Compare(const Weight &a_w, StringId a_str, const Weight &b_w,
StringId b_str) const {
int weight_comp = fst::Compare(a_w, b_w);
if (weight_comp != 0) return weight_comp;
// now comparing strings.
if (a_str == b_str) return 0;
std::vector<IntType> a_vec, b_vec;
repository_.ConvertToVector(a_str, &a_vec);
repository_.ConvertToVector(b_str, &b_vec);
// First compare their lengths.
int a_len = a_vec.size(), b_len = b_vec.size();
// use opposite order on the string lengths (c.f. Compare in
// lattice-weight.h)
if (a_len > b_len)
return -1;
else if (a_len < b_len)
return 1;
for (int i = 0; i < a_len; i++) {
if (a_vec[i] < b_vec[i])
return -1;
else if (a_vec[i] > b_vec[i])
return 1;
}
assert(
0); // because we checked if a_str == b_str above, shouldn't reach here
return 0;
}
// This function computes epsilon closure of subset of states by following
// epsilon links. Called by InitialToStateId and Initialize. Has no side
// effects except on the string repository. The "output_subset" is not
// necessarily normalized (in the sense of there being no common substring),
// unless input_subset was.
void EpsilonClosure(std::vector<Element> *subset) {
// at input, subset must have only one example of each StateId. [will still
// be so at output]. This function follows input-epsilons, and augments the
// subset accordingly.
std::deque<Element> queue;
std::unordered_map<InputStateId, Element> cur_subset;
typedef
typename std::unordered_map<InputStateId, Element>::iterator MapIter;
typedef typename std::vector<Element>::const_iterator VecIter;
for (VecIter iter = subset->begin(); iter != subset->end(); ++iter) {
queue.push_back(*iter);
cur_subset[iter->state] = *iter;
}
// find whether input fst is known to be sorted on input label.
bool sorted =
((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
bool replaced_elems = false; // relates to an optimization, see below.
int counter =
0; // stops infinite loops here for non-lattice-determinizable input;
// useful in testing.
while (queue.size() != 0) {
Element elem = queue.front();
queue.pop_front();
// The next if-statement is a kind of optimization. It's to prevent us
// unnecessarily repeating the processing of a state. "cur_subset" always
// contains only one Element with a particular state. The issue is that
// whenever we modify the Element corresponding to that state in
// "cur_subset", both the new (optimal) and old (less-optimal) Element
// will still be in "queue". The next if-statement stops us from wasting
// compute by processing the old Element.
if (replaced_elems && cur_subset[elem.state] != elem) continue;
if (opts_.max_loop > 0 && counter++ > opts_.max_loop) {
KALDI_ERR << "Lattice determinization aborted since looped more than "
<< opts_.max_loop << " times during epsilon closure";
}
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (sorted && arc.ilabel != 0)
break; // Break from the loop: due to sorting there will be no
// more transitions with epsilons as input labels.
if (arc.ilabel == 0 &&
arc.weight != Weight::Zero()) { // Epsilon transition.
Element next_elem;
next_elem.state = arc.nextstate;
next_elem.weight = Times(elem.weight, arc.weight);
// now must append strings
if (arc.olabel == 0)
next_elem.string = elem.string;
else
next_elem.string = repository_.Successor(elem.string, arc.olabel);
MapIter iter = cur_subset.find(next_elem.state);
if (iter == cur_subset.end()) {
// was no such StateId: insert and add to queue.
cur_subset[next_elem.state] = next_elem;
queue.push_back(next_elem);
} else {
// was not inserted because one already there. In normal
// determinization we'd add the weights. Here, we find which one
// has the better weight, and keep its corresponding string.
int comp = Compare(next_elem.weight, next_elem.string,
iter->second.weight, iter->second.string);
if (comp ==
1) { // next_elem is better, so use its (weight, string)
iter->second.string = next_elem.string;
iter->second.weight = next_elem.weight;
queue.push_back(next_elem);
replaced_elems = true;
}
// else it is the same or worse, so use original one.
}
}
}
}
{ // copy cur_subset to subset.
subset->clear();
subset->reserve(cur_subset.size());
MapIter iter = cur_subset.begin(), end = cur_subset.end();
for (; iter != end; ++iter) subset->push_back(iter->second);
// sort by state ID, because the subset hash function is
// order-dependent(see SubsetKey)
std::sort(subset->begin(), subset->end());
}
}
// This function works out the final-weight of the determinized state.
// called by ProcessSubset.
// Has no side effects except on the variable repository_, and output_arcs_.
void ProcessFinal(OutputStateId output_state) {
const std::vector<Element> &minimal_subset =
*(output_states_[output_state]);
// processes final-weights for this subset.
// minimal_subset may be empty if the graphs is not connected/trimmed, I
// think, do don't check that it's nonempty.
bool is_final = false;
StringId final_string = NULL; // = NULL to keep compiler happy.
Weight final_weight = Weight::Zero();
typename std::vector<Element>::const_iterator iter = minimal_subset.begin(),
end = minimal_subset.end();
for (; iter != end; ++iter) {
const Element &elem = *iter;
Weight this_final_weight = Times(elem.weight, ifst_->Final(elem.state));
StringId this_final_string = elem.string;
if (this_final_weight != Weight::Zero() &&
(!is_final || Compare(this_final_weight, this_final_string,
final_weight, final_string) == 1)) { // the new
// (weight, string) pair is more in semiring than our current
// one.
is_final = true;
final_weight = this_final_weight;
final_string = this_final_string;
}
}
if (is_final) {
// store final weights in TempArc structure, just like a transition.
TempArc temp_arc;
temp_arc.ilabel = 0;
temp_arc.nextstate =
kNoStateId; // special marker meaning "final weight".
temp_arc.string = final_string;
temp_arc.weight = final_weight;
output_arcs_[output_state].push_back(temp_arc);
num_arcs_++;
}
}
// NormalizeSubset normalizes the subset "elems" by
// removing any common string prefix (putting it in common_str),
// and dividing by the total weight (putting it in tot_weight).
void NormalizeSubset(std::vector<Element> *elems, Weight *tot_weight,
StringId *common_str) {
if (elems->empty()) { // just set common_str, tot_weight
KALDI_WARN << "[empty subset]"; // TEMP
// to defaults and return...
*common_str = repository_.EmptyString();
*tot_weight = Weight::Zero();
return;
}
size_t size = elems->size();
std::vector<IntType> common_prefix;
repository_.ConvertToVector((*elems)[0].string, &common_prefix);
Weight weight = (*elems)[0].weight;
for (size_t i = 1; i < size; i++) {
weight = Plus(weight, (*elems)[i].weight);
repository_.ReduceToCommonPrefix((*elems)[i].string, &common_prefix);
}
assert(weight != Weight::Zero()); // we made sure to ignore arcs with zero
// weights on them, so we shouldn't have zero here.
size_t prefix_len = common_prefix.size();
for (size_t i = 0; i < size; i++) {
(*elems)[i].weight = Divide((*elems)[i].weight, weight, DIVIDE_LEFT);
(*elems)[i].string =
repository_.RemovePrefix((*elems)[i].string, prefix_len);
}
*common_str = repository_.ConvertFromVector(common_prefix);
*tot_weight = weight;
}
// Take a subset of Elements that is sorted on state, and
// merge any Elements that have the same state (taking the best
// (weight, string) pair in the semiring).
void MakeSubsetUnique(std::vector<Element> *subset) {
typedef typename std::vector<Element>::iterator IterType;
// This assert is designed to fail (usually) if the subset is not sorted on
// state.
assert(subset->size() < 2 || (*subset)[0].state <= (*subset)[1].state);
IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
size_t num_out = 0;
// Merge elements with same state-id
while (cur_in != end) { // while we have more elements to process.
// At this point, cur_out points to location of next place we want to put
// an element, cur_in points to location of next element we want to
// process.
if (cur_in != cur_out) *cur_out = *cur_in;
cur_in++;
while (cur_in != end && cur_in->state == cur_out->state) {
if (Compare(cur_in->weight, cur_in->string, cur_out->weight,
cur_out->string) == 1) {
// if *cur_in > *cur_out in semiring, then take *cur_in.
cur_out->string = cur_in->string;
cur_out->weight = cur_in->weight;
}
cur_in++;
}
cur_out++;
num_out++;
}
subset->resize(num_out);
}
// ProcessTransition is called from "ProcessTransitions". Broken out for
// clarity. Processes a transition from state "state". The set of Elements
// represents a set of next-states with associated weights and strings, each
// one arising from an arc from some state in a determinized-state; the
// next-states are not necessarily unique (i.e. there may be >1 entry
// associated with each), and any such sets of Elements have to be merged
// within this routine (we take the [weight, string] pair that's better in the
// semiring).
void ProcessTransition(OutputStateId state, Label ilabel,
std::vector<Element> *subset) {
MakeSubsetUnique(subset); // remove duplicates with the same state.
StringId common_str;
Weight tot_weight;
NormalizeSubset(subset, &tot_weight, &common_str);
OutputStateId nextstate;
{
Weight next_tot_weight;
StringId next_common_str;
nextstate = InitialToStateId(*subset, &next_tot_weight, &next_common_str);
common_str = repository_.Concatenate(common_str, next_common_str);
tot_weight = Times(tot_weight, next_tot_weight);
}
// Now add an arc to the next state (would have been created if necessary by
// InitialToStateId).
TempArc temp_arc;
temp_arc.ilabel = ilabel;
temp_arc.nextstate = nextstate;
temp_arc.string = common_str;
temp_arc.weight = tot_weight;
output_arcs_[state].push_back(temp_arc); // record the arc.
num_arcs_++;
}
// "less than" operator for pair<Label, Element>. Used in
// ProcessTransitions. Lexicographical order, which only compares the state
// when ordering the "Element" member of the pair.
class PairComparator {
public:
inline bool operator()(const std::pair<Label, Element> &p1,
const std::pair<Label, Element> &p2) {
if (p1.first < p2.first) {
return true;
} else if (p1.first > p2.first) {
return false;
} else {
return p1.second.state < p2.second.state;
}
}
};
// ProcessTransitions processes emitting transitions (transitions
// with ilabels) out of this subset of states.
// Does not consider final states. Breaks the emitting transitions up by
// ilabel, and creates a new transition in the determinized FST for each
// unique ilabel. Does this by creating a big vector of pairs <Label, Element>
// and then sorting them using a lexicographical ordering, and calling
// ProcessTransition for each range with the same ilabel. Side effects on
// repository, and (via ProcessTransition) on Q_, hash_, and output_arcs_.
void ProcessTransitions(OutputStateId output_state) {
const std::vector<Element> &minimal_subset =
*(output_states_[output_state]);
// it's possible that minimal_subset could be empty if there are
// unreachable parts of the graph, so don't check that it's nonempty.
std::vector<std::pair<Label, Element> > &all_elems(
all_elems_tmp_); // use class member
// to avoid memory allocation/deallocation.
{
// Push back into "all_elems", elements corresponding to all
// non-epsilon-input transitions out of all states in "minimal_subset".
typename std::vector<Element>::const_iterator iter =
minimal_subset.begin(),
end = minimal_subset.end();
for (; iter != end; ++iter) {
const Element &elem = *iter;
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0 &&
arc.weight != Weight::Zero()) { // Non-epsilon transition --
// ignore epsilons here.
std::pair<Label, Element> this_pr;
this_pr.first = arc.ilabel;
Element &next_elem(this_pr.second);
next_elem.state = arc.nextstate;
next_elem.weight = Times(elem.weight, arc.weight);
if (arc.olabel == 0) // output epsilon
next_elem.string = elem.string;
else
next_elem.string = repository_.Successor(elem.string, arc.olabel);
all_elems.push_back(this_pr);
}
}
}
}
PairComparator pc;
std::sort(all_elems.begin(), all_elems.end(), pc);
// now sorted first on input label, then on state.
typedef typename std::vector<std::pair<Label, Element> >::const_iterator
PairIter;
PairIter cur = all_elems.begin(), end = all_elems.end();
std::vector<Element> this_subset;
while (cur != end) {
// Process ranges that share the same input symbol.
Label ilabel = cur->first;
this_subset.clear();
while (cur != end && cur->first == ilabel) {
this_subset.push_back(cur->second);
cur++;
}
// We now have a subset for this ilabel.
assert(!this_subset.empty()); // temp.
ProcessTransition(output_state, ilabel, &this_subset);
}
all_elems.clear(); // as it's a class variable-- want it to stay
// emtpy.
}
// ProcessState does the processing of a determinized state, i.e. it creates
// transitions out of it and the final-probability if any.
void ProcessState(OutputStateId output_state) {
ProcessFinal(output_state);
ProcessTransitions(output_state);
}
void Debug() { // this function called if you send a signal
// SIGUSR1 to the process (and it's caught by the handler in
// fstdeterminizestar). It prints out some traceback
// info and exits.
KALDI_WARN << "Debug function called (probably SIGUSR1 caught)";
// free up memory from the hash as we need a little memory
{
MinimalSubsetHash hash_tmp;
hash_tmp.swap(minimal_hash_);
}
if (output_arcs_.size() <= 2) {
KALDI_ERR << "Nothing to trace back";
}
size_t max_state = output_arcs_.size() - 2; // Don't take the last
// one as we might be halfway into constructing it.
std::vector<OutputStateId> predecessor(max_state + 1, kNoStateId);
for (size_t i = 0; i < max_state; i++) {
for (size_t j = 0; j < output_arcs_[i].size(); j++) {
OutputStateId nextstate = output_arcs_[i][j].nextstate;
// Always find an earlier-numbered predecessor; this
// is always possible because of the way the algorithm
// works.
if (nextstate <= max_state && nextstate > i) predecessor[nextstate] = i;
}
}
std::vector<std::pair<Label, StringId> > traceback;
// 'traceback' is a pair of (ilabel, olabel-seq).
OutputStateId cur_state = max_state; // A recently constructed state.
while (cur_state != 0 && cur_state != kNoStateId) {
OutputStateId last_state = predecessor[cur_state];
std::pair<Label, StringId> p;
size_t i;
for (i = 0; i < output_arcs_[last_state].size(); i++) {
if (output_arcs_[last_state][i].nextstate == cur_state) {
p.first = output_arcs_[last_state][i].ilabel;
p.second = output_arcs_[last_state][i].string;
traceback.push_back(p);
break;
}
}
KALDI_ASSERT(i != output_arcs_[last_state].size()); // Or fell off loop.
cur_state = last_state;
}
if (cur_state == kNoStateId)
KALDI_WARN << "Traceback did not reach start state "
<< "(possibly debug-code error)";
std::stringstream ss;
ss << "Traceback follows in format "
<< "ilabel (olabel olabel) ilabel (olabel) ... :";
for (ssize_t i = traceback.size() - 1; i >= 0; i--) {
ss << ' ' << traceback[i].first << " ( ";
std::vector<Label> seq;
repository_.ConvertToVector(traceback[i].second, &seq);
for (size_t j = 0; j < seq.size(); j++) ss << seq[j] << ' ';
ss << ')';
}
KALDI_ERR << ss.str();
}
bool IsIsymbolOrFinal(InputStateId state) { // returns true if this state
// of the input FST either is final or has an osymbol on an arc out of it.
// Uses the vector isymbol_or_final_ as a cache for this info.
assert(state >= 0);
if (isymbol_or_final_.size() <= state)
isymbol_or_final_.resize(state + 1, static_cast<char>(OSF_UNKNOWN));
if (isymbol_or_final_[state] == static_cast<char>(OSF_NO))
return false;
else if (isymbol_or_final_[state] == static_cast<char>(OSF_YES))
return true;
// else work it out...
isymbol_or_final_[state] = static_cast<char>(OSF_NO);
if (ifst_->Final(state) != Weight::Zero())
isymbol_or_final_[state] = static_cast<char>(OSF_YES);
for (ArcIterator<Fst<Arc> > aiter(*ifst_, state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0 && arc.weight != Weight::Zero()) {
isymbol_or_final_[state] = static_cast<char>(OSF_YES);
return true;
}
}
return IsIsymbolOrFinal(state); // will only recurse once.
}
void InitializeDeterminization() {
if (ifst_->Properties(kExpanded, false) != 0) { // if we know the number of
// states in ifst_, it might be a bit more efficient
// to pre-size the hashes so we're not constantly rebuilding them.
#if !(__GNUC__ == 4 && __GNUC_MINOR__ == 0)
StateId num_states =
down_cast<const ExpandedFst<Arc> *, const Fst<Arc> >(ifst_)
->NumStates();
minimal_hash_.rehash(num_states / 2 + 3);
initial_hash_.rehash(num_states / 2 + 3);
#endif
}
InputStateId start_id = ifst_->Start();
if (start_id != kNoStateId) {
/* Insert determinized-state corresponding to the start state into hash
and queue. Unlike all the other states, we don't "normalize" the
representation of this determinized-state before we put it into
minimal_hash_. This is actually what we want, as otherwise we'd have
problems dealing with any extra weight and string and might have to
create a "super-initial" state which would make the output
nondeterministic. Normalization is only needed to make the
determinized output more minimal anyway, it's not needed for
correctness. Note, we don't put anything in the initial_hash_. The
initial_hash_ is only a lookaside buffer anyway, so this isn't a
problem-- it will get populated later if it needs to be.
*/
Element elem;
elem.state = start_id;
elem.weight = Weight::One();
elem.string = repository_.EmptyString(); // Id of empty sequence.
std::vector<Element> subset;
subset.push_back(elem);
EpsilonClosure(&subset); // follow through epsilon-inputs links
ConvertToMinimal(&subset); // remove all but final states and
// states with input-labels on arcs out of them.
std::vector<Element> *subset_ptr = new std::vector<Element>(subset);
assert(output_arcs_.empty() && output_states_.empty());
// add the new state...
output_states_.push_back(subset_ptr);
output_arcs_.push_back(std::vector<TempArc>());
OutputStateId initial_state = 0;
minimal_hash_[subset_ptr] = initial_state;
queue_.push_back(initial_state);
}
}
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeDeterminizer);
std::vector<std::vector<Element> *>
output_states_; // maps from output state to
// minimal representation [normalized].
// View pointers as owned in
// minimal_hash_.
std::vector<std::vector<TempArc> >
output_arcs_; // essentially an FST in our format.
int num_arcs_; // keep track of memory usage: number of arcs in output_arcs_
int num_elems_; // keep track of memory usage: number of elems in
// output_states_
const Fst<Arc> *ifst_;
DeterminizeLatticeOptions opts_;
SubsetKey hasher_; // object that computes keys-- has no data members.
SubsetEqual
equal_; // object that compares subsets-- only data member is delta_.
bool determinized_; // set to true when user called Determinize(); used to
// make
// sure this object is used correctly.
MinimalSubsetHash
minimal_hash_; // hash from Subset to OutputStateId. Subset is "minimal
// representation" (only include final and states and
// states with nonzero ilabel on arc out of them. Owns
// the pointers in its keys.
InitialSubsetHash initial_hash_; // hash from Subset to Element, which
// represents the OutputStateId together
// with an extra weight and string. Subset
// is "initial representation". The extra
// weight and string is needed because after
// we convert to minimal representation and
// normalize, there may be an extra weight
// and string. Owns the pointers
// in its keys.
std::vector<OutputStateId>
queue_; // Queue of output-states to process. Starts with
// state 0, and increases and then (hopefully) decreases in length during
// determinization. LIFO queue (queue discipline doesn't really matter).
std::vector<std::pair<Label, Element> >
all_elems_tmp_; // temporary vector used in ProcessTransitions.
enum IsymbolOrFinal { OSF_UNKNOWN = 0, OSF_NO = 1, OSF_YES = 2 };
std::vector<char> isymbol_or_final_; // A kind of cache; it says whether
// each state is (emitting or final) where emitting means it has at least one
// non-epsilon output arc. Only accessed by IsIsymbolOrFinal()
LatticeStringRepository<IntType>
repository_; // defines a compact and fast way of
// storing sequences of labels.
};
// normally Weight would be LatticeWeight<float> (which has two floats),
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
template <class Weight, class IntType>
bool DeterminizeLattice(const Fst<ArcTpl<Weight> > &ifst,
MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticeOptions opts, bool *debug_ptr) {
ofst->SetInputSymbols(ifst.InputSymbols());
ofst->SetOutputSymbols(ifst.OutputSymbols());
LatticeDeterminizer<Weight, IntType> det(ifst, opts);
if (!det.Determinize(debug_ptr)) return false;
det.Output(ofst);
return true;
}
// normally Weight would be LatticeWeight<float> (which has two floats),
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
template <class Weight, class IntType>
bool DeterminizeLattice(
const Fst<ArcTpl<Weight> > &ifst,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticeOptions opts, bool *debug_ptr) {
ofst->SetInputSymbols(ifst.InputSymbols());
ofst->SetOutputSymbols(ifst.OutputSymbols());
LatticeDeterminizer<Weight, IntType> det(ifst, opts);
if (!det.Determinize(debug_ptr)) return false;
det.Output(ofst);
return true;
}
} // namespace fst
#endif // KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
// fstext/determinize-lattice.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
#define KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "fstext/lattice-weight.h"
namespace fst {
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-lattice.cc
/*
DeterminizeLattice implements a special form of determinization
with epsilon removal, optimized for a phase of lattice generation.
Its input is an FST with weight-type BaseWeightType (usually a pair of
floats, with a lexicographical type of order, such as
LatticeWeightTpl<float>). Typically this would be a state-level lattice, with
input symbols equal to words, and output-symbols equal to p.d.f's (so like
the inverse of HCLG). Imagine representing this as an acceptor of type
CompactLatticeWeightTpl<float>, in which the input/output symbols are words,
and the weights contain the original weights together with strings (with zero
or one symbol in them) containing the original output labels (the p.d.f.'s).
We determinize this using acceptor determinization with epsilon removal.
Remember (from lattice-weight.h) that CompactLatticeWeightTpl has a special
kind of semiring where we always take the string corresponding to the best
cost (of type BaseWeightType), and discard the other. This corresponds to
taking the best output-label sequence (of p.d.f.'s) for each input-label
sequence (of words). We couldn't use the Gallic weight for this, or it would
die as soon as it detected that the input FST was non-functional. In our
case, any acyclic FST (and many cyclic ones) can be determinized. We assume
that there is a function Compare(const BaseWeightType &a, const
BaseWeightType &b) that returns (-1, 0, 1) according to whether (a < b, a ==
b, a > b) in the total order on the BaseWeightType... this information should
be the same as NaturalLess would give, but it's more efficient to do it this
way. You can define this for things like TropicalWeight if you need to
instantiate this class for that weight type.
We implement this determinization in a special way to make it efficient for
the types of FSTs that we will apply it to. One issue is that if we
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
type vector<IntType>, the algorithm takes time quadratic in the length of
words (in states), because propagating each arc involves copying a whole
vector (of integers representing p.d.f.'s). Instead we use a hash structure
where each string is a pointer (Entry*), and uses a hash from (Entry*,
IntType), to the successor string (and a way to get the latest IntType and
the ancestor Entry*). [this is the class LatticeStringRepository].
Another issue is that rather than representing a determinized-state as a
collection of (state, weight), we represent it in a couple of reduced forms.
Suppose a determinized-state is a collection of (state, weight) pairs; call
this the "canonical representation". Note: these collections are always
normalized to remove any common weight and string part. Define end-states as
the subset of states that have an arc out of them with a label on, or are
final. If we represent a determinized-state a the set of just its
(end-state, weight) pairs, this will be a valid and more compact
representation, and will lead to a smaller set of determinized states (like
early minimization). Call this collection of (end-state, weight) pairs the
"minimal representation". As a mechanism to reduce compute, we can also
consider another representation. In the determinization algorithm, we start
off with a set of (begin-state, weight) pairs (where the "begin-states" are
initial or have a label on the transition into them), and the "canonical
representation" consists of the epsilon-closure of this set (i.e. follow
epsilons). Call this set of (begin-state, weight) pairs, appropriately
normalized, the "initial representation". If two initial representations are
the same, the "canonical representation" and hence the "minimal
representation" will be the same. We can use this to reduce compute. Note
that if two initial representations are different, this does not preclude the
other representations from being the same.
*/
struct DeterminizeLatticeOptions {
float delta; // A small offset used to measure equality of weights.
int max_mem; // If >0, determinization will fail and return false
// when the algorithm's (approximate) memory consumption crosses this
// threshold.
int max_loop; // If >0, can be used to detect non-determinizable input
// (a case that wouldn't be caught by max_mem).
DeterminizeLatticeOptions() : delta(kDelta), max_mem(-1), max_loop(-1) {}
};
/**
This function implements the normal version of DeterminizeLattice, in which
the output strings are represented using sequences of arcs, where all but
the first one has an epsilon on the input side. The debug_ptr argument is
an optional pointer to a bool that, if it becomes true while the algorithm
is executing, the algorithm will print a traceback and terminate (used in
fstdeterminizestar.cc debug non-terminating determinization). More
efficient if ifst is arc-sorted on input label. If the number of arcs gets
more than max_states, it will throw std::runtime_error (otherwise this code
does not use exceptions). This is mainly useful for debug. */
template <class Weight, class IntType>
bool DeterminizeLattice(
const Fst<ArcTpl<Weight> > &ifst, MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticeOptions opts = DeterminizeLatticeOptions(),
bool *debug_ptr = NULL);
/* This is a version of DeterminizeLattice with a slightly more "natural"
output format, where the output sequences are encoded using the
CompactLatticeArcTpl template (i.e. the sequences of output symbols are
represented directly as strings) More efficient if ifst is arc-sorted on
input label. If the #arcs gets more than max_arcs, it will throw
std::runtime_error (otherwise this code does not use exceptions). This is
mainly useful for debug.
*/
template <class Weight, class IntType>
bool DeterminizeLattice(
const Fst<ArcTpl<Weight> > &ifst,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticeOptions opts = DeterminizeLatticeOptions(),
bool *debug_ptr = NULL);
/// @} end "addtogroup fst_extensions"
} // end namespace fst
#include "fstext/determinize-lattice-inl.h"
#endif // KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
// fstext/determinize-star-inl.h
// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky
// 2015 Hainan Xu
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
#define KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
// Do not include this file directly. It is included by determinize-star.h
#include <algorithm>
#include <climits>
#include <deque>
#include <limits>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
using std::unordered_map;
#include "base/kaldi-error.h"
namespace fst {
// This class maps back and forth from/to integer id's to sequences of strings.
// used in determinization algorithm.
template <class Label, class StringId>
class StringRepository {
// Label and StringId are both integer types, possibly the same.
// This is a utility that maps back and forth between a vector<Label> and
// StringId representation of sequences of Labels. It is to save memory, and
// to save compute. We treat sequences of length zero and one separately, for
// efficiency.
public:
class VectorKey { // Hash function object.
public:
size_t operator()(const std::vector<Label> *vec) const {
assert(vec != NULL);
size_t hash = 0, factor = 1;
for (typename std::vector<Label>::const_iterator it = vec->begin();
it != vec->end(); it++) {
hash += factor * (*it);
factor *= 103333; // just an arbitrary prime number.
}
return hash;
}
};
class VectorEqual { // Equality-operator function object.
public:
size_t operator()(const std::vector<Label> *vec1,
const std::vector<Label> *vec2) const {
return (*vec1 == *vec2);
}
};
typedef unordered_map<const std::vector<Label> *, StringId, VectorKey,
VectorEqual>
MapType;
StringId IdOfEmpty() { return no_symbol; }
StringId IdOfLabel(Label l) {
if (l >= 0 && l <= (Label)single_symbol_range) {
return l + single_symbol_start;
} else {
// l is out of the allowed range so we have to treat it as a sequence of
// length one. Should be v. rare.
std::vector<Label> v;
v.push_back(l);
return IdOfSeqInternal(v);
}
}
StringId IdOfSeq(
const std::vector<Label> &v) { // also works for sizes 0 and 1.
size_t sz = v.size();
if (sz == 0)
return no_symbol;
else if (v.size() == 1)
return IdOfLabel(v[0]);
else
return IdOfSeqInternal(v);
}
inline bool IsEmptyString(StringId id) { return id == no_symbol; }
void SeqOfId(StringId id, std::vector<Label> *v) {
if (id == no_symbol) {
v->clear();
} else if (id >= single_symbol_start) {
v->resize(1);
(*v)[0] = id - single_symbol_start;
} else {
assert(static_cast<size_t>(id) < vec_.size());
*v = *(vec_[id]);
}
}
StringId RemovePrefix(StringId id, size_t prefix_len) {
if (prefix_len == 0) {
return id;
} else {
std::vector<Label> v;
SeqOfId(id, &v);
size_t sz = v.size();
assert(sz >= prefix_len);
std::vector<Label> v_noprefix(sz - prefix_len);
for (size_t i = 0; i < sz - prefix_len; i++)
v_noprefix[i] = v[i + prefix_len];
return IdOfSeq(v_noprefix);
}
}
StringRepository() {
// The following are really just constants but don't want to complicate
// compilation so make them class variables. Due to the brokenness of
// <limits>, they can't be accessed as constants.
string_end = (std::numeric_limits<StringId>::max() / 2) -
1; // all hash values must be <= this.
no_symbol = (std::numeric_limits<StringId>::max() /
2); // reserved for empty sequence.
single_symbol_start = (std::numeric_limits<StringId>::max() / 2) + 1;
single_symbol_range =
std::numeric_limits<StringId>::max() - single_symbol_start;
}
void Destroy() {
for (typename std::vector<std::vector<Label> *>::iterator iter =
vec_.begin();
iter != vec_.end(); ++iter)
delete *iter;
std::vector<std::vector<Label> *> tmp_vec;
tmp_vec.swap(vec_);
MapType tmp_map;
tmp_map.swap(map_);
}
~StringRepository() { Destroy(); }
private:
KALDI_DISALLOW_COPY_AND_ASSIGN(StringRepository);
StringId IdOfSeqInternal(const std::vector<Label> &v) {
typename MapType::iterator iter = map_.find(&v);
if (iter != map_.end()) {
return iter->second;
} else { // must add it to map.
StringId this_id = (StringId)vec_.size();
std::vector<Label> *v_new = new std::vector<Label>(v);
vec_.push_back(v_new);
map_[v_new] = this_id;
assert(this_id < string_end); // or we used up the labels.
return this_id;
}
}
std::vector<std::vector<Label> *> vec_;
MapType map_;
static const StringId string_start =
(StringId)0; // This must not change. It's assumed.
StringId string_end; // = (numeric_limits<StringId>::max() / 2) - 1; // all
// hash values must be <= this.
StringId no_symbol; // = (numeric_limits<StringId>::max() / 2); // reserved
// for empty sequence.
StringId
single_symbol_start; // = (numeric_limits<StringId>::max() / 2) + 1;
StringId single_symbol_range; // = numeric_limits<StringId>::max() -
// single_symbol_start;
};
template <class F>
class DeterminizerStar {
typedef typename F::Arc Arc;
public:
// Output to Gallic acceptor (so the strings go on weights, and there is a 1-1
// correspondence between our states and the states in ofst. If destroy ==
// true, release memory as we go (but we cannot output again).
void Output(MutableFst<GallicArc<Arc> > *ofst, bool destroy = true);
// Output to standard FST. We will create extra states to handle sequences of
// symbols on the output. If destroy == true, release memory as we go (but we
// cannot output again).
void Output(MutableFst<Arc> *ofst, bool destroy = true);
// Initializer. After initializing the object you will typically call
// Determinize() and then one of the Output functions.
DeterminizerStar(const Fst<Arc> &ifst, float delta = kDelta,
int max_states = -1, bool allow_partial = false)
: ifst_(ifst.Copy()),
delta_(delta),
max_states_(max_states),
determinized_(false),
allow_partial_(allow_partial),
is_partial_(false),
equal_(delta),
hash_(ifst.Properties(kExpanded, false)
? down_cast<const ExpandedFst<Arc> *, const Fst<Arc> >(&ifst)
->NumStates() /
2 +
3
: 20,
hasher_, equal_),
epsilon_closure_(ifst_, max_states, &repository_, delta) {}
void Determinize(bool *debug_ptr) {
assert(!determinized_);
// This determinizes the input fst but leaves it in the "special format"
// in "output_arcs_".
InputStateId start_id = ifst_->Start();
if (start_id == kNoStateId) {
determinized_ = true;
return; // Nothing to do.
} else { // Insert start state into hash and queue.
Element elem;
elem.state = start_id;
elem.weight = Weight::One();
elem.string = repository_.IdOfEmpty(); // Id of empty sequence.
std::vector<Element> vec;
vec.push_back(elem);
OutputStateId cur_id = SubsetToStateId(vec);
assert(cur_id == 0 && "Do not call Determinize twice.");
}
while (!Q_.empty()) {
std::pair<std::vector<Element> *, OutputStateId> cur_pair = Q_.front();
Q_.pop_front();
ProcessSubset(cur_pair);
if (debug_ptr && *debug_ptr) Debug(); // will exit.
if (max_states_ > 0 && output_arcs_.size() > max_states_) {
if (allow_partial_ == false) {
KALDI_ERR << "Determinization aborted since passed " << max_states_
<< " states";
} else {
KALDI_WARN << "Determinization terminated since passed "
<< max_states_
<< " states, partial results will be generated";
is_partial_ = true;
break;
}
}
}
determinized_ = true;
}
bool IsPartial() { return is_partial_; }
// frees all except output_arcs_, which contains the important info
// we need to output.
void FreeMostMemory() {
if (ifst_) {
delete ifst_;
ifst_ = NULL;
}
for (typename SubsetHash::iterator iter = hash_.begin();
iter != hash_.end(); ++iter)
delete iter->first;
SubsetHash tmp;
tmp.swap(hash_);
}
~DeterminizerStar() { FreeMostMemory(); }
private:
typedef typename Arc::Label Label;
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId InputStateId;
typedef typename Arc::StateId
OutputStateId; // same as above but distinguish states in output Fst.
typedef typename Arc::Label StringId; // Id type used in the StringRepository
typedef StringRepository<Label, StringId> StringRepositoryType;
// Element of a subset [of original states]
struct Element {
InputStateId state;
StringId string;
Weight weight;
bool operator!=(const Element &other) const {
return (state != other.state || string != other.string ||
weight != other.weight);
}
};
// Arcs in the format we temporarily create in this class (a representation,
// essentially of a Gallic Fst).
struct TempArc {
Label ilabel;
StringId ostring; // Look it up in the StringRepository, it's a sequence of
// Labels.
OutputStateId nextstate; // or kNoState for final weights.
Weight weight;
};
// Hashing function used in hash of subsets.
// A subset is a pointer to vector<Element>.
// The Elements are in sorted order on state id, and without repeated states.
// Because the order of Elements is fixed, we can use a hashing function that
// is order-dependent. However the weights are not included in the hashing
// function-- we hash subsets that differ only in weight to the same key. This
// is not optimal in terms of the O(N) performance but typically if we have a
// lot of determinized states that differ only in weight then the input
// probably was pathological in some way, or even non-determinizable.
// We don't quantize the weights, in order to avoid inexactness in simple
// cases.
// Instead we apply the delta when comparing subsets for equality, and allow a
// small difference.
class SubsetKey {
public:
size_t operator()(const std::vector<Element> *subset)
const { // hashes only the state and string.
size_t hash = 0, factor = 1;
for (typename std::vector<Element>::const_iterator iter = subset->begin();
iter != subset->end(); ++iter) {
hash *= factor;
hash += iter->state + 103333 * iter->string;
factor *= 23531; // these numbers are primes.
}
return hash;
}
};
// This is the equality operator on subsets. It checks for exact match on
// state-id and string, and approximate match on weights.
class SubsetEqual {
public:
bool operator()(const std::vector<Element> *s1,
const std::vector<Element> *s2) const {
size_t sz = s1->size();
assert(sz >= 0);
if (sz != s2->size()) return false;
typename std::vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(),
iter2 = s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state || iter1->string != iter2->string ||
!ApproxEqual(iter1->weight, iter2->weight, delta_))
return false;
}
return true;
}
float delta_;
explicit SubsetEqual(float delta) : delta_(delta) {}
SubsetEqual() : delta_(kDelta) {}
};
// Operator that says whether two Elements have the same states.
// Used only for debug.
class SubsetEqualStates {
public:
bool operator()(const std::vector<Element> *s1,
const std::vector<Element> *s2) const {
size_t sz = s1->size();
assert(sz >= 0);
if (sz != s2->size()) return false;
typename std::vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(),
iter2 = s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state) return false;
}
return true;
}
};
// Define the hash type we use to store subsets.
typedef unordered_map<const std::vector<Element> *, OutputStateId, SubsetKey,
SubsetEqual>
SubsetHash;
class EpsilonClosure {
public:
EpsilonClosure(const Fst<Arc> *ifst, int max_states,
StringRepository<Label, StringId> *repository, float delta)
: ifst_(ifst),
max_states_(max_states),
repository_(repository),
delta_(delta) {}
// This function computes epsilon closure of subset of states by following
// epsilon links. Called by ProcessSubset. Has no side effects except on the
// repository.
void GetEpsilonClosure(const std::vector<Element> &input_subset,
std::vector<Element> *output_subset);
private:
struct EpsilonClosureInfo {
EpsilonClosureInfo() {}
EpsilonClosureInfo(const Element &e, const Weight &w, bool i)
: element(e), weight_to_process(w), in_queue(i) {}
// the weight in the Element struct is the total current weight
// that has been processed already
Element element;
// this stores the weight that we haven't processed (propagated)
Weight weight_to_process;
// whether "this" struct is in the queue
// we store the info here so that we don't have to look it up every time
bool in_queue;
bool operator<(const EpsilonClosureInfo &other) const {
return this->element.state < other.element.state;
}
};
// to further speed up EpsilonClosure() computation, we have 2 queues
// the 2nd queue is used when we first iterate over the input set -
// if queue_2_.empty() then we directly set output_set equal to input_set
// and return immediately
// Since Epsilon arcs are relatively rare, this way we could efficiently
// detect the epsilon-free case, without having to waste our computation
// e.g. allocating the EpsilonClosureInfo structure; this also lets us do a
// level-by-level traversal, which could avoid some (unfortunately not all)
// duplicate computation if epsilons form a DAG that is not a tree
//
// We put the queues here for better efficiency for memory allocation
std::deque<typename Arc::StateId> queue_;
std::vector<Element> queue_2_;
// the following 2 structures together form our *virtual "map"*
// basically we need a map from state_id to EpsilonClosureInfo that operates
// in O(1) time, while still takes relatively small mem, and this does it
// well for efficiency we don't clear id_to_index_ of its outdated
// information As a result each time we do a look-up, we need to check if
// (ecinfo_[id_to_index_[id]].element.state == id) Yet this is still faster
// than using a std::map<StateId, EpsilonClosureInfo>
std::vector<int> id_to_index_;
// unlike id_to_index_, we clear the content of ecinfo_ each time we call
// EpsilonClosure(). This needed because we need an efficient way to
// traverse the virtual map - it is just too costly to traverse the
// id_to_index_ vector.
std::vector<EpsilonClosureInfo> ecinfo_;
// Add one element (elem) into cur_subset
// it also adds the necessary stuff to queue_, set the correct weight
void AddOneElement(const Element &elem, const Weight &unprocessed_weight);
// Sub-routine that we call in EpsilonClosure()
// It takes the current "unprocessed_weight" and propagate it to the
// states accessible from elem.state by an epsilon arc
// and add the results to cur_subset.
// save_to_queue_2 is set true when we iterate over the initial subset
// - then we save it to queue_2 s.t. if it's empty, we directly return
// the input set
void ExpandOneElement(const Element &elem, bool sorted,
const Weight &unprocessed_weight,
bool save_to_queue_2 = false);
// no pointers below would take the ownership
const Fst<Arc> *ifst_;
int max_states_;
StringRepository<Label, StringId> *repository_;
float delta_;
};
// This function works out the final-weight of the determinized state.
// called by ProcessSubset.
// Has no side effects except on the variable repository_, and output_arcs_.
void ProcessFinal(const std::vector<Element> &closed_subset,
OutputStateId state) {
// processes final-weights for this subset.
bool is_final = false;
StringId final_string = 0; // = 0 to keep compiler happy.
Weight final_weight =
Weight::One(); // This value will never be accessed, and
// we just set it to avoid spurious compiler warnings. We avoid setting it
// to Zero() because floating-point infinities can sometimes generate
// interrupts and slow things down.
typename std::vector<Element>::const_iterator iter = closed_subset.begin(),
end = closed_subset.end();
for (; iter != end; ++iter) {
const Element &elem = *iter;
Weight this_final_weight = ifst_->Final(elem.state);
if (this_final_weight != Weight::Zero()) {
if (!is_final) { // first final-weight
final_string = elem.string;
final_weight = Times(elem.weight, this_final_weight);
is_final = true;
} else { // already have one.
if (final_string != elem.string) {
KALDI_ERR << "FST was not functional -> not determinizable";
}
final_weight =
Plus(final_weight, Times(elem.weight, this_final_weight));
}
}
}
if (is_final) {
// store final weights in TempArc structure, just like a transition.
TempArc temp_arc;
temp_arc.ilabel = 0;
temp_arc.nextstate =
kNoStateId; // special marker meaning "final weight".
temp_arc.ostring = final_string;
temp_arc.weight = final_weight;
output_arcs_[state].push_back(temp_arc);
}
}
// ProcessTransition is called from "ProcessTransitions". Broken out for
// clarity. Has side effects on output_arcs_, and (via SubsetToStateId), Q_
// and hash_.
void ProcessTransition(OutputStateId state, Label ilabel,
std::vector<Element> *subset);
// "less than" operator for pair<Label, Element>. Used in
// ProcessTransitions. Lexicographical order, with comparing the state only
// for "Element".
class PairComparator {
public:
inline bool operator()(const std::pair<Label, Element> &p1,
const std::pair<Label, Element> &p2) {
if (p1.first < p2.first) {
return true;
} else if (p1.first > p2.first) {
return false;
} else {
return p1.second.state < p2.second.state;
}
}
};
// ProcessTransitions handles transitions out of this subset of states.
// Ignores epsilon transitions (epsilon closure already handled that).
// Does not consider final states. Breaks the transitions up by ilabel,
// and creates a new transition in determinized FST, for each ilabel.
// Does this by creating a big vector of pairs <Label, Element> and then
// sorting them using a lexicographical ordering, and calling
// ProcessTransition for each range with the same ilabel. Side effects on
// repository, and (via ProcessTransition) on Q_, hash_, and output_arcs_.
void ProcessTransitions(const std::vector<Element> &closed_subset,
OutputStateId state) {
std::vector<std::pair<Label, Element> > all_elems;
{ // Push back into "all_elems", elements corresponding to all
// non-epsilon-input transitions
// out of all states in "closed_subset".
typename std::vector<Element>::const_iterator iter =
closed_subset.begin(),
end = closed_subset.end();
for (; iter != end; ++iter) {
const Element &elem = *iter;
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel !=
0) { // Non-epsilon transition -- ignore epsilons here.
std::pair<Label, Element> this_pr;
this_pr.first = arc.ilabel;
Element &next_elem(this_pr.second);
next_elem.state = arc.nextstate;
next_elem.weight = Times(elem.weight, arc.weight);
if (arc.olabel == 0) { // output epsilon-- this is simple case so
// handle separately for efficiency
next_elem.string = elem.string;
} else {
std::vector<Label> seq;
repository_.SeqOfId(elem.string, &seq);
seq.push_back(arc.olabel);
next_elem.string = repository_.IdOfSeq(seq);
}
all_elems.push_back(this_pr);
}
}
}
}
PairComparator pc;
std::sort(all_elems.begin(), all_elems.end(), pc);
// now sorted first on input label, then on state.
typedef typename std::vector<std::pair<Label, Element> >::const_iterator
PairIter;
PairIter cur = all_elems.begin(), end = all_elems.end();
std::vector<Element> this_subset;
while (cur != end) {
// Process ranges that share the same input symbol.
Label ilabel = cur->first;
this_subset.clear();
while (cur != end && cur->first == ilabel) {
this_subset.push_back(cur->second);
cur++;
}
// We now have a subset for this ilabel.
ProcessTransition(state, ilabel, &this_subset);
}
}
// SubsetToStateId converts a subset (vector of Elements) to a StateId in the
// output fst. This is a hash lookup; if no such state exists, it adds a new
// state to the hash and adds a new pair to the queue. Side effects on hash_
// and Q_, and on output_arcs_ [just affects the size].
OutputStateId SubsetToStateId(
const std::vector<Element> &subset) { // may add the subset to the queue.
typedef typename SubsetHash::iterator IterType;
IterType iter = hash_.find(&subset);
if (iter == hash_.end()) { // was not there.
std::vector<Element> *new_subset = new std::vector<Element>(subset);
OutputStateId new_state_id = (OutputStateId)output_arcs_.size();
bool ans =
hash_
.insert(std::pair<const std::vector<Element> *, OutputStateId>(
new_subset, new_state_id))
.second;
assert(ans);
output_arcs_.push_back(std::vector<TempArc>());
if (allow_partial_ == false) {
// If --allow-partial is not requested, we do the old way.
Q_.push_front(std::pair<std::vector<Element> *, OutputStateId>(
new_subset, new_state_id));
} else {
// If --allow-partial is requested, we do breadth first search. This
// ensures that when we return partial results, we return the states
// that are reachable by the fewest steps from the start state.
Q_.push_back(std::pair<std::vector<Element> *, OutputStateId>(
new_subset, new_state_id));
}
return new_state_id;
} else {
return iter->second; // the OutputStateId.
}
}
// ProcessSubset does the processing of a determinized state, i.e. it creates
// transitions out of it and adds new determinized states to the queue if
// necessary. The first stage is "EpsilonClosure" (follow epsilons to get a
// possibly larger set of (states, weights)). After that we ignore epsilons.
// We process the final-weight of the state, and then handle transitions out
// (this may add more determinized states to the queue).
void ProcessSubset(
const std::pair<std::vector<Element> *, OutputStateId> &pair) {
const std::vector<Element> *subset = pair.first;
OutputStateId state = pair.second;
std::vector<Element> closed_subset; // subset after epsilon closure.
epsilon_closure_.GetEpsilonClosure(*subset, &closed_subset);
// Now follow non-epsilon arcs [and also process final states]
ProcessFinal(closed_subset, state);
// Now handle transitions out of these states.
ProcessTransitions(closed_subset, state);
}
void Debug();
KALDI_DISALLOW_COPY_AND_ASSIGN(DeterminizerStar);
std::deque<std::pair<std::vector<Element> *, OutputStateId> >
Q_; // queue of subsets to be processed.
std::vector<std::vector<TempArc> >
output_arcs_; // essentially an FST in our format.
const Fst<Arc> *ifst_;
float delta_;
int max_states_;
bool determinized_; // used to check usage.
bool allow_partial_; // output paritial results or not
bool is_partial_; // if we get partial results or not
SubsetKey hasher_; // object that computes keys-- has no data members.
SubsetEqual
equal_; // object that compares subsets-- only data member is delta_.
SubsetHash hash_; // hash from Subset to StateId in final Fst.
StringRepository<Label, StringId>
repository_; // associate integer id's with sequences of labels.
EpsilonClosure epsilon_closure_;
};
template <class F>
bool DeterminizeStar(F &ifst, // NOLINT
MutableFst<typename F::Arc> *ofst, float delta,
bool *debug_ptr, int max_states, bool allow_partial) {
ofst->SetOutputSymbols(ifst.OutputSymbols());
ofst->SetInputSymbols(ifst.InputSymbols());
DeterminizerStar<F> det(ifst, delta, max_states, allow_partial);
det.Determinize(debug_ptr);
det.Output(ofst);
return det.IsPartial();
}
template <class F>
bool DeterminizeStar(F &ifst, // NOLINT
MutableFst<GallicArc<typename F::Arc> > *ofst, float delta,
bool *debug_ptr, int max_states, bool allow_partial) {
ofst->SetOutputSymbols(ifst.InputSymbols());
ofst->SetInputSymbols(ifst.InputSymbols());
DeterminizerStar<F> det(ifst, delta, max_states, allow_partial);
det.Determinize(debug_ptr);
det.Output(ofst);
return det.IsPartial();
}
template <class F>
void DeterminizerStar<F>::EpsilonClosure::GetEpsilonClosure(
const std::vector<Element> &input_subset,
std::vector<Element> *output_subset) {
ecinfo_.resize(0);
size_t size = input_subset.size();
// find whether input fst is known to be sorted in input label.
bool sorted =
((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
// size is still the input_subset.size()
for (size_t i = 0; i < size; i++) {
ExpandOneElement(input_subset[i], sorted, input_subset[i].weight, true);
}
size_t s = queue_2_.size();
if (s == 0) {
*output_subset = input_subset;
return;
} else {
// queue_2 not empty. Need to create the vector<info>
for (size_t i = 0; i < size; i++) {
// the weight has not been processed yet,
// so put all of them in the "weight_to_process"
ecinfo_.push_back(
EpsilonClosureInfo(input_subset[i], input_subset[i].weight, false));
ecinfo_.back().element.weight = Weight::Zero(); // clear the weight
if (id_to_index_.size() < input_subset[i].state + 1) {
id_to_index_.resize(2 * input_subset[i].state + 1, -1);
}
id_to_index_[input_subset[i].state] = ecinfo_.size() - 1;
}
}
{
Element elem;
elem.weight = Weight::Zero();
for (size_t i = 0; i < s; i++) {
elem.state = queue_2_[i].state;
elem.string = queue_2_[i].string;
AddOneElement(elem, queue_2_[i].weight);
}
queue_2_.resize(0);
}
int counter = 0; // relates to max-states option, used for test.
while (!queue_.empty()) {
InputStateId id = queue_.front();
// no need to check validity of the index
// since anything in the queue we are sure they're in the "virtual set"
int index = id_to_index_[id];
EpsilonClosureInfo &info = ecinfo_[index];
Element &elem = info.element;
Weight unprocessed_weight = info.weight_to_process;
elem.weight = Plus(elem.weight, unprocessed_weight);
info.weight_to_process = Weight::Zero();
info.in_queue = false;
queue_.pop_front();
if (max_states_ > 0 && counter++ > max_states_) {
KALDI_ERR << "Determinization aborted since looped more than "
<< max_states_ << " times during epsilon closure";
}
// generally we need to be careful about iterator-invalidation problem
// here we pass a reference (elem), which could be an issue.
// In the beginning of ExpandOneElement, we make a copy of elem.string
// to avoid that issue
ExpandOneElement(elem, sorted, unprocessed_weight);
}
{
// this sorting is based on StateId
sort(ecinfo_.begin(), ecinfo_.end());
output_subset->clear();
size = ecinfo_.size();
output_subset->reserve(size);
for (size_t i = 0; i < size; i++) {
EpsilonClosureInfo &info = ecinfo_[i];
if (info.weight_to_process != Weight::Zero()) {
info.element.weight = Plus(info.element.weight, info.weight_to_process);
}
output_subset->push_back(info.element);
}
}
}
template <class F>
void DeterminizerStar<F>::EpsilonClosure::AddOneElement(
const Element &elem, const Weight &unprocessed_weight) {
// first we try to find the element info in the ecinfo_ vector
int index = -1;
if (elem.state < id_to_index_.size()) {
index = id_to_index_[elem.state];
}
if (index != -1) {
if (index >= ecinfo_.size()) {
index = -1;
} else if (ecinfo_[index].element.state != elem.state) {
// since ecinfo_ might store outdated information, we need to check
index = -1;
}
}
if (index == -1) {
// was no such StateId: insert and add to queue.
ecinfo_.push_back(EpsilonClosureInfo(elem, unprocessed_weight, true));
size_t size = id_to_index_.size();
if (size < elem.state + 1) {
// double the size to reduce memory operations
id_to_index_.resize(2 * elem.state + 1, -1);
}
id_to_index_[elem.state] = ecinfo_.size() - 1;
queue_.push_back(elem.state);
} else { // one is already there. Add weights.
EpsilonClosureInfo &info = ecinfo_[index];
if (info.element.string != elem.string) {
// Non-functional FST.
std::ostringstream ss;
ss << "FST was not functional -> not determinizable.";
{ // Print some debugging information. Can be helpful to debug
// the inputs when FSTs are mysteriously non-functional.
std::vector<Label> tmp_seq;
repository_->SeqOfId(info.element.string, &tmp_seq);
ss << "\nFirst string:";
for (size_t i = 0; i < tmp_seq.size(); i++) ss << ' ' << tmp_seq[i];
ss << "\nSecond string:";
repository_->SeqOfId(elem.string, &tmp_seq);
for (size_t i = 0; i < tmp_seq.size(); i++) ss << ' ' << tmp_seq[i];
}
KALDI_ERR << ss.str();
}
info.weight_to_process = Plus(info.weight_to_process, unprocessed_weight);
if (!info.in_queue) {
// this is because the code in "else" below: the
// iter->second.weight_to_process might not be Zero()
Weight weight = Plus(info.element.weight, info.weight_to_process);
// What is done below is, we propagate the weight (by adding them
// to the queue only when the change is big enough;
// otherwise we just store the weight, until before returning
// we add the element.weight and weight_to_process together
if (!ApproxEqual(weight, info.element.weight, delta_)) {
// add extra part of weight to queue.
info.in_queue = true;
queue_.push_back(elem.state);
}
}
}
}
template <class F>
void DeterminizerStar<F>::EpsilonClosure::ExpandOneElement(
const Element &elem, bool sorted, const Weight &unprocessed_weight,
bool save_to_queue_2) {
StringId str =
elem.string; // copy it here because there is an iterator-
// - invalidation problem (it really happens for some FSTs)
// now we are going to propagate the "unprocessed_weight"
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (sorted && arc.ilabel > 0) {
break;
// Break from the loop: due to sorting there will be no
// more transitions with epsilons as input labels.
}
if (arc.ilabel != 0) {
continue; // we only process epsilons here
}
Element next_elem;
next_elem.state = arc.nextstate;
next_elem.weight = Weight::Zero();
Weight next_unprocessed_weight = Times(unprocessed_weight, arc.weight);
// now must append strings
if (arc.olabel == 0) {
next_elem.string = str;
} else {
std::vector<Label> seq;
repository_->SeqOfId(str, &seq);
if (arc.olabel != 0) seq.push_back(arc.olabel);
next_elem.string = repository_->IdOfSeq(seq);
}
if (save_to_queue_2) {
next_elem.weight = next_unprocessed_weight;
queue_2_.push_back(next_elem);
} else {
AddOneElement(next_elem, next_unprocessed_weight);
}
}
}
template <class F>
void DeterminizerStar<F>::Output(MutableFst<GallicArc<Arc> > *ofst,
bool destroy) {
assert(determinized_);
if (destroy) determinized_ = false;
typedef GallicWeight<Label, Weight> ThisGallicWeight;
typedef typename Arc::StateId StateId;
if (destroy) FreeMostMemory();
StateId nStates = static_cast<StateId>(output_arcs_.size());
ofst->DeleteStates();
ofst->SetStart(kNoStateId);
if (nStates == 0) {
return;
}
for (StateId s = 0; s < nStates; s++) {
OutputStateId news = ofst->AddState();
assert(news == s);
}
ofst->SetStart(0);
// now process transitions.
for (StateId this_state = 0; this_state < nStates; this_state++) {
std::vector<TempArc> &this_vec(output_arcs_[this_state]);
typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
end = this_vec.end();
for (; iter != end; ++iter) {
const TempArc &temp_arc(*iter);
GallicArc<Arc> new_arc;
std::vector<Label> seq;
repository_.SeqOfId(temp_arc.ostring, &seq);
StringWeight<Label, STRING_LEFT> string_weight;
for (size_t i = 0; i < seq.size(); i++) string_weight.PushBack(seq[i]);
ThisGallicWeight gallic_weight(string_weight, temp_arc.weight);
if (temp_arc.nextstate == kNoStateId) { // is really final weight.
ofst->SetFinal(this_state, gallic_weight);
} else { // is really an arc.
new_arc.nextstate = temp_arc.nextstate;
new_arc.ilabel = temp_arc.ilabel;
new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
new_arc.weight = gallic_weight; // includes string and weight.
ofst->AddArc(this_state, new_arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating
// memory
if (destroy) {
std::vector<TempArc> temp;
temp.swap(this_vec);
}
}
if (destroy) {
std::vector<std::vector<TempArc> > temp;
temp.swap(output_arcs_);
}
}
template <class F>
void DeterminizerStar<F>::Output(MutableFst<Arc> *ofst, bool destroy) {
assert(determinized_);
if (destroy) determinized_ = false;
// Outputs to standard fst.
OutputStateId num_states = static_cast<OutputStateId>(output_arcs_.size());
if (destroy) FreeMostMemory();
ofst->DeleteStates();
if (num_states == 0) {
ofst->SetStart(kNoStateId);
return;
}
// Add basic states-- but will add extra ones to account for strings on
// output.
for (OutputStateId s = 0; s < num_states; s++) {
OutputStateId news = ofst->AddState();
assert(news == s);
}
ofst->SetStart(0);
for (OutputStateId this_state = 0; this_state < num_states; this_state++) {
std::vector<TempArc> &this_vec(output_arcs_[this_state]);
typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
end = this_vec.end();
for (; iter != end; ++iter) {
const TempArc &temp_arc(*iter);
std::vector<Label> seq;
repository_.SeqOfId(temp_arc.ostring, &seq);
if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
// Make a sequence of states going to a final state, with the strings as
// labels. Put the weight on the first arc.
OutputStateId cur_state = this_state;
for (size_t i = 0; i < seq.size(); i++) {
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = 0; // epsilon.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
ofst->SetFinal(cur_state,
(seq.size() == 0 ? temp_arc.weight : Weight::One()));
} else { // Really an arc.
OutputStateId cur_state = this_state;
// Have to be careful with this integer comparison (i+1 < seq.size())
// because unsigned. i < seq.size()-1 could fail for zero-length
// sequences.
for (size_t i = 0; i + 1 < seq.size(); i++) {
// for all but the last element of seq, create new state.
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = (i == 0 ? temp_arc.ilabel
: 0); // put ilabel on first element of seq.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
// Add the final arc in the sequence.
Arc arc;
arc.nextstate = temp_arc.nextstate;
arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
arc.olabel = (seq.size() > 0 ? seq.back() : 0);
ofst->AddArc(cur_state, arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating
// memory
if (destroy) {
std::vector<TempArc> temp;
temp.swap(this_vec);
}
}
if (destroy) {
std::vector<std::vector<TempArc> > temp;
temp.swap(output_arcs_);
repository_.Destroy();
}
}
template <class F>
void DeterminizerStar<F>::ProcessTransition(OutputStateId state, Label ilabel,
std::vector<Element> *subset) {
// At input, "subset" may contain duplicates for a given dest state (but in
// sorted order). This function removes duplicates from "subset", normalizes
// it, and adds a transition to the dest. state (possibly affecting Q_ and
// hash_, if state did not exist).
typedef typename std::vector<Element>::iterator IterType;
{ // This block makes the subset have one unique Element per state, adding
// the weights.
IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
size_t num_out = 0;
// Merge elements with same state-id
while (cur_in != end) { // while we have more elements to process.
// At this point, cur_out points to location of next place we want to put
// an element, cur_in points to location of next element we want to
// process.
if (cur_in != cur_out) *cur_out = *cur_in;
cur_in++;
while (cur_in != end &&
cur_in->state == cur_out->state) { // merge elements.
if (cur_in->string != cur_out->string) {
KALDI_ERR << "FST was not functional -> not determinizable";
}
cur_out->weight = Plus(cur_out->weight, cur_in->weight);
cur_in++;
}
cur_out++;
num_out++;
}
subset->resize(num_out);
}
StringId common_str;
Weight tot_weight;
{ // This block computes common_str and tot_weight (essentially: the common
// divisor)
// and removes them from the elements.
std::vector<Label> seq;
IterType begin = subset->begin(), iter, end = subset->end();
{ // This block computes "seq", which is the common prefix, and
// "common_str",
// which is the StringId version of "seq".
std::vector<Label> tmp_seq;
for (iter = begin; iter != end; ++iter) {
if (iter == begin) {
repository_.SeqOfId(iter->string, &seq);
} else {
repository_.SeqOfId(iter->string, &tmp_seq);
if (tmp_seq.size() < seq.size())
seq.resize(tmp_seq.size()); // size of shortest one.
for (size_t i = 0; i < seq.size();
i++) // seq.size() is the shorter one at this point.
if (tmp_seq[i] != seq[i]) seq.resize(i);
}
if (seq.size() == 0) break; // will not get any prefix.
}
common_str = repository_.IdOfSeq(seq);
}
{ // This block computes "tot_weight".
iter = begin;
tot_weight = iter->weight;
for (++iter; iter != end; ++iter)
tot_weight = Plus(tot_weight, iter->weight);
}
// Now divide out common stuff from elements.
size_t prefix_len = seq.size();
for (iter = begin; iter != end; ++iter) {
iter->weight = Divide(iter->weight, tot_weight);
iter->string = repository_.RemovePrefix(iter->string, prefix_len);
}
}
// Now add an arc to the state that the subset represents.
// We may create a new state id for this (in SubsetToStateId).
TempArc temp_arc;
temp_arc.ilabel = ilabel;
temp_arc.nextstate =
SubsetToStateId(*subset); // may or may not really add the subset.
temp_arc.ostring = common_str;
temp_arc.weight = tot_weight;
output_arcs_[state].push_back(temp_arc); // record the arc.
}
template <class F>
void DeterminizerStar<F>::Debug() {
// this function called if you send a signal
// SIGUSR1 to the process (and it's caught by the handler in
// fstdeterminizestar). It prints out some traceback
// info and exits.
KALDI_WARN << "Debug function called (probably SIGUSR1 caught)";
// free up memory from the hash as we need a little memory
{
SubsetHash hash_tmp;
std::swap(hash_tmp, hash_);
}
if (output_arcs_.size() <= 2) {
KALDI_ERR << "Nothing to trace back";
}
size_t max_state = output_arcs_.size() - 2; // don't take the last
// one as we might be halfway into constructing it.
std::vector<OutputStateId> predecessor(max_state + 1, kNoStateId);
for (size_t i = 0; i < max_state; i++) {
for (size_t j = 0; j < output_arcs_[i].size(); j++) {
OutputStateId nextstate = output_arcs_[i][j].nextstate;
// Always find an earlier-numbered predecessor; this
// is always possible because of the way the algorithm
// works.
if (nextstate <= max_state && nextstate > i) predecessor[nextstate] = i;
}
}
std::vector<std::pair<Label, StringId> > traceback;
// 'traceback' is a pair of (ilabel, olabel-seq).
OutputStateId cur_state = max_state; // A recently constructed state.
while (cur_state != 0 && cur_state != kNoStateId) {
OutputStateId last_state = predecessor[cur_state];
std::pair<Label, StringId> p;
size_t i;
for (i = 0; i < output_arcs_[last_state].size(); i++) {
if (output_arcs_[last_state][i].nextstate == cur_state) {
p.first = output_arcs_[last_state][i].ilabel;
p.second = output_arcs_[last_state][i].ostring;
traceback.push_back(p);
break;
}
}
KALDI_ASSERT(i != output_arcs_[last_state].size()); // Or fell off loop.
cur_state = last_state;
}
if (cur_state == kNoStateId)
KALDI_WARN << "Traceback did not reach start state "
<< "(possibly debug-code error)";
std::stringstream ss;
ss << "Traceback follows in format "
<< "ilabel (olabel olabel) ilabel (olabel) ... :";
for (ssize_t i = traceback.size() - 1; i >= 0; i--) {
ss << ' ' << traceback[i].first << " ( ";
std::vector<Label> seq;
repository_.SeqOfId(traceback[i].second, &seq);
for (size_t j = 0; j < seq.size(); j++) ss << seq[j] << ' ';
ss << ')';
}
KALDI_ERR << ss.str();
}
} // namespace fst
#endif // KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
// fstext/determinize-star.h
// Copyright 2009-2011 Microsoft Corporation
// 2014 Guoguo Chen
// 2015 Hainan Xu
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_DETERMINIZE_STAR_H_
#define KALDI_FSTEXT_DETERMINIZE_STAR_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
#include <algorithm>
#include <map>
#include <set>
#include <stdexcept> // this algorithm uses exceptions
#include <vector>
namespace fst {
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-star.cc
/*
DeterminizeStar implements determinization with epsilon removal, which we
distinguish with a star.
We define a determinized* FST as one in which no state has more than one
transition with the same input-label. Epsilon input labels are not allowed
except starting from states that have exactly one arc exiting them (and are
not final). [In the normal definition of determinized, epsilon-input labels
are not allowed at all, whereas in Mohri's definition, epsilons are treated
as ordinary symbols]. The determinized* definition is intended to simulate
the effect of allowing strings of output symbols at each state.
The algorithm implemented here takes an Fst<Arc>, and a pointer to a
MutableFst<Arc> where it puts its output. The weight type is assumed to be a
float-weight. It does epsilon removal and determinization.
This algorithm may fail if the input has epsilon cycles under
certain circumstances (i.e. the semiring is non-idempotent, e.g. the log
semiring, or there are negative cost epsilon cycles).
This implementation is much less fancy than the one in fst/determinize.h, and
does not have an "on-demand" version.
The algorithm is a fairly normal determinization algorithm. We keep in
memory the subsets of states, together with their leftover strings and their
weights. The only difference is we detect input epsilon transitions and
treat them "specially".
*/
// This algorithm will be slightly faster if you sort the input fst on input
// label.
/**
This function implements the normal version of DeterminizeStar, in which the
output strings are represented using sequences of arcs, where all but the
first one has an epsilon on the input side. The debug_ptr argument is an
optional pointer to a bool that, if it becomes true while the algorithm is
executing, the algorithm will print a traceback and terminate (used in
fstdeterminizestar.cc debug non-terminating determinization).
If max_states is positive, it will stop determinization and throw an
exception as soon as the max-states is reached. This can be useful in test.
If allow_partial is true, the algorithm will output partial results when the
specified max_states is reached (when larger than zero), instead of throwing
out an error.
Caution, the return status is un-intuitive: this function will return false
if determinization completed normally, and true if it was stopped early by
reaching the 'max-states' limit, and a partial FST was generated.
*/
template <class F>
bool DeterminizeStar(F &ifst, MutableFst<typename F::Arc> *ofst, // NOLINT
float delta = kDelta, bool *debug_ptr = NULL,
int max_states = -1, bool allow_partial = false);
/* This is a version of DeterminizeStar with a slightly more "natural" output
format, where the output sequences are encoded using the GallicArc (i.e. the
output symbols are strings. If max_states is positive, it will stop
determinization and throw an exception as soon as the max-states is reached.
This can be useful in test. If allow_partial is true, the algorithm will
output partial results when the specified max_states is reached (when larger
than zero), instead of throwing out an error.
Caution, the return status is un-intuitive: this function will return false
if determinization completed normally, and true if it was stopped early by
reaching the 'max-states' limit, and a partial FST was generated.
*/
template <class F>
bool DeterminizeStar(F &ifst, // NOLINT
MutableFst<GallicArc<typename F::Arc> > *ofst,
float delta = kDelta, bool *debug_ptr = NULL,
int max_states = -1, bool allow_partial = false);
/// @} end "addtogroup fst_extensions"
} // end namespace fst
#include "fstext/determinize-star-inl.h"
#endif // KALDI_FSTEXT_DETERMINIZE_STAR_H_
// fstext/fstext-lib.h
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (author:
// Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_FSTEXT_LIB_H_
#define KALDI_FSTEXT_FSTEXT_LIB_H_
#include "fst/fstlib.h"
#include "fstext/determinize-lattice.h"
#include "fstext/determinize-star.h"
#include "fstext/fstext-utils.h"
#include "fstext/kaldi-fst-io.h"
#include "fstext/lattice-utils.h"
#include "fstext/lattice-weight.h"
#include "fstext/pre-determinize.h"
#include "fstext/table-matcher.h"
#endif // KALDI_FSTEXT_FSTEXT_LIB_H_
// fstext/fstext-utils-inl.h
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author:
// Daniel Povey)
// 2014 Telepoint Global Hosting Service, LLC. (Author: David
// Snyder)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
#define KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
#include <algorithm>
#include <cstring>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "base/kaldi-common.h"
#include "fstext/determinize-star.h"
#include "fstext/pre-determinize.h"
#include "util/const-integer-set.h"
#include "util/kaldi-io.h"
#include "util/stl-utils.h"
#include "util/text-utils.h"
namespace fst {
template <class Arc>
typename Arc::Label HighestNumberedOutputSymbol(const Fst<Arc> &fst) {
typename Arc::Label ans = 0;
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
typename Arc::StateId s = siter.Value();
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
ans = std::max(ans, arc.olabel);
}
}
return ans;
}
template <class Arc>
typename Arc::Label HighestNumberedInputSymbol(const Fst<Arc> &fst) {
typename Arc::Label ans = 0;
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
typename Arc::StateId s = siter.Value();
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
ans = std::max(ans, arc.ilabel);
}
}
return ans;
}
template <class Arc>
typename Arc::StateId NumArcs(const ExpandedFst<Arc> &fst) {
typedef typename Arc::StateId StateId;
StateId num_arcs = 0;
for (StateId s = 0; s < fst.NumStates(); s++) num_arcs += fst.NumArcs(s);
return num_arcs;
}
template <class Arc, class I>
void GetOutputSymbols(const Fst<Arc> &fst, bool include_eps,
std::vector<I> *symbols) {
KALDI_ASSERT_IS_INTEGER_TYPE(I);
std::set<I> all_syms;
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
typename Arc::StateId s = siter.Value();
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
all_syms.insert(arc.olabel);
}
}
// Remove epsilon, if instructed.
if (!include_eps && !all_syms.empty() && *all_syms.begin() == 0)
all_syms.erase(0);
KALDI_ASSERT(symbols != NULL);
kaldi::CopySetToVector(all_syms, symbols);
}
template <class Arc, class I>
void GetInputSymbols(const Fst<Arc> &fst, bool include_eps,
std::vector<I> *symbols) {
KALDI_ASSERT_IS_INTEGER_TYPE(I);
unordered_set<I> all_syms;
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
typename Arc::StateId s = siter.Value();
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
all_syms.insert(arc.ilabel);
}
}
// Remove epsilon, if instructed.
if (!include_eps && all_syms.count(0) != 0) all_syms.erase(0);
KALDI_ASSERT(symbols != NULL);
kaldi::CopySetToVector(all_syms, symbols);
std::sort(symbols->begin(), symbols->end());
}
template <class Arc, class I>
class RemoveSomeInputSymbolsMapper {
public:
Arc operator()(const Arc &arc_in) {
Arc ans = arc_in;
if (to_remove_set_.count(ans.ilabel) != 0)
ans.ilabel = 0; // remove this symbol
return ans;
}
MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; }
MapSymbolsAction InputSymbolsAction() { return MAP_CLEAR_SYMBOLS; }
MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; }
uint64 Properties(uint64 props) const {
// remove the following as we don't know now if any of them are true.
uint64 to_remove = kAcceptor | kNotAcceptor | kIDeterministic |
kNonIDeterministic | kNoEpsilons | kNoIEpsilons |
kILabelSorted | kNotILabelSorted;
return props & ~to_remove;
}
explicit RemoveSomeInputSymbolsMapper(const std::vector<I> &to_remove)
: to_remove_set_(to_remove) {
KALDI_ASSERT_IS_INTEGER_TYPE(I);
assert(to_remove_set_.count(0) == 0); // makes no sense to remove epsilon.
}
private:
kaldi::ConstIntegerSet<I> to_remove_set_;
};
template <class Arc, class I>
using LookaheadFst = ArcMapFst<Arc, Arc, RemoveSomeInputSymbolsMapper<Arc, I> >;
// Lookahead composition is used for optimized online
// composition of FSTs during decoding. See
// nnet3/nnet3-latgen-faster-lookahead.cc. For details of compose filters
// see DefaultLookAhead in fst/compose.h
template <class Arc, class I>
LookaheadFst<Arc, I> *LookaheadComposeFst(const Fst<Arc> &ifst1,
const Fst<Arc> &ifst2,
const std::vector<I> &to_remove) {
fst::CacheOptions cache_opts(true, 1 << 25LL);
fst::CacheOptions cache_opts_map(true, 0);
fst::ArcMapFstOptions arcmap_opts(cache_opts);
RemoveSomeInputSymbolsMapper<Arc, I> mapper(to_remove);
return new LookaheadFst<Arc, I>(ComposeFst<Arc>(ifst1, ifst2, cache_opts),
mapper, arcmap_opts);
}
template <class Arc, class I>
void RemoveSomeInputSymbols(const std::vector<I> &to_remove,
MutableFst<Arc> *fst) {
KALDI_ASSERT_IS_INTEGER_TYPE(I);
RemoveSomeInputSymbolsMapper<Arc, I> mapper(to_remove);
Map(fst, mapper);
}
template <class Arc, class I>
class MapInputSymbolsMapper {
public:
Arc operator()(const Arc &arc_in) {
Arc ans = arc_in;
if (ans.ilabel > 0 && ans.ilabel < static_cast<typename Arc::Label>(
(*symbol_mapping_).size()))
ans.ilabel = (*symbol_mapping_)[ans.ilabel];
return ans;
}
MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
uint64 Properties(uint64 props) const { // Not tested.
bool remove_epsilons =
(symbol_mapping_->size() > 0 && (*symbol_mapping_)[0] != 0);
bool add_epsilons = (symbol_mapping_->size() > 1 &&
*std::min_element(symbol_mapping_->begin() + 1,
symbol_mapping_->end()) == 0);
// remove the following as we don't know now if any of them are true.
uint64 props_to_remove = kAcceptor | kNotAcceptor | kIDeterministic |
kNonIDeterministic | kILabelSorted |
kNotILabelSorted;
if (remove_epsilons) props_to_remove |= kEpsilons | kIEpsilons;
if (add_epsilons) props_to_remove |= kNoEpsilons | kNoIEpsilons;
uint64 props_to_add = 0;
if (remove_epsilons && !add_epsilons)
props_to_add |= kNoEpsilons | kNoIEpsilons;
return (props & ~props_to_remove) | props_to_add;
}
// initialize with copy = false only if the "to_remove" argument will not be
// deleted in the lifetime of this object.
MapInputSymbolsMapper(const std::vector<I> &to_remove, bool copy) {
KALDI_ASSERT_IS_INTEGER_TYPE(I);
if (copy)
symbol_mapping_ = new std::vector<I>(to_remove);
else
symbol_mapping_ = &to_remove;
owned = copy;
}
~MapInputSymbolsMapper() {
if (owned && symbol_mapping_ != NULL) delete symbol_mapping_;
}
private:
bool owned;
const std::vector<I> *symbol_mapping_;
};
template <class Arc, class I>
void MapInputSymbols(const std::vector<I> &symbol_mapping,
MutableFst<Arc> *fst) {
KALDI_ASSERT_IS_INTEGER_TYPE(I);
// false == don't copy the "symbol_mapping", retain pointer--
// safe since short-lived object.
MapInputSymbolsMapper<Arc, I> mapper(symbol_mapping, false);
Map(fst, mapper);
}
template <class Arc, class I>
bool GetLinearSymbolSequence(const Fst<Arc> &fst, std::vector<I> *isymbols_out,
std::vector<I> *osymbols_out,
typename Arc::Weight *tot_weight_out) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
Weight tot_weight = Weight::One();
std::vector<I> ilabel_seq;
std::vector<I> olabel_seq;
StateId cur_state = fst.Start();
if (cur_state == kNoStateId) { // empty sequence.
if (isymbols_out != NULL) isymbols_out->clear();
if (osymbols_out != NULL) osymbols_out->clear();
if (tot_weight_out != NULL) *tot_weight_out = Weight::Zero();
return true;
}
while (1) {
Weight w = fst.Final(cur_state);
if (w != Weight::Zero()) { // is final..
tot_weight = Times(w, tot_weight);
if (fst.NumArcs(cur_state) != 0) return false;
if (isymbols_out != NULL) *isymbols_out = ilabel_seq;
if (osymbols_out != NULL) *osymbols_out = olabel_seq;
if (tot_weight_out != NULL) *tot_weight_out = tot_weight;
return true;
} else {
if (fst.NumArcs(cur_state) != 1) return false;
ArcIterator<Fst<Arc> > iter(fst, cur_state); // get the only arc.
const Arc &arc = iter.Value();
tot_weight = Times(arc.weight, tot_weight);
if (arc.ilabel != 0) ilabel_seq.push_back(arc.ilabel);
if (arc.olabel != 0) olabel_seq.push_back(arc.olabel);
cur_state = arc.nextstate;
}
}
}
// see fstext-utils.h for comment.
template <class Arc>
void ConvertNbestToVector(const Fst<Arc> &fst,
std::vector<VectorFst<Arc> > *fsts_out) {
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
fsts_out->clear();
StateId start_state = fst.Start();
if (start_state == kNoStateId) return; // No output.
size_t n_arcs = fst.NumArcs(start_state);
bool start_is_final = (fst.Final(start_state) != Weight::Zero());
fsts_out->reserve(n_arcs + (start_is_final ? 1 : 0));
if (start_is_final) {
fsts_out->resize(fsts_out->size() + 1);
StateId start_state_out = fsts_out->back().AddState();
fsts_out->back().SetFinal(start_state_out, fst.Final(start_state));
}
for (ArcIterator<Fst<Arc> > start_aiter(fst, start_state);
!start_aiter.Done(); start_aiter.Next()) {
fsts_out->resize(fsts_out->size() + 1);
VectorFst<Arc> &ofst = fsts_out->back();
const Arc &first_arc = start_aiter.Value();
StateId cur_state = start_state, cur_ostate = ofst.AddState();
ofst.SetStart(cur_ostate);
StateId next_ostate = ofst.AddState();
ofst.AddArc(cur_ostate, Arc(first_arc.ilabel, first_arc.olabel,
first_arc.weight, next_ostate));
cur_state = first_arc.nextstate;
cur_ostate = next_ostate;
while (1) {
size_t this_n_arcs = fst.NumArcs(cur_state);
KALDI_ASSERT(this_n_arcs <= 1); // or it violates our assumptions
// about the input.
if (this_n_arcs == 1) {
KALDI_ASSERT(fst.Final(cur_state) == Weight::Zero());
// or problem with ShortestPath.
ArcIterator<Fst<Arc> > aiter(fst, cur_state);
const Arc &arc = aiter.Value();
next_ostate = ofst.AddState();
ofst.AddArc(cur_ostate,
Arc(arc.ilabel, arc.olabel, arc.weight, next_ostate));
cur_state = arc.nextstate;
cur_ostate = next_ostate;
} else {
KALDI_ASSERT(fst.Final(cur_state) != Weight::Zero());
// or problem with ShortestPath.
ofst.SetFinal(cur_ostate, fst.Final(cur_state));
break;
}
}
}
}
// see fstext-utils.sh for comment.
template <class Arc>
void NbestAsFsts(const Fst<Arc> &fst, size_t n,
std::vector<VectorFst<Arc> > *fsts_out) {
KALDI_ASSERT(n > 0);
KALDI_ASSERT(fsts_out != NULL);
VectorFst<Arc> nbest_fst;
ShortestPath(fst, &nbest_fst, n);
ConvertNbestToVector(nbest_fst, fsts_out);
}
template <class Arc, class I>
void MakeLinearAcceptorWithAlternatives(
const std::vector<std::vector<I> > &labels, MutableFst<Arc> *ofst) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
ofst->DeleteStates();
StateId cur_state = ofst->AddState();
ofst->SetStart(cur_state);
for (size_t i = 0; i < labels.size(); i++) {
KALDI_ASSERT(labels[i].size() != 0);
StateId next_state = ofst->AddState();
for (size_t j = 0; j < labels[i].size(); j++) {
Arc arc(labels[i][j], labels[i][j], Weight::One(), next_state);
ofst->AddArc(cur_state, arc);
}
cur_state = next_state;
}
ofst->SetFinal(cur_state, Weight::One());
}
template <class Arc, class I>
void MakeLinearAcceptor(const std::vector<I> &labels, MutableFst<Arc> *ofst) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
ofst->DeleteStates();
StateId cur_state = ofst->AddState();
ofst->SetStart(cur_state);
for (size_t i = 0; i < labels.size(); i++) {
StateId next_state = ofst->AddState();
Arc arc(labels[i], labels[i], Weight::One(), next_state);
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
ofst->SetFinal(cur_state, Weight::One());
}
template <class I>
void GetSymbols(const SymbolTable &symtab, bool include_eps,
std::vector<I> *syms_out) {
KALDI_ASSERT(syms_out != NULL);
syms_out->clear();
for (SymbolTableIterator iter(symtab); !iter.Done(); iter.Next()) {
if (include_eps || iter.Value() != 0) {
syms_out->push_back(iter.Value());
KALDI_ASSERT(syms_out->back() ==
iter.Value()); // an integer-range thing.
}
}
}
template <class Arc>
void SafeDeterminizeWrapper(MutableFst<Arc> *ifst, MutableFst<Arc> *ofst,
float delta) {
typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst);
std::vector<typename Arc::Label> extra_syms;
PreDeterminize(ifst, (typename Arc::Label)(highest_sym + 1), &extra_syms);
DeterminizeStar(*ifst, ofst, delta);
RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols.
}
template <class Arc>
void SafeDeterminizeMinimizeWrapper(MutableFst<Arc> *ifst, VectorFst<Arc> *ofst,
float delta) {
typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst);
std::vector<typename Arc::Label> extra_syms;
PreDeterminize(ifst, (typename Arc::Label)(highest_sym + 1), &extra_syms);
DeterminizeStar(*ifst, ofst, delta);
RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols.
RemoveEpsLocal(ofst); // this is "safe" and will never hurt.
MinimizeEncoded(ofst, delta);
}
inline void DeterminizeStarInLog(VectorFst<StdArc> *fst, float delta,
bool *debug_ptr, int max_states) {
// DeterminizeStarInLog determinizes 'fst' in the log semiring, using
// the DeterminizeStar algorithm (which also removes epsilons).
ArcSort(fst, ILabelCompare<StdArc>()); // helps DeterminizeStar to be faster.
VectorFst<LogArc> *fst_log =
new VectorFst<LogArc>; // Want to determinize in log semiring.
Cast(*fst, fst_log);
VectorFst<StdArc> tmp;
*fst = tmp; // make fst empty to free up memory. [actually may make no
// difference..]
VectorFst<LogArc> *fst_det_log = new VectorFst<LogArc>;
DeterminizeStar(*fst_log, fst_det_log, delta, debug_ptr, max_states);
Cast(*fst_det_log, fst);
delete fst_log;
delete fst_det_log;
}
inline void DeterminizeInLog(VectorFst<StdArc> *fst) {
// DeterminizeInLog determinizes 'fst' in the log semiring.
ArcSort(fst, ILabelCompare<StdArc>()); // helps DeterminizeStar to be faster.
VectorFst<LogArc> *fst_log =
new VectorFst<LogArc>; // Want to determinize in log semiring.
Cast(*fst, fst_log);
VectorFst<StdArc> tmp;
*fst = tmp; // make fst empty to free up memory. [actually may make no
// difference..]
VectorFst<LogArc> *fst_det_log = new VectorFst<LogArc>;
Determinize(*fst_log, fst_det_log);
Cast(*fst_det_log, fst);
delete fst_log;
delete fst_det_log;
}
// make it inline to avoid having to put it in a .cc file.
// destructive algorithm (changes ifst as well as ofst).
inline void SafeDeterminizeMinimizeWrapperInLog(VectorFst<StdArc> *ifst,
VectorFst<StdArc> *ofst,
float delta) {
VectorFst<LogArc> *ifst_log =
new VectorFst<LogArc>; // Want to determinize in log semiring.
Cast(*ifst, ifst_log);
VectorFst<LogArc> *ofst_log = new VectorFst<LogArc>;
SafeDeterminizeWrapper(ifst_log, ofst_log, delta);
Cast(*ofst_log, ofst);
delete ifst_log;
delete ofst_log;
RemoveEpsLocal(ofst); // this is "safe" and will never hurt. Do this in
// tropical, which is important.
MinimizeEncoded(ofst, delta); // Non-deterministic minimization will fail in
// log semiring so do it with StdARc.
}
inline void SafeDeterminizeWrapperInLog(VectorFst<StdArc> *ifst,
VectorFst<StdArc> *ofst, float delta) {
VectorFst<LogArc> *ifst_log =
new VectorFst<LogArc>; // Want to determinize in log semiring.
Cast(*ifst, ifst_log);
VectorFst<LogArc> *ofst_log = new VectorFst<LogArc>;
SafeDeterminizeWrapper(ifst_log, ofst_log, delta);
Cast(*ofst_log, ofst);
delete ifst_log;
delete ofst_log;
}
template <class Arc>
void RemoveWeights(MutableFst<Arc> *ifst) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
for (StateIterator<MutableFst<Arc> > siter(*ifst); !siter.Done();
siter.Next()) {
StateId s = siter.Value();
for (MutableArcIterator<MutableFst<Arc> > aiter(ifst, s); !aiter.Done();
aiter.Next()) {
Arc arc(aiter.Value());
arc.weight = Weight::One();
aiter.SetValue(arc);
}
if (ifst->Final(s) != Weight::Zero()) ifst->SetFinal(s, Weight::One());
}
ifst->SetProperties(kUnweighted, kUnweighted);
}
// Used in PrecedingInputSymbolsAreSame (non-functor version), and
// similar routines.
template <class T>
struct IdentityFunction {
typedef T Arg;
typedef T Result;
T operator()(const T &t) const { return t; }
};
template <class Arc>
bool PrecedingInputSymbolsAreSame(bool start_is_epsilon, const Fst<Arc> &fst) {
IdentityFunction<typename Arc::Label> f;
return PrecedingInputSymbolsAreSameClass(start_is_epsilon, fst, f);
}
template <class Arc, class F> // F is functor type from labels to classes.
bool PrecedingInputSymbolsAreSameClass(bool start_is_epsilon,
const Fst<Arc> &fst, const F &f) {
typedef typename F::Result ClassType;
typedef typename Arc::StateId StateId;
std::vector<ClassType> classes;
ClassType noClass = f(kNoLabel);
if (start_is_epsilon) {
StateId start_state = fst.Start();
if (start_state < 0 || start_state == kNoStateId)
return true; // empty fst-- doesn't matter.
classes.resize(start_state + 1, noClass);
classes[start_state] = 0;
}
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (classes.size() <= arc.nextstate)
classes.resize(arc.nextstate + 1, noClass);
if (classes[arc.nextstate] == noClass)
classes[arc.nextstate] = f(arc.ilabel);
else if (classes[arc.nextstate] != f(arc.ilabel))
return false;
}
}
return true;
}
template <class Arc>
bool FollowingInputSymbolsAreSame(bool end_is_epsilon, const Fst<Arc> &fst) {
IdentityFunction<typename Arc::Label> f;
return FollowingInputSymbolsAreSameClass(end_is_epsilon, fst, f);
}
template <class Arc, class F>
bool FollowingInputSymbolsAreSameClass(bool end_is_epsilon, const Fst<Arc> &fst,
const F &f) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
typedef typename F::Result ClassType;
const ClassType noClass = f(kNoLabel), epsClass = f(0);
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
ClassType c = noClass;
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (c == noClass)
c = f(arc.ilabel);
else if (c != f(arc.ilabel))
return false;
}
if (end_is_epsilon && c != noClass && c != epsClass &&
fst.Final(s) != Weight::Zero())
return false;
}
return true;
}
template <class Arc>
void MakePrecedingInputSymbolsSame(bool start_is_epsilon,
MutableFst<Arc> *fst) {
IdentityFunction<typename Arc::Label> f;
MakePrecedingInputSymbolsSameClass(start_is_epsilon, fst, f);
}
template <class Arc, class F>
void MakePrecedingInputSymbolsSameClass(bool start_is_epsilon,
MutableFst<Arc> *fst, const F &f) {
typedef typename F::Result ClassType;
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
std::vector<ClassType> classes;
ClassType noClass = f(kNoLabel);
ClassType epsClass = f(0);
if (start_is_epsilon) { // treat having-start-state as epsilon in-transition.
StateId start_state = fst->Start();
if (start_state < 0 || start_state == kNoStateId) // empty FST.
return;
classes.resize(start_state + 1, noClass);
classes[start_state] = epsClass;
}
// Find bad states (states with multiple input-symbols into them).
std::set<StateId> bad_states; // states that we need to change.
for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (classes.size() <= static_cast<size_t>(arc.nextstate))
classes.resize(arc.nextstate + 1, noClass);
if (classes[arc.nextstate] == noClass)
classes[arc.nextstate] = f(arc.ilabel);
else if (classes[arc.nextstate] != f(arc.ilabel))
bad_states.insert(arc.nextstate);
}
}
if (bad_states.empty()) return; // Nothing to do.
kaldi::ConstIntegerSet<StateId> bad_states_ciset(
bad_states); // faster lookup.
// Work out list of arcs we have to change as (state, arc-offset).
// Can't do the actual changes in this pass, since we have to add new
// states which invalidates the iterators.
std::vector<std::pair<StateId, size_t> > arcs_to_change;
for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0 && bad_states_ciset.count(arc.nextstate) != 0)
arcs_to_change.push_back(std::make_pair(s, aiter.Position()));
}
}
KALDI_ASSERT(!arcs_to_change.empty()); // since !bad_states.empty().
std::map<std::pair<StateId, ClassType>, StateId> state_map;
// state_map is a map from (bad-state, input-symbol-class) to dummy-state.
for (size_t i = 0; i < arcs_to_change.size(); i++) {
StateId s = arcs_to_change[i].first;
ArcIterator<MutableFst<Arc> > aiter(*fst, s);
aiter.Seek(arcs_to_change[i].second);
Arc arc = aiter.Value();
// Transition is non-eps transition to "bad" state. Introduce new state (or
// find existing one).
std::pair<StateId, ClassType> p(arc.nextstate, f(arc.ilabel));
if (state_map.count(p) == 0) {
StateId newstate = state_map[p] = fst->AddState();
fst->AddArc(newstate, Arc(0, 0, Weight::One(), arc.nextstate));
}
StateId dst_state = state_map[p];
arc.nextstate = dst_state;
// Initialize the MutableArcIterator only now, as the call to NewState()
// may have invalidated the first arc iterator.
MutableArcIterator<MutableFst<Arc> > maiter(fst, s);
maiter.Seek(arcs_to_change[i].second);
maiter.SetValue(arc);
}
}
template <class Arc>
void MakeFollowingInputSymbolsSame(bool end_is_epsilon, MutableFst<Arc> *fst) {
IdentityFunction<typename Arc::Label> f;
MakeFollowingInputSymbolsSameClass(end_is_epsilon, fst, f);
}
template <class Arc, class F>
void MakeFollowingInputSymbolsSameClass(bool end_is_epsilon,
MutableFst<Arc> *fst, const F &f) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
typedef typename F::Result ClassType;
std::vector<StateId> bad_states;
ClassType noClass = f(kNoLabel);
ClassType epsClass = f(0);
for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
ClassType c = noClass;
bool bad = false;
for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (c == noClass) {
c = f(arc.ilabel);
} else if (c != f(arc.ilabel)) {
bad = true;
break;
}
}
if (end_is_epsilon && c != noClass && c != epsClass &&
fst->Final(s) != Weight::Zero())
bad = true;
if (bad) bad_states.push_back(s);
}
std::vector<Arc> my_arcs;
for (size_t i = 0; i < bad_states.size(); i++) {
StateId s = bad_states[i];
my_arcs.clear();
for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); !aiter.Done();
aiter.Next())
my_arcs.push_back(aiter.Value());
for (size_t j = 0; j < my_arcs.size(); j++) {
Arc &arc = my_arcs[j];
if (arc.ilabel != 0) {
StateId newstate = fst->AddState();
// Create a new state for each non-eps arc in original FST, out of each
// bad state. Not as optimal as it could be, but does avoid some
// complicated weight-pushing issues in which, to maintain
// stochasticity, we would have to know which semiring we want to
// maintain stochasticity in.
fst->AddArc(newstate, Arc(arc.ilabel, 0, Weight::One(), arc.nextstate));
MutableArcIterator<MutableFst<Arc> > maiter(fst, s);
maiter.Seek(j);
maiter.SetValue(Arc(0, arc.olabel, arc.weight, newstate));
}
}
}
}
template <class Arc>
VectorFst<Arc> *MakeLoopFst(const std::vector<const ExpandedFst<Arc> *> &fsts) {
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
typedef typename Arc::Label Label;
VectorFst<Arc> *ans = new VectorFst<Arc>;
StateId loop_state = ans->AddState(); // = 0.
ans->SetStart(loop_state);
ans->SetFinal(loop_state, Weight::One());
// "cache" is used as an optimization when some of the pointers in "fsts"
// may have the same value.
unordered_map<const ExpandedFst<Arc> *, Arc> cache;
for (Label i = 0; i < static_cast<Label>(fsts.size()); i++) {
const ExpandedFst<Arc> *fst = fsts[i];
if (fst == NULL) continue;
{ // optimization with cache: helpful if some members of "fsts" may
// contain the same pointer value (e.g. in GetHTransducer).
typename unordered_map<const ExpandedFst<Arc> *, Arc>::iterator iter =
cache.find(fst);
if (iter != cache.end()) {
Arc arc = iter->second;
arc.olabel = i;
ans->AddArc(0, arc);
continue;
}
}
KALDI_ASSERT(fst->Properties(kAcceptor, true) ==
kAcceptor); // expect acceptor.
StateId fst_num_states = fst->NumStates();
StateId fst_start_state = fst->Start();
if (fst_start_state == kNoStateId) continue; // empty fst.
bool share_start_state =
fst->Properties(kInitialAcyclic, true) == kInitialAcyclic &&
fst->NumArcs(fst_start_state) == 1 &&
fst->Final(fst_start_state) == Weight::Zero();
std::vector<StateId> state_map(fst_num_states); // fst state -> ans state
for (StateId s = 0; s < fst_num_states; s++) {
if (s == fst_start_state && share_start_state)
state_map[s] = loop_state;
else
state_map[s] = ans->AddState();
}
if (!share_start_state) {
Arc arc(0, i, Weight::One(), state_map[fst_start_state]);
cache[fst] = arc;
ans->AddArc(0, arc);
}
for (StateId s = 0; s < fst_num_states; s++) {
// Add arcs out of state s.
for (ArcIterator<ExpandedFst<Arc> > aiter(*fst, s); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
Label olabel = (s == fst_start_state && share_start_state ? i : 0);
Arc newarc(arc.ilabel, olabel, arc.weight, state_map[arc.nextstate]);
ans->AddArc(state_map[s], newarc);
if (s == fst_start_state && share_start_state) cache[fst] = newarc;
}
if (fst->Final(s) != Weight::Zero()) {
KALDI_ASSERT(!(s == fst_start_state && share_start_state));
ans->AddArc(state_map[s], Arc(0, 0, fst->Final(s), loop_state));
}
}
}
return ans;
}
template <class Arc>
void ClearSymbols(bool clear_input, bool clear_output, MutableFst<Arc> *fst) {
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
siter.Next()) {
typename Arc::StateId s = siter.Value();
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, s); !aiter.Done();
aiter.Next()) {
Arc arc = aiter.Value();
bool change = false;
if (clear_input && arc.ilabel != 0) {
arc.ilabel = 0;
change = true;
}
if (clear_output && arc.olabel != 0) {
arc.olabel = 0;
change = true;
}
if (change) {
aiter.SetValue(arc);
}
}
}
}
template <class Arc>
void ApplyProbabilityScale(float scale, MutableFst<Arc> *fst) {
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
siter.Next()) {
StateId s = siter.Value();
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, s); !aiter.Done();
aiter.Next()) {
Arc arc = aiter.Value();
arc.weight = Weight(arc.weight.Value() * scale);
aiter.SetValue(arc);
}
if (fst->Final(s) != Weight::Zero())
fst->SetFinal(s, Weight(fst->Final(s).Value() * scale));
}
}
// return arc-offset of self-loop with ilabel (or -1 if none exists).
// if more than one such self-loop, pick first one.
template <class Arc>
ssize_t FindSelfLoopWithILabel(const Fst<Arc> &fst, typename Arc::StateId s) {
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next())
if (aiter.Value().nextstate == s && aiter.Value().ilabel != 0)
return static_cast<ssize_t>(aiter.Position());
return static_cast<ssize_t>(-1);
}
template <class Arc>
bool EqualAlign(const Fst<Arc> &ifst, typename Arc::StateId length,
int rand_seed, MutableFst<Arc> *ofst, int num_retries) {
srand(rand_seed);
KALDI_ASSERT(ofst->NumStates() == 0); // make sure ofst empty.
// make sure all states can reach final-state (or this algorithm may enter
// infinite loop.
KALDI_ASSERT(ifst.Properties(kCoAccessible, true) == kCoAccessible);
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
if (ifst.Start() == kNoStateId) {
KALDI_WARN << "Empty input fst.";
return false;
}
// First select path through ifst.
std::vector<StateId> path;
std::vector<size_t> arc_offsets; // arc taken out of each state.
std::vector<int> nof_ilabels;
StateId num_ilabels = 0;
int retry_no = 0;
// Under normal circumstances, this will be one-pass-only process
// Multiple tries might be needed in special cases, typically when
// the number of frames is close to number of transitions from
// the start node to the final node. It usually happens for really
// short utterances
do {
num_ilabels = 0;
arc_offsets.clear();
path.clear();
path.push_back(ifst.Start());
while (1) {
// Select either an arc or final-prob.
StateId s = path.back();
size_t num_arcs = ifst.NumArcs(s);
size_t num_arcs_tot = num_arcs;
if (ifst.Final(s) != Weight::Zero()) num_arcs_tot++;
// kaldi::RandInt is a bit like Rand(), but gets around situations
// where RAND_MAX is very small.
// Change this to Rand() % num_arcs_tot if compile issues arise
size_t arc_offset =
static_cast<size_t>(kaldi::RandInt(0, num_arcs_tot - 1));
if (arc_offset < num_arcs) { // an actual arc.
ArcIterator<Fst<Arc> > aiter(ifst, s);
aiter.Seek(arc_offset);
const Arc &arc = aiter.Value();
if (arc.nextstate == s) {
continue; // don't take this self-loop arc
} else {
arc_offsets.push_back(arc_offset);
path.push_back(arc.nextstate);
if (arc.ilabel != 0) num_ilabels++;
}
} else {
break; // Chose final-prob.
}
}
nof_ilabels.push_back(num_ilabels);
} while ((++retry_no < num_retries) && (num_ilabels > length));
if (num_ilabels > length) {
std::stringstream ilabel_vec;
std::copy(nof_ilabels.begin(), nof_ilabels.end(),
std::ostream_iterator<int>(ilabel_vec, ","));
std::string s = ilabel_vec.str();
s.erase(s.end() - 1);
KALDI_WARN << "EqualAlign: the randomly constructed paths lengths: " << s;
KALDI_WARN << "EqualAlign: utterance has too few frames " << length
<< " to align.";
return false; // can't make it shorter by adding self-loops!.
}
StateId num_self_loops = 0;
std::vector<ssize_t> self_loop_offsets(path.size());
for (size_t i = 0; i < path.size(); i++)
if ((self_loop_offsets[i] = FindSelfLoopWithILabel(ifst, path[i])) !=
static_cast<ssize_t>(-1))
num_self_loops++;
if (num_self_loops == 0 && num_ilabels < length) {
KALDI_WARN << "No self-loops on chosen path; cannot match length.";
return false; // no self-loops to make it longer.
}
StateId num_extra = length - num_ilabels; // Number of self-loops we need.
StateId min_num_loops = 0;
if (num_extra != 0)
min_num_loops = num_extra / num_self_loops; // prevent div by zero.
StateId num_with_one_more_loop = num_extra - (min_num_loops * num_self_loops);
KALDI_ASSERT(num_with_one_more_loop < num_self_loops || num_self_loops == 0);
ofst->AddState();
ofst->SetStart(0);
StateId cur_state = 0;
StateId counter = 0; // tell us when we should stop adding one more loop.
for (size_t i = 0; i < path.size(); i++) {
// First, add any self-loops that are necessary.
StateId num_loops = 0;
if (self_loop_offsets[i] != static_cast<ssize_t>(-1)) {
num_loops = min_num_loops + (counter < num_with_one_more_loop ? 1 : 0);
counter++;
}
for (StateId j = 0; j < num_loops; j++) {
ArcIterator<Fst<Arc> > aiter(ifst, path[i]);
aiter.Seek(self_loop_offsets[i]);
Arc arc = aiter.Value();
KALDI_ASSERT(arc.nextstate == path[i] &&
arc.ilabel != 0); // make sure self-loop with ilabel.
StateId next_state = ofst->AddState();
ofst->AddArc(cur_state,
Arc(arc.ilabel, arc.olabel, arc.weight, next_state));
cur_state = next_state;
}
if (i + 1 < path.size()) { // add forward transition.
ArcIterator<Fst<Arc> > aiter(ifst, path[i]);
aiter.Seek(arc_offsets[i]);
Arc arc = aiter.Value();
KALDI_ASSERT(arc.nextstate == path[i + 1]);
StateId next_state = ofst->AddState();
ofst->AddArc(cur_state,
Arc(arc.ilabel, arc.olabel, arc.weight, next_state));
cur_state = next_state;
} else { // add final-prob.
Weight weight = ifst.Final(path[i]);
KALDI_ASSERT(weight != Weight::Zero());
ofst->SetFinal(cur_state, weight);
}
}
return true;
}
// This function identifies two types of useless arcs:
// those where arc A and arc B both go from state X to
// state Y with the same input symbol (remove the one
// with smaller probability, or an arbitrary one if they
// are the same); and those where A is an arc from state X
// to state X, with epsilon input symbol [remove A].
// Only works for tropical (not log) semiring as it uses
// NaturalLess.
template <class Arc>
void RemoveUselessArcs(MutableFst<Arc> *fst) {
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
NaturalLess<Weight> nl;
StateId non_coacc_state = kNoStateId;
size_t num_arcs_removed = 0, tot_arcs = 0;
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
siter.Next()) {
std::vector<size_t> arcs_to_delete;
std::vector<Arc> arcs;
// pair2arclist lets us look up the arcs
std::map<std::pair<Label, StateId>, std::vector<size_t> > pair2arclist;
StateId state = siter.Value();
for (ArcIterator<MutableFst<Arc> > aiter(*fst, state); !aiter.Done();
aiter.Next()) {
size_t pos = arcs.size();
const Arc &arc = aiter.Value();
arcs.push_back(arc);
pair2arclist[std::make_pair(arc.ilabel, arc.nextstate)].push_back(pos);
}
typename std::map<std::pair<Label, StateId>, std::vector<size_t> >::iterator
iter = pair2arclist.begin(),
end = pair2arclist.end();
for (; iter != end; ++iter) {
const std::vector<size_t> &poslist = iter->second;
if (poslist.size() > 1) { // >1 arc with same ilabel, dest-state
size_t best_pos = poslist[0];
Weight best_weight = arcs[best_pos].weight;
for (size_t j = 1; j < poslist.size(); j++) {
size_t pos = poslist[j];
Weight this_weight = arcs[pos].weight;
if (nl(this_weight,
best_weight)) { // NaturalLess seems to be somehow
// "backwards".
best_weight = this_weight; // found a better one.
best_pos = pos;
}
}
for (size_t j = 0; j < poslist.size(); j++)
if (poslist[j] != best_pos) arcs_to_delete.push_back(poslist[j]);
} else {
KALDI_ASSERT(poslist.size() == 1);
size_t pos = poslist[0];
Arc &arc = arcs[pos];
if (arc.ilabel == 0 && arc.nextstate == state)
arcs_to_delete.push_back(pos);
}
}
tot_arcs += arcs.size();
if (arcs_to_delete.size() != 0) {
num_arcs_removed += arcs_to_delete.size();
if (non_coacc_state == kNoStateId) non_coacc_state = fst->AddState();
MutableArcIterator<MutableFst<Arc> > maiter(fst, state);
for (size_t j = 0; j < arcs_to_delete.size(); j++) {
size_t pos = arcs_to_delete[j];
maiter.Seek(pos);
arcs[pos].nextstate = non_coacc_state;
maiter.SetValue(arcs[pos]);
}
}
}
if (non_coacc_state != kNoStateId) Connect(fst);
KALDI_VLOG(1) << "removed " << num_arcs_removed << " of " << tot_arcs
<< "arcs.";
}
template <class Arc>
void PhiCompose(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
typename Arc::Label phi_label, MutableFst<Arc> *ofst) {
KALDI_ASSERT(phi_label !=
kNoLabel); // just use regular compose in this case.
typedef Fst<Arc> F;
typedef PhiMatcher<SortedMatcher<F> > PM;
CacheOptions base_opts;
base_opts.gc_limit = 0; // Cache only the last state for fastest copy.
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
// The matcher for fst1 doesn't matter; we'll use fst2's matcher.
ComposeFstImplOptions<SortedMatcher<F>, PM> impl_opts(base_opts);
// the false below is something called phi_loop which is something I don't
// fully understand, but I don't think we want it.
// These pointers are taken ownership of, by ComposeFst.
PM *phi_matcher = new PM(fst2, MATCH_INPUT, phi_label, false);
SortedMatcher<F> *sorted_matcher =
new SortedMatcher<F>(fst1, MATCH_NONE); // tell it
// not to use this matcher, as this would mean we would
// not follow phi transitions.
impl_opts.matcher1 = sorted_matcher;
impl_opts.matcher2 = phi_matcher;
*ofst = ComposeFst<Arc>(fst1, fst2, impl_opts);
Connect(ofst);
}
template <class Arc>
void PropagateFinalInternal(typename Arc::Label phi_label,
typename Arc::StateId s, MutableFst<Arc> *fst) {
typedef typename Arc::Weight Weight;
if (fst->Final(s) == Weight::Zero()) {
// search for phi transition. We assume there
// is just one-- phi nondeterminism is not allowed
// anyway.
int num_phis = 0;
for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel == phi_label) {
num_phis++;
if (arc.nextstate == s) continue; // don't expect
// phi loops but ignore them anyway.
// If this recurses infinitely, it means there
// are loops of phi transitions, which there should
// not be in a normal backoff LM. We could make this
// routine work for this case, but currently there is
// no need.
PropagateFinalInternal(phi_label, arc.nextstate, fst);
if (fst->Final(arc.nextstate) != Weight::Zero())
fst->SetFinal(s, Times(fst->Final(arc.nextstate), arc.weight));
}
KALDI_ASSERT(num_phis <= 1 && "Phi nondeterminism found");
}
}
}
template <class Arc>
void PropagateFinal(typename Arc::Label phi_label, MutableFst<Arc> *fst) {
typedef typename Arc::StateId StateId;
if (fst->Properties(kIEpsilons, true)) // just warn.
KALDI_WARN << "PropagateFinal: this may not work as desired "
"since your FST has input epsilons.";
StateId num_states = fst->NumStates();
for (StateId s = 0; s < num_states; s++)
PropagateFinalInternal(phi_label, s, fst);
}
template <class Arc>
void RhoCompose(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
typename Arc::Label rho_label, MutableFst<Arc> *ofst) {
KALDI_ASSERT(rho_label !=
kNoLabel); // just use regular compose in this case.
typedef Fst<Arc> F;
typedef RhoMatcher<SortedMatcher<F> > RM;
CacheOptions base_opts;
base_opts.gc_limit = 0; // Cache only the last state for fastest copy.
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
// The matcher for fst1 doesn't matter; we'll use fst2's matcher.
ComposeFstImplOptions<SortedMatcher<F>, RM> impl_opts(base_opts);
// the false below is something called rho_loop which is something I don't
// fully understand, but I don't think we want it.
// These pointers are taken ownership of, by ComposeFst.
RM *rho_matcher = new RM(fst2, MATCH_INPUT, rho_label);
SortedMatcher<F> *sorted_matcher =
new SortedMatcher<F>(fst1, MATCH_NONE); // tell it
// not to use this matcher, as this would mean we would
// not follow rho transitions.
impl_opts.matcher1 = sorted_matcher;
impl_opts.matcher2 = rho_matcher;
*ofst = ComposeFst<Arc>(fst1, fst2, impl_opts);
Connect(ofst);
}
// Declare an override of the template below.
template <>
inline bool IsStochasticFst(const Fst<LogArc> &fst, float delta,
LogArc::Weight *min_sum, LogArc::Weight *max_sum);
// Will override this for LogArc where NaturalLess will not work.
template <class Arc>
inline bool IsStochasticFst(const Fst<Arc> &fst, float delta,
typename Arc::Weight *min_sum,
typename Arc::Weight *max_sum) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
NaturalLess<Weight> nl;
bool first_time = true;
bool ans = true;
if (min_sum) *min_sum = Arc::Weight::One();
if (max_sum) *max_sum = Arc::Weight::One();
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
Weight sum = fst.Final(s);
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
sum = Plus(sum, arc.weight);
}
if (!ApproxEqual(Weight::One(), sum, delta)) ans = false;
if (first_time) {
first_time = false;
if (max_sum) *max_sum = sum;
if (min_sum) *min_sum = sum;
} else {
if (max_sum && nl(*max_sum, sum)) *max_sum = sum;
if (min_sum && nl(sum, *min_sum)) *min_sum = sum;
}
}
if (first_time) { // just avoid NaNs if FST was empty.
if (max_sum) *max_sum = Weight::One();
if (min_sum) *min_sum = Weight::One();
}
return ans;
}
// Overriding template for LogArc as NaturalLess does not work there.
template <>
inline bool IsStochasticFst(const Fst<LogArc> &fst, float delta,
LogArc::Weight *min_sum, LogArc::Weight *max_sum) {
typedef LogArc Arc;
typedef Arc::StateId StateId;
typedef Arc::Weight Weight;
bool first_time = true;
bool ans = true;
if (min_sum) *min_sum = LogArc::Weight::One();
if (max_sum) *max_sum = LogArc::Weight::One();
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
Weight sum = fst.Final(s);
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
sum = Plus(sum, arc.weight);
}
if (!ApproxEqual(Weight::One(), sum, delta)) ans = false;
if (first_time) {
first_time = false;
if (max_sum) *max_sum = sum;
if (min_sum) *min_sum = sum;
} else {
// note that max and min are reversed from their normal
// meanings here (max and min w.r.t. the underlying probabilities).
if (max_sum && sum.Value() < max_sum->Value()) *max_sum = sum;
if (min_sum && sum.Value() > min_sum->Value()) *min_sum = sum;
}
}
if (first_time) { // just avoid NaNs if FST was empty.
if (max_sum) *max_sum = Weight::One();
if (min_sum) *min_sum = Weight::One();
}
return ans;
}
// Tests whether a tropical FST is stochastic in the log
// semiring. (casts it and does the check.)
// This function deals with the generic fst.
// This version currently supports ConstFst<StdArc> or VectorFst<StdArc>.
// Otherwise, it will be died with an error.
inline bool IsStochasticFstInLog(const Fst<StdArc> &fst, float delta,
StdArc::Weight *min_sum,
StdArc::Weight *max_sum) {
bool ans = false;
LogArc::Weight log_min = LogArc::Weight::One(),
log_max = LogArc::Weight::Zero();
if (fst.Type() == "const") {
ConstFst<LogArc> logfst;
Cast(dynamic_cast<const ConstFst<StdArc> &>(fst), &logfst);
ans = IsStochasticFst(logfst, delta, &log_min, &log_max);
} else if (fst.Type() == "vector") {
VectorFst<LogArc> logfst;
Cast(dynamic_cast<const VectorFst<StdArc> &>(fst), &logfst);
ans = IsStochasticFst(logfst, delta, &log_min, &log_max);
} else {
KALDI_ERR << "This version currently supports ConstFst<StdArc> "
<< "or VectorFst<StdArc>";
}
if (min_sum) *min_sum = StdArc::Weight(log_min.Value());
if (max_sum) *max_sum = StdArc::Weight(log_max.Value());
return ans;
}
} // namespace fst.
#endif // KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
// fstext/fstext-utils.h
// Copyright 2009-2011 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2013 Guoguo Chen
// 2014 Telepoint Global Hosting Service, LLC. (Author: David
// Snyder)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_FSTEXT_UTILS_H_
#define KALDI_FSTEXT_FSTEXT_UTILS_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "fstext/determinize-star.h"
#include "fstext/remove-eps-local.h"
#include "base/kaldi-common.h" // for error reporting macros.
#include "util/text-utils.h" // for SplitStringToVector
#include "fst/script/print-impl.h"
namespace fst {
/// Returns the highest numbered output symbol id of the FST (or zero
/// for an empty FST.
template <class Arc>
typename Arc::Label HighestNumberedOutputSymbol(const Fst<Arc> &fst);
/// Returns the highest numbered input symbol id of the FST (or zero
/// for an empty FST.
template <class Arc>
typename Arc::Label HighestNumberedInputSymbol(const Fst<Arc> &fst);
/// Returns the total number of arcs in an FST.
template <class Arc>
typename Arc::StateId NumArcs(const ExpandedFst<Arc> &fst);
/// GetInputSymbols gets the list of symbols on the input of fst
/// (including epsilon, if include_eps == true), as a sorted, unique
/// list.
template <class Arc, class I>
void GetInputSymbols(const Fst<Arc> &fst, bool include_eps,
std::vector<I> *symbols);
/// GetOutputSymbols gets the list of symbols on the output of fst
/// (including epsilon, if include_eps == true)
template <class Arc, class I>
void GetOutputSymbols(const Fst<Arc> &fst, bool include_eps,
std::vector<I> *symbols);
/// ClearSymbols sets all the symbols on the input and/or
/// output side of the FST to zero, as specified.
/// It does not alter the symbol tables.
template <class Arc>
void ClearSymbols(bool clear_input, bool clear_output, MutableFst<Arc> *fst);
template <class I>
void GetSymbols(const SymbolTable &symtab, bool include_eps,
std::vector<I> *syms_out);
inline void DeterminizeStarInLog(VectorFst<StdArc> *fst, float delta = kDelta,
bool *debug_ptr = NULL, int max_states = -1);
// e.g. of using this function: PushInLog<REWEIGHT_TO_INITIAL>(fst,
// kPushWeights|kPushLabels);
template <ReweightType rtype> // == REWEIGHT_TO_{INITIAL, FINAL}
void PushInLog(VectorFst<StdArc> *fst, uint32 ptype, float delta = kDelta) {
// PushInLog pushes the FST
// and returns a new pushed FST (labels and weights pushed to the left).
VectorFst<LogArc> *fst_log =
new VectorFst<LogArc>; // Want to determinize in log semiring.
Cast(*fst, fst_log);
VectorFst<StdArc> tmp;
*fst = tmp; // free up memory.
VectorFst<LogArc> *fst_pushed_log = new VectorFst<LogArc>;
Push<LogArc, rtype>(*fst_log, fst_pushed_log, ptype, delta);
Cast(*fst_pushed_log, fst);
delete fst_log;
delete fst_pushed_log;
}
// Minimizes after encoding; applicable to all FSTs. It is like what you get
// from the Minimize() function, except it will not push the weights, or the
// symbols. This is better for our recipes, as we avoid ever pushing the
// weights. However, it will only minimize optimally if your graphs are such
// that the symbols are as far to the left as they can go, and the weights
// in combinable paths are the same... hard to formalize this, but it's
// something that is satisified by our normal FSTs.
template <class Arc>
void MinimizeEncoded(VectorFst<Arc> *fst, float delta = kDelta) {
Map(fst, QuantizeMapper<Arc>(delta));
EncodeMapper<Arc> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
Encode(fst, &encoder);
internal::AcceptorMinimize(fst);
Decode(fst, encoder);
}
/// GetLinearSymbolSequence gets the symbol sequence from a linear FST.
/// If the FST is not just a linear sequence, it returns false. If it is
/// a linear sequence (including the empty FST), it returns true. In this
/// case it outputs the symbol
/// sequences as "isymbols_out" and "osymbols_out" (removing epsilons), and
/// the total weight as "tot_weight". The total weight will be Weight::Zero()
/// if the FST is empty. If any of the output pointers are NULL, it does not
/// create that output.
template <class Arc, class I>
bool GetLinearSymbolSequence(const Fst<Arc> &fst, std::vector<I> *isymbols_out,
std::vector<I> *osymbols_out,
typename Arc::Weight *tot_weight_out);
/// This function converts an FST with a special structure, which is
/// output by the OpenFst functions ShortestPath and RandGen, and converts
/// them into a std::vector of separate FSTs. This special structure is that
/// the only state that has more than one (arcs-out or final-prob) is the
/// start state. fsts_out is resized to the appropriate size.
template <class Arc>
void ConvertNbestToVector(const Fst<Arc> &fst,
std::vector<VectorFst<Arc> > *fsts_out);
/// Takes the n-shortest-paths (using ShortestPath), but outputs
/// the result as a vector of up to n fsts. This function will
/// size the "fsts_out" vector to however many paths it got
/// (which will not exceed n). n must be >= 1.
template <class Arc>
void NbestAsFsts(const Fst<Arc> &fst, size_t n,
std::vector<VectorFst<Arc> > *fsts_out);
/// Creates unweighted linear acceptor from symbol sequence.
template <class Arc, class I>
void MakeLinearAcceptor(const std::vector<I> &labels, MutableFst<Arc> *ofst);
/// Creates an unweighted acceptor with a linear structure, with alternatives
/// at each position. Epsilon is treated like a normal symbol here.
/// Each position in "labels" must have at least one alternative.
template <class Arc, class I>
void MakeLinearAcceptorWithAlternatives(
const std::vector<std::vector<I> > &labels, MutableFst<Arc> *ofst);
/// Does PreDeterminize and DeterminizeStar and then removes the disambiguation
/// symbols. This is a form of determinization that will never blow up. Note
/// that ifst is non-const and can be considered to be destroyed by this
/// operation.
/// Does not do epsilon removal (RemoveEpsLocal)-- this is so it's safe to cast
/// to log and do this, and maintain equivalence in tropical.
template <class Arc>
void SafeDeterminizeWrapper(MutableFst<Arc> *ifst, MutableFst<Arc> *ofst,
float delta = kDelta);
/// SafeDeterminizeMinimizeWapper is as SafeDeterminizeWrapper except that it
/// also minimizes (encoded minimization, which is safe). This algorithm will
/// destroy "ifst".
template <class Arc>
void SafeDeterminizeMinimizeWrapper(MutableFst<Arc> *ifst, VectorFst<Arc> *ofst,
float delta = kDelta);
/// SafeDeterminizeMinimizeWapperInLog is as SafeDeterminizeMinimizeWrapper
/// except it first casts tothe log semiring.
void SafeDeterminizeMinimizeWrapperInLog(VectorFst<StdArc> *ifst,
VectorFst<StdArc> *ofst,
float delta = kDelta);
/// RemoveSomeInputSymbols removes any symbol that appears in "to_remove", from
/// the input side of the FST, replacing them with epsilon.
template <class Arc, class I>
void RemoveSomeInputSymbols(const std::vector<I> &to_remove,
MutableFst<Arc> *fst);
// MapInputSymbols will replace any input symbol i that is between 0 and
// symbol_map.size()-1, with symbol_map[i]. It removes the input symbol
// table of the FST.
template <class Arc, class I>
void MapInputSymbols(const std::vector<I> &symbol_map, MutableFst<Arc> *fst);
template <class Arc>
void RemoveWeights(MutableFst<Arc> *fst);
/// Returns true if and only if the FST is such that the input symbols
/// on arcs entering any given state all have the same value.
/// if "start_is_epsilon", treat start-state as an epsilon input arc
/// [i.e. ensure only epsilon can enter start-state].
template <class Arc>
bool PrecedingInputSymbolsAreSame(bool start_is_epsilon, const Fst<Arc> &fst);
/// This is as PrecedingInputSymbolsAreSame, but with a functor f that maps
/// labels to classes. The function tests whether the symbols preceding any
/// given state are in the same class. Formally, f is of a type F that has an
/// operator of type F::Result F::operator() (F::Arg a) const; where F::Result
/// is an integer type and F::Arc can be constructed from Arc::Label. this must
/// apply to valid labels and also to kNoLabel (so we can have a marker for the
/// invalid labels.
template <class Arc, class F>
bool PrecedingInputSymbolsAreSameClass(bool start_is_epsilon,
const Fst<Arc> &fst, const F &f);
/// Returns true if and only if the FST is such that the input symbols
/// on arcs exiting any given state all have the same value.
/// If end_is_epsilon, treat end-state as an epsilon output arc [i.e. ensure
/// end-states cannot have non-epsilon output transitions.]
template <class Arc>
bool FollowingInputSymbolsAreSame(bool end_is_epsilon, const Fst<Arc> &fst);
template <class Arc, class F>
bool FollowingInputSymbolsAreSameClass(bool end_is_epsilon, const Fst<Arc> &fst,
const F &f);
/// MakePrecedingInputSymbolsSame ensures that all arcs entering any given fst
/// state have the same input symbol. It does this by detecting states
/// that have differing input symbols going in, and inserting, for each of
/// the preceding arcs with non-epsilon input symbol, a new dummy state that
/// has an epsilon link to the fst state.
/// If "start_is_epsilon", ensure that start-state can have only epsilon-links
/// into it.
template <class Arc>
void MakePrecedingInputSymbolsSame(bool start_is_epsilon, MutableFst<Arc> *fst);
/// As MakePrecedingInputSymbolsSame, but takes a functor object that maps
/// labels to classes.
template <class Arc, class F>
void MakePrecedingInputSymbolsSameClass(bool start_is_epsilon,
MutableFst<Arc> *fst, const F &f);
/// MakeFollowingInputSymbolsSame ensures that all arcs exiting any given fst
/// state have the same input symbol. It does this by detecting states that
/// have differing input symbols on arcs that exit it, and inserting, for each
/// of the following arcs with non-epsilon input symbol, a new dummy state that
/// has an input-epsilon link from the fst state. The output symbol and weight
/// stay on the link to the dummy state (in order to keep the FST
/// output-deterministic and stochastic, if it already was). If end_is_epsilon,
/// treat "being a final-state" like having an epsilon output link.
template <class Arc>
void MakeFollowingInputSymbolsSame(bool end_is_epsilon, MutableFst<Arc> *fst);
/// As MakeFollowingInputSymbolsSame, but takes a functor object that maps
/// labels to classes.
template <class Arc, class F>
void MakeFollowingInputSymbolsSameClass(bool end_is_epsilon,
MutableFst<Arc> *fst, const F &f);
/// MakeLoopFst creates an FST that has a state that is both initial and
/// final (weight == Weight::One()), and for each non-NULL pointer fsts[i],
/// it has an arc out whose output-symbol is i and which goes to a
/// sub-graph whose input language is equivalent to fsts[i], where the
/// final-state becomes a transition to the loop-state. Each fst in "fsts"
/// should be an acceptor. The fst MakeLoopFst returns is output-deterministic,
/// but not output-epsilon free necessarily, and arcs are sorted on output
/// label. Note: if some of the pointers in the input vector "fsts" have the
/// same value, "MakeLoopFst" uses this to speed up the computation.
/// Formally: suppose I is the set of indexes i such that fsts[i] != NULL.
/// Let L[i] be the language that the acceptor fsts[i] accepts.
/// Let the language K be the set of input-output pairs i:l such
/// that i in I and l in L[i]. Then the FST returned by MakeLoopFst
/// accepts the language K*, where * is the Kleene closure (CLOSURE_STAR)
/// of K.
/// We could have implemented this via a combination of "project",
/// "concat", "union" and "closure". But that FST would have been
/// less well optimized and would have a lot of final-states.
template <class Arc>
VectorFst<Arc> *MakeLoopFst(const std::vector<const ExpandedFst<Arc> *> &fsts);
/// ApplyProbabilityScale is applicable to FSTs in the log or tropical semiring.
/// It multiplies the arc and final weights by "scale" [this is not the Mul
/// operation of the semiring, it's actual multiplication, which is equivalent
/// to taking a power in the semiring].
template <class Arc>
void ApplyProbabilityScale(float scale, MutableFst<Arc> *fst);
/// EqualAlign is similar to RandGen, but it generates a sequence with exactly
/// "length" input symbols. It returns true on success, false on failure
/// (failure is partly random but should never happen in practice for normal
/// speech models.) It generates a random path through the input FST, finds out
/// which subset of the states it visits along the way have self-loops with
/// inupt symbols on them, and outputs a path with exactly enough self-loops to
/// have the requested number of input symbols. Note that EqualAlign does not
/// use the probabilities on the FST. It just uses equal probabilities in the
/// first stage of selection (since the output will anyway not be a truly random
/// sample from the FST). The input fst "ifst" must be connected or this may
/// enter an infinite loop.
template <class Arc>
bool EqualAlign(const Fst<Arc> &ifst, typename Arc::StateId length,
int rand_seed, MutableFst<Arc> *ofst, int num_retries = 10);
// RemoveUselessArcs removes arcs such that there is no input symbol
// sequence for which the best path through the FST would contain
// those arcs [for these purposes, epsilon is not treated as a real symbol].
// This is mainly geared towards decoding-graph FSTs which may contain
// transitions that have less likely words on them that would never be
// taken. We do not claim that this algorithm removes all such arcs;
// it just does the best job it can.
// Only works for tropical (not log) semiring as it uses
// NaturalLess.
template <class Arc>
void RemoveUselessArcs(MutableFst<Arc> *fst);
// PhiCompose is a version of composition where
// the right hand FST (fst2) is treated as a backoff
// LM, with the phi symbol (e.g. #0) treated as a
// "failure transition", only taken when we don't
// have a match for the requested symbol.
template <class Arc>
void PhiCompose(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
typename Arc::Label phi_label, MutableFst<Arc> *fst);
// PropagateFinal propagates final-probs through
// "phi" transitions (note that here, phi_label may
// be epsilon if you want). If you have a backoff LM
// with special symbols ("phi") on the backoff arcs
// instead of epsilon, you may use PhiCompose to compose
// with it, but this won't do the right thing w.r.t.
// final probabilities. You should first call PropagateFinal
// on the FST with phi's i it (fst2 in PhiCompose above),
// to fix this. If a state does not have a final-prob,
// but has a phi transition, it makes the state's final-prob
// (phi-prob * final-prob-of-dest-state), and does this
// recursively i.e. follows phi transitions on the dest state
// first. It behaves as if there were a super-final state
// with a special symbol leading to it, from each currently
// final state. Note that this may not behave as desired
// if there are epsilons in your FST; it might be better
// to remove those before calling this function.
template <class Arc>
void PropagateFinal(typename Arc::Label phi_label, MutableFst<Arc> *fst);
// PhiCompose is a version of composition where
// the right hand FST (fst2) has speciall "rho transitions"
// which are taken whenever no normal transition matches; these
// transitions will be rewritten with whatever symbol was on
// the first FST.
template <class Arc>
void RhoCompose(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
typename Arc::Label rho_label, MutableFst<Arc> *fst);
/** This function returns true if, in the semiring of the FST, the sum (within
the semiring) of all the arcs out of each state in the FST is one, to within
delta. After MakeStochasticFst, this should be true (for a connected FST).
@param fst [in] the FST that we are testing.
@param delta [in] the tolerance to within which we test equality to 1.
@param min_sum [out] if non, NULL, contents will be set to the minimum sum
of weights.
@param max_sum [out] if non, NULL, contents will be set to the maximum sum
of weights.
@return Returns true if the FST is stochastic, and false otherwise.
*/
template <class Arc>
bool IsStochasticFst(const Fst<Arc> &fst,
float delta = kDelta, // kDelta = 1.0/1024.0 by default.
typename Arc::Weight *min_sum = NULL,
typename Arc::Weight *max_sum = NULL);
// IsStochasticFstInLog makes sure it's stochastic after casting to log.
inline bool IsStochasticFstInLog(
const Fst<StdArc> &fst,
float delta = kDelta, // kDelta = 1.0/1024.0 by default.
StdArc::Weight *min_sum = NULL, StdArc::Weight *max_sum = NULL);
} // end namespace fst
#include "fstext/fstext-utils-inl.h"
#endif // KALDI_FSTEXT_FSTEXT_UTILS_H_
// fstext/kaldi-fst-io-inl.h
// Copyright 2009-2011 Microsoft Corporation
// 2012-2015 Johns Hopkins University (Author: Daniel Povey)
// 2013 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_KALDI_FST_IO_INL_H_
#define KALDI_FSTEXT_KALDI_FST_IO_INL_H_
#include <string>
#include <vector>
#include "util/text-utils.h"
namespace fst {
template <class Arc>
void WriteFstKaldi(std::ostream &os, bool binary, const VectorFst<Arc> &t) {
bool ok;
if (binary) {
// Binary-mode writing.
ok = t.Write(os, FstWriteOptions());
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached. Write a
// newline to start the FST; in a table, the first line of the FST will
// appear on its own line.
os << '\n';
bool acceptor = false, write_one = false;
FstPrinter<Arc> printer(t, t.InputSymbols(), t.OutputSymbols(), NULL,
acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail()) KALDI_ERR << "Stream failure detected writing FST to stream";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not something in the original
// OpenFst code].
os << '\n';
ok = os.good();
}
if (!ok) {
KALDI_ERR << "Error writing FST to stream";
}
}
// Utility function used in ReadFstKaldi
template <class W>
inline bool StrToWeight(const std::string &s, bool allow_zero, W *w) {
std::istringstream strm(s);
strm >> *w;
if (strm.fail() || (!allow_zero && *w == W::Zero())) {
return false;
}
return true;
}
template <class Arc>
void ReadFstKaldi(std::istream &is, bool binary, VectorFst<Arc> *fst) {
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
if (binary) {
// We don't have access to the filename here, so write [unknown].
VectorFst<Arc> *ans =
VectorFst<Arc>::Read(is, fst::FstReadOptions(std::string("[unknown]")));
if (ans == NULL) {
KALDI_ERR << "Error reading FST from stream.";
}
*fst = *ans; // shallow copy.
delete ans;
} else {
// Consume the \r on Windows, the \n that the text-form FST format starts
// with, and any extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') {
is.get(); // consume the newline.
} else { // saw spaces but no newline.. this is not expected.
KALDI_ERR << "Reading FST: unexpected sequence of spaces "
<< " at file position " << is.tellg();
}
using kaldi::ConvertStringToInteger;
using kaldi::SplitStringToIntegers;
using std::string;
using std::vector;
fst->DeleteStates();
string line;
size_t nline = 0;
string separator = FLAGS_fst_field_separator + "\r\n";
while (std::getline(is, line)) {
nline++;
vector<string> col;
// on Windows we'll write in text and read in binary mode.
kaldi::SplitStringToVector(line, separator.c_str(), true, &col);
if (col.size() == 0) break; // Empty line is a signal to stop, in our
// archive format.
if (col.size() > 5) {
KALDI_ERR << "Bad line in FST: " << line;
}
StateId s;
if (!ConvertStringToInteger(col[0], &s)) {
KALDI_ERR << "Bad line in FST: " << line;
}
while (s >= fst->NumStates()) fst->AddState();
if (nline == 1) fst->SetStart(s);
bool ok = true;
Arc arc;
Weight w;
StateId d = s;
switch (col.size()) {
case 1:
fst->SetFinal(s, Weight::One());
break;
case 2:
if (!StrToWeight(col[1], true, &w))
ok = false;
else
fst->SetFinal(s, w);
break;
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
ok = false;
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel);
if (ok) {
d = arc.nextstate;
arc.weight = Weight::One();
fst->AddArc(s, arc);
}
break;
case 5:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel) &&
StrToWeight(col[4], false, &arc.weight);
if (ok) {
d = arc.nextstate;
fst->AddArc(s, arc);
}
break;
default:
ok = false;
}
while (d >= fst->NumStates()) fst->AddState();
if (!ok) KALDI_ERR << "Bad line in FST: " << line;
}
}
}
template <class Arc> // static
bool VectorFstTplHolder<Arc>::Write(std::ostream &os, bool binary, const T &t) {
try {
WriteFstKaldi(os, binary, t);
return true;
} catch (...) {
return false;
}
}
template <class Arc> // static
bool VectorFstTplHolder<Arc>::Read(std::istream &is) {
Clear();
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading Fst";
return false;
} else if (isspace(c)) { // The text form of the FST begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is
// not space).
try {
t_ = new VectorFst<Arc>();
ReadFstKaldi(is, false, t_);
} catch (...) {
Clear();
return false;
}
} else { // reading a binary FST.
try {
t_ = new VectorFst<Arc>();
ReadFstKaldi(is, true, t_);
} catch (...) {
Clear();
return false;
}
}
return true;
}
} // namespace fst.
#endif // KALDI_FSTEXT_KALDI_FST_IO_INL_H_
// fstext/kaldi-fst-io.cc
// Copyright 2009-2011 Microsoft Corporation
// 2012-2015 Johns Hopkins University (Author: Daniel Povey)
// 2013 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "fstext/kaldi-fst-io.h"
#include <string>
#include "base/kaldi-error.h"
#include "base/kaldi-math.h"
#include "util/kaldi-io.h"
namespace fst {
VectorFst<StdArc> *ReadFstKaldi(std::string rxfilename) {
if (rxfilename == "") rxfilename = "-"; // interpret "" as stdin,
// for compatibility with OpenFst conventions.
kaldi::Input ki(rxfilename);
fst::FstHeader hdr;
if (!hdr.Read(ki.Stream(), rxfilename))
KALDI_ERR << "Reading FST: error reading FST header from "
<< kaldi::PrintableRxfilename(rxfilename);
FstReadOptions ropts("<unspecified>", &hdr);
VectorFst<StdArc> *fst = VectorFst<StdArc>::Read(ki.Stream(), ropts);
if (!fst)
KALDI_ERR << "Could not read fst from "
<< kaldi::PrintableRxfilename(rxfilename);
return fst;
}
// Register const fst to load it automatically. Other types like
// olabel_lookahead or ngram or compact_fst should be registered
// through OpenFst registration API.
static fst::FstRegisterer<VectorFst<StdArc>> VectorFst_StdArc_registerer;
static fst::FstRegisterer<ConstFst<StdArc>> ConstFst_StdArc_registerer;
Fst<StdArc> *ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err) {
if (rxfilename == "") rxfilename = "-"; // interpret "" as stdin,
// for compatibility with OpenFst conventions.
kaldi::Input ki(rxfilename);
fst::FstHeader hdr;
// Read FstHeader which contains the type of FST
if (!hdr.Read(ki.Stream(), rxfilename)) {
if (throw_on_err) {
KALDI_ERR << "Reading FST: error reading FST header from "
<< kaldi::PrintableRxfilename(rxfilename);
} else {
KALDI_WARN << "We fail to read FST header from "
<< kaldi::PrintableRxfilename(rxfilename)
<< ". A NULL pointer is returned.";
return NULL;
}
}
// Check the type of Arc
if (hdr.ArcType() != fst::StdArc::Type()) {
if (throw_on_err) {
KALDI_ERR << "FST with arc type " << hdr.ArcType()
<< " is not supported.";
} else {
KALDI_WARN << "Fst with arc type" << hdr.ArcType()
<< " is not supported. A NULL pointer is returned.";
return NULL;
}
}
// Read the FST
FstReadOptions ropts("<unspecified>", &hdr);
Fst<StdArc> *fst = Fst<StdArc>::Read(ki.Stream(), ropts);
if (!fst) {
if (throw_on_err) {
KALDI_ERR << "Could not read fst from "
<< kaldi::PrintableRxfilename(rxfilename);
} else {
KALDI_WARN << "Could not read fst from "
<< kaldi::PrintableRxfilename(rxfilename)
<< ". A NULL pointer is returned.";
return NULL;
}
}
return fst;
}
VectorFst<StdArc> *CastOrConvertToVectorFst(Fst<StdArc> *fst) {
// This version currently supports ConstFst<StdArc> or VectorFst<StdArc>
std::string real_type = fst->Type();
KALDI_ASSERT(real_type == "vector" || real_type == "const");
if (real_type == "vector") {
return dynamic_cast<VectorFst<StdArc> *>(fst);
} else {
// As the 'fst' can't cast to VectorFst, we create a new
// VectorFst<StdArc> initialized by 'fst', and delete 'fst'.
VectorFst<StdArc> *new_fst = new VectorFst<StdArc>(*fst);
delete fst;
return new_fst;
}
}
void ReadFstKaldi(std::string rxfilename, fst::StdVectorFst *ofst) {
fst::StdVectorFst *fst = ReadFstKaldi(rxfilename);
*ofst = *fst;
delete fst;
}
void WriteFstKaldi(const VectorFst<StdArc> &fst, std::string wxfilename) {
if (wxfilename == "") wxfilename = "-"; // interpret "" as stdout,
// for compatibility with OpenFst conventions.
bool write_binary = true, write_header = false;
kaldi::Output ko(wxfilename, write_binary, write_header);
FstWriteOptions wopts(kaldi::PrintableWxfilename(wxfilename));
fst.Write(ko.Stream(), wopts);
}
fst::VectorFst<fst::StdArc> *ReadAndPrepareLmFst(std::string rxfilename) {
// ReadFstKaldi() will die with exception on failure.
fst::VectorFst<fst::StdArc> *ans = fst::ReadFstKaldi(rxfilename);
if (ans->Properties(fst::kAcceptor, true) == 0) {
// If it's not already an acceptor, project on the output, i.e. copy olabels
// to ilabels. Generally the G.fst's on disk will have the disambiguation
// symbol #0 on the input symbols of the backoff arc, and projection will
// replace them with epsilons which is what is on the output symbols of
// those arcs.
fst::Project(ans, fst::PROJECT_OUTPUT);
}
if (ans->Properties(fst::kILabelSorted, true) == 0) {
// Make sure LM is sorted on ilabel.
fst::ILabelCompare<fst::StdArc> ilabel_comp;
fst::ArcSort(ans, ilabel_comp);
}
return ans;
}
} // end namespace fst
// fstext/kaldi-fst-io.h
// Copyright 2009-2011 Microsoft Corporation
// 2012-2015 Johns Hopkins University (Author: Daniel Povey)
// 2013 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_KALDI_FST_IO_H_
#define KALDI_FSTEXT_KALDI_FST_IO_H_
#include <string>
#include <utility>
#include "fst/fst-decl.h"
#include "fst/fstlib.h"
#include "fst/script/print-impl.h"
#include "base/kaldi-common.h"
// Some functions for writing Fsts.
// I/O for FSTs is a bit of a mess, and not very well integrated with Kaldi's
// generic I/O mechanisms, because we want files containing just FSTs to
// be readable by OpenFST's native binaries, which is not compatible
// with the normal \0B header that identifies Kaldi files as containing
// binary data.
// So use the functions here with your eyes open, and with caution!
namespace fst {
// Read a binary FST using Kaldi I/O mechanisms (pipes, etc.)
// On error returns NULL. Only supports VectorFst and exists
// mainly for backward code compabibility.
VectorFst<StdArc> *ReadFstKaldi(std::string rxfilename);
// Read a binary FST using Kaldi I/O mechanisms (pipes, etc.)
// If it can't read the FST, if throw_on_err == true it throws using KALDI_ERR;
// otherwise it prints a warning and returns. Note:this
// doesn't support the text-mode option that we generally like to support.
// This version currently supports ConstFst<StdArc> or VectorFst<StdArc>
// (const-fst can give better performance for decoding). Other
// types could be also loaded if registered inside OpenFst.
Fst<StdArc> *ReadFstKaldiGeneric(std::string rxfilename,
bool throw_on_err = true);
// This function attempts to dynamic_cast the pointer 'fst' (which will likely
// have been returned by ReadFstGeneric()), to the more derived
// type VectorFst<StdArc>. If this succeeds, it returns the same pointer;
// if it fails, it converts the FST type (by creating a new VectorFst<stdArc>
// initialized by 'fst'), prints a warning, and deletes 'fst'.
VectorFst<StdArc> *CastOrConvertToVectorFst(Fst<StdArc> *fst);
// Version of ReadFstKaldi() that writes to a pointer. Assumes
// the FST is binary with no binary marker. Crashes on error.
void ReadFstKaldi(std::string rxfilename, VectorFst<StdArc> *ofst);
// Write an FST using Kaldi I/O mechanisms (pipes, etc.)
// On error, throws using KALDI_ERR. For use only in code in fstbin/,
// as it doesn't support the text-mode option.
void WriteFstKaldi(const VectorFst<StdArc> &fst, std::string wxfilename);
// This is a more general Kaldi-type-IO mechanism of writing FSTs to
// streams, supporting binary or text-mode writing. (note: we just
// write the integers, symbol tables are not supported).
// On error, throws using KALDI_ERR.
template <class Arc>
void WriteFstKaldi(std::ostream &os, bool binary, const VectorFst<Arc> &fst);
// A generic Kaldi-type-IO mechanism of reading FSTs from streams,
// supporting binary or text-mode reading/writing.
template <class Arc>
void ReadFstKaldi(std::istream &is, bool binary, VectorFst<Arc> *fst);
// Read an FST file for LM (G.fst) and make it an acceptor,
// and make sure it is sorted on labels
fst::VectorFst<fst::StdArc> *ReadAndPrepareLmFst(std::string rxfilename);
// This is a Holder class with T = VectorFst<Arc>, that meets the requirements
// of a Holder class as described in ../util/kaldi-holder.h. This enables us to
// read/write collections of FSTs indexed by strings, using the Table concept (
// see ../util/kaldi-table.h).
// Originally it was only templated on T = VectorFst<StdArc>, but as the keyword
// spotting stuff introduced more types of FSTs, we made it also templated on
// the arc.
template <class Arc>
class VectorFstTplHolder {
public:
typedef VectorFst<Arc> T;
VectorFstTplHolder() : t_(NULL) {}
static bool Write(std::ostream &os, bool binary, const T &t);
void Copy(const T &t) { // copies it into the holder.
Clear();
t_ = new T(t);
}
// Reads into the holder.
bool Read(std::istream &is);
// It's potentially a binary format, so must read in binary mode (linefeed
// translation will corrupt the file. We don't know till we open the file if
// it's really binary, so we need to read in binary mode to be on the safe
// side. Extra linefeeds won't matter, the text-mode reading code ignores
// them.
static bool IsReadInBinary() { return true; }
T &Value() {
// code error if !t_.
if (!t_) KALDI_ERR << "VectorFstTplHolder::Value() called wrongly.";
return *t_;
}
void Clear() {
if (t_) {
delete t_;
t_ = NULL;
}
}
void Swap(VectorFstTplHolder<Arc> *other) { std::swap(t_, other->t_); }
bool ExtractRange(const VectorFstTplHolder<Arc> &other,
const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~VectorFstTplHolder() { Clear(); }
// No destructor. Assignment and
// copy constructor take their default implementations.
private:
KALDI_DISALLOW_COPY_AND_ASSIGN(VectorFstTplHolder);
T *t_;
};
// Now make the original VectorFstHolder as the typedef of
// VectorFstHolder<StdArc>.
typedef VectorFstTplHolder<StdArc> VectorFstHolder;
} // end namespace fst
#include "fstext/kaldi-fst-io-inl.h"
#endif // KALDI_FSTEXT_KALDI_FST_IO_H_
// fstext/lattice-utils-inl.h
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author:
// Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_LATTICE_UTILS_INL_H_
#define KALDI_FSTEXT_LATTICE_UTILS_INL_H_
// Do not include this file directly. It is included by lattice-utils.h
#include <utility>
#include <vector>
namespace fst {
/* Convert from FST with arc-type Weight, to one with arc-type
CompactLatticeWeight. Uses FactorFst to identify chains
of states which can be turned into a single output arc. */
template <class Weight, class Int>
void ConvertLattice(
const ExpandedFst<ArcTpl<Weight> > &ifst,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *ofst,
bool invert) {
typedef ArcTpl<Weight> Arc;
typedef typename Arc::StateId StateId;
typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight;
typedef ArcTpl<CompactWeight> CompactArc;
VectorFst<ArcTpl<Weight> > ffst;
std::vector<std::vector<Int> > labels;
if (invert) { // normal case: want the ilabels as sequences on the arcs of
Factor(ifst, &ffst, &labels); // the output... Factor makes seqs of
// ilabels.
} else {
VectorFst<ArcTpl<Weight> > invfst(ifst);
Invert(&invfst);
Factor(invfst, &ffst, &labels);
}
TopSort(&ffst); // Put the states in ffst in topological order, which is
// easier on the eye when reading the text-form lattices and corresponds to
// what we get when we generate the lattices in the decoder.
ofst->DeleteStates();
// The states will be numbered exactly the same as the original FST.
// Add the states to the new FST.
StateId num_states = ffst.NumStates();
for (StateId s = 0; s < num_states; s++) {
StateId news = ofst->AddState();
assert(news == s);
}
ofst->SetStart(ffst.Start());
for (StateId s = 0; s < num_states; s++) {
Weight final_weight = ffst.Final(s);
if (final_weight != Weight::Zero()) {
CompactWeight final_compact_weight(final_weight, std::vector<Int>());
ofst->SetFinal(s, final_compact_weight);
}
for (ArcIterator<ExpandedFst<Arc> > iter(ffst, s); !iter.Done();
iter.Next()) {
const Arc &arc = iter.Value();
KALDI_PARANOID_ASSERT(arc.weight != Weight::Zero());
// note: zero-weight arcs not allowed anyway so weight should not be zero,
// but no harm in checking.
CompactArc compact_arc(arc.olabel, arc.olabel,
CompactWeight(arc.weight, labels[arc.ilabel]),
arc.nextstate);
ofst->AddArc(s, compact_arc);
}
}
}
template <class Weight, class Int>
void ConvertLattice(
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &ifst,
MutableFst<ArcTpl<Weight> > *ofst, bool invert) {
typedef ArcTpl<Weight> Arc;
typedef typename Arc::StateId StateId;
typedef typename Arc::Label Label;
typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight;
typedef ArcTpl<CompactWeight> CompactArc;
ofst->DeleteStates();
// make the states in the new FST have the same numbers as
// the original ones, and add chains of states as necessary
// to encode the string-valued weights.
StateId num_states = ifst.NumStates();
for (StateId s = 0; s < num_states; s++) {
StateId news = ofst->AddState();
assert(news == s);
}
ofst->SetStart(ifst.Start());
for (StateId s = 0; s < num_states; s++) {
CompactWeight final_weight = ifst.Final(s);
if (final_weight != CompactWeight::Zero()) {
StateId cur_state = s;
size_t string_length = final_weight.String().size();
for (size_t n = 0; n < string_length; n++) {
StateId next_state = ofst->AddState();
Label ilabel = 0;
Arc arc(ilabel, final_weight.String()[n],
(n == 0 ? final_weight.Weight() : Weight::One()), next_state);
if (invert) std::swap(arc.ilabel, arc.olabel);
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
ofst->SetFinal(cur_state,
string_length > 0 ? Weight::One() : final_weight.Weight());
}
for (ArcIterator<ExpandedFst<CompactArc> > iter(ifst, s); !iter.Done();
iter.Next()) {
const CompactArc &arc = iter.Value();
size_t string_length = arc.weight.String().size();
StateId cur_state = s;
// for all but the last element in the string--
// add a temporary state.
for (size_t n = 0; n + 1 < string_length; n++) {
StateId next_state = ofst->AddState();
Label ilabel = (n == 0 ? arc.ilabel : 0),
olabel = static_cast<Label>(arc.weight.String()[n]);
Weight weight = (n == 0 ? arc.weight.Weight() : Weight::One());
Arc new_arc(ilabel, olabel, weight, next_state);
if (invert) std::swap(new_arc.ilabel, new_arc.olabel);
ofst->AddArc(cur_state, new_arc);
cur_state = next_state;
}
Label ilabel = (string_length <= 1 ? arc.ilabel : 0),
olabel = (string_length > 0 ? arc.weight.String()[string_length - 1]
: 0);
Weight weight =
(string_length <= 1 ? arc.weight.Weight() : Weight::One());
Arc new_arc(ilabel, olabel, weight, arc.nextstate);
if (invert) std::swap(new_arc.ilabel, new_arc.olabel);
ofst->AddArc(cur_state, new_arc);
}
}
}
// This function converts lattices between float and double;
// it works for both CompactLatticeWeight and LatticeWeight.
template <class WeightIn, class WeightOut>
void ConvertLattice(const ExpandedFst<ArcTpl<WeightIn> > &ifst,
MutableFst<ArcTpl<WeightOut> > *ofst) {
typedef ArcTpl<WeightIn> ArcIn;
typedef ArcTpl<WeightOut> ArcOut;
typedef typename ArcIn::StateId StateId;
ofst->DeleteStates();
// The states will be numbered exactly the same as the original FST.
// Add the states to the new FST.
StateId num_states = ifst.NumStates();
for (StateId s = 0; s < num_states; s++) {
StateId news = ofst->AddState();
assert(news == s);
}
ofst->SetStart(ifst.Start());
for (StateId s = 0; s < num_states; s++) {
WeightIn final_iweight = ifst.Final(s);
if (final_iweight != WeightIn::Zero()) {
WeightOut final_oweight;
ConvertLatticeWeight(final_iweight, &final_oweight);
ofst->SetFinal(s, final_oweight);
}
for (ArcIterator<ExpandedFst<ArcIn> > iter(ifst, s); !iter.Done();
iter.Next()) {
ArcIn arc = iter.Value();
KALDI_PARANOID_ASSERT(arc.weight != WeightIn::Zero());
ArcOut oarc;
ConvertLatticeWeight(arc.weight, &oarc.weight);
oarc.ilabel = arc.ilabel;
oarc.olabel = arc.olabel;
oarc.nextstate = arc.nextstate;
ofst->AddArc(s, oarc);
}
}
}
template <class Weight, class ScaleFloat>
void ScaleLattice(const std::vector<std::vector<ScaleFloat> > &scale,
MutableFst<ArcTpl<Weight> > *fst) {
assert(scale.size() == 2 && scale[0].size() == 2 && scale[1].size() == 2);
if (scale == DefaultLatticeScale()) // nothing to do.
return;
typedef ArcTpl<Weight> Arc;
typedef MutableFst<Arc> Fst;
typedef typename Arc::StateId StateId;
StateId num_states = fst->NumStates();
for (StateId s = 0; s < num_states; s++) {
for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
Arc arc = aiter.Value();
arc.weight = Weight(ScaleTupleWeight(arc.weight, scale));
aiter.SetValue(arc);
}
Weight final_weight = fst->Final(s);
if (final_weight != Weight::Zero())
fst->SetFinal(s, Weight(ScaleTupleWeight(final_weight, scale)));
}
}
template <class Weight, class Int>
void RemoveAlignmentsFromCompactLattice(
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *fst) {
typedef CompactLatticeWeightTpl<Weight, Int> W;
typedef ArcTpl<W> Arc;
typedef MutableFst<Arc> Fst;
typedef typename Arc::StateId StateId;
StateId num_states = fst->NumStates();
for (StateId s = 0; s < num_states; s++) {
for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
Arc arc = aiter.Value();
arc.weight = W(arc.weight.Weight(), std::vector<Int>());
aiter.SetValue(arc);
}
W final_weight = fst->Final(s);
if (final_weight != W::Zero())
fst->SetFinal(s, W(final_weight.Weight(), std::vector<Int>()));
}
}
template <class Weight, class Int>
bool CompactLatticeHasAlignment(
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &fst) {
typedef CompactLatticeWeightTpl<Weight, Int> W;
typedef ArcTpl<W> Arc;
typedef ExpandedFst<Arc> Fst;
typedef typename Arc::StateId StateId;
StateId num_states = fst.NumStates();
for (StateId s = 0; s < num_states; s++) {
for (ArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (!arc.weight.String().empty()) return true;
}
W final_weight = fst.Final(s);
if (!final_weight.String().empty()) return true;
}
return false;
}
template <class Real>
void ConvertFstToLattice(const ExpandedFst<ArcTpl<TropicalWeight> > &ifst,
MutableFst<ArcTpl<LatticeWeightTpl<Real> > > *ofst) {
int32 num_states_cache = 50000;
fst::CacheOptions cache_opts(true, num_states_cache);
fst::MapFstOptions mapfst_opts(cache_opts);
StdToLatticeMapper<Real> mapper;
MapFst<StdArc, ArcTpl<LatticeWeightTpl<Real> >, StdToLatticeMapper<Real> >
map_fst(ifst, mapper, mapfst_opts);
*ofst = map_fst;
}
} // namespace fst
#endif // KALDI_FSTEXT_LATTICE_UTILS_INL_H_
// fstext/lattice-utils.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_LATTICE_UTILS_H_
#define KALDI_FSTEXT_LATTICE_UTILS_H_
#include <vector>
#include "fst/fstlib.h"
#include "fstext/lattice-weight.h"
namespace fst {
// The template ConvertLattice does conversions to and from
// LatticeWeight FSTs and CompactLatticeWeight FSTs, and
// between float and double, and to convert from LatticeWeight
// to TropicalWeight. It's used in the I/O code for lattices,
// and for converting lattices to standard FSTs (e.g. for creating
// decoding graphs from lattices).
/**
Convert lattice from a normal FST to a CompactLattice FST.
This is a bit like converting to the Gallic semiring, except
the semiring behaves in a different way (designed to take
the best path).
Note: the ilabels end up as the symbols on the arcs of the
output acceptor, and the olabels go to the strings. To make
it the other way around (useful for the speech-recognition
application), set invert=true [the default].
*/
template <class Weight, class Int>
void ConvertLattice(
const ExpandedFst<ArcTpl<Weight> > &ifst,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *ofst,
bool invert = true);
/**
Convert lattice CompactLattice format to Lattice. This is a bit
like converting from the Gallic semiring. As for any CompactLattice, "ifst"
must be an acceptor (i.e., ilabels and olabels should be identical). If
invert=false, the labels on "ifst" become the ilabels on "ofst" and the
strings in the weights of "ifst" becomes the olabels. If invert=true
[default], this is reversed (useful for speech recognition lattices; our
standard non-compact format has the words on the output side to match HCLG).
*/
template <class Weight, class Int>
void ConvertLattice(
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &ifst,
MutableFst<ArcTpl<Weight> > *ofst, bool invert = true);
/**
Convert between CompactLattices and Lattices of different floating point
types... this works between any pair of weight types for which
ConvertLatticeWeight is defined (c.f. lattice-weight.h), and also includes
conversion from LatticeWeight to TropicalWeight.
*/
template <class WeightIn, class WeightOut>
void ConvertLattice(const ExpandedFst<ArcTpl<WeightIn> > &ifst,
MutableFst<ArcTpl<WeightOut> > *ofst);
// Now define some ConvertLattice functions that require two phases of
// conversion (don't bother coding these separately as they will be used rarely.
// Lattice with float to CompactLattice with double.
template <class Int>
void ConvertLattice(
const ExpandedFst<ArcTpl<LatticeWeightTpl<float> > > &ifst,
MutableFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<double>, Int> > >
*ofst) {
VectorFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<float>, Int> > >
fst;
ConvertLattice(ifst, &fst);
ConvertLattice(fst, ofst);
}
// Lattice with double to CompactLattice with float.
template <class Int>
void ConvertLattice(
const ExpandedFst<ArcTpl<LatticeWeightTpl<double> > > &ifst,
MutableFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<float>, Int> > >
*ofst) {
VectorFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<double>, Int> > >
fst;
ConvertLattice(ifst, &fst);
ConvertLattice(fst, ofst);
}
/// Converts CompactLattice with double to Lattice with float.
template <class Int>
void ConvertLattice(
const ExpandedFst<
ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<double>, Int> > > &ifst,
MutableFst<ArcTpl<LatticeWeightTpl<float> > > *ofst) {
VectorFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<float>, Int> > >
fst;
ConvertLattice(ifst, &fst);
ConvertLattice(fst, ofst);
}
/// Converts CompactLattice with float to Lattice with double.
template <class Int>
void ConvertLattice(
const ExpandedFst<
ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<float>, Int> > > &ifst,
MutableFst<ArcTpl<LatticeWeightTpl<double> > > *ofst) {
VectorFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<double>, Int> > >
fst;
ConvertLattice(ifst, &fst);
ConvertLattice(fst, ofst);
}
/// Converts TropicalWeight to LatticeWeight (puts all the weight on
/// the first float in the lattice's pair).
template <class Real>
void ConvertFstToLattice(const ExpandedFst<ArcTpl<TropicalWeight> > &ifst,
MutableFst<ArcTpl<LatticeWeightTpl<Real> > > *ofst);
/** Returns a default 2x2 matrix scaling factor for LatticeWeight */
inline std::vector<std::vector<double> > DefaultLatticeScale() {
std::vector<std::vector<double> > ans(2);
ans[0].resize(2, 0.0);
ans[1].resize(2, 0.0);
ans[0][0] = ans[1][1] = 1.0;
return ans;
}
inline std::vector<std::vector<double> > AcousticLatticeScale(double acwt) {
std::vector<std::vector<double> > ans(2);
ans[0].resize(2, 0.0);
ans[1].resize(2, 0.0);
ans[0][0] = 1.0;
ans[1][1] = acwt;
return ans;
}
inline std::vector<std::vector<double> > GraphLatticeScale(double lmwt) {
std::vector<std::vector<double> > ans(2);
ans[0].resize(2, 0.0);
ans[1].resize(2, 0.0);
ans[0][0] = lmwt;
ans[1][1] = 1.0;
return ans;
}
inline std::vector<std::vector<double> > LatticeScale(double lmwt,
double acwt) {
std::vector<std::vector<double> > ans(2);
ans[0].resize(2, 0.0);
ans[1].resize(2, 0.0);
ans[0][0] = lmwt;
ans[1][1] = acwt;
return ans;
}
/** Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by
viewing the pair (a, b) as a 2-vector and pre-multiplying by the 2x2 matrix
in "scale". E.g. typically scale would equal
[ 1 0;
0 acwt ]
if we want to scale the acoustics by "acwt".
*/
template <class Weight, class ScaleFloat>
void ScaleLattice(const std::vector<std::vector<ScaleFloat> > &scale,
MutableFst<ArcTpl<Weight> > *fst);
/// Removes state-level alignments (the strings that are
/// part of the weights).
template <class Weight, class Int>
void RemoveAlignmentsFromCompactLattice(
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *fst);
/// Returns true if lattice has alignments, i.e. it has
/// any nonempty strings inside its weights.
template <class Weight, class Int>
bool CompactLatticeHasAlignment(
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &fst);
/// Class StdToLatticeMapper maps a normal arc (StdArc)
/// to a LatticeArc by putting the StdArc weight as the first
/// element of the LatticeWeight. Useful when doing LM
/// rescoring.
template <class Real>
class StdToLatticeMapper {
typedef LatticeWeightTpl<Real> LatticeWeight;
typedef ArcTpl<LatticeWeight> LatticeArc;
public:
LatticeArc operator()(const StdArc &arc) {
// Note: we have to check whether the arc's weight is zero below,
// and if so return (infinity, infinity) and not (infinity, zero),
// because (infinity, zero) is not a valid LatticeWeight, which should
// either be both finite, or both infinite (i.e. Zero()).
return LatticeArc(
arc.ilabel, arc.olabel,
LatticeWeight(arc.weight.Value(), arc.weight == StdArc::Weight::Zero()
? arc.weight.Value()
: 0.0),
arc.nextstate);
}
MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; }
MapSymbolsAction InputSymbolsAction() { return MAP_COPY_SYMBOLS; }
MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; }
// I believe all properties are preserved.
uint64 Properties(uint64 props) { return props; }
};
/// Class LatticeToStdMapper maps a LatticeArc to a normal arc (StdArc)
/// by adding the elements of the LatticeArc weight.
template <class Real>
class LatticeToStdMapper {
typedef LatticeWeightTpl<Real> LatticeWeight;
typedef ArcTpl<LatticeWeight> LatticeArc;
public:
StdArc operator()(const LatticeArc &arc) {
return StdArc(arc.ilabel, arc.olabel,
StdArc::Weight(arc.weight.Value1() + arc.weight.Value2()),
arc.nextstate);
}
MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; }
MapSymbolsAction InputSymbolsAction() { return MAP_COPY_SYMBOLS; }
MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; }
// I believe all properties are preserved.
uint64 Properties(uint64 props) { return props; }
};
template <class Weight, class Int>
void PruneCompactLattice(
Weight beam,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *fst);
} // end namespace fst
#include "fstext/lattice-utils-inl.h"
#endif // KALDI_FSTEXT_LATTICE_UTILS_H_
// fstext/lattice-weight.h
// Copyright 2009-2012 Microsoft Corporation
// Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_LATTICE_WEIGHT_H_
#define KALDI_FSTEXT_LATTICE_WEIGHT_H_
#include <algorithm>
#include <limits>
#include <string>
#include <vector>
#include "base/kaldi-common.h"
#include "fst/fstlib.h"
namespace fst {
// Declare weight type for lattice... will import to namespace kaldi. has two
// members, value1_ and value2_, of type BaseFloat (normally equals float). It
// is basically the same as the tropical semiring on value1_+value2_, except it
// keeps track of a and b separately. More precisely, it is equivalent to the
// lexicographic semiring on (value1_+value2_), (value1_-value2_)
template <class FloatType>
class LatticeWeightTpl;
template <class FloatType>
inline std::ostream &operator<<(std::ostream &strm,
const LatticeWeightTpl<FloatType> &w);
template <class FloatType>
inline std::istream &operator>>(std::istream &strm,
LatticeWeightTpl<FloatType> &w);
template <class FloatType>
class LatticeWeightTpl {
public:
typedef FloatType T; // normally float.
typedef LatticeWeightTpl ReverseWeight;
inline T Value1() const { return value1_; }
inline T Value2() const { return value2_; }
inline void SetValue1(T f) { value1_ = f; }
inline void SetValue2(T f) { value2_ = f; }
LatticeWeightTpl() : value1_{}, value2_{} {}
LatticeWeightTpl(T a, T b) : value1_(a), value2_(b) {}
LatticeWeightTpl(const LatticeWeightTpl &other)
: value1_(other.value1_), value2_(other.value2_) {}
LatticeWeightTpl &operator=(const LatticeWeightTpl &w) {
value1_ = w.value1_;
value2_ = w.value2_;
return *this;
}
LatticeWeightTpl<FloatType> Reverse() const { return *this; }
static const LatticeWeightTpl Zero() {
return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
std::numeric_limits<T>::infinity());
}
static const LatticeWeightTpl One() { return LatticeWeightTpl(0.0, 0.0); }
static const std::string &Type() {
static const std::string type = (sizeof(T) == 4 ? "lattice4" : "lattice8");
return type;
}
static const LatticeWeightTpl NoWeight() {
return LatticeWeightTpl(std::numeric_limits<FloatType>::quiet_NaN(),
std::numeric_limits<FloatType>::quiet_NaN());
}
bool Member() const {
// value1_ == value1_ tests for NaN.
// also test for no -inf, and either both or neither
// must be +inf, and
if (value1_ != value1_ || value2_ != value2_) return false; // NaN
if (value1_ == -std::numeric_limits<T>::infinity() ||
value2_ == -std::numeric_limits<T>::infinity())
return false; // -infty not allowed
if (value1_ == std::numeric_limits<T>::infinity() ||
value2_ == std::numeric_limits<T>::infinity()) {
if (value1_ != std::numeric_limits<T>::infinity() ||
value2_ != std::numeric_limits<T>::infinity())
return false; // both must be +infty;
// this is necessary so that the semiring has only one zero.
}
return true;
}
LatticeWeightTpl Quantize(float delta = kDelta) const {
if (value1_ + value2_ == -std::numeric_limits<T>::infinity()) {
return LatticeWeightTpl(-std::numeric_limits<T>::infinity(),
-std::numeric_limits<T>::infinity());
} else if (value1_ + value2_ == std::numeric_limits<T>::infinity()) {
return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
std::numeric_limits<T>::infinity());
} else if (value1_ + value2_ != value1_ + value2_) { // NaN
return LatticeWeightTpl(value1_ + value2_, value1_ + value2_);
} else {
return LatticeWeightTpl(floor(value1_ / delta + 0.5F) * delta,
floor(value2_ / delta + 0.5F) * delta);
}
}
static constexpr uint64 Properties() {
return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
}
// This is used in OpenFst for binary I/O. This is OpenFst-style,
// not Kaldi-style, I/O.
std::istream &Read(std::istream &strm) {
// Always read/write as float, even if T is double,
// so we can use OpenFst-style read/write and still maintain
// compatibility when compiling with different FloatTypes
ReadType(strm, &value1_);
ReadType(strm, &value2_);
return strm;
}
// This is used in OpenFst for binary I/O. This is OpenFst-style,
// not Kaldi-style, I/O.
std::ostream &Write(std::ostream &strm) const {
WriteType(strm, value1_);
WriteType(strm, value2_);
return strm;
}
size_t Hash() const {
size_t ans;
union {
T f;
size_t s;
} u;
u.s = 0;
u.f = value1_;
ans = u.s;
u.f = value2_;
ans += u.s;
return ans;
}
protected:
inline static void WriteFloatType(std::ostream &strm, const T &f) {
if (f == std::numeric_limits<T>::infinity())
strm << "Infinity";
else if (f == -std::numeric_limits<T>::infinity())
strm << "-Infinity";
else if (f != f)
strm << "BadNumber";
else
strm << f;
}
// Internal helper function, used in ReadNoParen.
inline static void ReadFloatType(std::istream &strm, T &f) { // NOLINT
std::string s;
strm >> s;
if (s == "Infinity") {
f = std::numeric_limits<T>::infinity();
} else if (s == "-Infinity") {
f = -std::numeric_limits<T>::infinity();
} else if (s == "BadNumber") {
f = std::numeric_limits<T>::quiet_NaN();
} else {
char *p;
f = strtod(s.c_str(), &p);
if (p < s.c_str() + s.size()) strm.clear(std::ios::badbit);
}
}
// Reads LatticeWeight when there are no parentheses around pair terms...
// currently the only form supported.
inline std::istream &ReadNoParen(std::istream &strm, char separator) {
int c;
do {
c = strm.get();
} while (isspace(c));
std::string s1;
while (c != separator) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s1 += c;
c = strm.get();
}
std::istringstream strm1(s1);
ReadFloatType(strm1, value1_); // ReadFloatType is class member function
// read second element
ReadFloatType(strm, value2_);
return strm;
}
friend std::istream &operator>>
<FloatType>(std::istream &, LatticeWeightTpl<FloatType> &);
friend std::ostream &operator<<<FloatType>(
std::ostream &, const LatticeWeightTpl<FloatType> &);
private:
T value1_;
T value2_;
};
/* ScaleTupleWeight is a function defined for LatticeWeightTpl and
CompactLatticeWeightTpl that mutliplies the pair (value1_, value2_) by a 2x2
matrix. Used, for example, in applying acoustic scaling.
*/
template <class FloatType, class ScaleFloatType>
inline LatticeWeightTpl<FloatType> ScaleTupleWeight(
const LatticeWeightTpl<FloatType> &w,
const std::vector<std::vector<ScaleFloatType> > &scale) {
// Without the next special case we'd get NaNs from infinity * 0
if (w.Value1() == std::numeric_limits<FloatType>::infinity())
return LatticeWeightTpl<FloatType>::Zero();
return LatticeWeightTpl<FloatType>(
scale[0][0] * w.Value1() + scale[0][1] * w.Value2(),
scale[1][0] * w.Value1() + scale[1][1] * w.Value2());
}
/* For testing purposes and in case it's ever useful, we define a similar
function to apply to LexicographicWeight and the like, templated on
TropicalWeight<float> etc.; we use PairWeight which is the base class of
LexicographicWeight.
*/
template <class FloatType, class ScaleFloatType>
inline PairWeight<TropicalWeightTpl<FloatType>, TropicalWeightTpl<FloatType> >
ScaleTupleWeight(const PairWeight<TropicalWeightTpl<FloatType>,
TropicalWeightTpl<FloatType> > &w,
const std::vector<std::vector<ScaleFloatType> > &scale) {
typedef TropicalWeightTpl<FloatType> BaseType;
typedef PairWeight<BaseType, BaseType> PairType;
const BaseType zero = BaseType::Zero();
// Without the next special case we'd get NaNs from infinity * 0
if (w.Value1() == zero || w.Value2() == zero) return PairType(zero, zero);
FloatType f1 = w.Value1().Value(), f2 = w.Value2().Value();
return PairType(BaseType(scale[0][0] * f1 + scale[0][1] * f2),
BaseType(scale[1][0] * f1 + scale[1][1] * f2));
}
template <class FloatType>
inline bool operator==(const LatticeWeightTpl<FloatType> &wa,
const LatticeWeightTpl<FloatType> &wb) {
// Volatile qualifier thwarts over-aggressive compiler optimizations
// that lead to problems esp. with NaturalLess().
volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
vb2 = wb.Value2();
return (va1 == vb1 && va2 == vb2);
}
template <class FloatType>
inline bool operator!=(const LatticeWeightTpl<FloatType> &wa,
const LatticeWeightTpl<FloatType> &wb) {
// Volatile qualifier thwarts over-aggressive compiler optimizations
// that lead to problems esp. with NaturalLess().
volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
vb2 = wb.Value2();
return (va1 != vb1 || va2 != vb2);
}
// We define a Compare function LatticeWeightTpl even though it's
// not required by the semiring standard-- it's just more efficient
// to do it this way rather than using the NaturalLess template.
/// Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
template <class FloatType>
inline int Compare(const LatticeWeightTpl<FloatType> &w1,
const LatticeWeightTpl<FloatType> &w2) {
FloatType f1 = w1.Value1() + w1.Value2(), f2 = w2.Value1() + w2.Value2();
if (f1 < f2) { // having smaller cost means you're larger
return 1;
} else if (f1 > f2) { // in the semiring [higher probability]
return -1;
} else if (w1.Value1() < w2.Value1()) {
// mathematically we should be comparing (w1.value1_-w1.value2_ <
// w2.value1_-w2.value2_) in the next line, but add w1.value1_+w1.value2_ =
// w2.value1_+w2.value2_ to both sides and divide by two, and we get the
// simpler equivalent form w1.value1_ < w2.value1_.
return 1;
} else if (w1.Value1() > w2.Value1()) {
return -1;
} else {
return 0;
}
}
template <class FloatType>
inline LatticeWeightTpl<FloatType> Plus(const LatticeWeightTpl<FloatType> &w1,
const LatticeWeightTpl<FloatType> &w2) {
return (Compare(w1, w2) >= 0 ? w1 : w2);
}
// For efficiency, override the NaturalLess template class.
template <class FloatType>
class NaturalLess<LatticeWeightTpl<FloatType> > {
public:
typedef LatticeWeightTpl<FloatType> Weight;
NaturalLess() {}
bool operator()(const Weight &w1, const Weight &w2) const {
// NaturalLess is a negative order (opposite to normal ordering).
// This operator () corresponds to "<" in the negative order, which
// corresponds to the ">" in the normal order.
return (Compare(w1, w2) == 1);
}
};
template <>
class NaturalLess<LatticeWeightTpl<float> > {
public:
typedef LatticeWeightTpl<float> Weight;
NaturalLess() {}
bool operator()(const Weight &w1, const Weight &w2) const {
// NaturalLess is a negative order (opposite to normal ordering).
// This operator () corresponds to "<" in the negative order, which
// corresponds to the ">" in the normal order.
return (Compare(w1, w2) == 1);
}
};
template <>
class NaturalLess<LatticeWeightTpl<double> > {
public:
typedef LatticeWeightTpl<double> Weight;
NaturalLess() {}
bool operator()(const Weight &w1, const Weight &w2) const {
// NaturalLess is a negative order (opposite to normal ordering).
// This operator () corresponds to "<" in the negative order, which
// corresponds to the ">" in the normal order.
return (Compare(w1, w2) == 1);
}
};
template <class FloatType>
inline LatticeWeightTpl<FloatType> Times(
const LatticeWeightTpl<FloatType> &w1,
const LatticeWeightTpl<FloatType> &w2) {
return LatticeWeightTpl<FloatType>(w1.Value1() + w2.Value1(),
w1.Value2() + w2.Value2());
}
// divide w1 by w2 (on left/right/any doesn't matter as
// commutative).
template <class FloatType>
inline LatticeWeightTpl<FloatType> Divide(const LatticeWeightTpl<FloatType> &w1,
const LatticeWeightTpl<FloatType> &w2,
DivideType typ = DIVIDE_ANY) {
typedef FloatType T;
T a = w1.Value1() - w2.Value1(), b = w1.Value2() - w2.Value2();
if (a != a || b != b || a == -std::numeric_limits<T>::infinity() ||
b == -std::numeric_limits<T>::infinity()) {
KALDI_WARN << "LatticeWeightTpl::Divide, NaN or invalid number produced. "
<< "[dividing by zero?] Returning zero";
return LatticeWeightTpl<T>::Zero();
}
if (a == std::numeric_limits<T>::infinity() ||
b == std::numeric_limits<T>::infinity())
return LatticeWeightTpl<T>::Zero(); // not a valid number if only one is
// infinite.
return LatticeWeightTpl<T>(a, b);
}
template <class FloatType>
inline bool ApproxEqual(const LatticeWeightTpl<FloatType> &w1,
const LatticeWeightTpl<FloatType> &w2,
float delta = kDelta) {
if (w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2())
return true; // handles Zero().
return (fabs((w1.Value1() + w1.Value2()) - (w2.Value1() + w2.Value2())) <=
delta);
}
template <class FloatType>
inline std::ostream &operator<<(std::ostream &strm,
const LatticeWeightTpl<FloatType> &w) {
LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value1());
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
strm << FLAGS_fst_weight_separator[0]; // comma by default;
// may or may not be settable from Kaldi programs.
LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value2());
return strm;
}
template <class FloatType>
inline std::istream &operator>>(std::istream &strm,
LatticeWeightTpl<FloatType> &w1) {
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
// separator defaults to ','
return w1.ReadNoParen(strm, FLAGS_fst_weight_separator[0]);
}
// CompactLattice will be an acceptor (accepting the words/output-symbols),
// with the weights and input-symbol-seqs on the arcs.
// There must be a total order on W. We assume for the sake of efficiency
// that there is a function
// Compare(W w1, W w2) that returns -1 if w1 < w2, +1 if w1 > w2, and
// zero if w1 == w2, and Plus for type W returns (Compare(w1,w2) >= 0 ? w1 :
// w2).
template <class WeightType, class IntType>
class CompactLatticeWeightTpl {
public:
typedef WeightType W;
typedef CompactLatticeWeightTpl<WeightType, IntType> ReverseWeight;
// Plus is like LexicographicWeight on the pair (weight_, string_), but where
// we use standard lexicographic order on string_ [this is not the same as
// NaturalLess on the StringWeight equivalent, which does not define a
// total order].
// Times, Divide obvious... (support both left & right division..)
// CommonDivisor would need to be coded separately.
CompactLatticeWeightTpl() {}
CompactLatticeWeightTpl(const WeightType &w, const std::vector<IntType> &s)
: weight_(w), string_(s) {}
CompactLatticeWeightTpl &operator=(
const CompactLatticeWeightTpl<WeightType, IntType> &w) {
weight_ = w.weight_;
string_ = w.string_;
return *this;
}
const W &Weight() const { return weight_; }
const std::vector<IntType> &String() const { return string_; }
void SetWeight(const W &w) { weight_ = w; }
void SetString(const std::vector<IntType> &s) { string_ = s; }
static const CompactLatticeWeightTpl<WeightType, IntType> Zero() {
return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::Zero(),
std::vector<IntType>());
}
static const CompactLatticeWeightTpl<WeightType, IntType> One() {
return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::One(),
std::vector<IntType>());
}
inline static std::string GetIntSizeString() {
char buf[2];
buf[0] = '0' + sizeof(IntType);
buf[1] = '\0';
return buf;
}
static const std::string &Type() {
static const std::string type =
"compact" + WeightType::Type() + GetIntSizeString();
return type;
}
static const CompactLatticeWeightTpl<WeightType, IntType> NoWeight() {
return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::NoWeight(),
std::vector<IntType>());
}
CompactLatticeWeightTpl<WeightType, IntType> Reverse() const {
size_t s = string_.size();
std::vector<IntType> v(s);
for (size_t i = 0; i < s; i++) v[i] = string_[s - i - 1];
return CompactLatticeWeightTpl<WeightType, IntType>(weight_, v);
}
bool Member() const {
// a semiring has only one zero, this is the important property
// we're trying to maintain here. So force string_ to be empty if
// w_ == zero.
if (!weight_.Member()) return false;
if (weight_ == WeightType::Zero())
return string_.empty();
else
return true;
}
CompactLatticeWeightTpl Quantize(float delta = kDelta) const {
return CompactLatticeWeightTpl(weight_.Quantize(delta), string_);
}
static constexpr uint64 Properties() {
return kLeftSemiring | kRightSemiring | kPath | kIdempotent;
}
// This is used in OpenFst for binary I/O. This is OpenFst-style,
// not Kaldi-style, I/O.
std::istream &Read(std::istream &strm) {
weight_.Read(strm);
if (strm.fail()) {
return strm;
}
int32 sz;
ReadType(strm, &sz);
if (strm.fail()) {
return strm;
}
if (sz < 0) {
KALDI_WARN << "Negative string size! Read failure";
strm.clear(std::ios::badbit);
return strm;
}
string_.resize(sz);
for (int32 i = 0; i < sz; i++) {
ReadType(strm, &(string_[i]));
}
return strm;
}
// This is used in OpenFst for binary I/O. This is OpenFst-style,
// not Kaldi-style, I/O.
std::ostream &Write(std::ostream &strm) const {
weight_.Write(strm);
if (strm.fail()) {
return strm;
}
int32 sz = static_cast<int32>(string_.size());
WriteType(strm, sz);
for (int32 i = 0; i < sz; i++) WriteType(strm, string_[i]);
return strm;
}
size_t Hash() const {
size_t ans = weight_.Hash();
// any weird numbers here are largish primes
size_t sz = string_.size(), mult = 6967;
for (size_t i = 0; i < sz; i++) {
ans += string_[i] * mult;
mult *= 7499;
}
return ans;
}
private:
W weight_;
std::vector<IntType> string_;
};
template <class WeightType, class IntType>
inline bool operator==(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
return (w1.Weight() == w2.Weight() && w1.String() == w2.String());
}
template <class WeightType, class IntType>
inline bool operator!=(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
return (w1.Weight() != w2.Weight() || w1.String() != w2.String());
}
template <class WeightType, class IntType>
inline bool ApproxEqual(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
const CompactLatticeWeightTpl<WeightType, IntType> &w2,
float delta = kDelta) {
return (ApproxEqual(w1.Weight(), w2.Weight(), delta) &&
w1.String() == w2.String());
}
// Compare is not part of the standard for weight types, but used internally for
// efficiency. The comparison here first compares the weight; if this is the
// same, it compares the string. The comparison on strings is: first compare
// the length, if this is the same, use lexicographical order. We can't just
// use the lexicographical order because this would destroy the distributive
// property of multiplication over addition, taking into account that addition
// uses Compare. The string element of "Compare" isn't super-important in
// practical terms; it's only needed to ensure that Plus always give consistent
// answers and is symmetric. It's essentially for tie-breaking, but we need to
// make sure all the semiring axioms are satisfied otherwise OpenFst might
// break.
template <class WeightType, class IntType>
inline int Compare(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
int c1 = Compare(w1.Weight(), w2.Weight());
if (c1 != 0) return c1;
int l1 = w1.String().size(), l2 = w2.String().size();
// Use opposite order on the string lengths, so that if the costs are the
// same, the shorter string wins.
if (l1 > l2)
return -1;
else if (l1 < l2)
return 1;
for (int i = 0; i < l1; i++) {
if (w1.String()[i] < w2.String()[i])
return -1;
else if (w1.String()[i] > w2.String()[i])
return 1;
}
return 0;
}
// For efficiency, override the NaturalLess template class.
template <class FloatType, class IntType>
class NaturalLess<
CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> > {
public:
typedef CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> Weight;
NaturalLess() {}
bool operator()(const Weight &w1, const Weight &w2) const {
// NaturalLess is a negative order (opposite to normal ordering).
// This operator () corresponds to "<" in the negative order, which
// corresponds to the ">" in the normal order.
return (Compare(w1, w2) == 1);
}
};
template <>
class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> > {
public:
typedef CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> Weight;
NaturalLess() {}
bool operator()(const Weight &w1, const Weight &w2) const {
// NaturalLess is a negative order (opposite to normal ordering).
// This operator () corresponds to "<" in the negative order, which
// corresponds to the ">" in the normal order.
return (Compare(w1, w2) == 1);
}
};
template <>
class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> > {
public:
typedef CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> Weight;
NaturalLess() {}
bool operator()(const Weight &w1, const Weight &w2) const {
// NaturalLess is a negative order (opposite to normal ordering).
// This operator () corresponds to "<" in the negative order, which
// corresponds to the ">" in the normal order.
return (Compare(w1, w2) == 1);
}
};
// Make sure Compare is defined for TropicalWeight, so everything works
// if we substitute LatticeWeight for TropicalWeight.
inline int Compare(const TropicalWeight &w1, const TropicalWeight &w2) {
float f1 = w1.Value(), f2 = w2.Value();
if (f1 == f2)
return 0;
else if (f1 > f2)
return -1;
else
return 1;
}
template <class WeightType, class IntType>
inline CompactLatticeWeightTpl<WeightType, IntType> Plus(
const CompactLatticeWeightTpl<WeightType, IntType> &w1,
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
return (Compare(w1, w2) >= 0 ? w1 : w2);
}
template <class WeightType, class IntType>
inline CompactLatticeWeightTpl<WeightType, IntType> Times(
const CompactLatticeWeightTpl<WeightType, IntType> &w1,
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
WeightType w = Times(w1.Weight(), w2.Weight());
if (w == WeightType::Zero()) {
return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
// special case to ensure zero is unique
} else {
std::vector<IntType> v;
v.resize(w1.String().size() + w2.String().size());
typename std::vector<IntType>::iterator iter = v.begin();
iter = std::copy(w1.String().begin(), w1.String().end(),
iter); // returns end of first range.
std::copy(w2.String().begin(), w2.String().end(), iter);
return CompactLatticeWeightTpl<WeightType, IntType>(w, v);
}
}
template <class WeightType, class IntType>
inline CompactLatticeWeightTpl<WeightType, IntType> Divide(
const CompactLatticeWeightTpl<WeightType, IntType> &w1,
const CompactLatticeWeightTpl<WeightType, IntType> &w2,
DivideType div = DIVIDE_ANY) {
if (w1.Weight() == WeightType::Zero()) {
if (w2.Weight() != WeightType::Zero()) {
return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
} else {
KALDI_ERR << "Division by zero [0/0]";
}
} else if (w2.Weight() == WeightType::Zero()) {
KALDI_ERR << "Error: division by zero";
}
WeightType w = Divide(w1.Weight(), w2.Weight());
const std::vector<IntType> v1 = w1.String(), v2 = w2.String();
if (v2.size() > v1.size()) {
KALDI_ERR << "Cannot divide, length mismatch";
}
typename std::vector<IntType>::const_iterator v1b = v1.begin(),
v1e = v1.end(),
v2b = v2.begin(),
v2e = v2.end();
if (div == DIVIDE_LEFT) {
if (!std::equal(v2b, v2e,
v1b)) { // v2 must be identical to first part of v1.
KALDI_ERR << "Cannot divide, data mismatch";
}
return CompactLatticeWeightTpl<WeightType, IntType>(
w, std::vector<IntType>(v1b + (v2e - v2b),
v1e)); // return last part of v1.
} else if (div == DIVIDE_RIGHT) {
if (!std::equal(
v2b, v2e,
v1e - (v2e - v2b))) { // v2 must be identical to last part of v1.
KALDI_ERR << "Cannot divide, data mismatch";
}
return CompactLatticeWeightTpl<WeightType, IntType>(
w, std::vector<IntType>(
v1b, v1e - (v2e - v2b))); // return first part of v1.
} else {
KALDI_ERR << "Cannot divide CompactLatticeWeightTpl with DIVIDE_ANY";
}
return CompactLatticeWeightTpl<WeightType,
IntType>::Zero(); // keep compiler happy.
}
template <class WeightType, class IntType>
inline std::ostream &operator<<(
std::ostream &strm, const CompactLatticeWeightTpl<WeightType, IntType> &w) {
strm << w.Weight();
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
strm << FLAGS_fst_weight_separator[0]; // comma by default.
for (size_t i = 0; i < w.String().size(); i++) {
strm << w.String()[i];
if (i + 1 < w.String().size())
strm << kStringSeparator; // '_'; defined in string-weight.h in OpenFst
// code.
}
return strm;
}
template <class WeightType, class IntType>
inline std::istream &operator>>(
std::istream &strm, CompactLatticeWeightTpl<WeightType, IntType> &w) {
std::string s;
strm >> s;
if (strm.fail()) {
return strm;
}
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
size_t pos = s.find_last_of(FLAGS_fst_weight_separator); // normally ","
if (pos == std::string::npos) {
strm.clear(std::ios::badbit);
return strm;
}
// get parts of str before and after the separator (default: ',');
std::string s1(s, 0, pos), s2(s, pos + 1);
std::istringstream strm1(s1);
WeightType weight;
strm1 >> weight;
w.SetWeight(weight);
if (strm1.fail() || !strm1.eof()) {
strm.clear(std::ios::badbit);
return strm;
}
// read string part.
std::vector<IntType> string;
const char *c = s2.c_str();
while (*c != '\0') {
if (*c == kStringSeparator) // '_'
c++;
char *c2;
int64_t i = strtol(c, &c2, 10);
if (c2 == c || static_cast<int64_t>(static_cast<IntType>(i)) != i) {
strm.clear(std::ios::badbit);
return strm;
}
c = c2;
string.push_back(static_cast<IntType>(i));
}
w.SetString(string);
return strm;
}
template <class BaseWeightType, class IntType>
class CompactLatticeWeightCommonDivisorTpl {
public:
typedef CompactLatticeWeightTpl<BaseWeightType, IntType> Weight;
Weight operator()(const Weight &w1, const Weight &w2) const {
// First find longest common prefix of the strings.
typename std::vector<IntType>::const_iterator s1b = w1.String().begin(),
s1e = w1.String().end(),
s2b = w2.String().begin(),
s2e = w2.String().end();
while (s1b < s1e && s2b < s2e && *s1b == *s2b) {
s1b++;
s2b++;
}
return Weight(Plus(w1.Weight(), w2.Weight()),
std::vector<IntType>(w1.String().begin(), s1b));
}
};
/** Scales the pair (a, b) of floating-point weights inside a
CompactLatticeWeight by premultiplying it (viewed as a vector)
by a 2x2 matrix "scale".
Assumes there is a ScaleTupleWeight function that applies to "Weight";
this currently only works if Weight equals LatticeWeightTpl<FloatType>
for some FloatType.
*/
template <class Weight, class IntType, class ScaleFloatType>
inline CompactLatticeWeightTpl<Weight, IntType> ScaleTupleWeight(
const CompactLatticeWeightTpl<Weight, IntType> &w,
const std::vector<std::vector<ScaleFloatType> > &scale) {
return CompactLatticeWeightTpl<Weight, IntType>(
Weight(ScaleTupleWeight(w.Weight(), scale)), w.String());
}
/** Define some ConvertLatticeWeight functions that are used in various lattice
conversions... make them all templates, some with no arguments, since some
must be templates.*/
template <class Float1, class Float2>
inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1> &w_in,
LatticeWeightTpl<Float2> *w_out) {
w_out->SetValue1(w_in.Value1());
w_out->SetValue2(w_in.Value2());
}
template <class Float1, class Float2, class Int>
inline void ConvertLatticeWeight(
const CompactLatticeWeightTpl<LatticeWeightTpl<Float1>, Int> &w_in,
CompactLatticeWeightTpl<LatticeWeightTpl<Float2>, Int> *w_out) {
LatticeWeightTpl<Float2> weight2(w_in.Weight().Value1(),
w_in.Weight().Value2());
w_out->SetWeight(weight2);
w_out->SetString(w_in.String());
}
// to convert from Lattice to standard FST
template <class Float1, class Float2>
inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1> &w_in,
TropicalWeightTpl<Float2> *w_out) {
TropicalWeightTpl<Float2> w1(w_in.Value1());
TropicalWeightTpl<Float2> w2(w_in.Value2());
*w_out = Times(w1, w2);
}
template <class Float>
inline double ConvertToCost(const LatticeWeightTpl<Float> &w) {
return static_cast<double>(w.Value1()) + static_cast<double>(w.Value2());
}
template <class Float, class Int>
inline double ConvertToCost(
const CompactLatticeWeightTpl<LatticeWeightTpl<Float>, Int> &w) {
return static_cast<double>(w.Weight().Value1()) +
static_cast<double>(w.Weight().Value2());
}
template <class Float>
inline double ConvertToCost(const TropicalWeightTpl<Float> &w) {
return w.Value();
}
} // namespace fst
#endif // KALDI_FSTEXT_LATTICE_WEIGHT_H_
// fstext/pre-determinize-inl.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_PRE_DETERMINIZE_INL_H_
#define KALDI_FSTEXT_PRE_DETERMINIZE_INL_H_
#include <algorithm>
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
/* Do not include this file directly. It is an implementation file included by
* PreDeterminize.h */
/*
Predeterminization
This is a function that makes an FST compactly determinizable by inserting
symbols on the input side as necessary for disambiguation. Note that we do
not treat epsilon as a real symbol when measuring determinizability in this
sense. The extra symbols are added to the vocabulary, on the input side;
these are of the form (prefix)1, (prefix)2, and so on without limit, where
(prefix) is some prefix the user provides, e.g. '#' (the function checks that
this will not lead to conflicts with symbols already in the FST). The
function tells us how many such symbols it created.
Note that there is a paper "Generalized optimization algorithm for speech
recognition transducers" by Allauzen and Mohri, that deals with a similar
issue, but this is a very different algorithm that only aims to ensure
determinizability, but not *compact* determinizability.
Our algorithm is slightly heuristic, and probably not optimal, but does
ensure that the output is compactly determinizable, possibly at the expense of
inserting unnecessary symbols. We considered more sophisticated algorithms,
but these were extremely complicated and would give the same output for the
kinds of inputs that we envisage.
Suppose the input FST is T. We want to ensure that in det(T), if we consider
the states of det(T) as weighted subsets of states of T, each state of T only
appears once in any given subset. This ensures that det(T) is no larger than
T in an appropriate sense. The way we do this is as follows. We identify all
states in T that have multiple input transitions (counting "being an initial
state" as an input transition). Let's call these "problematic" states. For a
problematic state p we stipulate that it can never appear in any state of
det(T) unless that state equals (p, \bar{1}) [i.e. p, unweighted]. In order
to ensure this, we insert input symbols on the transitions to these
problematic states (this may necessitate adding extra states).
We also stipulate that the path through det(T) should always be sufficient
to tell us the path through T (and we insert extra symbols sufficient to make
this so). This is to simplify the algorithm, so that we don't have to
consider the output symbols or weights when predeterminizing.
The algorithm is as follows.
(A) Definitions
(i) Define a *problematic state* as a state that either has multiple
input transitions, or is an initial state and has at least one input
transition.
(ii) For an arc a, define:
i[a] = input symbol on a
o[a] = output symbol on a
n[a] = dest-state of a
p[a] = origin-state of a
For a state q, define
E[q] = set of transitions leaving q.
For a set of states Q, define
E[Q] = set of transitions leaving some q in Q
(iii) For a state s, define Closure(s) as the union of state s, and all
states t that are reachable via sequences of arcs a such that i[a]=epsilon and
n[a] is not problematic.
For a set of states S, define Closure(S) as the union of the closures
of states s in S.
(B) Inputs and outputs.
(i) Inputs and preconditions. Input is an FST, which should have a symbol
table compiled into it, and a prefix (e.g. #) for symbols to be added. We
check that the input FST is trim, and that it does not have any symbols that
appear on its arcs, that are equal to the prefix followed by digits.
(ii) Outputs: The algorithm modifies the FST that is given to it, and
returns the number of the highest numbered "extra symbol" inserted. The extra
symbols are numbered #1, #2 and so on without limit (as integers). They are
inserted into the symbol table in a sequential way by calling AvailableKey()
for each in turn (this is stipulated in case we need to keep other
symbol tables in sync).
(C) Sub-algorithm: Closure(S). This requires the array p(s), defined
below, which is true if s is problematic. This also requires, for efficiency,
that the arcs be sorted on input label. Input: a set of states S. [plus, the
fst and the array p]. Output: a set of states T. Algorithm: set T <-- S, Q <--
S. while Q is nonempty: pop a state s from Q. for each transition a from state
s with epsilon on the input label [we can find these efficiently using the
sorting on arcs]: If p(n[a]) is false and n[a] is not in T: Insert n[a] into
T. Add n[a] to Q. return T.
(D) Main algorithm.
(i) (a) Check preconditions (FST is trim)
(b) Make sure there is just one final state (insert epsilon
transitions as necessary). (c) Sort arcs on input label (so epsilon arcs are
at the start of arc lists).
(ii) Work out the set of problematic states by constructing a boolean
array indexed by states, i.e. p(s) which is true if the state is problematic.
We can do this by constructing an array t(s) to store the number of
transitions into each state [adding one for the initial state], and then
setting p(s) = true if t(s) > 1.
Also create a boolean array d(s), defined for states, and set d(s) =
false. This array is purely for sanity-checking that we are processing each
state exactly once.
(iii) Set up an array of integers m(a), indexed by arcs (how exactly we
store these is implementation-dependent, but this will probably be a hash from
(state, arc-index) to integers. m(a) will store the extra symbol, if any, to
be added to that arc (or -1 if no such symbol; we can also simply have the arc
not present in the hash). The initial value of m(a) is -1 (if array), or
undefined (if hash).
(iv) Initialize a set of sets-of-states S, and a queue of pairs Q, as
follows. The pairs in Q are a pair of (set-of-states, integer), where the
integer is the number of "special symbols" already used up for that state.
Note that we use a special indexing for the sets in both S and Q,
rather than using std::set. We use a sorted vector of StateId's. And in S,
we index them by the lowest-numbered state-id. Because each state is supposed
to only ever be a member of one set, if there is an attempt to add another,
different set with the same lowest-numbered state-id, we detect an error.
Let I be the single initial state (OpenFST only supports one).
We set:
S = { Closure(I) }
Push (Closure(I), 0) onto Q.
Then for each state s such that p(s) = true, and s is not an initial
state: S <-- S u { Closure(s) } Push (Closure(s), 0) onto Q.
(v) While Q is nonempty:
(a) Pop pair (A, n) from Q (queue discipline is arbitrary).
(b) For each state s in A, check that d(s) is false, and set d(s) to
true. This is for sanity checking only.
(c)
Let S_\eps be the set of epsilon-transitions from members of A to
problematic states (i.e. S_\eps = \{ a \in E[A]: i[a]=\epsilon, p(n[a]) = true
\}).
Next, we will define, for each t \neq \epsilon, S_t as the set of
transitions from some state s in S with t as the input label,
i.e.: S_t = \{ a \in E[A]: i[a] = t \} We further define T_t and U_t as the
subsets of S where the destination state is problematic and non-problematic
respectively, i.e: T_t = \{ a \in E[A]: i[a] = t, p(n[a]) = true \} U_t = \{ a
\in E[A]: i[a] = t, p(n[a]) = false \}
The easiest way to obtain these sets is probably to have a hash
indexed by t that maps to a list of pairs (state, arc-offset) that stores S_t.
From this we can work out the sizes of T_t and U_t on the fly.
(d)
for each transition a in S_\eps:
m(a) <-- n # Will put symbol n on this transition.
n <-- n+1 # Note, same n as in pair (A, n)
(e)
next,
for each t\neq epsilon s.t. S_t is nonempty,
if |S_t| > 1 #if-statement is because if |S_t|=|T_t|=1, no need
for prefix. k = 0 for each transition a in T_t: set m(a) to k. set k = k+1
if |U_t| > 0
Let V_t be the set of destination-states of arcs in U_t.
if Closure(V_t) is not in S:
insert Closure(V_t) into S, and add the pair (Closure(V_t),
k) to Q.
(vi) Check that for each state in the FST, d(s) = true.
(vii) Let n = max_a m(a). This is the highest-numbered extra symbol
(extra symbols start from zero, in this numbering which doesn't correspond to
the symbol-table numbering). Here we add n+1 extra symbols to the symbol
table and store the mappings from 0, 1, ... n to the symbol-id.
(viii) Set up a hash h from (state, int) to (state-id) such that
t = h(s, k)
will be the state-id of a newly-created state that has a transition
to state s with input-label #k.
(ix) For each arc a such that m(a) != 0:
If i[a] = epsilon (the input label is epsilon):
Change i[a] to #m(a). [i.e. prefix then digit m(a)]
Otherwise:
If t = h(n[a], m(a)) is not defined [where n[a] is the
dest-state]: create a new state t with a transition to n[a], with input-label
#m(a) and no output-label or weight. Set h(n[a], m(a)) = t. Change n[a] to
h(n[a], m(a)).
*/
namespace fst {
namespace pre_determinize_helpers {
// make it inline to avoid having to put it in a .cc file which most functions
// here could not go in.
inline bool HasBannedPrefixPlusDigits(SymbolTable *symTable, std::string prefix,
std::string *bad_sym) {
// returns true if the symbol table contains any string consisting of this
// (possibly empty) prefix followed by a nonempty sequence of digits (0 to 9).
// requires symTable to be non-NULL.
// if bad_sym != NULL, puts the first bad symbol it finds in *bad_sym.
assert(symTable != NULL);
const char *prefix_ptr = prefix.c_str();
size_t prefix_len =
strlen(prefix_ptr); // allowed to be zero but not encouraged.
for (SymbolTableIterator siter(*symTable); !siter.Done(); siter.Next()) {
const std::string &sym = siter.Symbol();
if (!strncmp(prefix_ptr, sym.c_str(), prefix_len)) { // has prefix.
if (isdigit(sym[prefix_len])) { // we don't allow prefix followed by a
// digit, as a symbol.
// Has at least one digit.
size_t pos;
for (pos = prefix_len; sym[pos] != '\0'; pos++)
if (!isdigit(sym[pos])) break;
if (sym[pos] == '\0') { // All remaining characters were digits.
if (bad_sym != NULL) *bad_sym = sym;
return true;
}
} // else OK because prefix was followed by '\0' or a non-digit.
}
}
return false; // doesn't have banned symbol.
}
template <class T>
void CopySetToVector(const std::set<T> s, std::vector<T> *v) {
// adds members of s to v, in sorted order from lowest to highest
// (because the set was in sorted order).
assert(v != NULL);
v->resize(s.size());
typename std::set<T>::const_iterator siter = s.begin();
typename std::vector<T>::iterator viter = v->begin();
for (; siter != s.end(); ++siter, ++viter) {
assert(viter != v->end());
*viter = *siter;
}
}
// Warning. This function calls 'new'.
template <class T>
std::vector<T> *InsertMember(const std::vector<T> m,
std::vector<std::vector<T> *> *S) {
assert(m.size() > 0);
T idx = m[0];
assert(idx >= (T)0 && idx < (T)S->size());
if ((*S)[idx] != NULL) {
assert(*((*S)[idx]) == m);
// The vectors should be the same. Otherwise this is a bug in the
// algorithm. It could either be a programming error or a deeper conceptual
// bug.
return NULL; // nothing was inserted.
} else {
std::vector<T> *ret = (*S)[idx] = new std::vector<T>(m); // New copy of m.
return ret; // was inserted.
}
}
// See definition of Closure(S) in item A(iii) in the comment above. it's the
// set of states that are reachable from S via sequences of arcs a such that
// i[a]=epsilon and n[a] is not problematic. We assume that the fst is sorted
// on input label (so epsilon arcs first) The algorithm is described in section
// (C) above. We use the same variable for S and T.
template <class Arc>
void Closure(MutableFst<Arc> *fst, std::set<typename Arc::StateId> *S,
const std::vector<bool> &pVec) {
typedef typename Arc::StateId StateId;
std::vector<StateId> Q;
CopySetToVector(*S, &Q);
while (Q.size() != 0) {
StateId s = Q.back();
Q.pop_back();
for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0)
break; // Break from the loop: due to sorting there will be no
// more transitions with epsilons as input labels.
if (!pVec[arc.nextstate]) { // Next state is not problematic -> we can
// use this transition.
std::pair<typename std::set<StateId>::iterator, bool> p =
S->insert(arc.nextstate);
if (p.second) { // True means: was inserted into S (wasn't already
// there).
Q.push_back(arc.nextstate);
}
}
}
}
} // end function Closure.
} // end namespace pre_determinize_helpers.
template <class Arc, class Int>
void PreDeterminize(MutableFst<Arc> *fst, typename Arc::Label first_new_sym,
std::vector<Int> *symsOut) {
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef size_t ArcId; // Our own typedef, not standard OpenFst. Use size_t
// for compatibility with argument of ArcIterator::Seek().
typedef typename Arc::Weight Weight;
assert(first_new_sym > 0);
assert(fst != NULL);
if (fst->Start() == kNoStateId) return; // for empty FST, nothing to do.
assert(symsOut != NULL &&
symsOut->size() == 0); // we will output the symbols we add into this.
{ // (D)(i)(a): check is trim (i.e. connected, in OpenFST parlance).
KALDI_VLOG(2) << "PreDeterminize: Checking FST properties";
uint64 props = fst->Properties(
kAccessible | kCoAccessible,
true); // true-> computes properties if unknown at time when called.
if (props !=
(kAccessible | kCoAccessible)) { // All states are not both accessible
// and co-accessible...
KALDI_ERR << "PreDeterminize: FST is not trim";
}
}
{ // (D)(i)(b): make single final state.
KALDI_VLOG(2) << "PreDeterminize: creating single final state";
CreateSuperFinal(fst);
}
{ // (D)(i)(c): sort arcs on input.
KALDI_VLOG(2) << "PreDeterminize: sorting arcs on input";
ILabelCompare<Arc> icomp;
ArcSort(fst, icomp);
}
StateId n_states = 0,
max_state =
0; // Compute n_states, max_state = highest-numbered state.
{ // compute nStates, maxStates.
for (StateIterator<MutableFst<Arc> > iter(*fst); !iter.Done();
iter.Next()) {
StateId state = iter.Value();
assert(state >= 0);
n_states++;
if (state > max_state) max_state = state;
}
KALDI_VLOG(2) << "PreDeterminize: n_states = " << (n_states)
<< ", max_state =" << (max_state);
}
std::vector<bool> p_vec(max_state + 1, false); // compute this next.
{ // D(ii): computing the array p. ["problematic states, i.e. states with >1
// input transition,
// counting being the initial state as an input transition"].
std::vector<bool> seen_vec(
max_state + 1,
false); // rather than counting incoming transitions we just have a
// bool that says we saw at least one.
seen_vec[fst->Start()] = true;
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
siter.Next()) {
for (ArcIterator<MutableFst<Arc> > aiter(*fst, siter.Value());
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
assert(arc.nextstate >= 0 && arc.nextstate < max_state + 1);
if (seen_vec[arc.nextstate])
p_vec[arc.nextstate] =
true; // now have >1 transition in, so problematic.
else
seen_vec[arc.nextstate] = true;
}
}
}
// D(iii): set up m(a)
std::map<std::pair<StateId, ArcId>, size_t> m_map;
// This is the array m, indexed by arcs. It maps to the index of the symbol
// we add.
// WARNING: we should be sure to clean up this memory before exiting. Do not
// return or throw an exception from this function, later than this point,
// without cleaning up! Note that the vectors are shared between Q and S (they
// "belong to" S.
std::vector<std::vector<StateId> *> S(max_state + 1,
(std::vector<StateId> *)(void *)0);
std::vector<std::pair<std::vector<StateId> *, size_t> > Q;
// D(iv): initialize S and Q.
{
std::vector<StateId>
all_seed_states; // all "problematic" states, plus initial state (if
// not problematic).
if (!p_vec[fst->Start()]) all_seed_states.push_back(fst->Start());
for (StateId s = 0; s <= max_state; s++)
if (p_vec[s]) all_seed_states.push_back(s);
for (size_t idx = 0; idx < all_seed_states.size(); idx++) {
StateId s = all_seed_states[idx];
std::set<StateId> closure_s;
closure_s.insert(s); // insert "seed" state.
pre_determinize_helpers::Closure(
fst, &closure_s,
p_vec); // follow epsilons to non-problematic states.
// Closure in this case whis will usually not add anything, for typical
// topologies in speech
std::vector<StateId> closure_s_vec;
pre_determinize_helpers::CopySetToVector(closure_s, &closure_s_vec);
KALDI_ASSERT(closure_s_vec.size() != 0);
std::vector<StateId> *ptr =
pre_determinize_helpers::InsertMember(closure_s_vec, &S);
KALDI_ASSERT(ptr != NULL); // Or conceptual bug or programming error.
Q.push_back(std::pair<std::vector<StateId> *, size_t>(ptr, 0));
}
}
std::vector<bool> d_vec(max_state + 1,
false); // "done vector". Purely for debugging.
size_t num_extra_det_states = 0;
// (D)(v)
while (Q.size() != 0) {
// (D)(v)(a)
std::pair<std::vector<StateId> *, size_t> cur_pair(Q.back());
Q.pop_back();
const std::vector<StateId> &A(*cur_pair.first);
size_t n = cur_pair.second; // next special symbol to add.
// (D)(v)(b)
for (size_t idx = 0; idx < A.size(); idx++) {
assert(d_vec[A[idx]] == false &&
"This state has been seen before. Algorithm error.");
d_vec[A[idx]] = true;
}
// From here is (D)(v)(c). We work out S_\eps and S_t (for t\neq eps)
// simultaneously at first.
std::map<Label, std::set<std::pair<std::pair<StateId, ArcId>, StateId> > >
arc_hash;
// arc_hash is a hash with info of all arcs from states in the set A to
// non-problematic states.
// It is a map from ilabel to pair(pair(start-state, arc-offset),
// end-state). Here, arc-offset reflects the order in which we accessed the
// arc using the ArcIterator (zero for the first arc).
{ // This block sets up arc_hash
for (size_t idx = 0; idx < A.size(); idx++) {
StateId s = A[idx];
assert(s >= 0 && s <= max_state);
ArcId arc_id = 0;
for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); !aiter.Done();
aiter.Next(), ++arc_id) {
const Arc &arc = aiter.Value();
std::pair<std::pair<StateId, ArcId>, StateId> this_pair(
std::pair<StateId, ArcId>(s, arc_id), arc.nextstate);
bool inserted = (arc_hash[arc.ilabel].insert(this_pair)).second;
assert(inserted); // Otherwise we had a duplicate.
}
}
}
// (D)(v)(d)
if (arc_hash.count(0) == 1) { // We have epsilon transitions out.
std::set<std::pair<std::pair<StateId, ArcId>, StateId> > &eps_set =
arc_hash[0];
typedef typename std::set<
std::pair<std::pair<StateId, ArcId>, StateId> >::iterator set_iter_t;
for (set_iter_t siter = eps_set.begin(); siter != eps_set.end();
++siter) {
const std::pair<std::pair<StateId, ArcId>, StateId> &this_pr = *siter;
if (p_vec[this_pr.second]) { // Eps-transition to problematic state.
assert(m_map.count(this_pr.first) == 0);
m_map[this_pr.first] = n;
n++;
}
}
}
// (D)(v)(e)
{
typedef typename std::map<
Label,
std::set<std::pair<std::pair<StateId, ArcId>, StateId> > >::iterator
map_iter_t;
typedef typename std::set<
std::pair<std::pair<StateId, ArcId>, StateId> >::iterator set_iter_t2;
for (map_iter_t miter = arc_hash.begin(); miter != arc_hash.end();
++miter) {
Label t = miter->first;
std::set<std::pair<std::pair<StateId, ArcId>, StateId> > &S_t =
miter->second;
if (t != 0) { // For t != epsilon,
std::set<StateId> V_t; // set of destination non-problem states. Will
// create this set now.
// exists_noproblem is true iff |U_t| > 0.
size_t k = 0;
// First loop "for each transition a in T_t" (i.e. transitions to
// problematic states) The if-statement if (|S_t|>1) is pushed inside
// the loop, as the loop also computes the set V_t.
for (set_iter_t2 siter = S_t.begin(); siter != S_t.end(); ++siter) {
const std::pair<std::pair<StateId, ArcId>, StateId> &this_pr =
*siter;
if (p_vec[this_pr.second]) { // only consider problematic states
// (just set T_t)
if (S_t.size() >
1) { // This is where we pushed the if-statement in.
assert(m_map.count(this_pr.first) == 0);
m_map[this_pr.first] = k;
k++;
num_extra_det_states++;
}
} else { // Create the set V_t.
V_t.insert(this_pr.second);
}
}
if (V_t.size() != 0) {
pre_determinize_helpers::Closure(
fst, &V_t,
p_vec); // follow epsilons to non-problematic states.
std::vector<StateId> closure_V_t_vec;
pre_determinize_helpers::CopySetToVector(V_t, &closure_V_t_vec);
std::vector<StateId> *ptr =
pre_determinize_helpers::InsertMember(closure_V_t_vec, &S);
if (ptr != NULL) { // was inserted.
Q.push_back(std::pair<std::vector<StateId> *, size_t>(ptr, k));
}
}
}
}
}
} // end while (Q.size() != 0)
{ // (D)(vi): Check that for each state in the FST, d(s) = true.
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
siter.Next()) {
StateId val = siter.Value();
assert(d_vec[val] == true);
}
}
{ // (D)(vii): compute symbol-table ID's.
// sets up symsOut array.
int64 n = -1;
for (typename std::map<std::pair<StateId, ArcId>, size_t>::iterator m_iter =
m_map.begin();
m_iter != m_map.end(); ++m_iter) {
n = std::max(n,
static_cast<int64>(
m_iter->second)); // m_iter->second is of type size_t.
}
// At this point n is the highest symbol-id (type size_t) of symbols we must
// add.
n++; // This is now the number of symbols we must add.
for (size_t i = 0; static_cast<int64>(i) < n; i++)
symsOut->push_back(first_new_sym + i);
}
// (D)(viii): set up hash.
std::map<std::pair<StateId, size_t>, StateId> h_map;
{ // D(ix): add extra symbols! This is where the work gets done.
// Core part of this is below, search for (*)
size_t n_states_added = 0;
for (typename std::map<std::pair<StateId, ArcId>, size_t>::iterator m_iter =
m_map.begin();
m_iter != m_map.end(); ++m_iter) {
StateId state = m_iter->first.first;
ArcId arcpos = m_iter->first.second;
size_t m_a = m_iter->second;
MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
aiter.Seek(arcpos);
Arc arc = aiter.Value();
// (*) core part here.
if (arc.ilabel == 0) {
arc.ilabel = (*symsOut)[m_a];
} else {
std::pair<StateId, size_t> pr(arc.nextstate, m_a);
if (!h_map.count(pr)) {
n_states_added++;
StateId newstate = fst->AddState();
assert(newstate >= 0);
Arc new_arc((*symsOut)[m_a], (Label)0, Weight::One(), arc.nextstate);
fst->AddArc(newstate, new_arc);
h_map[pr] = newstate;
}
arc.nextstate = h_map[pr];
}
aiter.SetValue(arc);
}
KALDI_VLOG(2) << "Added " << (n_states_added)
<< " new states and added/changed " << (m_map.size())
<< " arcs";
}
// Now free up memory.
for (size_t i = 0; i < S.size(); i++) delete S[i];
} // end function PreDeterminize
template <class Label>
void CreateNewSymbols(SymbolTable *input_sym_table, int nSym,
std::string prefix, std::vector<Label> *symsOut) {
// Creates nSym new symbols named (prefix)0, (prefix)1 and so on.
// Crashes if it cannot create them because one or more of them were in the
// symbol table already.
assert(symsOut && symsOut->size() == 0);
for (int i = 0; i < nSym; i++) {
std::stringstream ss;
ss << prefix << i;
std::string str = ss.str();
if (input_sym_table->Find(str) != -1) { // should not be present.
}
assert(symsOut);
symsOut->push_back((Label)input_sym_table->AddSymbol(str));
}
}
// see pre-determinize.h for documentation.
template <class Arc>
void AddSelfLoops(MutableFst<Arc> *fst,
const std::vector<typename Arc::Label> &isyms,
const std::vector<typename Arc::Label> &osyms) {
assert(fst != NULL);
assert(isyms.size() == osyms.size());
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
size_t n = isyms.size();
if (n == 0) return; // Nothing to do.
// {
// the following declarations and statements are for quick detection of these
// symbols, which is purely for debugging/checking purposes.
Label isyms_min = *std::min_element(isyms.begin(), isyms.end()),
isyms_max = *std::max_element(isyms.begin(), isyms.end()),
osyms_min = *std::min_element(osyms.begin(), osyms.end()),
osyms_max = *std::max_element(osyms.begin(), osyms.end());
std::set<Label> isyms_set, osyms_set;
for (size_t i = 0; i < isyms.size(); i++) {
assert(isyms[i] > 0 &&
osyms[i] > 0); // should not have epsilon or invalid symbols.
isyms_set.insert(isyms[i]);
osyms_set.insert(osyms[i]);
}
assert(isyms_set.size() == n && osyms_set.size() == n);
// } end block.
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
siter.Next()) {
StateId state = siter.Value();
bool this_state_needs_self_loops = (fst->Final(state) != Weight::Zero());
for (ArcIterator<MutableFst<Arc> > aiter(*fst, state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
// If one of the following asserts fails, it means that the input FST
// already had the symbols we are inserting. This is contrary to the
// preconditions of this algorithm.
assert(!(arc.ilabel >= isyms_min && arc.ilabel <= isyms_max &&
isyms_set.count(arc.ilabel) != 0));
assert(!(arc.olabel >= osyms_min && arc.olabel <= osyms_max &&
osyms_set.count(arc.olabel) != 0));
if (arc.olabel != 0) // Has non-epsilon output label -> need self loops.
this_state_needs_self_loops = true;
}
if (this_state_needs_self_loops) {
for (size_t i = 0; i < n; i++) {
Arc arc;
arc.ilabel = isyms[i];
arc.olabel = osyms[i];
arc.weight = Weight::One();
arc.nextstate = state;
fst->AddArc(state, arc);
}
}
}
}
template <class Arc>
int64 DeleteISymbols(MutableFst<Arc> *fst,
std::vector<typename Arc::Label> isyms) {
// We could do this using the Mapper concept, but this is much easier to
// understand.
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
int64 num_deleted = 0;
if (isyms.size() == 0) return 0;
Label isyms_min = *std::min_element(isyms.begin(), isyms.end()),
isyms_max = *std::max_element(isyms.begin(), isyms.end());
bool isyms_consecutive =
(isyms_max + 1 - isyms_min == static_cast<Label>(isyms.size()));
std::set<Label> isyms_set;
if (!isyms_consecutive) {
for (size_t i = 0; i < isyms.size(); i++) isyms_set.insert(isyms[i]);
}
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
siter.Next()) {
StateId state = siter.Value();
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel >= isyms_min && arc.ilabel <= isyms_max) {
if (isyms_consecutive || isyms_set.count(arc.ilabel) != 0) {
num_deleted++;
Arc mod_arc(arc);
mod_arc.ilabel = 0; // change label to epsilon.
aiter.SetValue(mod_arc);
}
}
}
}
return num_deleted;
}
template <class Arc>
typename Arc::StateId CreateSuperFinal(MutableFst<Arc> *fst) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
assert(fst != NULL);
StateId num_states = fst->NumStates();
StateId num_final = 0;
std::vector<StateId> final_states;
for (StateId s = 0; s < num_states; s++) {
if (fst->Final(s) != Weight::Zero()) {
num_final++;
final_states.push_back(s);
}
}
if (final_states.size() == 1) {
if (fst->Final(final_states[0]) == Weight::One()) {
ArcIterator<MutableFst<Arc> > iter(*fst, final_states[0]);
if (iter.Done()) {
// We already have a final state w/ no transitions out and unit weight.
// So we're done.
return final_states[0];
}
}
}
StateId final_state = fst->AddState();
fst->SetFinal(final_state, Weight::One());
for (size_t idx = 0; idx < final_states.size(); idx++) {
StateId s = final_states[idx];
Weight weight = fst->Final(s);
fst->SetFinal(s, Weight::Zero());
Arc arc;
arc.ilabel = 0;
arc.olabel = 0;
arc.nextstate = final_state;
arc.weight = weight;
fst->AddArc(s, arc);
}
return final_state;
}
} // namespace fst
#endif // KALDI_FSTEXT_PRE_DETERMINIZE_INL_H_
// fstext/pre-determinize.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_PRE_DETERMINIZE_H_
#define KALDI_FSTEXT_PRE_DETERMINIZE_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
#include <algorithm>
#include <map>
#include <set>
#include <string>
#include <vector>
#include "base/kaldi-common.h"
namespace fst {
/* PreDeterminize inserts extra symbols on the input side of an FST as necessary
to ensure that, after epsilon removal, it will be compactly determinizable by
the determinize* algorithm. By compactly determinizable we mean that no
original FST state is represented in more than one determinized state).
Caution: this code is now only used in testing.
The new symbols start from the value "first_new_symbol", which should be
higher than the largest-numbered symbol currently in the FST. The new
symbols added are put in the array syms_out, which should be empty at start.
*/
template <class Arc, class Int>
void PreDeterminize(MutableFst<Arc> *fst, typename Arc::Label first_new_symbol,
std::vector<Int> *syms_out);
/* CreateNewSymbols is a helper function used inside PreDeterminize, and is also
useful when you need to add a number of extra symbols to a different
vocabulary from the one modified by PreDeterminize. */
template <class Label>
void CreateNewSymbols(SymbolTable *inputSymTable, int nSym, std::string prefix,
std::vector<Label> *syms_out);
/** AddSelfLoops is a function you will probably want to use alongside
PreDeterminize, to add self-loops to any FSTs that you compose on the left
hand side of the one modified by PreDeterminize.
This function inserts loops with "special symbols" [e.g. \#0, \#1] into an
FST. This is done at each final state and each state with non-epsilon output
symbols on at least one arc out of it. This is to ensure that these symbols,
when inserted into the input side of an FST we will compose with on the
right, can "pass through" this FST.
At input, isyms and osyms must be vectors of the same size n, corresponding
to symbols that currently do not exist in 'fst'. For each state in n that
has non-epsilon symbols on the output side of arcs leaving it, or which is a
final state, this function inserts n self-loops with unit weight and one of
the n pairs of symbols on its input and output.
*/
template <class Arc>
void AddSelfLoops(MutableFst<Arc> *fst,
const std::vector<typename Arc::Label> &isyms,
const std::vector<typename Arc::Label> &osyms);
/* DeleteSymbols replaces any instances of symbols in the vector symsIn,
appearing on the input side, with epsilon. */
/* It returns the number of instances of symbols deleted. */
template <class Arc>
int64 DeleteISymbols(MutableFst<Arc> *fst,
std::vector<typename Arc::Label> symsIn);
/* CreateSuperFinal takes an FST, and creates an equivalent FST with a single
final state with no transitions out and unit final weight, by inserting
epsilon transitions as necessary. */
template <class Arc>
typename Arc::StateId CreateSuperFinal(MutableFst<Arc> *fst);
} // end namespace fst
#include "fstext/pre-determinize-inl.h"
#endif // KALDI_FSTEXT_PRE_DETERMINIZE_H_
// fstext/remove-eps-local-inl.h
// Copyright 2009-2011 Microsoft Corporation
// 2014 Johns Hopkins University (author: Daniel Povey
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
#define KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
#include <vector>
namespace fst {
template <class Weight>
struct ReweightPlusDefault {
inline Weight operator()(const Weight &a, const Weight &b) {
return Plus(a, b);
}
};
struct ReweightPlusLogArc {
inline TropicalWeight operator()(const TropicalWeight &a,
const TropicalWeight &b) {
LogWeight a_log(a.Value()), b_log(b.Value());
return TropicalWeight(Plus(a_log, b_log).Value());
}
};
template <class Arc,
class ReweightPlus = ReweightPlusDefault<typename Arc::Weight> >
class RemoveEpsLocalClass {
typedef typename Arc::StateId StateId;
typedef typename Arc::Label Label;
typedef typename Arc::Weight Weight;
public:
explicit RemoveEpsLocalClass(MutableFst<Arc> *fst) : fst_(fst) {
if (fst_->Start() == kNoStateId) return; // empty.
non_coacc_state_ = fst_->AddState();
InitNumArcs();
StateId num_states = fst_->NumStates();
for (StateId s = 0; s < num_states; s++)
for (size_t pos = 0; pos < fst_->NumArcs(s); pos++) RemoveEps(s, pos);
assert(CheckNumArcs());
Connect(fst); // remove inaccessible states.
}
private:
MutableFst<Arc> *fst_;
StateId non_coacc_state_; // use this to delete arcs: make it nextstate
std::vector<StateId> num_arcs_in_; // The number of arcs into the state, plus
// one if it's the start state.
std::vector<StateId> num_arcs_out_; // The number of arcs out of the state,
// plus one if it's a final state.
ReweightPlus reweight_plus_;
bool CanCombineArcs(const Arc &a, const Arc &b, Arc *c) {
if (a.ilabel != 0 && b.ilabel != 0) return false;
if (a.olabel != 0 && b.olabel != 0) return false;
c->weight = Times(a.weight, b.weight);
c->ilabel = (a.ilabel != 0 ? a.ilabel : b.ilabel);
c->olabel = (a.olabel != 0 ? a.olabel : b.olabel);
c->nextstate = b.nextstate;
return true;
}
static bool CanCombineFinal(const Arc &a, Weight final_prob,
Weight *final_prob_out) {
if (a.ilabel != 0 || a.olabel != 0) {
return false;
} else {
*final_prob_out = Times(a.weight, final_prob);
return true;
}
}
void InitNumArcs() { // init num transitions in/out of each state.
StateId num_states = fst_->NumStates();
num_arcs_in_.resize(num_states);
num_arcs_out_.resize(num_states);
num_arcs_in_[fst_->Start()]++; // count start as trans in.
for (StateId s = 0; s < num_states; s++) {
if (fst_->Final(s) != Weight::Zero())
num_arcs_out_[s]++; // count final as transition.
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
aiter.Next()) {
num_arcs_in_[aiter.Value().nextstate]++;
num_arcs_out_[s]++;
}
}
}
bool CheckNumArcs() { // check num arcs in/out of each state, at end. Debug.
num_arcs_in_[fst_->Start()]--; // count start as trans in.
StateId num_states = fst_->NumStates();
for (StateId s = 0; s < num_states; s++) {
if (s == non_coacc_state_) continue;
if (fst_->Final(s) != Weight::Zero())
num_arcs_out_[s]--; // count final as transition.
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
aiter.Next()) {
if (aiter.Value().nextstate == non_coacc_state_) continue;
num_arcs_in_[aiter.Value().nextstate]--;
num_arcs_out_[s]--;
}
}
for (StateId s = 0; s < num_states; s++) {
assert(num_arcs_in_[s] == 0);
assert(num_arcs_out_[s] == 0);
}
return true; // always does this. so we can assert it w/o warnings.
}
inline void GetArc(StateId s, size_t pos, Arc *arc) const {
ArcIterator<MutableFst<Arc> > aiter(*fst_, s);
aiter.Seek(pos);
*arc = aiter.Value();
}
inline void SetArc(StateId s, size_t pos, const Arc &arc) {
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
aiter.Seek(pos);
aiter.SetValue(arc);
}
void Reweight(StateId s, size_t pos, Weight reweight) {
// Reweight is called from RemoveEpsPattern1; it is a step we
// do to preserve stochasticity. This function multiplies the
// arc at (s, pos) by reweight and divides all the arcs [+final-prob]
// out of the next state by the same. This is only valid if
// the next state has only one arc in and is not the start state.
assert(reweight != Weight::Zero());
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
aiter.Seek(pos);
Arc arc = aiter.Value();
assert(num_arcs_in_[arc.nextstate] == 1);
arc.weight = Times(arc.weight, reweight);
aiter.SetValue(arc);
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, arc.nextstate);
!aiter_next.Done(); aiter_next.Next()) {
Arc nextarc = aiter_next.Value();
if (nextarc.nextstate != non_coacc_state_) {
nextarc.weight = Divide(nextarc.weight, reweight, DIVIDE_LEFT);
aiter_next.SetValue(nextarc);
}
}
Weight final = fst_->Final(arc.nextstate);
if (final != Weight::Zero()) {
fst_->SetFinal(arc.nextstate, Divide(final, reweight, DIVIDE_LEFT));
}
}
// RemoveEpsPattern1 applies where this arc, which is not a
// self-loop, enters a state which has only one input transition
// [and is not the start state], and has multiple output
// transitions [counting being the final-state as a final-transition].
void RemoveEpsPattern1(StateId s, size_t pos, Arc arc) {
const StateId nextstate = arc.nextstate;
Weight total_removed = Weight::Zero(),
total_kept = Weight::Zero(); // totals out of nextstate.
std::vector<Arc> arcs_to_add; // to add to state s.
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
!aiter_next.Done(); aiter_next.Next()) {
Arc nextarc = aiter_next.Value();
if (nextarc.nextstate == non_coacc_state_) continue; // deleted.
Arc combined;
if (CanCombineArcs(arc, nextarc, &combined)) {
total_removed = reweight_plus_(total_removed, nextarc.weight);
num_arcs_out_[nextstate]--;
num_arcs_in_[nextarc.nextstate]--;
nextarc.nextstate = non_coacc_state_;
aiter_next.SetValue(nextarc);
arcs_to_add.push_back(combined);
} else {
total_kept = reweight_plus_(total_kept, nextarc.weight);
}
}
{ // now final-state.
Weight next_final = fst_->Final(nextstate);
if (next_final != Weight::Zero()) {
Weight new_final;
if (CanCombineFinal(arc, next_final, &new_final)) {
total_removed = reweight_plus_(total_removed, next_final);
if (fst_->Final(s) == Weight::Zero())
num_arcs_out_[s]++; // final is counted as arc.
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
num_arcs_out_[nextstate]--;
fst_->SetFinal(nextstate, Weight::Zero());
} else {
total_kept = reweight_plus_(total_kept, next_final);
}
}
}
if (total_removed != Weight::Zero()) { // did something...
if (total_kept == Weight::Zero()) { // removed everything: remove arc.
num_arcs_out_[s]--;
num_arcs_in_[arc.nextstate]--;
arc.nextstate = non_coacc_state_;
SetArc(s, pos, arc);
} else {
// Have to reweight.
Weight total = reweight_plus_(total_removed, total_kept);
Weight reweight = Divide(total_kept, total, DIVIDE_LEFT); // <=1
Reweight(s, pos, reweight);
}
}
// Now add the arcs we were going to add.
for (size_t i = 0; i < arcs_to_add.size(); i++) {
num_arcs_out_[s]++;
num_arcs_in_[arcs_to_add[i].nextstate]++;
fst_->AddArc(s, arcs_to_add[i]);
}
}
void RemoveEpsPattern2(StateId s, size_t pos, Arc arc) {
// Pattern 2 is where "nextstate" has only one arc out, counting
// being-the-final-state as an arc, but possibly multiple arcs in.
// Also, nextstate != s.
const StateId nextstate = arc.nextstate;
bool can_delete_next = (num_arcs_in_[nextstate] == 1); // if
// we combine, can delete the corresponding out-arc/final-prob
// of nextstate.
bool delete_arc = false; // set to true if this arc to be deleted.
Weight next_final = fst_->Final(arc.nextstate);
if (next_final !=
Weight::Zero()) { // nextstate has no actual arcs out, only final-prob.
Weight new_final;
if (CanCombineFinal(arc, next_final, &new_final)) {
if (fst_->Final(s) == Weight::Zero())
num_arcs_out_[s]++; // final is counted as arc.
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
delete_arc = true; // will delete "arc".
if (can_delete_next) {
num_arcs_out_[nextstate]--;
fst_->SetFinal(nextstate, Weight::Zero());
}
}
} else { // has an arc but no final prob.
MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
assert(!aiter_next.Done());
while (aiter_next.Value().nextstate == non_coacc_state_) {
aiter_next.Next();
assert(!aiter_next.Done());
}
// now aiter_next points to a real arc out of nextstate.
Arc nextarc = aiter_next.Value();
Arc combined;
if (CanCombineArcs(arc, nextarc, &combined)) {
delete_arc = true;
if (can_delete_next) { // do it before we invalidate iterators
num_arcs_out_[nextstate]--;
num_arcs_in_[nextarc.nextstate]--;
nextarc.nextstate = non_coacc_state_;
aiter_next.SetValue(nextarc);
}
num_arcs_out_[s]++;
num_arcs_in_[combined.nextstate]++;
fst_->AddArc(s, combined);
}
}
if (delete_arc) {
num_arcs_out_[s]--;
num_arcs_in_[nextstate]--;
arc.nextstate = non_coacc_state_;
SetArc(s, pos, arc);
}
}
void RemoveEps(StateId s, size_t pos) {
// Tries to do local epsilon-removal for arc sequences starting with this
// arc
Arc arc;
GetArc(s, pos, &arc);
StateId nextstate = arc.nextstate;
if (nextstate == non_coacc_state_) return; // deleted arc.
if (nextstate == s) return; // don't handle self-loops: too complex.
if (num_arcs_in_[nextstate] == 1 && num_arcs_out_[nextstate] > 1) {
RemoveEpsPattern1(s, pos, arc);
} else if (num_arcs_out_[nextstate] == 1) {
RemoveEpsPattern2(s, pos, arc);
}
}
};
template <class Arc>
void RemoveEpsLocal(MutableFst<Arc> *fst) {
RemoveEpsLocalClass<Arc> c(fst); // work gets done in initializer.
}
void RemoveEpsLocalSpecial(MutableFst<StdArc> *fst) {
// work gets done in initializer.
RemoveEpsLocalClass<StdArc, ReweightPlusLogArc> c(fst);
}
} // end namespace fst.
#endif // KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
// fstext/remove-eps-local.h
// Copyright 2009-2011 Microsoft Corporation
// 2014 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_H_
#define KALDI_FSTEXT_REMOVE_EPS_LOCAL_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
namespace fst {
/// RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST,
/// using an algorithm that is guaranteed to never increase the number of arcs
/// in the FST (and will also never increase the number of states). The
/// algorithm is not optimal but is reasonably clever. It does not just remove
/// epsilon arcs;it also combines pairs of input-epsilon and output-epsilon arcs
/// into one.
/// The algorithm preserves equivalence and stochasticity in the given semiring.
/// If you want to preserve stochasticity in a different semiring (e.g. log),
/// then use RemoveEpsLocalSpecial, which only works for StdArc but which
/// preserves stochasticity, where possible (*) in the LogArc sense. The reason
/// that we can't just cast to a different semiring is that in that case we
/// would no longer be able to guarantee equivalence in the original semiring
/// (this arises from what happens when we combine identical arcs).
/// (*) by "where possible".. there are situations where we wouldn't be able to
/// preserve stochasticity in the LogArc sense while maintaining equivalence in
/// the StdArc sense, so in these situations we maintain equivalence.
template <class Arc>
void RemoveEpsLocal(MutableFst<Arc> *fst);
/// As RemoveEpsLocal but takes care to preserve stochasticity
/// when cast to LogArc.
inline void RemoveEpsLocalSpecial(MutableFst<StdArc> *fst);
} // namespace fst
#include "fstext/remove-eps-local-inl.h"
#endif // KALDI_FSTEXT_REMOVE_EPS_LOCAL_H_
// fstext/table-matcher.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_TABLE_MATCHER_H_
#define KALDI_FSTEXT_TABLE_MATCHER_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
#include <memory>
#include <vector>
namespace fst {
/// TableMatcher is a matcher specialized for the case where the output
/// side of the left FST always has either all-epsilons coming out of
/// a state, or a majority of the symbol table. Therefore we can
/// either store nothing (for the all-epsilon case) or store a lookup
/// table from Labels to arc offsets. Since the TableMatcher has to
/// iterate over all arcs in each left-hand state the first time it sees
/// it, this matcher type is not efficient if you compose with
/// something very small on the right-- unless you do it multiple
/// times and keep the matcher around. To do this requires using the
/// most advanced form of ComposeFst in Compose.h, that initializes
/// with ComposeFstImplOptions.
struct TableMatcherOptions {
float
table_ratio; // we construct the table if it would be at least this full.
int min_table_size;
TableMatcherOptions() : table_ratio(0.25), min_table_size(4) {}
};
// Introducing an "impl" class for TableMatcher because
// we need to do a shallow copy of the Matcher for when
// we want to cache tables for multiple compositions.
template <class F, class BackoffMatcher = SortedMatcher<F> >
class TableMatcherImpl : public MatcherBase<typename F::Arc> {
public:
typedef F FST;
typedef typename F::Arc Arc;
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef StateId
ArcId; // Use this type to store arc offsets [it's actually size_t
// in the Seek function of ArcIterator, but StateId should be big enough].
typedef typename Arc::Weight Weight;
public:
TableMatcherImpl(const FST &fst, MatchType match_type,
const TableMatcherOptions &opts = TableMatcherOptions())
: match_type_(match_type),
fst_(fst.Copy()),
loop_(match_type == MATCH_INPUT
? Arc(kNoLabel, 0, Weight::One(), kNoStateId)
: Arc(0, kNoLabel, Weight::One(), kNoStateId)),
aiter_(NULL),
s_(kNoStateId),
opts_(opts),
backoff_matcher_(fst, match_type) {
assert(opts_.min_table_size > 0);
if (match_type == MATCH_INPUT)
assert(fst_->Properties(kILabelSorted, true) == kILabelSorted);
else if (match_type == MATCH_OUTPUT)
assert(fst_->Properties(kOLabelSorted, true) == kOLabelSorted);
else
assert(0 && "Invalid FST properties");
}
virtual const FST &GetFst() const { return *fst_; }
virtual ~TableMatcherImpl() {
std::vector<ArcId> *const empty =
((std::vector<ArcId> *)(NULL)) + 1; // special marker.
for (size_t i = 0; i < tables_.size(); i++) {
if (tables_[i] != NULL && tables_[i] != empty) delete tables_[i];
}
delete aiter_;
delete fst_;
}
virtual MatchType Type(bool test) const { return match_type_; }
void SetState(StateId s) {
if (aiter_) {
delete aiter_;
aiter_ = NULL;
}
if (match_type_ == MATCH_NONE) LOG(FATAL) << "TableMatcher: bad match type";
s_ = s;
std::vector<ArcId> *const empty =
((std::vector<ArcId> *)(NULL)) + 1; // special marker.
if (static_cast<size_t>(s) >= tables_.size()) {
assert(s >= 0);
tables_.resize(s + 1, NULL);
}
std::vector<ArcId> *&this_table_ = tables_[s]; // note: ref to ptr.
if (this_table_ == empty) {
backoff_matcher_.SetState(s);
return;
} else if (this_table_ == NULL) { // NULL means has not been set.
ArcId num_arcs = fst_->NumArcs(s);
if (num_arcs == 0 || num_arcs < opts_.min_table_size) {
this_table_ = empty;
backoff_matcher_.SetState(s);
return;
}
ArcIterator<FST> aiter(*fst_, s);
aiter.SetFlags(
kArcNoCache |
(match_type_ == MATCH_OUTPUT ? kArcOLabelValue : kArcILabelValue),
kArcNoCache | kArcValueFlags);
// the statement above, says: "Don't cache stuff; and I only need the
// ilabel/olabel to be computed.
aiter.Seek(num_arcs - 1);
Label highest_label =
(match_type_ == MATCH_OUTPUT ? aiter.Value().olabel
: aiter.Value().ilabel);
if ((highest_label + 1) * opts_.table_ratio > num_arcs) {
this_table_ = empty;
backoff_matcher_.SetState(s);
return; // table would be too sparse.
}
// OK, now we are creating the table.
this_table_ = new std::vector<ArcId>(highest_label + 1, kNoStateId);
ArcId pos = 0;
for (aiter.Seek(0); !aiter.Done(); aiter.Next(), pos++) {
Label label = (match_type_ == MATCH_OUTPUT ? aiter.Value().olabel
: aiter.Value().ilabel);
assert(static_cast<size_t>(label) <=
static_cast<size_t>(highest_label)); // also checks >= 0.
if ((*this_table_)[label] == kNoStateId) (*this_table_)[label] = pos;
// set this_table_[label] to first position where arc has this
// label.
}
}
// At this point in the code, this_table_ != NULL and != empty.
aiter_ = new ArcIterator<FST>(*fst_, s);
aiter_->SetFlags(kArcNoCache,
kArcNoCache); // don't need to cache arcs as may only
// need a small subset.
loop_.nextstate = s;
// aiter_ = NULL;
// backoff_matcher_.SetState(s);
}
bool Find(Label match_label) {
if (!aiter_) {
return backoff_matcher_.Find(match_label);
} else {
match_label_ = match_label;
current_loop_ = (match_label == 0);
// kNoLabel means the implicit loop on the other FST --
// matches real epsilons but not the self-loop.
match_label_ = (match_label_ == kNoLabel ? 0 : match_label_);
if (static_cast<size_t>(match_label_) < tables_[s_]->size() &&
(*(tables_[s_]))[match_label_] != kNoStateId) {
aiter_->Seek((*(tables_[s_]))[match_label_]); // label exists.
return true;
}
return current_loop_;
}
}
const Arc &Value() const {
if (aiter_)
return current_loop_ ? loop_ : aiter_->Value();
else
return backoff_matcher_.Value();
}
void Next() {
if (aiter_) {
if (current_loop_)
current_loop_ = false;
else
aiter_->Next();
} else {
backoff_matcher_.Next();
}
}
bool Done() const {
if (aiter_ != NULL) {
if (current_loop_) return false;
if (aiter_->Done()) return true;
Label label = (match_type_ == MATCH_OUTPUT ? aiter_->Value().olabel
: aiter_->Value().ilabel);
return (label != match_label_);
} else {
return backoff_matcher_.Done();
}
}
const Arc &Value() {
if (aiter_ != NULL) {
return (current_loop_ ? loop_ : aiter_->Value());
} else {
return backoff_matcher_.Value();
}
}
virtual TableMatcherImpl<FST> *Copy(bool safe = false) const {
assert(0); // shouldn't be called. This is not a "real" matcher,
// although we derive from MatcherBase for convenience.
return NULL;
}
virtual uint64 Properties(uint64 props) const {
return props;
} // simple matcher that does
// not change its FST, so properties are properties of FST it is applied to
private:
virtual void SetState_(StateId s) { SetState(s); }
virtual bool Find_(Label label) { return Find(label); }
virtual bool Done_() const { return Done(); }
virtual const Arc &Value_() const { return Value(); }
virtual void Next_() { Next(); }
MatchType match_type_;
FST *fst_;
bool current_loop_;
Label match_label_;
Arc loop_;
ArcIterator<FST> *aiter_;
StateId s_;
std::vector<std::vector<ArcId> *> tables_;
TableMatcherOptions opts_;
BackoffMatcher backoff_matcher_;
};
template <class F, class BackoffMatcher = SortedMatcher<F> >
class TableMatcher : public MatcherBase<typename F::Arc> {
public:
typedef F FST;
typedef typename F::Arc Arc;
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef StateId
ArcId; // Use this type to store arc offsets [it's actually size_t
// in the Seek function of ArcIterator, but StateId should be big enough].
typedef typename Arc::Weight Weight;
typedef TableMatcherImpl<F, BackoffMatcher> Impl;
TableMatcher(const FST &fst, MatchType match_type,
const TableMatcherOptions &opts = TableMatcherOptions())
: impl_(std::make_shared<Impl>(fst, match_type, opts)) {}
TableMatcher(const TableMatcher<FST, BackoffMatcher> &matcher,
bool safe = false)
: impl_(matcher.impl_) {
if (safe == true) {
LOG(FATAL) << "TableMatcher: Safe copy not supported";
}
}
virtual const FST &GetFst() const { return impl_->GetFst(); }
virtual MatchType Type(bool test) const { return impl_->Type(test); }
void SetState(StateId s) { return impl_->SetState(s); }
bool Find(Label match_label) { return impl_->Find(match_label); }
const Arc &Value() const { return impl_->Value(); }
void Next() { return impl_->Next(); }
bool Done() const { return impl_->Done(); }
const Arc &Value() { return impl_->Value(); }
virtual TableMatcher<FST, BackoffMatcher> *Copy(bool safe = false) const {
return new TableMatcher<FST, BackoffMatcher>(*this, safe);
}
virtual uint64 Properties(uint64 props) const {
return impl_->Properties(props);
} // simple matcher that does
// not change its FST, so properties are properties of FST it is applied to
private:
std::shared_ptr<Impl> impl_;
virtual void SetState_(StateId s) { impl_->SetState(s); }
virtual bool Find_(Label label) { return impl_->Find(label); }
virtual bool Done_() const { return impl_->Done(); }
virtual const Arc &Value_() const { return impl_->Value(); }
virtual void Next_() { impl_->Next(); }
TableMatcher &operator=(const TableMatcher &) = delete;
};
struct TableComposeOptions : public TableMatcherOptions {
bool connect; // Connect output
ComposeFilter filter_type; // Which pre-defined filter to use
MatchType table_match_type;
explicit TableComposeOptions(const TableMatcherOptions &mo, bool c = true,
ComposeFilter ft = SEQUENCE_FILTER,
MatchType tms = MATCH_OUTPUT)
: TableMatcherOptions(mo),
connect(c),
filter_type(ft),
table_match_type(tms) {}
TableComposeOptions()
: connect(true),
filter_type(SEQUENCE_FILTER),
table_match_type(MATCH_OUTPUT) {}
};
template <class Arc>
void TableCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
MutableFst<Arc> *ofst,
const TableComposeOptions &opts = TableComposeOptions()) {
typedef Fst<Arc> F;
CacheOptions nopts;
nopts.gc_limit = 0; // Cache only the last state for fastest copy.
if (opts.table_match_type == MATCH_OUTPUT) {
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
impl_opts.matcher1 = new TableMatcher<F>(ifst1, MATCH_OUTPUT, opts);
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
} else {
assert(opts.table_match_type == MATCH_INPUT);
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
impl_opts.matcher2 = new TableMatcher<F>(ifst2, MATCH_INPUT, opts);
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
}
if (opts.connect) Connect(ofst);
}
/// TableComposeCache lets us do multiple compositions while caching the same
/// matcher.
template <class F>
struct TableComposeCache {
TableMatcher<F> *matcher;
TableComposeOptions opts;
explicit TableComposeCache(
const TableComposeOptions &opts = TableComposeOptions())
: matcher(NULL), opts(opts) {}
~TableComposeCache() { delete (matcher); }
};
template <class Arc>
void TableCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
MutableFst<Arc> *ofst, TableComposeCache<Fst<Arc> > *cache) {
typedef Fst<Arc> F;
assert(cache != NULL);
CacheOptions nopts;
nopts.gc_limit = 0; // Cache only the last state for fastest copy.
if (cache->opts.table_match_type == MATCH_OUTPUT) {
ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
if (cache->matcher == NULL)
cache->matcher = new TableMatcher<F>(ifst1, MATCH_OUTPUT, cache->opts);
impl_opts.matcher1 = cache->matcher->Copy(); // not passing "safe": may not
// be thread-safe-- anway I don't understand this part.
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
} else {
assert(cache->opts.table_match_type == MATCH_INPUT);
ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
if (cache->matcher == NULL)
cache->matcher = new TableMatcher<F>(ifst2, MATCH_INPUT, cache->opts);
impl_opts.matcher2 = cache->matcher->Copy();
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
}
if (cache->opts.connect) Connect(ofst);
}
} // namespace fst
#endif // KALDI_FSTEXT_TABLE_MATCHER_H_
add_library(kaldi-lat
determinize-lattice-pruned.cc
lattice-functions.cc
)
target_link_libraries(kaldi-lat PUBLIC kaldi-util)
\ No newline at end of file
// lat/determinize-lattice-pruned-test.cc
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/determinize-lattice-pruned.h"
#include "fstext/lattice-utils.h"
#include "fstext/fst-test-utils.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
namespace fst {
// Caution: these tests are not as generic as you might think from all the
// templates in the code. They are basically only valid for LatticeArc.
// This is partly due to the fact that certain templates need to be instantiated
// in other .cc files in this directory.
// test that determinization proceeds correctly on general
// FSTs (not guaranteed determinzable, but we use the
// max-states option to stop it getting out of control).
template<class Arc> void TestDeterminizeLatticePruned() {
typedef kaldi::int32 Int;
typedef typename Arc::Weight Weight;
typedef ArcTpl<CompactLatticeWeightTpl<Weight, Int> > CompactArc;
for(int i = 0; i < 100; i++) {
RandFstOptions opts;
opts.n_states = 4;
opts.n_arcs = 10;
opts.n_final = 2;
opts.allow_empty = false;
opts.weight_multiplier = 0.5; // impt for the randomly generated weights
opts.acyclic = true;
// to be exactly representable in float,
// or this test fails because numerical differences can cause symmetry in
// weights to be broken, which causes the wrong path to be chosen as far
// as the string part is concerned.
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
bool sorted = TopSort(fst);
KALDI_ASSERT(sorted);
ILabelCompare<Arc> ilabel_comp;
if (kaldi::Rand() % 2 == 0)
ArcSort(fst, ilabel_comp);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> det_fst;
try {
DeterminizeLatticePrunedOptions lat_opts;
lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000);
lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20);
lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30);
bool ans = DeterminizeLatticePruned<Weight>(*fst, 10.0, &det_fst, lat_opts);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic);
// OK, now determinize it a different way and check equivalence.
// [note: it's not normal determinization, it's taking the best path
// for any input-symbol sequence....
VectorFst<Arc> pruned_fst(*fst);
if (pruned_fst.NumStates() != 0)
kaldi::PruneLattice(10.0, &pruned_fst);
VectorFst<CompactArc> compact_pruned_fst, compact_pruned_det_fst;
ConvertLattice<Weight, Int>(pruned_fst, &compact_pruned_fst, false);
std::cout << "Compact pruned FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
ConvertLattice<Weight, Int>(det_fst, &compact_pruned_det_fst, false);
std::cout << "Compact version of determinized FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
if (ans)
KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/));
} catch (...) {
std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n";
}
delete fst;
}
}
// test that determinization proceeds without crash on acyclic FSTs
// (guaranteed determinizable in this sense).
template<class Arc> void TestDeterminizeLatticePruned2() {
typedef typename Arc::Weight Weight;
RandFstOptions opts;
opts.acyclic = true;
for(int i = 0; i < 100; i++) {
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> ofst;
DeterminizeLatticePruned<Weight>(*fst, 10.0, &ofst);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(ofst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
delete fst;
}
}
} // end namespace fst
int main() {
using namespace fst;
TestDeterminizeLatticePruned<kaldi::LatticeArc>();
TestDeterminizeLatticePruned2<kaldi::LatticeArc>();
std::cout << "Tests succeeded\n";
}
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include "fstext/determinize-lattice.h" // for LatticeStringRepository #include "fstext/determinize-lattice.h" // for LatticeStringRepository
#include "fstext/fstext-utils.h" #include "fstext/fstext-utils.h"
#include "lat/lattice-functions.h" // for PruneLattice #include "lat/lattice-functions.h" // for PruneLattice
#include "lat/minimize-lattice.h" // for minimization // #include "lat/minimize-lattice.h" // for minimization
#include "lat/push-lattice.h" // for minimization // #include "lat/push-lattice.h" // for minimization
#include "lat/determinize-lattice-pruned.h" #include "lat/determinize-lattice-pruned.h"
namespace fst { namespace fst {
...@@ -223,6 +223,10 @@ template<class Weight, class IntType> class LatticeDeterminizerPruned { ...@@ -223,6 +223,10 @@ template<class Weight, class IntType> class LatticeDeterminizerPruned {
iter != initial_hash_.end(); ++iter) iter != initial_hash_.end(); ++iter)
delete iter->first; delete iter->first;
{ InitialSubsetHash tmp; tmp.swap(initial_hash_); } { InitialSubsetHash tmp; tmp.swap(initial_hash_); }
for (size_t i = 0; i < output_states_.size(); i++) {
vector<Element> tmp;
tmp.swap(output_states_[i]->minimal_subset);
}
{ vector<char> tmp; tmp.swap(isymbol_or_final_); } { vector<char> tmp; tmp.swap(isymbol_or_final_); }
{ // Free up the queue. I'm not sure how to make sure all { // Free up the queue. I'm not sure how to make sure all
// the memory is really freed (no swap() function)... doesn't really // the memory is really freed (no swap() function)... doesn't really
...@@ -1288,222 +1292,222 @@ bool DeterminizeLatticePruned(const ExpandedFst<ArcTpl<Weight> > &ifst, ...@@ -1288,222 +1292,222 @@ bool DeterminizeLatticePruned(const ExpandedFst<ArcTpl<Weight> > &ifst,
return false; // Suppress compiler warning; this code is unreachable. return false; // Suppress compiler warning; this code is unreachable.
} }
template<class Weight> // template<class Weight>
typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones( // typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
MutableFst<ArcTpl<Weight> > *fst) { // MutableFst<ArcTpl<Weight> > *fst) {
// Define some types. // // Define some types.
typedef ArcTpl<Weight> Arc; // typedef ArcTpl<Weight> Arc;
typedef typename Arc::StateId StateId; // typedef typename Arc::StateId StateId;
typedef typename Arc::Label Label; // typedef typename Arc::Label Label;
//
// Work out the first phone symbol. This is more related to the phone // // Work out the first phone symbol. This is more related to the phone
// insertion function, so we put it here and make it the returning value of // // insertion function, so we put it here and make it the returning value of
// DeterminizeLatticeInsertPhones(). // // DeterminizeLatticeInsertPhones().
Label first_phone_label = HighestNumberedInputSymbol(*fst) + 1; // Label first_phone_label = HighestNumberedInputSymbol(*fst) + 1;
//
// Insert phones here. // // Insert phones here.
for (StateIterator<MutableFst<Arc> > siter(*fst); // for (StateIterator<MutableFst<Arc> > siter(*fst);
!siter.Done(); siter.Next()) { // !siter.Done(); siter.Next()) {
StateId state = siter.Value(); // StateId state = siter.Value();
if (state == fst->Start()) // if (state == fst->Start())
continue; // continue;
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state); // for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
!aiter.Done(); aiter.Next()) { // !aiter.Done(); aiter.Next()) {
Arc arc = aiter.Value(); // Arc arc = aiter.Value();
//
// Note: the words are on the input symbol side and transition-id's are on // // Note: the words are on the input symbol side and transition-id's are on
// the output symbol side. // // the output symbol side.
if ((arc.olabel != 0) // if ((arc.olabel != 0)
&& (trans_model.TransitionIdIsStartOfPhone(arc.olabel)) // && (trans_model.TransitionIdToHmmState(arc.olabel) == 0)
&& (!trans_model.IsSelfLoop(arc.olabel))) { // && (!trans_model.IsSelfLoop(arc.olabel))) {
Label phone = // Label phone =
static_cast<Label>(trans_model.TransitionIdToPhone(arc.olabel)); // static_cast<Label>(trans_model.TransitionIdToPhone(arc.olabel));
//
// Skips <eps>. // // Skips <eps>.
KALDI_ASSERT(phone != 0); // KALDI_ASSERT(phone != 0);
//
if (arc.ilabel == 0) { // if (arc.ilabel == 0) {
// If there is no word on the arc, insert the phone directly. // // If there is no word on the arc, insert the phone directly.
arc.ilabel = first_phone_label + phone; // arc.ilabel = first_phone_label + phone;
} else { // } else {
// Otherwise, add an additional arc. // // Otherwise, add an additional arc.
StateId additional_state = fst->AddState(); // StateId additional_state = fst->AddState();
StateId next_state = arc.nextstate; // StateId next_state = arc.nextstate;
arc.nextstate = additional_state; // arc.nextstate = additional_state;
fst->AddArc(additional_state, // fst->AddArc(additional_state,
Arc(first_phone_label + phone, 0, // Arc(first_phone_label + phone, 0,
Weight::One(), next_state)); // Weight::One(), next_state));
} // }
} // }
//
aiter.SetValue(arc); // aiter.SetValue(arc);
} // }
} // }
//
return first_phone_label; // return first_phone_label;
} // }
//
template<class Weight> // template<class Weight>
void DeterminizeLatticeDeletePhones( // void DeterminizeLatticeDeletePhones(
typename ArcTpl<Weight>::Label first_phone_label, // typename ArcTpl<Weight>::Label first_phone_label,
MutableFst<ArcTpl<Weight> > *fst) { // MutableFst<ArcTpl<Weight> > *fst) {
// Define some types. // // Define some types.
typedef ArcTpl<Weight> Arc; // typedef ArcTpl<Weight> Arc;
typedef typename Arc::StateId StateId; // typedef typename Arc::StateId StateId;
typedef typename Arc::Label Label; // typedef typename Arc::Label Label;
//
// Delete phones here. // // Delete phones here.
for (StateIterator<MutableFst<Arc> > siter(*fst); // for (StateIterator<MutableFst<Arc> > siter(*fst);
!siter.Done(); siter.Next()) { // !siter.Done(); siter.Next()) {
StateId state = siter.Value(); // StateId state = siter.Value();
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state); // for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
!aiter.Done(); aiter.Next()) { // !aiter.Done(); aiter.Next()) {
Arc arc = aiter.Value(); // Arc arc = aiter.Value();
//
if (arc.ilabel >= first_phone_label) // if (arc.ilabel >= first_phone_label)
arc.ilabel = 0; // arc.ilabel = 0;
//
aiter.SetValue(arc); // aiter.SetValue(arc);
} // }
} // }
} // }
// instantiate for type LatticeWeight // instantiate for type LatticeWeight
template // template
void DeterminizeLatticeDeletePhones( // void DeterminizeLatticeDeletePhones(
ArcTpl<kaldi::LatticeWeight>::Label first_phone_label, // ArcTpl<kaldi::LatticeWeight>::Label first_phone_label,
MutableFst<ArcTpl<kaldi::LatticeWeight> > *fst); // MutableFst<ArcTpl<kaldi::LatticeWeight> > *fst);
//
/** This function does a first pass determinization with phone symbols inserted // /** This function does a first pass determinization with phone symbols inserted
at phone boundary. It uses a transition model to work out the transition-id // at phone boundary. It uses a transition model to work out the transition-id
to phone map. First, phones will be inserted into the word level lattice. // to phone map. First, phones will be inserted into the word level lattice.
Second, determinization will be applied on top of the phone + word lattice. // Second, determinization will be applied on top of the phone + word lattice.
Finally, the inserted phones will be removed, converting the lattice back to // Finally, the inserted phones will be removed, converting the lattice back to
a word level lattice. The output lattice of this pass is not deterministic, // a word level lattice. The output lattice of this pass is not deterministic,
since we remove the phone symbols as a last step. It is supposed to be // since we remove the phone symbols as a last step. It is supposed to be
followed by another pass of determinization at the word level. It could also // followed by another pass of determinization at the word level. It could also
be useful for some other applications such as fMLLR estimation, confidence // be useful for some other applications such as fMLLR estimation, confidence
estimation, discriminative training, etc. // estimation, discriminative training, etc.
*/ // */
template<class Weight, class IntType> // template<class Weight, class IntType>
bool DeterminizeLatticePhonePrunedFirstPass( // bool DeterminizeLatticePhonePrunedFirstPass(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
double beam, // double beam,
MutableFst<ArcTpl<Weight> > *fst, // MutableFst<ArcTpl<Weight> > *fst,
const DeterminizeLatticePrunedOptions &opts) { // const DeterminizeLatticePrunedOptions &opts) {
// First, insert the phones. // // First, insert the phones.
typename ArcTpl<Weight>::Label first_phone_label = // typename ArcTpl<Weight>::Label first_phone_label =
DeterminizeLatticeInsertPhones(trans_model, fst); // DeterminizeLatticeInsertPhones(trans_model, fst);
TopSort(fst); // TopSort(fst);
//
// Second, do determinization with phone inserted. // // Second, do determinization with phone inserted.
bool ans = DeterminizeLatticePruned<Weight>(*fst, beam, fst, opts); // bool ans = DeterminizeLatticePruned<Weight>(*fst, beam, fst, opts);
//
// Finally, remove the inserted phones. // // Finally, remove the inserted phones.
DeterminizeLatticeDeletePhones(first_phone_label, fst); // DeterminizeLatticeDeletePhones(first_phone_label, fst);
TopSort(fst); // TopSort(fst);
//
return ans; // return ans;
} // }
//
// "Destructive" version of DeterminizeLatticePhonePruned() where the input // // "Destructive" version of DeterminizeLatticePhonePruned() where the input
// lattice might be modified. // // lattice might be modified.
template<class Weight, class IntType> // template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned( // bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
MutableFst<ArcTpl<Weight> > *ifst, // MutableFst<ArcTpl<Weight> > *ifst,
double beam, // double beam,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst, // MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts) { // DeterminizeLatticePhonePrunedOptions opts) {
// Returning status. // // Returning status.
bool ans = true; // bool ans = true;
//
// Make sure at least one of opts.phone_determinize and opts.word_determinize // // Make sure at least one of opts.phone_determinize and opts.word_determinize
// is not false, otherwise calling this function doesn't make any sense. // // is not false, otherwise calling this function doesn't make any sense.
if ((opts.phone_determinize || opts.word_determinize) == false) { // if ((opts.phone_determinize || opts.word_determinize) == false) {
KALDI_WARN << "Both --phone-determinize and --word-determinize are set to " // KALDI_WARN << "Both --phone-determinize and --word-determinize are set to "
<< "false, copying lattice without determinization."; // << "false, copying lattice without determinization.";
// We are expecting the words on the input side. // // We are expecting the words on the input side.
ConvertLattice<Weight, IntType>(*ifst, ofst, false); // ConvertLattice<Weight, IntType>(*ifst, ofst, false);
return ans; // return ans;
} // }
//
// Determinization options. // // Determinization options.
DeterminizeLatticePrunedOptions det_opts; // DeterminizeLatticePrunedOptions det_opts;
det_opts.delta = opts.delta; // det_opts.delta = opts.delta;
det_opts.max_mem = opts.max_mem; // det_opts.max_mem = opts.max_mem;
//
// If --phone-determinize is true, do the determinization on phone + word // // If --phone-determinize is true, do the determinization on phone + word
// lattices. // // lattices.
if (opts.phone_determinize) { // if (opts.phone_determinize) {
KALDI_VLOG(3) << "Doing first pass of determinization on phone + word " // KALDI_VLOG(3) << "Doing first pass of determinization on phone + word "
<< "lattices."; // << "lattices.";
ans = DeterminizeLatticePhonePrunedFirstPass<Weight, IntType>( // ans = DeterminizeLatticePhonePrunedFirstPass<Weight, IntType>(
trans_model, beam, ifst, det_opts) && ans; // trans_model, beam, ifst, det_opts) && ans;
//
// If --word-determinize is false, we've finished the job and return here. // // If --word-determinize is false, we've finished the job and return here.
if (!opts.word_determinize) { // if (!opts.word_determinize) {
// We are expecting the words on the input side. // // We are expecting the words on the input side.
ConvertLattice<Weight, IntType>(*ifst, ofst, false); // ConvertLattice<Weight, IntType>(*ifst, ofst, false);
return ans; // return ans;
} // }
} // }
//
// If --word-determinize is true, do the determinization on word lattices. // // If --word-determinize is true, do the determinization on word lattices.
if (opts.word_determinize) { // if (opts.word_determinize) {
KALDI_VLOG(3) << "Doing second pass of determinization on word lattices."; // KALDI_VLOG(3) << "Doing second pass of determinization on word lattices.";
ans = DeterminizeLatticePruned<Weight, IntType>( // ans = DeterminizeLatticePruned<Weight, IntType>(
*ifst, beam, ofst, det_opts) && ans; // *ifst, beam, ofst, det_opts) && ans;
} // }
//
// If --minimize is true, push and minimize after determinization. // // If --minimize is true, push and minimize after determinization.
if (opts.minimize) { // if (opts.minimize) {
KALDI_VLOG(3) << "Pushing and minimizing on word lattices."; // KALDI_VLOG(3) << "Pushing and minimizing on word lattices.";
ans = PushCompactLatticeStrings<Weight, IntType>(ofst) && ans; // ans = PushCompactLatticeStrings<Weight, IntType>(ofst) && ans;
ans = PushCompactLatticeWeights<Weight, IntType>(ofst) && ans; // ans = PushCompactLatticeWeights<Weight, IntType>(ofst) && ans;
ans = MinimizeCompactLattice<Weight, IntType>(ofst) && ans; // ans = MinimizeCompactLattice<Weight, IntType>(ofst) && ans;
} // }
//
return ans; // return ans;
} // }
//
// Normal verson of DeterminizeLatticePhonePruned(), where the input lattice // // Normal verson of DeterminizeLatticePhonePruned(), where the input lattice
// will be kept as unchanged. // // will be kept as unchanged.
template<class Weight, class IntType> // template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned( // bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
const ExpandedFst<ArcTpl<Weight> > &ifst, // const ExpandedFst<ArcTpl<Weight> > &ifst,
double beam, // double beam,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst, // MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts) { // DeterminizeLatticePhonePrunedOptions opts) {
VectorFst<ArcTpl<Weight> > temp_fst(ifst); // VectorFst<ArcTpl<Weight> > temp_fst(ifst);
return DeterminizeLatticePhonePruned(trans_model, &temp_fst, // return DeterminizeLatticePhonePruned(trans_model, &temp_fst,
beam, ofst, opts); // beam, ofst, opts);
} // }
//
bool DeterminizeLatticePhonePrunedWrapper( // bool DeterminizeLatticePhonePrunedWrapper(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
MutableFst<kaldi::LatticeArc> *ifst, // MutableFst<kaldi::LatticeArc> *ifst,
double beam, // double beam,
MutableFst<kaldi::CompactLatticeArc> *ofst, // MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts) { // DeterminizeLatticePhonePrunedOptions opts) {
bool ans = true; // bool ans = true;
Invert(ifst); // Invert(ifst);
if (ifst->Properties(fst::kTopSorted, true) == 0) { // if (ifst->Properties(fst::kTopSorted, true) == 0) {
if (!TopSort(ifst)) { // if (!TopSort(ifst)) {
// Cannot topologically sort the lattice -- determinization will fail. // // Cannot topologically sort the lattice -- determinization will fail.
KALDI_ERR << "Topological sorting of state-level lattice failed (probably" // KALDI_ERR << "Topological sorting of state-level lattice failed (probably"
<< " your lexicon has empty words or your LM has epsilon cycles" // << " your lexicon has empty words or your LM has epsilon cycles"
<< ")."; // << ").";
} // }
} // }
ILabelCompare<kaldi::LatticeArc> ilabel_comp; // ILabelCompare<kaldi::LatticeArc> ilabel_comp;
ArcSort(ifst, ilabel_comp); // ArcSort(ifst, ilabel_comp);
ans = DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>( // ans = DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
trans_model, ifst, beam, ofst, opts); // trans_model, ifst, beam, ofst, opts);
Connect(ofst); // Connect(ofst);
return ans; // return ans;
} // }
// Instantiate the templates for the types we might need. // Instantiate the templates for the types we might need.
// Note: there are actually four templates, each of which // Note: there are actually four templates, each of which
...@@ -1522,20 +1526,20 @@ bool DeterminizeLatticePruned<kaldi::LatticeWeight>( ...@@ -1522,20 +1526,20 @@ bool DeterminizeLatticePruned<kaldi::LatticeWeight>(
MutableFst<kaldi::LatticeArc> *ofst, MutableFst<kaldi::LatticeArc> *ofst,
DeterminizeLatticePrunedOptions opts); DeterminizeLatticePrunedOptions opts);
template // template
bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>( // bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
const ExpandedFst<kaldi::LatticeArc> &ifst, // const ExpandedFst<kaldi::LatticeArc> &ifst,
double prune, // double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst, // MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts); // DeterminizeLatticePhonePrunedOptions opts);
//
template // template
bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>( // bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
MutableFst<kaldi::LatticeArc> *ifst, // MutableFst<kaldi::LatticeArc> *ifst,
double prune, // double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst, // MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts); // DeterminizeLatticePhonePrunedOptions opts);
} }
...@@ -28,8 +28,8 @@ ...@@ -28,8 +28,8 @@
#include <set> #include <set>
#include <vector> #include <vector>
#include "fstext/lattice-weight.h" #include "fstext/lattice-weight.h"
#include "itf/transition-information.h" // #include "hmm/transition-model.h"
#include "itf/options-itf.h" #include "util/options-itf.h"
#include "lat/kaldi-lattice.h" #include "lat/kaldi-lattice.h"
namespace fst { namespace fst {
...@@ -212,82 +212,82 @@ bool DeterminizeLatticePruned( ...@@ -212,82 +212,82 @@ bool DeterminizeLatticePruned(
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst, MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions()); DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/** This function takes in lattices and inserts phones at phone boundaries. It // /** This function takes in lattices and inserts phones at phone boundaries. It
uses the transition model to work out the transition_id to phone map. The // uses the transition model to work out the transition_id to phone map. The
returning value is the starting index of the phone label. Typically we pick // returning value is the starting index of the phone label. Typically we pick
(maximum_output_label_index + 1) as this value. The inserted phones are then // (maximum_output_label_index + 1) as this value. The inserted phones are then
mapped to (returning_value + original_phone_label) in the new lattice. The // mapped to (returning_value + original_phone_label) in the new lattice. The
returning value will be used by DeterminizeLatticeDeletePhones() where it // returning value will be used by DeterminizeLatticeDeletePhones() where it
works out the phones according to this value. // works out the phones according to this value.
*/ // */
template<class Weight> // template<class Weight>
typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones( // typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
MutableFst<ArcTpl<Weight> > *fst); // MutableFst<ArcTpl<Weight> > *fst);
//
/** This function takes in lattices and deletes "phones" from them. The "phones" // /** This function takes in lattices and deletes "phones" from them. The "phones"
here are actually any label that is larger than first_phone_label because // here are actually any label that is larger than first_phone_label because
when we insert phones into the lattice, we map the original phone label to // when we insert phones into the lattice, we map the original phone label to
(first_phone_label + original_phone_label). It is supposed to be used // (first_phone_label + original_phone_label). It is supposed to be used
together with DeterminizeLatticeInsertPhones() // together with DeterminizeLatticeInsertPhones()
*/ // */
template<class Weight> // template<class Weight>
void DeterminizeLatticeDeletePhones( // void DeterminizeLatticeDeletePhones(
typename ArcTpl<Weight>::Label first_phone_label, // typename ArcTpl<Weight>::Label first_phone_label,
MutableFst<ArcTpl<Weight> > *fst); // MutableFst<ArcTpl<Weight> > *fst);
//
/** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and // /** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and
DeterminizeLatticePruned(). If --phone-determinize is set to true, it first // DeterminizeLatticePruned(). If --phone-determinize is set to true, it first
calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of // calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of
determinization on the phone + word lattices. If --word-determinize is set // determinization on the phone + word lattices. If --word-determinize is set
true, it then does a second pass of determinization on the word lattices by // true, it then does a second pass of determinization on the word lattices by
calling DeterminizeLatticePruned(). If both are set to false, then it gives // calling DeterminizeLatticePruned(). If both are set to false, then it gives
a warning and copying the lattices without determinization. // a warning and copying the lattices without determinization.
//
Note: the point of doing first a phone-level determinization pass and then // Note: the point of doing first a phone-level determinization pass and then
a word-level determinization pass is that it allows us to determinize // a word-level determinization pass is that it allows us to determinize
deeper lattices without "failing early" and returning a too-small lattice // deeper lattices without "failing early" and returning a too-small lattice
due to the max-mem constraint. The result should be the same as word-level // due to the max-mem constraint. The result should be the same as word-level
determinization in general, but for deeper lattices it is a bit faster, // determinization in general, but for deeper lattices it is a bit faster,
despite the fact that we now have two passes of determinization by default. // despite the fact that we now have two passes of determinization by default.
*/ // */
template<class Weight, class IntType> // template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned( // bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
const ExpandedFst<ArcTpl<Weight> > &ifst, // const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune, // double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst, // MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts // DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions()); // = DeterminizeLatticePhonePrunedOptions());
//
/** "Destructive" version of DeterminizeLatticePhonePruned() where the input // /** "Destructive" version of DeterminizeLatticePhonePruned() where the input
lattice might be changed. // lattice might be changed.
*/ // */
template<class Weight, class IntType> // template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned( // bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
MutableFst<ArcTpl<Weight> > *ifst, // MutableFst<ArcTpl<Weight> > *ifst,
double prune, // double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst, // MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts // DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions()); // = DeterminizeLatticePhonePrunedOptions());
//
/** This function is a wrapper of DeterminizeLatticePhonePruned() that works for // /** This function is a wrapper of DeterminizeLatticePhonePruned() that works for
Lattice type FSTs. It simplifies the calling process by calling // Lattice type FSTs. It simplifies the calling process by calling
TopSort() Invert() and ArcSort() for you. // TopSort() Invert() and ArcSort() for you.
Unlike other determinization routines, the function // Unlike other determinization routines, the function
requires "ifst" to have transition-id's on the input side and words on the // requires "ifst" to have transition-id's on the input side and words on the
output side. // output side.
This function can be used as the top-level interface to all the determinization // This function can be used as the top-level interface to all the determinization
code. // code.
*/ // */
bool DeterminizeLatticePhonePrunedWrapper( // bool DeterminizeLatticePhonePrunedWrapper(
const kaldi::TransitionInformation &trans_model, // const kaldi::TransitionModel &trans_model,
MutableFst<kaldi::LatticeArc> *ifst, // MutableFst<kaldi::LatticeArc> *ifst,
double prune, // double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst, // MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts // DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions()); // = DeterminizeLatticePhonePrunedOptions());
/// @} end "addtogroup fst_extensions" /// @} end "addtogroup fst_extensions"
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "fstext/fstext-lib.h" #include "fstext/fstext-lib.h"
#include "base/kaldi-common.h" #include "base/kaldi-common.h"
#include "util/common-utils.h" // #include "util/common-utils.h"
namespace kaldi { namespace kaldi {
...@@ -142,13 +142,13 @@ class LatticeHolder { ...@@ -142,13 +142,13 @@ class LatticeHolder {
T *t_; T *t_;
}; };
typedef TableWriter<LatticeHolder> LatticeWriter; // typedef TableWriter<LatticeHolder> LatticeWriter;
typedef SequentialTableReader<LatticeHolder> SequentialLatticeReader; // typedef SequentialTableReader<LatticeHolder> SequentialLatticeReader;
typedef RandomAccessTableReader<LatticeHolder> RandomAccessLatticeReader; // typedef RandomAccessTableReader<LatticeHolder> RandomAccessLatticeReader;
//
typedef TableWriter<CompactLatticeHolder> CompactLatticeWriter; // typedef TableWriter<CompactLatticeHolder> CompactLatticeWriter;
typedef SequentialTableReader<CompactLatticeHolder> SequentialCompactLatticeReader; // typedef SequentialTableReader<CompactLatticeHolder> SequentialCompactLatticeReader;
typedef RandomAccessTableReader<CompactLatticeHolder> RandomAccessCompactLatticeReader; // typedef RandomAccessTableReader<CompactLatticeHolder> RandomAccessCompactLatticeReader;
} // namespace kaldi } // namespace kaldi
......
...@@ -23,211 +23,214 @@ ...@@ -23,211 +23,214 @@
// limitations under the License. // limitations under the License.
#include "base/kaldi-math.h"
#include "lat/lattice-functions.h" #include "lat/lattice-functions.h"
// #include "hmm/transition-model.h"
// #include "util/stl-utils.h"
#include "base/kaldi-math.h"
// #include "hmm/hmm-utils.h"
namespace kaldi { namespace kaldi {
using std::map; using std::map;
using std::vector; using std::vector;
void GetPerFrameAcousticCosts(const Lattice &nbest, // void GetPerFrameAcousticCosts(const Lattice &nbest,
Vector<BaseFloat> *per_frame_loglikes) { // Vector<BaseFloat> *per_frame_loglikes) {
using namespace fst; // using namespace fst;
typedef Lattice::Arc::Weight Weight; // typedef Lattice::Arc::Weight Weight;
vector<BaseFloat> loglikes; // vector<BaseFloat> loglikes;
//
int32 cur_state = nbest.Start(); // int32 cur_state = nbest.Start();
int32 prev_frame = -1; // int32 prev_frame = -1;
BaseFloat eps_acwt = 0.0; // BaseFloat eps_acwt = 0.0;
while(1) { // while(1) {
Weight w = nbest.Final(cur_state); // Weight w = nbest.Final(cur_state);
if (w != Weight::Zero()) { // if (w != Weight::Zero()) {
KALDI_ASSERT(nbest.NumArcs(cur_state) == 0); // KALDI_ASSERT(nbest.NumArcs(cur_state) == 0);
if (per_frame_loglikes != NULL) { // if (per_frame_loglikes != NULL) {
SubVector<BaseFloat> subvec(&(loglikes[0]), loglikes.size()); // SubVector<BaseFloat> subvec(&(loglikes[0]), loglikes.size());
Vector<BaseFloat> vec(subvec); // Vector<BaseFloat> vec(subvec);
*per_frame_loglikes = vec; // *per_frame_loglikes = vec;
} // }
break; // break;
} else { // } else {
KALDI_ASSERT(nbest.NumArcs(cur_state) == 1); // KALDI_ASSERT(nbest.NumArcs(cur_state) == 1);
fst::ArcIterator<Lattice> iter(nbest, cur_state); // fst::ArcIterator<Lattice> iter(nbest, cur_state);
const Lattice::Arc &arc = iter.Value(); // const Lattice::Arc &arc = iter.Value();
BaseFloat acwt = arc.weight.Value2(); // BaseFloat acwt = arc.weight.Value2();
if (arc.ilabel != 0) { // if (arc.ilabel != 0) {
if (eps_acwt > 0) { // if (eps_acwt > 0) {
acwt += eps_acwt; // acwt += eps_acwt;
eps_acwt = 0.0; // eps_acwt = 0.0;
} // }
loglikes.push_back(acwt); // loglikes.push_back(acwt);
prev_frame++; // prev_frame++;
} else if (acwt == acwt){ // } else if (acwt == acwt){
if (prev_frame > -1) { // if (prev_frame > -1) {
loglikes[prev_frame] += acwt; // loglikes[prev_frame] += acwt;
} else { // } else {
eps_acwt += acwt; // eps_acwt += acwt;
} // }
} // }
cur_state = arc.nextstate; // cur_state = arc.nextstate;
} // }
} // }
} // }
//
int32 LatticeStateTimes(const Lattice &lat, vector<int32> *times) { // int32 LatticeStateTimes(const Lattice &lat, vector<int32> *times) {
if (!lat.Properties(fst::kTopSorted, true)) // if (!lat.Properties(fst::kTopSorted, true))
KALDI_ERR << "Input lattice must be topologically sorted."; // KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0); // KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates(); // int32 num_states = lat.NumStates();
times->clear(); // times->clear();
times->resize(num_states, -1); // times->resize(num_states, -1);
(*times)[0] = 0; // (*times)[0] = 0;
for (int32 state = 0; state < num_states; state++) { // for (int32 state = 0; state < num_states; state++) {
int32 cur_time = (*times)[state]; // int32 cur_time = (*times)[state];
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done(); // for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
const LatticeArc &arc = aiter.Value(); // const LatticeArc &arc = aiter.Value();
//
if (arc.ilabel != 0) { // Non-epsilon input label on arc // if (arc.ilabel != 0) { // Non-epsilon input label on arc
// next time instance // // next time instance
if ((*times)[arc.nextstate] == -1) { // if ((*times)[arc.nextstate] == -1) {
(*times)[arc.nextstate] = cur_time + 1; // (*times)[arc.nextstate] = cur_time + 1;
} else { // } else {
KALDI_ASSERT((*times)[arc.nextstate] == cur_time + 1); // KALDI_ASSERT((*times)[arc.nextstate] == cur_time + 1);
} // }
} else { // epsilon input label on arc // } else { // epsilon input label on arc
// Same time instance // // Same time instance
if ((*times)[arc.nextstate] == -1) // if ((*times)[arc.nextstate] == -1)
(*times)[arc.nextstate] = cur_time; // (*times)[arc.nextstate] = cur_time;
else // else
KALDI_ASSERT((*times)[arc.nextstate] == cur_time); // KALDI_ASSERT((*times)[arc.nextstate] == cur_time);
} // }
} // }
} // }
return (*std::max_element(times->begin(), times->end())); // return (*std::max_element(times->begin(), times->end()));
} // }
//
int32 CompactLatticeStateTimes(const CompactLattice &lat, // int32 CompactLatticeStateTimes(const CompactLattice &lat,
vector<int32> *times) { // vector<int32> *times) {
if (!lat.Properties(fst::kTopSorted, true)) // if (!lat.Properties(fst::kTopSorted, true))
KALDI_ERR << "Input lattice must be topologically sorted."; // KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0); // KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates(); // int32 num_states = lat.NumStates();
times->clear(); // times->clear();
times->resize(num_states, -1); // times->resize(num_states, -1);
(*times)[0] = 0; // (*times)[0] = 0;
int32 utt_len = -1; // int32 utt_len = -1;
for (int32 state = 0; state < num_states; state++) { // for (int32 state = 0; state < num_states; state++) {
int32 cur_time = (*times)[state]; // int32 cur_time = (*times)[state];
for (fst::ArcIterator<CompactLattice> aiter(lat, state); !aiter.Done(); // for (fst::ArcIterator<CompactLattice> aiter(lat, state); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value(); // const CompactLatticeArc &arc = aiter.Value();
int32 arc_len = static_cast<int32>(arc.weight.String().size()); // int32 arc_len = static_cast<int32>(arc.weight.String().size());
if ((*times)[arc.nextstate] == -1) // if ((*times)[arc.nextstate] == -1)
(*times)[arc.nextstate] = cur_time + arc_len; // (*times)[arc.nextstate] = cur_time + arc_len;
else // else
KALDI_ASSERT((*times)[arc.nextstate] == cur_time + arc_len); // KALDI_ASSERT((*times)[arc.nextstate] == cur_time + arc_len);
} // }
if (lat.Final(state) != CompactLatticeWeight::Zero()) { // if (lat.Final(state) != CompactLatticeWeight::Zero()) {
int32 this_utt_len = (*times)[state] + lat.Final(state).String().size(); // int32 this_utt_len = (*times)[state] + lat.Final(state).String().size();
if (utt_len == -1) utt_len = this_utt_len; // if (utt_len == -1) utt_len = this_utt_len;
else { // else {
if (this_utt_len != utt_len) { // if (this_utt_len != utt_len) {
KALDI_WARN << "Utterance does not " // KALDI_WARN << "Utterance does not "
"seem to have a consistent length."; // "seem to have a consistent length.";
utt_len = std::max(utt_len, this_utt_len); // utt_len = std::max(utt_len, this_utt_len);
} // }
} // }
} // }
} // }
if (utt_len == -1) { // if (utt_len == -1) {
KALDI_WARN << "Utterance does not have a final-state."; // KALDI_WARN << "Utterance does not have a final-state.";
return 0; // return 0;
} // }
return utt_len; // return utt_len;
} // }
//
bool ComputeCompactLatticeAlphas(const CompactLattice &clat, // bool ComputeCompactLatticeAlphas(const CompactLattice &clat,
vector<double> *alpha) { // vector<double> *alpha) {
using namespace fst; // using namespace fst;
//
// typedef the arc, weight types // // typedef the arc, weight types
typedef CompactLattice::Arc Arc; // typedef CompactLattice::Arc Arc;
typedef Arc::Weight Weight; // typedef Arc::Weight Weight;
typedef Arc::StateId StateId; // typedef Arc::StateId StateId;
//
//Make sure the lattice is topologically sorted. // //Make sure the lattice is topologically sorted.
if (clat.Properties(fst::kTopSorted, true) == 0) { // if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_WARN << "Input lattice must be topologically sorted."; // KALDI_WARN << "Input lattice must be topologically sorted.";
return false; // return false;
} // }
if (clat.Start() != 0) { // if (clat.Start() != 0) {
KALDI_WARN << "Input lattice must start from state 0."; // KALDI_WARN << "Input lattice must start from state 0.";
return false; // return false;
} // }
//
int32 num_states = clat.NumStates(); // int32 num_states = clat.NumStates();
(*alpha).resize(0); // (*alpha).resize(0);
(*alpha).resize(num_states, kLogZeroDouble); // (*alpha).resize(num_states, kLogZeroDouble);
//
// Now propagate alphas forward. Note that we don't acount the weight of the // // Now propagate alphas forward. Note that we don't acount the weight of the
// final state to alpha[final_state] -- we acount it to beta[final_state]; // // final state to alpha[final_state] -- we acount it to beta[final_state];
(*alpha)[0] = 0.0; // (*alpha)[0] = 0.0;
for (StateId s = 0; s < num_states; s++) { // for (StateId s = 0; s < num_states; s++) {
double this_alpha = (*alpha)[s]; // double this_alpha = (*alpha)[s];
for (ArcIterator<CompactLattice> aiter(clat, s); // for (ArcIterator<CompactLattice> aiter(clat, s);
!aiter.Done(); aiter.Next()) { // !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -(arc.weight.Weight().Value1() + // double arc_like = -(arc.weight.Weight().Value1() +
arc.weight.Weight().Value2()); // arc.weight.Weight().Value2());
(*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate], // (*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate],
this_alpha + arc_like); // this_alpha + arc_like);
} // }
} // }
//
return true; // return true;
} // }
//
bool ComputeCompactLatticeBetas(const CompactLattice &clat, // bool ComputeCompactLatticeBetas(const CompactLattice &clat,
vector<double> *beta) { // vector<double> *beta) {
using namespace fst; // using namespace fst;
//
// typedef the arc, weight types // // typedef the arc, weight types
typedef CompactLattice::Arc Arc; // typedef CompactLattice::Arc Arc;
typedef Arc::Weight Weight; // typedef Arc::Weight Weight;
typedef Arc::StateId StateId; // typedef Arc::StateId StateId;
//
// Make sure the lattice is topologically sorted. // // Make sure the lattice is topologically sorted.
if (clat.Properties(fst::kTopSorted, true) == 0) { // if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_WARN << "Input lattice must be topologically sorted."; // KALDI_WARN << "Input lattice must be topologically sorted.";
return false; // return false;
} // }
if (clat.Start() != 0) { // if (clat.Start() != 0) {
KALDI_WARN << "Input lattice must start from state 0."; // KALDI_WARN << "Input lattice must start from state 0.";
return false; // return false;
} // }
//
int32 num_states = clat.NumStates(); // int32 num_states = clat.NumStates();
(*beta).resize(0); // (*beta).resize(0);
(*beta).resize(num_states, kLogZeroDouble); // (*beta).resize(num_states, kLogZeroDouble);
//
// Now propagate betas backward. Note that beta[final_state] contains the // // Now propagate betas backward. Note that beta[final_state] contains the
// weight of the final state in the lattice -- compare that with alpha. // // weight of the final state in the lattice -- compare that with alpha.
for (StateId s = num_states-1; s >= 0; s--) { // for (StateId s = num_states-1; s >= 0; s--) {
Weight f = clat.Final(s); // Weight f = clat.Final(s);
double this_beta = -(f.Weight().Value1()+f.Weight().Value2()); // double this_beta = -(f.Weight().Value1()+f.Weight().Value2());
for (ArcIterator<CompactLattice> aiter(clat, s); // for (ArcIterator<CompactLattice> aiter(clat, s);
!aiter.Done(); aiter.Next()) { // !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -(arc.weight.Weight().Value1() + // double arc_like = -(arc.weight.Weight().Value1() +
arc.weight.Weight().Value2()); // arc.weight.Weight().Value2());
double arc_beta = (*beta)[arc.nextstate] + arc_like; // double arc_beta = (*beta)[arc.nextstate] + arc_like;
this_beta = LogAdd(this_beta, arc_beta); // this_beta = LogAdd(this_beta, arc_beta);
} // }
(*beta)[s] = this_beta; // (*beta)[s] = this_beta;
} // }
//
return true; // return true;
} // }
template<class LatType> // could be Lattice or CompactLattice template<class LatType> // could be Lattice or CompactLattice
bool PruneLattice(BaseFloat beam, LatType *lat) { bool PruneLattice(BaseFloat beam, LatType *lat) {
...@@ -315,1566 +318,1675 @@ template bool PruneLattice(BaseFloat beam, Lattice *lat); ...@@ -315,1566 +318,1675 @@ template bool PruneLattice(BaseFloat beam, Lattice *lat);
template bool PruneLattice(BaseFloat beam, CompactLattice *lat); template bool PruneLattice(BaseFloat beam, CompactLattice *lat);
BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post, // BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post,
double *acoustic_like_sum) { // double *acoustic_like_sum) {
// Note, Posterior is defined as follows: Indexed [frame], then a list // // Note, Posterior is defined as follows: Indexed [frame], then a list
// of (transition-id, posterior-probability) pairs. // // of (transition-id, posterior-probability) pairs.
// typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior; // // typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
using namespace fst; // using namespace fst;
typedef Lattice::Arc Arc; // typedef Lattice::Arc Arc;
typedef Arc::Weight Weight; // typedef Arc::Weight Weight;
typedef Arc::StateId StateId; // typedef Arc::StateId StateId;
//
if (acoustic_like_sum) *acoustic_like_sum = 0.0; // if (acoustic_like_sum) *acoustic_like_sum = 0.0;
//
// Make sure the lattice is topologically sorted. // // Make sure the lattice is topologically sorted.
if (lat.Properties(fst::kTopSorted, true) == 0) // if (lat.Properties(fst::kTopSorted, true) == 0)
KALDI_ERR << "Input lattice must be topologically sorted."; // KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0); // KALDI_ASSERT(lat.Start() == 0);
//
int32 num_states = lat.NumStates(); // int32 num_states = lat.NumStates();
vector<int32> state_times; // vector<int32> state_times;
int32 max_time = LatticeStateTimes(lat, &state_times); // int32 max_time = LatticeStateTimes(lat, &state_times);
std::vector<double> alpha(num_states, kLogZeroDouble); // std::vector<double> alpha(num_states, kLogZeroDouble);
std::vector<double> &beta(alpha); // we re-use the same memory for // std::vector<double> &beta(alpha); // we re-use the same memory for
// this, but it's semantically distinct so we name it differently. // // this, but it's semantically distinct so we name it differently.
double tot_forward_prob = kLogZeroDouble; // double tot_forward_prob = kLogZeroDouble;
//
post->clear(); // post->clear();
post->resize(max_time); // post->resize(max_time);
//
alpha[0] = 0.0; // alpha[0] = 0.0;
// Propagate alphas forward. // // Propagate alphas forward.
for (StateId s = 0; s < num_states; s++) { // for (StateId s = 0; s < num_states; s++) {
double this_alpha = alpha[s]; // double this_alpha = alpha[s];
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) { // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight); // double arc_like = -ConvertToCost(arc.weight);
alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like); // alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
} // }
Weight f = lat.Final(s); // Weight f = lat.Final(s);
if (f != Weight::Zero()) { // if (f != Weight::Zero()) {
double final_like = this_alpha - (f.Value1() + f.Value2()); // double final_like = this_alpha - (f.Value1() + f.Value2());
tot_forward_prob = LogAdd(tot_forward_prob, final_like); // tot_forward_prob = LogAdd(tot_forward_prob, final_like);
KALDI_ASSERT(state_times[s] == max_time && // KALDI_ASSERT(state_times[s] == max_time &&
"Lattice is inconsistent (final-prob not at max_time)"); // "Lattice is inconsistent (final-prob not at max_time)");
} // }
} // }
for (StateId s = num_states-1; s >= 0; s--) { // for (StateId s = num_states-1; s >= 0; s--) {
Weight f = lat.Final(s); // Weight f = lat.Final(s);
double this_beta = -(f.Value1() + f.Value2()); // double this_beta = -(f.Value1() + f.Value2());
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) { // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight), // double arc_like = -ConvertToCost(arc.weight),
arc_beta = beta[arc.nextstate] + arc_like; // arc_beta = beta[arc.nextstate] + arc_like;
this_beta = LogAdd(this_beta, arc_beta); // this_beta = LogAdd(this_beta, arc_beta);
int32 transition_id = arc.ilabel; // int32 transition_id = arc.ilabel;
//
// The following "if" is an optimization to avoid un-needed exp(). // // The following "if" is an optimization to avoid un-needed exp().
if (transition_id != 0 || acoustic_like_sum != NULL) { // if (transition_id != 0 || acoustic_like_sum != NULL) {
double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob); // double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
//
if (transition_id != 0) // Arc has a transition-id on it [not epsilon] // if (transition_id != 0) // Arc has a transition-id on it [not epsilon]
(*post)[state_times[s]].push_back(std::make_pair(transition_id, // (*post)[state_times[s]].push_back(std::make_pair(transition_id,
static_cast<kaldi::BaseFloat>(posterior))); // static_cast<kaldi::BaseFloat>(posterior)));
if (acoustic_like_sum != NULL) // if (acoustic_like_sum != NULL)
*acoustic_like_sum -= posterior * arc.weight.Value2(); // *acoustic_like_sum -= posterior * arc.weight.Value2();
} // }
} // }
if (acoustic_like_sum != NULL && f != Weight::Zero()) { // if (acoustic_like_sum != NULL && f != Weight::Zero()) {
double final_logprob = - ConvertToCost(f), // double final_logprob = - ConvertToCost(f),
posterior = Exp(alpha[s] + final_logprob - tot_forward_prob); // posterior = Exp(alpha[s] + final_logprob - tot_forward_prob);
*acoustic_like_sum -= posterior * f.Value2(); // *acoustic_like_sum -= posterior * f.Value2();
} // }
beta[s] = this_beta; // beta[s] = this_beta;
} // }
double tot_backward_prob = beta[0]; // double tot_backward_prob = beta[0];
if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) { // if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob // KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
<< ", while total backward probability = " << tot_backward_prob; // << ", while total backward probability = " << tot_backward_prob;
} // }
// Now combine any posteriors with the same transition-id. // // Now combine any posteriors with the same transition-id.
for (int32 t = 0; t < max_time; t++) // for (int32 t = 0; t < max_time; t++)
MergePairVectorSumming(&((*post)[t])); // MergePairVectorSumming(&((*post)[t]));
return tot_backward_prob; // return tot_backward_prob;
} // }
//
//
void LatticeActivePhones(const Lattice &lat, const TransitionInformation &trans, // void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
const vector<int32> &silence_phones, // const vector<int32> &silence_phones,
vector< std::set<int32> > *active_phones) { // vector< std::set<int32> > *active_phones) {
KALDI_ASSERT(IsSortedAndUniq(silence_phones)); // KALDI_ASSERT(IsSortedAndUniq(silence_phones));
vector<int32> state_times; // vector<int32> state_times;
int32 num_states = lat.NumStates(); // int32 num_states = lat.NumStates();
int32 max_time = LatticeStateTimes(lat, &state_times); // int32 max_time = LatticeStateTimes(lat, &state_times);
active_phones->clear(); // active_phones->clear();
active_phones->resize(max_time); // active_phones->resize(max_time);
for (int32 state = 0; state < num_states; state++) { // for (int32 state = 0; state < num_states; state++) {
int32 cur_time = state_times[state]; // int32 cur_time = state_times[state];
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done(); // for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
const LatticeArc &arc = aiter.Value(); // const LatticeArc &arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon arc // if (arc.ilabel != 0) { // Non-epsilon arc
int32 phone = trans.TransitionIdToPhone(arc.ilabel); // int32 phone = trans.TransitionIdToPhone(arc.ilabel);
if (!std::binary_search(silence_phones.begin(), // if (!std::binary_search(silence_phones.begin(),
silence_phones.end(), phone)) // silence_phones.end(), phone))
(*active_phones)[cur_time].insert(phone); // (*active_phones)[cur_time].insert(phone);
} // }
} // end looping over arcs // } // end looping over arcs
} // end looping over states // } // end looping over states
} // }
//
void ConvertLatticeToPhones(const TransitionInformation &trans, // void ConvertLatticeToPhones(const TransitionModel &trans,
Lattice *lat) { // Lattice *lat) {
typedef LatticeArc Arc; // typedef LatticeArc Arc;
int32 num_states = lat->NumStates(); // int32 num_states = lat->NumStates();
for (int32 state = 0; state < num_states; state++) { // for (int32 state = 0; state < num_states; state++) {
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done(); // for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
Arc arc(aiter.Value()); // Arc arc(aiter.Value());
arc.olabel = 0; // remove any word. // arc.olabel = 0; // remove any word.
if ((arc.ilabel != 0) // has a transition-id on input.. // if ((arc.ilabel != 0) // has a transition-id on input..
&& (trans.TransitionIdIsStartOfPhone(arc.ilabel)) // && (trans.TransitionIdToHmmState(arc.ilabel) == 0)
&& (!trans.IsSelfLoop(arc.ilabel))) { // && (!trans.IsSelfLoop(arc.ilabel))) {
// && trans.IsFinal(arc.ilabel)) // there is one of these per phone... // // && trans.IsFinal(arc.ilabel)) // there is one of these per phone...
arc.olabel = trans.TransitionIdToPhone(arc.ilabel); // arc.olabel = trans.TransitionIdToPhone(arc.ilabel);
} // }
aiter.SetValue(arc); // aiter.SetValue(arc);
} // end looping over arcs // } // end looping over arcs
} // end looping over states // } // end looping over states
} // }
//
//
static inline double LogAddOrMax(bool viterbi, double a, double b) { // static inline double LogAddOrMax(bool viterbi, double a, double b) {
if (viterbi) // if (viterbi)
return std::max(a, b); // return std::max(a, b);
else // else
return LogAdd(a, b); // return LogAdd(a, b);
} // }
//
template<typename LatticeType> // template<typename LatticeType>
double ComputeLatticeAlphasAndBetas(const LatticeType &lat, // double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
bool viterbi, // bool viterbi,
vector<double> *alpha, // vector<double> *alpha,
vector<double> *beta) { // vector<double> *beta) {
typedef typename LatticeType::Arc Arc; // typedef typename LatticeType::Arc Arc;
typedef typename Arc::Weight Weight; // typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId; // typedef typename Arc::StateId StateId;
//
StateId num_states = lat.NumStates(); // StateId num_states = lat.NumStates();
KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted); // KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted);
KALDI_ASSERT(lat.Start() == 0); // KALDI_ASSERT(lat.Start() == 0);
alpha->clear(); // alpha->clear();
beta->clear(); // beta->clear();
alpha->resize(num_states, kLogZeroDouble); // alpha->resize(num_states, kLogZeroDouble);
beta->resize(num_states, kLogZeroDouble); // beta->resize(num_states, kLogZeroDouble);
//
double tot_forward_prob = kLogZeroDouble; // double tot_forward_prob = kLogZeroDouble;
(*alpha)[0] = 0.0; // (*alpha)[0] = 0.0;
// Propagate alphas forward. // // Propagate alphas forward.
for (StateId s = 0; s < num_states; s++) { // for (StateId s = 0; s < num_states; s++) {
double this_alpha = (*alpha)[s]; // double this_alpha = (*alpha)[s];
for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done(); // for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight); // double arc_like = -ConvertToCost(arc.weight);
(*alpha)[arc.nextstate] = LogAddOrMax(viterbi, (*alpha)[arc.nextstate], // (*alpha)[arc.nextstate] = LogAddOrMax(viterbi, (*alpha)[arc.nextstate],
this_alpha + arc_like); // this_alpha + arc_like);
} // }
Weight f = lat.Final(s); // Weight f = lat.Final(s);
if (f != Weight::Zero()) { // if (f != Weight::Zero()) {
double final_like = this_alpha - ConvertToCost(f); // double final_like = this_alpha - ConvertToCost(f);
tot_forward_prob = LogAddOrMax(viterbi, tot_forward_prob, final_like); // tot_forward_prob = LogAddOrMax(viterbi, tot_forward_prob, final_like);
} // }
} // }
for (StateId s = num_states-1; s >= 0; s--) { // it's guaranteed signed. // for (StateId s = num_states-1; s >= 0; s--) { // it's guaranteed signed.
double this_beta = -ConvertToCost(lat.Final(s)); // double this_beta = -ConvertToCost(lat.Final(s));
for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done(); // for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight), // double arc_like = -ConvertToCost(arc.weight),
arc_beta = (*beta)[arc.nextstate] + arc_like; // arc_beta = (*beta)[arc.nextstate] + arc_like;
this_beta = LogAddOrMax(viterbi, this_beta, arc_beta); // this_beta = LogAddOrMax(viterbi, this_beta, arc_beta);
} // }
(*beta)[s] = this_beta; // (*beta)[s] = this_beta;
} // }
double tot_backward_prob = (*beta)[lat.Start()]; // double tot_backward_prob = (*beta)[lat.Start()];
if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) { // if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob // KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
<< ", while total backward probability = " << tot_backward_prob; // << ", while total backward probability = " << tot_backward_prob;
} // }
// Split the difference when returning... they should be the same. // // Split the difference when returning... they should be the same.
return 0.5 * (tot_backward_prob + tot_forward_prob); // return 0.5 * (tot_backward_prob + tot_forward_prob);
} // }
//
// instantiate the template for Lattice and CompactLattice // // instantiate the template for Lattice and CompactLattice
template // template
double ComputeLatticeAlphasAndBetas(const Lattice &lat, // double ComputeLatticeAlphasAndBetas(const Lattice &lat,
bool viterbi, // bool viterbi,
vector<double> *alpha, // vector<double> *alpha,
vector<double> *beta); // vector<double> *beta);
//
template // template
double ComputeLatticeAlphasAndBetas(const CompactLattice &lat, // double ComputeLatticeAlphasAndBetas(const CompactLattice &lat,
bool viterbi, // bool viterbi,
vector<double> *alpha, // vector<double> *alpha,
vector<double> *beta); // vector<double> *beta);
//
//
//
/// This is used in CompactLatticeLimitDepth. // /// This is used in CompactLatticeLimitDepth.
struct LatticeArcRecord { // struct LatticeArcRecord {
BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc, // BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc,
// minus the overall best-cost of the lattice. // // minus the overall best-cost of the lattice.
CompactLatticeArc::StateId state; // state in the lattice. // CompactLatticeArc::StateId state; // state in the lattice.
size_t arc; // arc index within the state. // size_t arc; // arc index within the state.
bool operator < (const LatticeArcRecord &other) const { // bool operator < (const LatticeArcRecord &other) const {
return logprob < other.logprob; // return logprob < other.logprob;
} // }
}; // };
//
void CompactLatticeLimitDepth(int32 max_depth_per_frame, // void CompactLatticeLimitDepth(int32 max_depth_per_frame,
CompactLattice *clat) { // CompactLattice *clat) {
typedef CompactLatticeArc Arc; // typedef CompactLatticeArc Arc;
typedef Arc::Weight Weight; // typedef Arc::Weight Weight;
typedef Arc::StateId StateId; // typedef Arc::StateId StateId;
//
if (clat->Start() == fst::kNoStateId) { // if (clat->Start() == fst::kNoStateId) {
KALDI_WARN << "Limiting depth of empty lattice."; // KALDI_WARN << "Limiting depth of empty lattice.";
return; // return;
} // }
if (clat->Properties(fst::kTopSorted, true) == 0) { // if (clat->Properties(fst::kTopSorted, true) == 0) {
if (!TopSort(clat)) // if (!TopSort(clat))
KALDI_ERR << "Topological sorting of lattice failed."; // KALDI_ERR << "Topological sorting of lattice failed.";
} // }
//
vector<int32> state_times; // vector<int32> state_times;
int32 T = CompactLatticeStateTimes(*clat, &state_times); // int32 T = CompactLatticeStateTimes(*clat, &state_times);
//
// The alpha and beta quantities here are "viterbi" alphas and beta. // // The alpha and beta quantities here are "viterbi" alphas and beta.
std::vector<double> alpha; // std::vector<double> alpha;
std::vector<double> beta; // std::vector<double> beta;
bool viterbi = true; // bool viterbi = true;
double best_prob = ComputeLatticeAlphasAndBetas(*clat, viterbi, // double best_prob = ComputeLatticeAlphasAndBetas(*clat, viterbi,
&alpha, &beta); // &alpha, &beta);
//
std::vector<std::vector<LatticeArcRecord> > arc_records(T); // std::vector<std::vector<LatticeArcRecord> > arc_records(T);
//
StateId num_states = clat->NumStates(); // StateId num_states = clat->NumStates();
for (StateId s = 0; s < num_states; s++) { // for (StateId s = 0; s < num_states; s++) {
for (fst::ArcIterator<CompactLattice> aiter(*clat, s); !aiter.Done(); // for (fst::ArcIterator<CompactLattice> aiter(*clat, s); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
LatticeArcRecord arc_record; // LatticeArcRecord arc_record;
arc_record.state = s; // arc_record.state = s;
arc_record.arc = aiter.Position(); // arc_record.arc = aiter.Position();
arc_record.logprob = // arc_record.logprob =
(alpha[s] + beta[arc.nextstate] - ConvertToCost(arc.weight)) // (alpha[s] + beta[arc.nextstate] - ConvertToCost(arc.weight))
- best_prob; // - best_prob;
KALDI_ASSERT(arc_record.logprob < 0.1); // Should be zero or negative. // KALDI_ASSERT(arc_record.logprob < 0.1); // Should be zero or negative.
int32 num_frames = arc.weight.String().size(), start_t = state_times[s]; // int32 num_frames = arc.weight.String().size(), start_t = state_times[s];
for (int32 t = start_t; t < start_t + num_frames; t++) { // for (int32 t = start_t; t < start_t + num_frames; t++) {
KALDI_ASSERT(t < T); // KALDI_ASSERT(t < T);
arc_records[t].push_back(arc_record); // arc_records[t].push_back(arc_record);
} // }
} // }
} // }
StateId dead_state = clat->AddState(); // A non-coaccesible state which we use // StateId dead_state = clat->AddState(); // A non-coaccesible state which we use
// to remove arcs (make them end // // to remove arcs (make them end
// there). // // there).
size_t max_depth = max_depth_per_frame; // size_t max_depth = max_depth_per_frame;
for (int32 t = 0; t < T; t++) { // for (int32 t = 0; t < T; t++) {
size_t size = arc_records[t].size(); // size_t size = arc_records[t].size();
if (size > max_depth) { // if (size > max_depth) {
// we sort from worst to best, so we keep the later-numbered ones, // // we sort from worst to best, so we keep the later-numbered ones,
// and delete the lower-numbered ones. // // and delete the lower-numbered ones.
size_t cutoff = size - max_depth; // size_t cutoff = size - max_depth;
std::nth_element(arc_records[t].begin(), // std::nth_element(arc_records[t].begin(),
arc_records[t].begin() + cutoff, // arc_records[t].begin() + cutoff,
arc_records[t].end()); // arc_records[t].end());
for (size_t index = 0; index < cutoff; index++) { // for (size_t index = 0; index < cutoff; index++) {
LatticeArcRecord record(arc_records[t][index]); // LatticeArcRecord record(arc_records[t][index]);
fst::MutableArcIterator<CompactLattice> aiter(clat, record.state); // fst::MutableArcIterator<CompactLattice> aiter(clat, record.state);
aiter.Seek(record.arc); // aiter.Seek(record.arc);
Arc arc = aiter.Value(); // Arc arc = aiter.Value();
if (arc.nextstate != dead_state) { // not already killed. // if (arc.nextstate != dead_state) { // not already killed.
arc.nextstate = dead_state; // arc.nextstate = dead_state;
aiter.SetValue(arc); // aiter.SetValue(arc);
} // }
} // }
} // }
} // }
Connect(clat); // Connect(clat);
TopSortCompactLatticeIfNeeded(clat); // TopSortCompactLatticeIfNeeded(clat);
} // }
//
//
void TopSortCompactLatticeIfNeeded(CompactLattice *clat) { // void TopSortCompactLatticeIfNeeded(CompactLattice *clat) {
if (clat->Properties(fst::kTopSorted, true) == 0) { // if (clat->Properties(fst::kTopSorted, true) == 0) {
if (fst::TopSort(clat) == false) { // if (fst::TopSort(clat) == false) {
KALDI_ERR << "Topological sorting failed"; // KALDI_ERR << "Topological sorting failed";
} // }
} // }
} // }
//
void TopSortLatticeIfNeeded(Lattice *lat) { // void TopSortLatticeIfNeeded(Lattice *lat) {
if (lat->Properties(fst::kTopSorted, true) == 0) { // if (lat->Properties(fst::kTopSorted, true) == 0) {
if (fst::TopSort(lat) == false) { // if (fst::TopSort(lat) == false) {
KALDI_ERR << "Topological sorting failed"; // KALDI_ERR << "Topological sorting failed";
} // }
} // }
} // }
//
//
/// Returns the depth of the lattice, defined as the average number of // /// Returns the depth of the lattice, defined as the average number of
/// arcs crossing any given frame. Returns 1 for empty lattices. // /// arcs crossing any given frame. Returns 1 for empty lattices.
/// Requires that input is topologically sorted. // /// Requires that input is topologically sorted.
BaseFloat CompactLatticeDepth(const CompactLattice &clat, // BaseFloat CompactLatticeDepth(const CompactLattice &clat,
int32 *num_frames) { // int32 *num_frames) {
typedef CompactLattice::Arc::StateId StateId; // typedef CompactLattice::Arc::StateId StateId;
if (clat.Properties(fst::kTopSorted, true) == 0) { // if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_ERR << "Lattice input to CompactLatticeDepth was not topologically " // KALDI_ERR << "Lattice input to CompactLatticeDepth was not topologically "
<< "sorted."; // << "sorted.";
} // }
if (clat.Start() == fst::kNoStateId) { // if (clat.Start() == fst::kNoStateId) {
*num_frames = 0; // *num_frames = 0;
return 1.0; // return 1.0;
} // }
size_t num_arc_frames = 0; // size_t num_arc_frames = 0;
int32 t; // int32 t;
{ // {
vector<int32> state_times; // vector<int32> state_times;
t = CompactLatticeStateTimes(clat, &state_times); // t = CompactLatticeStateTimes(clat, &state_times);
} // }
if (num_frames != NULL) // if (num_frames != NULL)
*num_frames = t; // *num_frames = t;
for (StateId s = 0; s < clat.NumStates(); s++) { // for (StateId s = 0; s < clat.NumStates(); s++) {
for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done(); // for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value(); // const CompactLatticeArc &arc = aiter.Value();
num_arc_frames += arc.weight.String().size(); // num_arc_frames += arc.weight.String().size();
} // }
num_arc_frames += clat.Final(s).String().size(); // num_arc_frames += clat.Final(s).String().size();
} // }
return num_arc_frames / static_cast<BaseFloat>(t); // return num_arc_frames / static_cast<BaseFloat>(t);
} // }
//
//
void CompactLatticeDepthPerFrame(const CompactLattice &clat, // void CompactLatticeDepthPerFrame(const CompactLattice &clat,
std::vector<int32> *depth_per_frame) { // std::vector<int32> *depth_per_frame) {
typedef CompactLattice::Arc::StateId StateId; // typedef CompactLattice::Arc::StateId StateId;
if (clat.Properties(fst::kTopSorted, true) == 0) { // if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_ERR << "Lattice input to CompactLatticeDepthPerFrame was not " // KALDI_ERR << "Lattice input to CompactLatticeDepthPerFrame was not "
<< "topologically sorted."; // << "topologically sorted.";
} // }
if (clat.Start() == fst::kNoStateId) { // if (clat.Start() == fst::kNoStateId) {
depth_per_frame->clear(); // depth_per_frame->clear();
return; // return;
} // }
vector<int32> state_times; // vector<int32> state_times;
int32 T = CompactLatticeStateTimes(clat, &state_times); // int32 T = CompactLatticeStateTimes(clat, &state_times);
//
depth_per_frame->clear(); // depth_per_frame->clear();
if (T <= 0) { // if (T <= 0) {
return; // return;
} else { // } else {
depth_per_frame->resize(T, 0); // depth_per_frame->resize(T, 0);
for (StateId s = 0; s < clat.NumStates(); s++) { // for (StateId s = 0; s < clat.NumStates(); s++) {
int32 start_time = state_times[s]; // int32 start_time = state_times[s];
for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done(); // for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
const CompactLatticeArc &arc = aiter.Value(); // const CompactLatticeArc &arc = aiter.Value();
int32 len = arc.weight.String().size(); // int32 len = arc.weight.String().size();
for (int32 t = start_time; t < start_time + len; t++) { // for (int32 t = start_time; t < start_time + len; t++) {
KALDI_ASSERT(t < T); // KALDI_ASSERT(t < T);
(*depth_per_frame)[t]++; // (*depth_per_frame)[t]++;
} // }
} // }
int32 final_len = clat.Final(s).String().size(); // int32 final_len = clat.Final(s).String().size();
for (int32 t = start_time; t < start_time + final_len; t++) { // for (int32 t = start_time; t < start_time + final_len; t++) {
KALDI_ASSERT(t < T); // KALDI_ASSERT(t < T);
(*depth_per_frame)[t]++; // (*depth_per_frame)[t]++;
} // }
} // }
} // }
} // }
//
//
//
void ConvertCompactLatticeToPhones(const TransitionInformation &trans, // void ConvertCompactLatticeToPhones(const TransitionModel &trans,
CompactLattice *clat) { // CompactLattice *clat) {
typedef CompactLatticeArc Arc; // typedef CompactLatticeArc Arc;
typedef Arc::Weight Weight; // typedef Arc::Weight Weight;
int32 num_states = clat->NumStates(); // int32 num_states = clat->NumStates();
for (int32 state = 0; state < num_states; state++) { // for (int32 state = 0; state < num_states; state++) {
for (fst::MutableArcIterator<CompactLattice> aiter(clat, state); // for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
!aiter.Done(); // !aiter.Done();
aiter.Next()) { // aiter.Next()) {
Arc arc(aiter.Value()); // Arc arc(aiter.Value());
std::vector<int32> phone_seq; // std::vector<int32> phone_seq;
const std::vector<int32> &tid_seq = arc.weight.String(); // const std::vector<int32> &tid_seq = arc.weight.String();
for (std::vector<int32>::const_iterator iter = tid_seq.begin(); // for (std::vector<int32>::const_iterator iter = tid_seq.begin();
iter != tid_seq.end(); ++iter) { // iter != tid_seq.end(); ++iter) {
if (trans.IsFinal(*iter))// note: there is one of these per phone... // if (trans.IsFinal(*iter))// note: there is one of these per phone...
phone_seq.push_back(trans.TransitionIdToPhone(*iter)); // phone_seq.push_back(trans.TransitionIdToPhone(*iter));
} // }
arc.weight.SetString(phone_seq); // arc.weight.SetString(phone_seq);
aiter.SetValue(arc); // aiter.SetValue(arc);
} // end looping over arcs // } // end looping over arcs
Weight f = clat->Final(state); // Weight f = clat->Final(state);
if (f != Weight::Zero()) { // if (f != Weight::Zero()) {
std::vector<int32> phone_seq; // std::vector<int32> phone_seq;
const std::vector<int32> &tid_seq = f.String(); // const std::vector<int32> &tid_seq = f.String();
for (std::vector<int32>::const_iterator iter = tid_seq.begin(); // for (std::vector<int32>::const_iterator iter = tid_seq.begin();
iter != tid_seq.end(); ++iter) { // iter != tid_seq.end(); ++iter) {
if (trans.IsFinal(*iter))// note: there is one of these per phone... // if (trans.IsFinal(*iter))// note: there is one of these per phone...
phone_seq.push_back(trans.TransitionIdToPhone(*iter)); // phone_seq.push_back(trans.TransitionIdToPhone(*iter));
} // }
f.SetString(phone_seq); // f.SetString(phone_seq);
clat->SetFinal(state, f); // clat->SetFinal(state, f);
} // }
} // end looping over states // } // end looping over states
} // }
//
bool LatticeBoost(const TransitionInformation &trans, // bool LatticeBoost(const TransitionModel &trans,
const std::vector<int32> &alignment, // const std::vector<int32> &alignment,
const std::vector<int32> &silence_phones, // const std::vector<int32> &silence_phones,
BaseFloat b, // BaseFloat b,
BaseFloat max_silence_error, // BaseFloat max_silence_error,
Lattice *lat) { // Lattice *lat) {
TopSortLatticeIfNeeded(lat); // TopSortLatticeIfNeeded(lat);
//
// get all stored properties (test==false means don't test if not known). // // get all stored properties (test==false means don't test if not known).
uint64 props = lat->Properties(fst::kFstProperties, // uint64 props = lat->Properties(fst::kFstProperties,
false); // false);
//
KALDI_ASSERT(IsSortedAndUniq(silence_phones)); // KALDI_ASSERT(IsSortedAndUniq(silence_phones));
KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0); // KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
vector<int32> state_times; // vector<int32> state_times;
int32 num_states = lat->NumStates(); // int32 num_states = lat->NumStates();
int32 num_frames = LatticeStateTimes(*lat, &state_times); // int32 num_frames = LatticeStateTimes(*lat, &state_times);
KALDI_ASSERT(num_frames == static_cast<int32>(alignment.size())); // KALDI_ASSERT(num_frames == static_cast<int32>(alignment.size()));
for (int32 state = 0; state < num_states; state++) { // for (int32 state = 0; state < num_states; state++) {
int32 cur_time = state_times[state]; // int32 cur_time = state_times[state];
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done(); // for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) { // aiter.Next()) {
LatticeArc arc = aiter.Value(); // LatticeArc arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon arc // if (arc.ilabel != 0) { // Non-epsilon arc
if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) { // if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) {
KALDI_WARN << "Lattice has out-of-range transition-ids: " // KALDI_WARN << "Lattice has out-of-range transition-ids: "
<< "lattice/model mismatch?"; // << "lattice/model mismatch?";
return false; // return false;
} // }
int32 phone = trans.TransitionIdToPhone(arc.ilabel), // int32 phone = trans.TransitionIdToPhone(arc.ilabel),
ref_phone = trans.TransitionIdToPhone(alignment[cur_time]); // ref_phone = trans.TransitionIdToPhone(alignment[cur_time]);
BaseFloat frame_error; // BaseFloat frame_error;
if (phone == ref_phone) { // if (phone == ref_phone) {
frame_error = 0.0; // frame_error = 0.0;
} else { // an error... // } else { // an error...
if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone)) // if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone))
frame_error = max_silence_error; // frame_error = max_silence_error;
else // else
frame_error = 1.0; // frame_error = 1.0;
} // }
BaseFloat delta_cost = -b * frame_error; // negative cost if // BaseFloat delta_cost = -b * frame_error; // negative cost if
// frame is wrong, to boost likelihood of arcs with errors on them. // // frame is wrong, to boost likelihood of arcs with errors on them.
// Add this cost to the graph part. // // Add this cost to the graph part.
arc.weight.SetValue1(arc.weight.Value1() + delta_cost); // arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
aiter.SetValue(arc); // aiter.SetValue(arc);
} // }
} // }
} // }
// All we changed is the weights, so any properties that were // // All we changed is the weights, so any properties that were
// known before, are still known, except for whether or not the // // known before, are still known, except for whether or not the
// lattice was weighted. // // lattice was weighted.
lat->SetProperties(props, // lat->SetProperties(props,
~(fst::kWeighted|fst::kUnweighted)); // ~(fst::kWeighted|fst::kUnweighted));
//
return true; // return true;
} // }
//
//
//
BaseFloat LatticeForwardBackwardMpeVariants( // BaseFloat LatticeForwardBackwardMpeVariants(
const TransitionInformation &trans, // const TransitionModel &trans,
const std::vector<int32> &silence_phones, // const std::vector<int32> &silence_phones,
const Lattice &lat, // const Lattice &lat,
const std::vector<int32> &num_ali, // const std::vector<int32> &num_ali,
std::string criterion, // std::string criterion,
bool one_silence_class, // bool one_silence_class,
Posterior *post) { // Posterior *post) {
using namespace fst; // using namespace fst;
typedef Lattice::Arc Arc; // typedef Lattice::Arc Arc;
typedef Arc::Weight Weight; // typedef Arc::Weight Weight;
typedef Arc::StateId StateId; // typedef Arc::StateId StateId;
//
KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr"); // KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr");
bool is_mpfe = (criterion == "mpfe"); // bool is_mpfe = (criterion == "mpfe");
//
if (lat.Properties(fst::kTopSorted, true) == 0) // if (lat.Properties(fst::kTopSorted, true) == 0)
KALDI_ERR << "Input lattice must be topologically sorted."; // KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0); // KALDI_ASSERT(lat.Start() == 0);
//
int32 num_states = lat.NumStates(); // int32 num_states = lat.NumStates();
vector<int32> state_times; // vector<int32> state_times;
int32 max_time = LatticeStateTimes(lat, &state_times); // int32 max_time = LatticeStateTimes(lat, &state_times);
KALDI_ASSERT(max_time == static_cast<int32>(num_ali.size())); // KALDI_ASSERT(max_time == static_cast<int32>(num_ali.size()));
std::vector<double> alpha(num_states, kLogZeroDouble), // std::vector<double> alpha(num_states, kLogZeroDouble),
alpha_smbr(num_states, 0), //forward variable for sMBR // alpha_smbr(num_states, 0), //forward variable for sMBR
beta(num_states, kLogZeroDouble), // beta(num_states, kLogZeroDouble),
beta_smbr(num_states, 0); //backward variable for sMBR // beta_smbr(num_states, 0); //backward variable for sMBR
//
double tot_forward_prob = kLogZeroDouble; // double tot_forward_prob = kLogZeroDouble;
double tot_forward_score = 0; // double tot_forward_score = 0;
//
post->clear(); // post->clear();
post->resize(max_time); // post->resize(max_time);
//
alpha[0] = 0.0; // alpha[0] = 0.0;
// First Pass Forward, // // First Pass Forward,
for (StateId s = 0; s < num_states; s++) { // for (StateId s = 0; s < num_states; s++) {
double this_alpha = alpha[s]; // double this_alpha = alpha[s];
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) { // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight); // double arc_like = -ConvertToCost(arc.weight);
alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like); // alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
} // }
Weight f = lat.Final(s); // Weight f = lat.Final(s);
if (f != Weight::Zero()) { // if (f != Weight::Zero()) {
double final_like = this_alpha - (f.Value1() + f.Value2()); // double final_like = this_alpha - (f.Value1() + f.Value2());
tot_forward_prob = LogAdd(tot_forward_prob, final_like); // tot_forward_prob = LogAdd(tot_forward_prob, final_like);
KALDI_ASSERT(state_times[s] == max_time && // KALDI_ASSERT(state_times[s] == max_time &&
"Lattice is inconsistent (final-prob not at max_time)"); // "Lattice is inconsistent (final-prob not at max_time)");
} // }
} // }
// First Pass Backward, // // First Pass Backward,
for (StateId s = num_states-1; s >= 0; s--) { // for (StateId s = num_states-1; s >= 0; s--) {
Weight f = lat.Final(s); // Weight f = lat.Final(s);
double this_beta = -(f.Value1() + f.Value2()); // double this_beta = -(f.Value1() + f.Value2());
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) { // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight), // double arc_like = -ConvertToCost(arc.weight),
arc_beta = beta[arc.nextstate] + arc_like; // arc_beta = beta[arc.nextstate] + arc_like;
this_beta = LogAdd(this_beta, arc_beta); // this_beta = LogAdd(this_beta, arc_beta);
} // }
beta[s] = this_beta; // beta[s] = this_beta;
} // }
// First Pass Forward-Backward Check // // First Pass Forward-Backward Check
double tot_backward_prob = beta[0]; // double tot_backward_prob = beta[0];
// may loose the condition somehow here 1e-6 (was 1e-8) // // may loose the condition somehow here 1e-6 (was 1e-8)
if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-6)) { // if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-6)) {
KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob // KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob
<< ", while total backward probability = " << tot_backward_prob; // << ", while total backward probability = " << tot_backward_prob;
} // }
//
alpha_smbr[0] = 0.0; // alpha_smbr[0] = 0.0;
// Second Pass Forward, calculate forward for MPFE/SMBR // // Second Pass Forward, calculate forward for MPFE/SMBR
for (StateId s = 0; s < num_states; s++) { // for (StateId s = 0; s < num_states; s++) {
double this_alpha = alpha[s]; // double this_alpha = alpha[s];
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) { // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight); // double arc_like = -ConvertToCost(arc.weight);
double frame_acc = 0.0; // double frame_acc = 0.0;
if (arc.ilabel != 0) { // if (arc.ilabel != 0) {
int32 cur_time = state_times[s]; // int32 cur_time = state_times[s];
int32 phone = trans.TransitionIdToPhone(arc.ilabel), // int32 phone = trans.TransitionIdToPhone(arc.ilabel),
ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]); // ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
bool phone_is_sil = std::binary_search(silence_phones.begin(), // bool phone_is_sil = std::binary_search(silence_phones.begin(),
silence_phones.end(), // silence_phones.end(),
phone), // phone),
ref_phone_is_sil = std::binary_search(silence_phones.begin(), // ref_phone_is_sil = std::binary_search(silence_phones.begin(),
silence_phones.end(), // silence_phones.end(),
ref_phone), // ref_phone),
both_sil = phone_is_sil && ref_phone_is_sil; // both_sil = phone_is_sil && ref_phone_is_sil;
if (!is_mpfe) { // smbr. // if (!is_mpfe) { // smbr.
int32 pdf = trans.TransitionIdToPdf(arc.ilabel), // int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]); // ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
if (!one_silence_class) // old behavior // if (!one_silence_class) // old behavior
frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0; // frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
else // else
frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0; // frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
} else { // } else {
if (!one_silence_class) // old behavior // if (!one_silence_class) // old behavior
frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0; // frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
else // else
frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0; // frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
} // }
} // }
double arc_scale = Exp(alpha[s] + arc_like - alpha[arc.nextstate]); // double arc_scale = Exp(alpha[s] + arc_like - alpha[arc.nextstate]);
alpha_smbr[arc.nextstate] += arc_scale * (alpha_smbr[s] + frame_acc); // alpha_smbr[arc.nextstate] += arc_scale * (alpha_smbr[s] + frame_acc);
} // }
Weight f = lat.Final(s); // Weight f = lat.Final(s);
if (f != Weight::Zero()) { // if (f != Weight::Zero()) {
double final_like = this_alpha - (f.Value1() + f.Value2()); // double final_like = this_alpha - (f.Value1() + f.Value2());
double arc_scale = Exp(final_like - tot_forward_prob); // double arc_scale = Exp(final_like - tot_forward_prob);
tot_forward_score += arc_scale * alpha_smbr[s]; // tot_forward_score += arc_scale * alpha_smbr[s];
KALDI_ASSERT(state_times[s] == max_time && // KALDI_ASSERT(state_times[s] == max_time &&
"Lattice is inconsistent (final-prob not at max_time)"); // "Lattice is inconsistent (final-prob not at max_time)");
} // }
} // }
// Second Pass Backward, collect Mpe style posteriors // // Second Pass Backward, collect Mpe style posteriors
for (StateId s = num_states-1; s >= 0; s--) { // for (StateId s = num_states-1; s >= 0; s--) {
for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) { // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight), // double arc_like = -ConvertToCost(arc.weight),
arc_beta = beta[arc.nextstate] + arc_like; // arc_beta = beta[arc.nextstate] + arc_like;
double frame_acc = 0.0; // double frame_acc = 0.0;
int32 transition_id = arc.ilabel; // int32 transition_id = arc.ilabel;
if (arc.ilabel != 0) { // if (arc.ilabel != 0) {
int32 cur_time = state_times[s]; // int32 cur_time = state_times[s];
int32 phone = trans.TransitionIdToPhone(arc.ilabel), // int32 phone = trans.TransitionIdToPhone(arc.ilabel),
ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]); // ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
bool phone_is_sil = std::binary_search(silence_phones.begin(), // bool phone_is_sil = std::binary_search(silence_phones.begin(),
silence_phones.end(), phone), // silence_phones.end(), phone),
ref_phone_is_sil = std::binary_search(silence_phones.begin(), // ref_phone_is_sil = std::binary_search(silence_phones.begin(),
silence_phones.end(), // silence_phones.end(),
ref_phone), // ref_phone),
both_sil = phone_is_sil && ref_phone_is_sil; // both_sil = phone_is_sil && ref_phone_is_sil;
if (!is_mpfe) { // smbr. // if (!is_mpfe) { // smbr.
int32 pdf = trans.TransitionIdToPdf(arc.ilabel), // int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]); // ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
if (!one_silence_class) // old behavior // if (!one_silence_class) // old behavior
frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0; // frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
else // else
frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0; // frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
} else { // } else {
if (!one_silence_class) // old behavior // if (!one_silence_class) // old behavior
frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0; // frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
else // else
frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0; // frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
} // }
} // }
double arc_scale = Exp(beta[arc.nextstate] + arc_like - beta[s]); // double arc_scale = Exp(beta[arc.nextstate] + arc_like - beta[s]);
// check arc_scale NAN, // // check arc_scale NAN,
// this is to prevent partial paths in Lattices // // this is to prevent partial paths in Lattices
// i.e., paths don't survive to the final state // // i.e., paths don't survive to the final state
if (KALDI_ISNAN(arc_scale)) arc_scale = 0; // if (KALDI_ISNAN(arc_scale)) arc_scale = 0;
beta_smbr[s] += arc_scale * (beta_smbr[arc.nextstate] + frame_acc); // beta_smbr[s] += arc_scale * (beta_smbr[arc.nextstate] + frame_acc);
//
if (transition_id != 0) { // Arc has a transition-id on it [not epsilon] // if (transition_id != 0) { // Arc has a transition-id on it [not epsilon]
double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob); // double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
double acc_diff = alpha_smbr[s] + frame_acc + beta_smbr[arc.nextstate] // double acc_diff = alpha_smbr[s] + frame_acc + beta_smbr[arc.nextstate]
- tot_forward_score; // - tot_forward_score;
double posterior_smbr = posterior * acc_diff; // double posterior_smbr = posterior * acc_diff;
(*post)[state_times[s]].push_back(std::make_pair(transition_id, // (*post)[state_times[s]].push_back(std::make_pair(transition_id,
static_cast<BaseFloat>(posterior_smbr))); // static_cast<BaseFloat>(posterior_smbr)));
} // }
} // }
} // }
//
//Second Pass Forward Backward check // //Second Pass Forward Backward check
double tot_backward_score = beta_smbr[0]; // Initial state id == 0 // double tot_backward_score = beta_smbr[0]; // Initial state id == 0
// may loose the condition somehow here 1e-5/1e-4 // // may loose the condition somehow here 1e-5/1e-4
if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-4)) { // if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-4)) {
KALDI_ERR << "Total forward score over lattice = " << tot_forward_score // KALDI_ERR << "Total forward score over lattice = " << tot_forward_score
<< ", while total backward score = " << tot_backward_score; // << ", while total backward score = " << tot_backward_score;
} // }
//
// Output the computed posteriors // // Output the computed posteriors
for (int32 t = 0; t < max_time; t++) // for (int32 t = 0; t < max_time; t++)
MergePairVectorSumming(&((*post)[t])); // MergePairVectorSumming(&((*post)[t]));
return tot_forward_score; // return tot_forward_score;
} // }
//
bool CompactLatticeToWordAlignment(const CompactLattice &clat, // bool CompactLatticeToWordAlignment(const CompactLattice &clat,
std::vector<int32> *words, // std::vector<int32> *words,
std::vector<int32> *begin_times, // std::vector<int32> *begin_times,
std::vector<int32> *lengths) { // std::vector<int32> *lengths) {
words->clear(); // words->clear();
begin_times->clear(); // begin_times->clear();
lengths->clear(); // lengths->clear();
typedef CompactLattice::Arc Arc; // typedef CompactLattice::Arc Arc;
typedef Arc::Label Label; // typedef Arc::Label Label;
typedef CompactLattice::StateId StateId; // typedef CompactLattice::StateId StateId;
typedef CompactLattice::Weight Weight; // typedef CompactLattice::Weight Weight;
using namespace fst; // using namespace fst;
StateId state = clat.Start(); // StateId state = clat.Start();
int32 cur_time = 0; // int32 cur_time = 0;
if (state == kNoStateId) { // if (state == kNoStateId) {
KALDI_WARN << "Empty lattice."; // KALDI_WARN << "Empty lattice.";
return false; // return false;
} // }
while (1) { // while (1) {
Weight final = clat.Final(state); // Weight final = clat.Final(state);
size_t num_arcs = clat.NumArcs(state); // size_t num_arcs = clat.NumArcs(state);
if (final != Weight::Zero()) { // if (final != Weight::Zero()) {
if (num_arcs != 0) { // if (num_arcs != 0) {
KALDI_WARN << "Lattice is not linear."; // KALDI_WARN << "Lattice is not linear.";
return false; // return false;
} // }
if (! final.String().empty()) { // if (! final.String().empty()) {
KALDI_WARN << "Lattice has alignments on final-weight: probably " // KALDI_WARN << "Lattice has alignments on final-weight: probably "
"was not word-aligned (alignments will be approximate)"; // "was not word-aligned (alignments will be approximate)";
} // }
return true; // return true;
} else { // } else {
if (num_arcs != 1) { // if (num_arcs != 1) {
KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs; // KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
return false; // return false;
} // }
fst::ArcIterator<CompactLattice> aiter(clat, state); // fst::ArcIterator<CompactLattice> aiter(clat, state);
const Arc &arc = aiter.Value(); // const Arc &arc = aiter.Value();
Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor. // Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
// Also note: word_id may be zero; we output it anyway. // // Also note: word_id may be zero; we output it anyway.
int32 length = arc.weight.String().size(); // int32 length = arc.weight.String().size();
words->push_back(word_id); // words->push_back(word_id);
begin_times->push_back(cur_time); // begin_times->push_back(cur_time);
lengths->push_back(length); // lengths->push_back(length);
cur_time += length; // cur_time += length;
state = arc.nextstate; // state = arc.nextstate;
} // }
} // }
} // }
//
//
void CompactLatticeShortestPath(const CompactLattice &clat, // bool CompactLatticeToWordProns(
CompactLattice *shortest_path) { // const TransitionModel &tmodel,
using namespace fst; // const CompactLattice &clat,
if (clat.Properties(fst::kTopSorted, true) == 0) { // std::vector<int32> *words,
CompactLattice clat_copy(clat); // std::vector<int32> *begin_times,
if (!TopSort(&clat_copy)) // std::vector<int32> *lengths,
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // std::vector<std::vector<int32> > *prons,
CompactLatticeShortestPath(clat_copy, shortest_path); // std::vector<std::vector<int32> > *phone_lengths) {
return; // words->clear();
} // begin_times->clear();
// Now we can assume it's topologically sorted. // lengths->clear();
shortest_path->DeleteStates(); // prons->clear();
if (clat.Start() == kNoStateId) return; // phone_lengths->clear();
typedef CompactLatticeArc Arc; // typedef CompactLattice::Arc Arc;
typedef Arc::StateId StateId; // typedef Arc::Label Label;
typedef CompactLatticeWeight Weight; // typedef CompactLattice::StateId StateId;
vector<std::pair<double, StateId> > best_cost_and_pred(clat.NumStates() + 1); // typedef CompactLattice::Weight Weight;
StateId superfinal = clat.NumStates(); // using namespace fst;
for (StateId s = 0; s <= clat.NumStates(); s++) { // StateId state = clat.Start();
best_cost_and_pred[s].first = std::numeric_limits<double>::infinity(); // int32 cur_time = 0;
best_cost_and_pred[s].second = fst::kNoStateId; // if (state == kNoStateId) {
} // KALDI_WARN << "Empty lattice.";
best_cost_and_pred[clat.Start()].first = 0; // return false;
for (StateId s = 0; s < clat.NumStates(); s++) { // }
double my_cost = best_cost_and_pred[s].first; // while (1) {
for (ArcIterator<CompactLattice> aiter(clat, s); // Weight final = clat.Final(state);
!aiter.Done(); // size_t num_arcs = clat.NumArcs(state);
aiter.Next()) { // if (final != Weight::Zero()) {
const Arc &arc = aiter.Value(); // if (num_arcs != 0) {
double arc_cost = ConvertToCost(arc.weight), // KALDI_WARN << "Lattice is not linear.";
next_cost = my_cost + arc_cost; // return false;
if (next_cost < best_cost_and_pred[arc.nextstate].first) { // }
best_cost_and_pred[arc.nextstate].first = next_cost; // if (! final.String().empty()) {
best_cost_and_pred[arc.nextstate].second = s; // KALDI_WARN << "Lattice has alignments on final-weight: probably "
} // "was not word-aligned (alignments will be approximate)";
} // }
double final_cost = ConvertToCost(clat.Final(s)), // return true;
tot_final = my_cost + final_cost; // } else {
if (tot_final < best_cost_and_pred[superfinal].first) { // if (num_arcs != 1) {
best_cost_and_pred[superfinal].first = tot_final; // KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
best_cost_and_pred[superfinal].second = s; // return false;
} // }
} // fst::ArcIterator<CompactLattice> aiter(clat, state);
std::vector<StateId> states; // states on best path. // const Arc &arc = aiter.Value();
StateId cur_state = superfinal, start_state = clat.Start(); // Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
while (cur_state != start_state) { // // Also note: word_id may be zero; we output it anyway.
StateId prev_state = best_cost_and_pred[cur_state].second; // int32 length = arc.weight.String().size();
if (prev_state == kNoStateId) { // words->push_back(word_id);
KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)"; // begin_times->push_back(cur_time);
return; // return empty best-path. // lengths->push_back(length);
} // const std::vector<int32> &arc_alignment = arc.weight.String();
states.push_back(prev_state); // std::vector<std::vector<int32> > split_alignment;
KALDI_ASSERT(cur_state != prev_state && "Lattice with cycles"); // SplitToPhones(tmodel, arc_alignment, &split_alignment);
cur_state = prev_state; // std::vector<int32> phones(split_alignment.size());
} // std::vector<int32> plengths(split_alignment.size());
std::reverse(states.begin(), states.end()); // for (size_t i = 0; i < split_alignment.size(); i++) {
for (size_t i = 0; i < states.size(); i++) // KALDI_ASSERT(!split_alignment[i].empty());
shortest_path->AddState(); // phones[i] = tmodel.TransitionIdToPhone(split_alignment[i][0]);
for (StateId s = 0; static_cast<size_t>(s) < states.size(); s++) { // plengths[i] = split_alignment[i].size();
if (s == 0) shortest_path->SetStart(s); // }
if (static_cast<size_t>(s + 1) < states.size()) { // transition to next state. // prons->push_back(phones);
bool have_arc = false; // phone_lengths->push_back(plengths);
Arc cur_arc; //
for (ArcIterator<CompactLattice> aiter(clat, states[s]); // cur_time += length;
!aiter.Done(); // state = arc.nextstate;
aiter.Next()) { // }
const Arc &arc = aiter.Value(); // }
if (arc.nextstate == states[s+1]) { // }
if (!have_arc || //
ConvertToCost(arc.weight) < ConvertToCost(cur_arc.weight)) { //
cur_arc = arc; //
have_arc = true; // void CompactLatticeShortestPath(const CompactLattice &clat,
} // CompactLattice *shortest_path) {
} // using namespace fst;
} // if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_ASSERT(have_arc && "Code error."); // CompactLattice clat_copy(clat);
shortest_path->AddArc(s, Arc(cur_arc.ilabel, cur_arc.olabel, // if (!TopSort(&clat_copy))
cur_arc.weight, s+1)); // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
} else { // final-prob. // CompactLatticeShortestPath(clat_copy, shortest_path);
shortest_path->SetFinal(s, clat.Final(states[s])); // return;
} // }
} // // Now we can assume it's topologically sorted.
} // shortest_path->DeleteStates();
// if (clat.Start() == kNoStateId) return;
// typedef CompactLatticeArc Arc;
void ExpandCompactLattice(const CompactLattice &clat, // typedef Arc::StateId StateId;
double epsilon, // typedef CompactLatticeWeight Weight;
CompactLattice *expand_clat) { // vector<std::pair<double, StateId> > best_cost_and_pred(clat.NumStates() + 1);
using namespace fst; // StateId superfinal = clat.NumStates();
typedef CompactLattice::Arc Arc; // for (StateId s = 0; s <= clat.NumStates(); s++) {
typedef Arc::Weight Weight; // best_cost_and_pred[s].first = std::numeric_limits<double>::infinity();
typedef Arc::StateId StateId; // best_cost_and_pred[s].second = fst::kNoStateId;
typedef std::pair<StateId, StateId> StatePair; // }
typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType; // best_cost_and_pred[clat.Start()].first = 0;
typedef MapType::iterator IterType; // for (StateId s = 0; s < clat.NumStates(); s++) {
// double my_cost = best_cost_and_pred[s].first;
if (clat.Start() == kNoStateId) return; // for (ArcIterator<CompactLattice> aiter(clat, s);
// Make sure the input lattice is topologically sorted. // !aiter.Done();
if (clat.Properties(kTopSorted, true) == 0) { // aiter.Next()) {
CompactLattice clat_copy(clat); // const Arc &arc = aiter.Value();
KALDI_LOG << "Topsort this lattice."; // double arc_cost = ConvertToCost(arc.weight),
if (!TopSort(&clat_copy)) // next_cost = my_cost + arc_cost;
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // if (next_cost < best_cost_and_pred[arc.nextstate].first) {
ExpandCompactLattice(clat_copy, epsilon, expand_clat); // best_cost_and_pred[arc.nextstate].first = next_cost;
return; // best_cost_and_pred[arc.nextstate].second = s;
} // }
// }
// Compute backward logprobs betas for the expanded lattice. // double final_cost = ConvertToCost(clat.Final(s)),
// Note: the backward logprobs in the original lattice <clat> and the // tot_final = my_cost + final_cost;
// expanded lattice <expand_clat> are the same. // if (tot_final < best_cost_and_pred[superfinal].first) {
int32 num_states = clat.NumStates(); // best_cost_and_pred[superfinal].first = tot_final;
std::vector<double> beta(num_states, kLogZeroDouble); // best_cost_and_pred[superfinal].second = s;
ComputeCompactLatticeBetas(clat, &beta); // }
double tot_backward_logprob = beta[0]; // }
std::vector<double> alpha; // std::vector<StateId> states; // states on best path.
alpha.push_back(0.0); // StateId cur_state = superfinal, start_state = clat.Start();
expand_clat->DeleteStates(); // while (cur_state != start_state) {
MapType state_map; // Map from state pair (orig_state, copy_state) to // StateId prev_state = best_cost_and_pred[cur_state].second;
// copy_state, where orig_state is a state in the original lattice, and // if (prev_state == kNoStateId) {
// copy_state is its corresponding one in the expanded lattice. // KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)";
unordered_map<StateId, StateId> states; // Map from orig_state to its // return; // return empty best-path.
// copy_state for states with incoming arcs' posteriors <= epsilon. // }
std::queue<StatePair> state_queue; // states.push_back(prev_state);
// KALDI_ASSERT(cur_state != prev_state && "Lattice with cycles");
// Set start state in the expanded lattice. // cur_state = prev_state;
StateId start_state = expand_clat->AddState(); // }
expand_clat->SetStart(start_state); // std::reverse(states.begin(), states.end());
StatePair start_pair(clat.Start(), start_state); // for (size_t i = 0; i < states.size(); i++)
state_queue.push(start_pair); // shortest_path->AddState();
std::pair<IterType, bool> result = // for (StateId s = 0; static_cast<size_t>(s) < states.size(); s++) {
state_map.insert(std::make_pair(start_pair, start_state)); // if (s == 0) shortest_path->SetStart(s);
KALDI_ASSERT(result.second == true); // if (static_cast<size_t>(s + 1) < states.size()) { // transition to next state.
// bool have_arc = false;
// Expand <clat> and update forward logprobs alphas in <expand_clat>. // Arc cur_arc;
while (!state_queue.empty()) { // for (ArcIterator<CompactLattice> aiter(clat, states[s]);
StatePair s = state_queue.front(); // !aiter.Done();
StateId s1 = s.first, // aiter.Next()) {
s2 = s.second; // const Arc &arc = aiter.Value();
state_queue.pop(); // if (arc.nextstate == states[s+1]) {
// if (!have_arc ||
Weight f = clat.Final(s1); // ConvertToCost(arc.weight) < ConvertToCost(cur_arc.weight)) {
if (f != Weight::Zero()) { // cur_arc = arc;
KALDI_ASSERT(state_map.find(s) != state_map.end()); // have_arc = true;
expand_clat->SetFinal(state_map[s], f); // }
} // }
// }
for (ArcIterator<CompactLattice> aiter(clat, s1); // KALDI_ASSERT(have_arc && "Code error.");
!aiter.Done(); aiter.Next()) { // shortest_path->AddArc(s, Arc(cur_arc.ilabel, cur_arc.olabel,
const Arc &arc = aiter.Value(); // cur_arc.weight, s+1));
StateId orig_state = arc.nextstate; // } else { // final-prob.
double arc_like = -ConvertToCost(arc.weight), // shortest_path->SetFinal(s, clat.Final(states[s]));
this_alpha = alpha[s2] + arc_like, // }
arc_post = Exp(this_alpha + beta[orig_state] - // }
tot_backward_logprob); // }
// Generate the expanded lattice. //
StateId copy_state; //
if (arc_post > epsilon) { // void ExpandCompactLattice(const CompactLattice &clat,
copy_state = expand_clat->AddState(); // double epsilon,
StatePair next_pair(orig_state, copy_state); // CompactLattice *expand_clat) {
std::pair<IterType, bool> result = // using namespace fst;
state_map.insert(std::make_pair(next_pair, copy_state)); // typedef CompactLattice::Arc Arc;
KALDI_ASSERT(result.second == true); // typedef Arc::Weight Weight;
state_queue.push(next_pair); // typedef Arc::StateId StateId;
} else { // typedef std::pair<StateId, StateId> StatePair;
unordered_map<StateId, StateId>::iterator iter = states.find(orig_state); // typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
if (iter == states.end() ) { // The counterpart state of orig_state // typedef MapType::iterator IterType;
// has not been created in <expand_clat> yet. //
copy_state = expand_clat->AddState(); // if (clat.Start() == kNoStateId) return;
StatePair next_pair(orig_state, copy_state); // // Make sure the input lattice is topologically sorted.
std::pair<IterType, bool> result = // if (clat.Properties(kTopSorted, true) == 0) {
state_map.insert(std::make_pair(next_pair, copy_state)); // CompactLattice clat_copy(clat);
KALDI_ASSERT(result.second == true); // KALDI_LOG << "Topsort this lattice.";
state_queue.push(next_pair); // if (!TopSort(&clat_copy))
states[orig_state] = copy_state; // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
} else { // ExpandCompactLattice(clat_copy, epsilon, expand_clat);
copy_state = iter->second; // return;
} // }
} //
// Create an arc from state_map[s] to copy_state in the expanded lattice. // // Compute backward logprobs betas for the expanded lattice.
expand_clat->AddArc(state_map[s], Arc(arc.ilabel, arc.olabel, arc.weight, // // Note: the backward logprobs in the original lattice <clat> and the
copy_state)); // // expanded lattice <expand_clat> are the same.
// Compute forward logprobs alpha for the expanded lattice. // int32 num_states = clat.NumStates();
if ((alpha.size() - 1) < copy_state) { // The first time to compute alpha // std::vector<double> beta(num_states, kLogZeroDouble);
// for copy_state in <expand_clat>. // ComputeCompactLatticeBetas(clat, &beta);
alpha.push_back(this_alpha); // double tot_backward_logprob = beta[0];
} else { // Accumulate alpha. // std::vector<double> alpha;
alpha[copy_state] = LogAdd(alpha[copy_state], this_alpha); // alpha.push_back(0.0);
} // expand_clat->DeleteStates();
} // MapType state_map; // Map from state pair (orig_state, copy_state) to
} // end while // // copy_state, where orig_state is a state in the original lattice, and
} // // copy_state is its corresponding one in the expanded lattice.
// unordered_map<StateId, StateId> states; // Map from orig_state to its
// // copy_state for states with incoming arcs' posteriors <= epsilon.
void CompactLatticeBestCostsAndTracebacks( // std::queue<StatePair> state_queue;
const CompactLattice &clat, //
CostTraceType *forward_best_cost_and_pred, // // Set start state in the expanded lattice.
CostTraceType *backward_best_cost_and_pred) { // StateId start_state = expand_clat->AddState();
// expand_clat->SetStart(start_state);
// typedef the arc, weight types // StatePair start_pair(clat.Start(), start_state);
typedef CompactLatticeArc Arc; // state_queue.push(start_pair);
typedef Arc::Weight Weight; // std::pair<IterType, bool> result =
typedef Arc::StateId StateId; // state_map.insert(std::make_pair(start_pair, start_state));
// KALDI_ASSERT(result.second == true);
forward_best_cost_and_pred->clear(); //
backward_best_cost_and_pred->clear(); // // Expand <clat> and update forward logprobs alphas in <expand_clat>.
forward_best_cost_and_pred->resize(clat.NumStates()); // while (!state_queue.empty()) {
backward_best_cost_and_pred->resize(clat.NumStates()); // StatePair s = state_queue.front();
// Initialize the cost and predecessor state for each state. // StateId s1 = s.first,
for (StateId s = 0; s < clat.NumStates(); s++) { // s2 = s.second;
(*forward_best_cost_and_pred)[s].first = // state_queue.pop();
std::numeric_limits<double>::infinity(); //
(*backward_best_cost_and_pred)[s].first = // Weight f = clat.Final(s1);
std::numeric_limits<double>::infinity(); // if (f != Weight::Zero()) {
(*forward_best_cost_and_pred)[s].second = fst::kNoStateId; // KALDI_ASSERT(state_map.find(s) != state_map.end());
(*backward_best_cost_and_pred)[s].second = fst::kNoStateId; // expand_clat->SetFinal(state_map[s], f);
} // }
//
StateId start_state = clat.Start(); // for (ArcIterator<CompactLattice> aiter(clat, s1);
(*forward_best_cost_and_pred)[start_state].first = 0; // !aiter.Done(); aiter.Next()) {
// Transverse the lattice forwardly to compute the best cost from the start // const Arc &arc = aiter.Value();
// state to each state and the best predecessor state of each state. // StateId orig_state = arc.nextstate;
for (StateId s = 0; s < clat.NumStates(); s++) { // double arc_like = -ConvertToCost(arc.weight),
double cur_cost = (*forward_best_cost_and_pred)[s].first; // this_alpha = alpha[s2] + arc_like,
for (fst::ArcIterator<CompactLattice> aiter(clat, s); // arc_post = Exp(this_alpha + beta[orig_state] -
!aiter.Done(); aiter.Next()) { // tot_backward_logprob);
const Arc &arc = aiter.Value(); // // Generate the expanded lattice.
double next_cost = cur_cost + ConvertToCost(arc.weight); // StateId copy_state;
if (next_cost < (*forward_best_cost_and_pred)[arc.nextstate].first) { // if (arc_post > epsilon) {
(*forward_best_cost_and_pred)[arc.nextstate].first = next_cost; // copy_state = expand_clat->AddState();
(*forward_best_cost_and_pred)[arc.nextstate].second = s; // StatePair next_pair(orig_state, copy_state);
} // std::pair<IterType, bool> result =
} // state_map.insert(std::make_pair(next_pair, copy_state));
} // KALDI_ASSERT(result.second == true);
// Transverse the lattice backwardly to compute the best cost from a final // state_queue.push(next_pair);
// state to each state and the best predecessor state of each state. // } else {
for (StateId s = clat.NumStates() - 1; s >= 0; s--) { // unordered_map<StateId, StateId>::iterator iter = states.find(orig_state);
double this_cost = ConvertToCost(clat.Final(s)); // if (iter == states.end() ) { // The counterpart state of orig_state
for (fst::ArcIterator<CompactLattice> aiter(clat, s); // // has not been created in <expand_clat> yet.
!aiter.Done(); aiter.Next()) { // copy_state = expand_clat->AddState();
const Arc &arc = aiter.Value(); // StatePair next_pair(orig_state, copy_state);
double next_cost = (*backward_best_cost_and_pred)[arc.nextstate].first + // std::pair<IterType, bool> result =
ConvertToCost(arc.weight); // state_map.insert(std::make_pair(next_pair, copy_state));
if (next_cost < this_cost) { // KALDI_ASSERT(result.second == true);
this_cost = next_cost; // state_queue.push(next_pair);
(*backward_best_cost_and_pred)[s].second = arc.nextstate; // states[orig_state] = copy_state;
} // } else {
} // copy_state = iter->second;
(*backward_best_cost_and_pred)[s].first = this_cost; // }
} // }
} // // Create an arc from state_map[s] to copy_state in the expanded lattice.
// expand_clat->AddArc(state_map[s], Arc(arc.ilabel, arc.olabel, arc.weight,
// copy_state));
void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores, // // Compute forward logprobs alpha for the expanded lattice.
CompactLattice *clat) { // if ((alpha.size() - 1) < copy_state) { // The first time to compute alpha
if (clat->Start() == fst::kNoStateId) return; // // for copy_state in <expand_clat>.
// Make sure the input lattice is topologically sorted. // alpha.push_back(this_alpha);
if (clat->Properties(fst::kTopSorted, true) == 0) { // } else { // Accumulate alpha.
KALDI_LOG << "Topsort this lattice."; // alpha[copy_state] = LogAdd(alpha[copy_state], this_alpha);
if (!TopSort(clat)) // }
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // }
AddNnlmScoreToCompactLattice(nnlm_scores, clat); // } // end while
return; // }
} //
//
// typedef the arc, weight types // void CompactLatticeBestCostsAndTracebacks(
typedef CompactLatticeArc Arc; // const CompactLattice &clat,
typedef Arc::Weight Weight; // CostTraceType *forward_best_cost_and_pred,
typedef Arc::StateId StateId; // CostTraceType *backward_best_cost_and_pred) {
typedef std::pair<int32, int32> StatePair; //
// // typedef the arc, weight types
int32 num_states = clat->NumStates(); // typedef CompactLatticeArc Arc;
unordered_map<StatePair, bool, PairHasher<int32> > final_state_check; // typedef Arc::Weight Weight;
for (StateId s = 0; s < num_states; s++) { // typedef Arc::StateId StateId;
for (fst::MutableArcIterator<CompactLattice> aiter(clat, s); //
!aiter.Done(); aiter.Next()) { // forward_best_cost_and_pred->clear();
Arc arc(aiter.Value()); // backward_best_cost_and_pred->clear();
StatePair arc_index = std::make_pair(static_cast<int32>(s), // forward_best_cost_and_pred->resize(clat.NumStates());
static_cast<int32>(arc.nextstate)); // backward_best_cost_and_pred->resize(clat.NumStates());
MapT::const_iterator it = nnlm_scores.find(arc_index); // // Initialize the cost and predecessor state for each state.
double nnlm_score; // for (StateId s = 0; s < clat.NumStates(); s++) {
if (it != nnlm_scores.end()) // (*forward_best_cost_and_pred)[s].first =
nnlm_score = it->second; // std::numeric_limits<double>::infinity();
else // (*backward_best_cost_and_pred)[s].first =
KALDI_ERR << "Some arc does not have neural language model score."; // std::numeric_limits<double>::infinity();
if (arc.ilabel != 0) { // if there is a word on this arc // (*forward_best_cost_and_pred)[s].second = fst::kNoStateId;
LatticeWeight weight = arc.weight.Weight(); // (*backward_best_cost_and_pred)[s].second = fst::kNoStateId;
// Add associated neural LM score to each arc. // }
weight.SetValue1(weight.Value1() + nnlm_score); //
arc.weight.SetWeight(weight); // StateId start_state = clat.Start();
aiter.SetValue(arc); // (*forward_best_cost_and_pred)[start_state].first = 0;
} // // Transverse the lattice forwardly to compute the best cost from the start
Weight clat_final = clat->Final(arc.nextstate); // // state to each state and the best predecessor state of each state.
StatePair final_pair = std::make_pair(arc.nextstate, arc.nextstate); // for (StateId s = 0; s < clat.NumStates(); s++) {
// Add neural LM scores to each final state only once. // double cur_cost = (*forward_best_cost_and_pred)[s].first;
if (clat_final != CompactLatticeWeight::Zero() && // for (fst::ArcIterator<CompactLattice> aiter(clat, s);
final_state_check.find(final_pair) == final_state_check.end()) { // !aiter.Done(); aiter.Next()) {
MapT::const_iterator final_it = nnlm_scores.find(final_pair); // const Arc &arc = aiter.Value();
double final_nnlm_score = 0.0; // double next_cost = cur_cost + ConvertToCost(arc.weight);
if (final_it != nnlm_scores.end()) // if (next_cost < (*forward_best_cost_and_pred)[arc.nextstate].first) {
final_nnlm_score = final_it->second; // (*forward_best_cost_and_pred)[arc.nextstate].first = next_cost;
// Add neural LM scores to the final weight. // (*forward_best_cost_and_pred)[arc.nextstate].second = s;
Weight final_weight(LatticeWeight(clat_final.Weight().Value1() + // }
final_nnlm_score, // }
clat_final.Weight().Value2()), // }
clat_final.String()); // // Transverse the lattice backwardly to compute the best cost from a final
clat->SetFinal(arc.nextstate, final_weight); // // state to each state and the best predecessor state of each state.
final_state_check[final_pair] = true; // for (StateId s = clat.NumStates() - 1; s >= 0; s--) {
} // double this_cost = ConvertToCost(clat.Final(s));
} // end looping over arcs // for (fst::ArcIterator<CompactLattice> aiter(clat, s);
} // end looping over states // !aiter.Done(); aiter.Next()) {
} // const Arc &arc = aiter.Value();
// double next_cost = (*backward_best_cost_and_pred)[arc.nextstate].first +
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty, // ConvertToCost(arc.weight);
CompactLattice *clat) { // if (next_cost < this_cost) {
typedef CompactLatticeArc Arc; // this_cost = next_cost;
int32 num_states = clat->NumStates(); // (*backward_best_cost_and_pred)[s].second = arc.nextstate;
// }
//scan the lattice // }
for (int32 state = 0; state < num_states; state++) { // (*backward_best_cost_and_pred)[s].first = this_cost;
for (fst::MutableArcIterator<CompactLattice> aiter(clat, state); // }
!aiter.Done(); aiter.Next()) { // }
//
Arc arc(aiter.Value()); //
// void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
if (arc.ilabel != 0) { // if there is a word on this arc // CompactLattice *clat) {
LatticeWeight weight = arc.weight.Weight(); // if (clat->Start() == fst::kNoStateId) return;
// add word insertion penalty to lattice // // Make sure the input lattice is topologically sorted.
weight.SetValue1( weight.Value1() + word_ins_penalty); // if (clat->Properties(fst::kTopSorted, true) == 0) {
arc.weight.SetWeight(weight); // KALDI_LOG << "Topsort this lattice.";
aiter.SetValue(arc); // if (!TopSort(clat))
} // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
} // end looping over arcs // AddNnlmScoreToCompactLattice(nnlm_scores, clat);
} // end looping over states // return;
} // }
//
struct ClatRescoreTuple { // // typedef the arc, weight types
ClatRescoreTuple(int32 state, int32 arc, int32 tid): // typedef CompactLatticeArc Arc;
state_id(state), arc_id(arc), tid(tid) { } // typedef Arc::Weight Weight;
int32 state_id; // typedef Arc::StateId StateId;
int32 arc_id; // typedef std::pair<int32, int32> StatePair;
int32 tid; //
}; // int32 num_states = clat->NumStates();
// unordered_map<StatePair, bool, PairHasher<int32> > final_state_check;
/** RescoreCompactLatticeInternal is the internal code for both // for (StateId s = 0; s < num_states; s++) {
RescoreCompactLattice and RescoreCompatLatticeSpeedup. For // for (fst::MutableArcIterator<CompactLattice> aiter(clat, s);
RescoreCompactLattice, "tmodel" will be NULL and speedup_factor will be 1.0. // !aiter.Done(); aiter.Next()) {
*/ // Arc arc(aiter.Value());
bool RescoreCompactLatticeInternal( // StatePair arc_index = std::make_pair(static_cast<int32>(s),
const TransitionInformation *tmodel, // static_cast<int32>(arc.nextstate));
BaseFloat speedup_factor, // MapT::const_iterator it = nnlm_scores.find(arc_index);
DecodableInterface *decodable, // double nnlm_score;
CompactLattice *clat) { // if (it != nnlm_scores.end())
KALDI_ASSERT(speedup_factor >= 1.0); // nnlm_score = it->second;
if (clat->NumStates() == 0) { // else
KALDI_WARN << "Rescoring empty lattice"; // KALDI_ERR << "Some arc does not have neural language model score.";
return false; // if (arc.ilabel != 0) { // if there is a word on this arc
} // LatticeWeight weight = arc.weight.Weight();
if (!clat->Properties(fst::kTopSorted, true)) { // // Add associated neural LM score to each arc.
if (fst::TopSort(clat) == false) { // weight.SetValue1(weight.Value1() + nnlm_score);
KALDI_WARN << "Cycles detected in lattice."; // arc.weight.SetWeight(weight);
return false; // aiter.SetValue(arc);
} // }
} // Weight clat_final = clat->Final(arc.nextstate);
std::vector<int32> state_times; // StatePair final_pair = std::make_pair(arc.nextstate, arc.nextstate);
int32 utt_len = kaldi::CompactLatticeStateTimes(*clat, &state_times); // // Add neural LM scores to each final state only once.
// if (clat_final != CompactLatticeWeight::Zero() &&
std::vector<std::vector<ClatRescoreTuple> > time_to_state(utt_len); // final_state_check.find(final_pair) == final_state_check.end()) {
// MapT::const_iterator final_it = nnlm_scores.find(final_pair);
int32 num_states = clat->NumStates(); // double final_nnlm_score = 0.0;
KALDI_ASSERT(num_states == state_times.size()); // if (final_it != nnlm_scores.end())
for (size_t state = 0; state < num_states; state++) { // final_nnlm_score = final_it->second;
KALDI_ASSERT(state_times[state] >= 0); // // Add neural LM scores to the final weight.
int32 t = state_times[state]; // Weight final_weight(LatticeWeight(clat_final.Weight().Value1() +
int32 arc_id = 0; // final_nnlm_score,
for (fst::MutableArcIterator<CompactLattice> aiter(clat, state); // clat_final.Weight().Value2()),
!aiter.Done(); aiter.Next(), arc_id++) { // clat_final.String());
CompactLatticeArc arc = aiter.Value(); // clat->SetFinal(arc.nextstate, final_weight);
std::vector<int32> arc_string = arc.weight.String(); // final_state_check[final_pair] = true;
// }
for (size_t offset = 0; offset < arc_string.size(); offset++) { // } // end looping over arcs
if (t < utt_len) { // end state may be past this.. // } // end looping over states
int32 tid = arc_string[offset]; // }
time_to_state[t+offset].push_back(ClatRescoreTuple(state, arc_id, tid)); //
} else { // void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
if (t != utt_len) { // CompactLattice *clat) {
KALDI_WARN << "There appears to be lattice/feature mismatch, " // typedef CompactLatticeArc Arc;
<< "aborting."; // int32 num_states = clat->NumStates();
return false; //
} // //scan the lattice
} // for (int32 state = 0; state < num_states; state++) {
} // for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
} // !aiter.Done(); aiter.Next()) {
if (clat->Final(state) != CompactLatticeWeight::Zero()) { //
arc_id = -1; // Arc arc(aiter.Value());
std::vector<int32> arc_string = clat->Final(state).String(); //
for (size_t offset = 0; offset < arc_string.size(); offset++) { // if (arc.ilabel != 0) { // if there is a word on this arc
KALDI_ASSERT(t + offset < utt_len); // already checked in // LatticeWeight weight = arc.weight.Weight();
// CompactLatticeStateTimes, so would be code error. // // add word insertion penalty to lattice
time_to_state[t+offset].push_back( // weight.SetValue1( weight.Value1() + word_ins_penalty);
ClatRescoreTuple(state, arc_id, arc_string[offset])); // arc.weight.SetWeight(weight);
} // aiter.SetValue(arc);
} // }
} // } // end looping over arcs
// } // end looping over states
for (int32 t = 0; t < utt_len; t++) { // }
if ((t < utt_len - 1) && decodable->IsLastFrame(t)) { //
KALDI_WARN << "Features are too short for lattice: utt-len is " // struct ClatRescoreTuple {
<< utt_len << ", " << t << " is last frame"; // ClatRescoreTuple(int32 state, int32 arc, int32 tid):
return false; // state_id(state), arc_id(arc), tid(tid) { }
} // int32 state_id;
// frame_scale is the scale we put on the computed acoustic probs for this // int32 arc_id;
// frame. It will always be 1.0 if tmodel == NULL (i.e. if we are not doing // int32 tid;
// the "speedup" code). For frames with multiple pdf-ids it will be one. // };
// For frames with only one pdf-id, it will equal speedup_factor (>=1.0) //
// with probability 1.0 / speedup_factor, and zero otherwise. If it is zero, // /** RescoreCompactLatticeInternal is the internal code for both
// we can avoid computing the probabilities. // RescoreCompactLattice and RescoreCompatLatticeSpeedup. For
BaseFloat frame_scale = 1.0; // RescoreCompactLattice, "tmodel" will be NULL and speedup_factor will be 1.0.
KALDI_ASSERT(!time_to_state[t].empty()); // */
if (tmodel != NULL) { // bool RescoreCompactLatticeInternal(
int32 pdf_id = tmodel->TransitionIdToPdf(time_to_state[t][0].tid); // const TransitionModel *tmodel,
bool frame_has_multiple_pdfs = false; // BaseFloat speedup_factor,
for (size_t i = 1; i < time_to_state[t].size(); i++) { // DecodableInterface *decodable,
if (tmodel->TransitionIdToPdf(time_to_state[t][i].tid) != pdf_id) { // CompactLattice *clat) {
frame_has_multiple_pdfs = true; // KALDI_ASSERT(speedup_factor >= 1.0);
break; // if (clat->NumStates() == 0) {
} // KALDI_WARN << "Rescoring empty lattice";
} // return false;
if (frame_has_multiple_pdfs) { // }
frame_scale = 1.0; // if (!clat->Properties(fst::kTopSorted, true)) {
} else { // if (fst::TopSort(clat) == false) {
if (WithProb(1.0 / speedup_factor)) { // KALDI_WARN << "Cycles detected in lattice.";
frame_scale = speedup_factor; // return false;
} else { // }
frame_scale = 0.0; // }
} // std::vector<int32> state_times;
} // int32 utt_len = kaldi::CompactLatticeStateTimes(*clat, &state_times);
if (frame_scale == 0.0) //
continue; // the code below would be pointless. // std::vector<std::vector<ClatRescoreTuple> > time_to_state(utt_len);
} //
// int32 num_states = clat->NumStates();
for (size_t i = 0; i < time_to_state[t].size(); i++) { // KALDI_ASSERT(num_states == state_times.size());
int32 state = time_to_state[t][i].state_id; // for (size_t state = 0; state < num_states; state++) {
int32 arc_id = time_to_state[t][i].arc_id; // KALDI_ASSERT(state_times[state] >= 0);
int32 tid = time_to_state[t][i].tid; // int32 t = state_times[state];
// int32 arc_id = 0;
if (arc_id == -1) { // Final state // for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
// Access the trans_id // !aiter.Done(); aiter.Next(), arc_id++) {
CompactLatticeWeight curr_clat_weight = clat->Final(state); // CompactLatticeArc arc = aiter.Value();
// std::vector<int32> arc_string = arc.weight.String();
// Calculate likelihood //
BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale; // for (size_t offset = 0; offset < arc_string.size(); offset++) {
// update weight // if (t < utt_len) { // end state may be past this..
CompactLatticeWeight new_clat_weight = curr_clat_weight; // int32 tid = arc_string[offset];
LatticeWeight new_lat_weight = new_clat_weight.Weight(); // time_to_state[t+offset].push_back(ClatRescoreTuple(state, arc_id, tid));
new_lat_weight.SetValue2(-log_like + curr_clat_weight.Weight().Value2()); // } else {
new_clat_weight.SetWeight(new_lat_weight); // if (t != utt_len) {
clat->SetFinal(state, new_clat_weight); // KALDI_WARN << "There appears to be lattice/feature mismatch, "
} else { // << "aborting.";
fst::MutableArcIterator<CompactLattice> aiter(clat, state); // return false;
// }
aiter.Seek(arc_id); // }
CompactLatticeArc arc = aiter.Value(); // }
// }
// Calculate likelihood // if (clat->Final(state) != CompactLatticeWeight::Zero()) {
BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale; // arc_id = -1;
// update weight // std::vector<int32> arc_string = clat->Final(state).String();
LatticeWeight new_weight = arc.weight.Weight(); // for (size_t offset = 0; offset < arc_string.size(); offset++) {
new_weight.SetValue2(-log_like + arc.weight.Weight().Value2()); // KALDI_ASSERT(t + offset < utt_len); // already checked in
arc.weight.SetWeight(new_weight); // // CompactLatticeStateTimes, so would be code error.
aiter.SetValue(arc); // time_to_state[t+offset].push_back(
} // ClatRescoreTuple(state, arc_id, arc_string[offset]));
} // }
} // }
return true; // }
} //
// for (int32 t = 0; t < utt_len; t++) {
// if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
bool RescoreCompactLatticeSpeedup( // KALDI_WARN << "Features are too short for lattice: utt-len is "
const TransitionInformation &tmodel, // << utt_len << ", " << t << " is last frame";
BaseFloat speedup_factor, // return false;
DecodableInterface *decodable, // }
CompactLattice *clat) { // // frame_scale is the scale we put on the computed acoustic probs for this
return RescoreCompactLatticeInternal(&tmodel, speedup_factor, decodable, clat); // // frame. It will always be 1.0 if tmodel == NULL (i.e. if we are not doing
} // // the "speedup" code). For frames with multiple pdf-ids it will be one.
// // For frames with only one pdf-id, it will equal speedup_factor (>=1.0)
bool RescoreCompactLattice(DecodableInterface *decodable, // // with probability 1.0 / speedup_factor, and zero otherwise. If it is zero,
CompactLattice *clat) { // // we can avoid computing the probabilities.
return RescoreCompactLatticeInternal(NULL, 1.0, decodable, clat); // BaseFloat frame_scale = 1.0;
} // KALDI_ASSERT(!time_to_state[t].empty());
// if (tmodel != NULL) {
// int32 pdf_id = tmodel->TransitionIdToPdf(time_to_state[t][0].tid);
bool RescoreLattice(DecodableInterface *decodable, // bool frame_has_multiple_pdfs = false;
Lattice *lat) { // for (size_t i = 1; i < time_to_state[t].size(); i++) {
if (lat->NumStates() == 0) { // if (tmodel->TransitionIdToPdf(time_to_state[t][i].tid) != pdf_id) {
KALDI_WARN << "Rescoring empty lattice"; // frame_has_multiple_pdfs = true;
return false; // break;
} // }
if (!lat->Properties(fst::kTopSorted, true)) { // }
if (fst::TopSort(lat) == false) { // if (frame_has_multiple_pdfs) {
KALDI_WARN << "Cycles detected in lattice."; // frame_scale = 1.0;
return false; // } else {
} // if (WithProb(1.0 / speedup_factor)) {
} // frame_scale = speedup_factor;
std::vector<int32> state_times; // } else {
int32 utt_len = kaldi::LatticeStateTimes(*lat, &state_times); // frame_scale = 0.0;
// }
std::vector<std::vector<int32> > time_to_state(utt_len ); // }
// if (frame_scale == 0.0)
int32 num_states = lat->NumStates(); // continue; // the code below would be pointless.
KALDI_ASSERT(num_states == state_times.size()); // }
for (size_t state = 0; state < num_states; state++) { //
int32 t = state_times[state]; // for (size_t i = 0; i < time_to_state[t].size(); i++) {
// Don't check t >= 0 because non-accessible states could have t = -1. // int32 state = time_to_state[t][i].state_id;
KALDI_ASSERT(t <= utt_len); // int32 arc_id = time_to_state[t][i].arc_id;
if (t >= 0 && t < utt_len) // int32 tid = time_to_state[t][i].tid;
time_to_state[t].push_back(state); //
} // if (arc_id == -1) { // Final state
// // Access the trans_id
for (int32 t = 0; t < utt_len; t++) { // CompactLatticeWeight curr_clat_weight = clat->Final(state);
if ((t < utt_len - 1) && decodable->IsLastFrame(t)) { //
KALDI_WARN << "Features are too short for lattice: utt-len is " // // Calculate likelihood
<< utt_len << ", " << t << " is last frame"; // BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
return false; // // update weight
} // CompactLatticeWeight new_clat_weight = curr_clat_weight;
for (size_t i = 0; i < time_to_state[t].size(); i++) { // LatticeWeight new_lat_weight = new_clat_weight.Weight();
int32 state = time_to_state[t][i]; // new_lat_weight.SetValue2(-log_like + curr_clat_weight.Weight().Value2());
for (fst::MutableArcIterator<Lattice> aiter(lat, state); // new_clat_weight.SetWeight(new_lat_weight);
!aiter.Done(); aiter.Next()) { // clat->SetFinal(state, new_clat_weight);
LatticeArc arc = aiter.Value(); // } else {
if (arc.ilabel != 0) { // fst::MutableArcIterator<CompactLattice> aiter(clat, state);
int32 trans_id = arc.ilabel; // Note: it doesn't necessarily //
// have to be a transition-id, just whatever the Decodable // aiter.Seek(arc_id);
// object is expecting, but it's normally a transition-id. // CompactLatticeArc arc = aiter.Value();
//
BaseFloat log_like = decodable->LogLikelihood(t, trans_id); // // Calculate likelihood
arc.weight.SetValue2(-log_like + arc.weight.Value2()); // BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
aiter.SetValue(arc); // // update weight
} // LatticeWeight new_weight = arc.weight.Weight();
} // new_weight.SetValue2(-log_like + arc.weight.Weight().Value2());
} // arc.weight.SetWeight(new_weight);
} // aiter.SetValue(arc);
return true; // }
} // }
// }
// return true;
int32 LongestSentenceLength(const Lattice &lat) { // }
typedef Lattice::Arc Arc; //
typedef Arc::Label Label; //
typedef Arc::StateId StateId; // bool RescoreCompactLatticeSpeedup(
// const TransitionModel &tmodel,
if (lat.Properties(fst::kTopSorted, true) == 0) { // BaseFloat speedup_factor,
Lattice lat_copy(lat); // DecodableInterface *decodable,
if (!TopSort(&lat_copy)) // CompactLattice *clat) {
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // return RescoreCompactLatticeInternal(&tmodel, speedup_factor, decodable, clat);
return LongestSentenceLength(lat_copy); // }
} //
std::vector<int32> max_length(lat.NumStates(), 0); // bool RescoreCompactLattice(DecodableInterface *decodable,
int32 lattice_max_length = 0; // CompactLattice *clat) {
for (StateId s = 0; s < lat.NumStates(); s++) { // return RescoreCompactLatticeInternal(NULL, 1.0, decodable, clat);
int32 this_max_length = max_length[s]; // }
for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) { //
const Arc &arc = aiter.Value(); //
bool arc_has_word = (arc.olabel != 0); // bool RescoreLattice(DecodableInterface *decodable,
StateId nextstate = arc.nextstate; // Lattice *lat) {
KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size()); // if (lat->NumStates() == 0) {
if (arc_has_word) { // KALDI_WARN << "Rescoring empty lattice";
// A lattice should ideally not have cycles anyway; a cycle with a word // return false;
// on is something very bad. // }
KALDI_ASSERT(nextstate > s && "Lattice has cycles with words on."); // if (!lat->Properties(fst::kTopSorted, true)) {
max_length[nextstate] = std::max(max_length[nextstate], // if (fst::TopSort(lat) == false) {
this_max_length + 1); // KALDI_WARN << "Cycles detected in lattice.";
} else { // return false;
max_length[nextstate] = std::max(max_length[nextstate], // }
this_max_length); // }
} // std::vector<int32> state_times;
} // int32 utt_len = kaldi::LatticeStateTimes(*lat, &state_times);
if (lat.Final(s) != LatticeWeight::Zero()) //
lattice_max_length = std::max(lattice_max_length, max_length[s]); // std::vector<std::vector<int32> > time_to_state(utt_len );
} //
return lattice_max_length; // int32 num_states = lat->NumStates();
} // KALDI_ASSERT(num_states == state_times.size());
// for (size_t state = 0; state < num_states; state++) {
int32 LongestSentenceLength(const CompactLattice &clat) { // int32 t = state_times[state];
typedef CompactLattice::Arc Arc; // // Don't check t >= 0 because non-accessible states could have t = -1.
typedef Arc::Label Label; // KALDI_ASSERT(t <= utt_len);
typedef Arc::StateId StateId; // if (t >= 0 && t < utt_len)
// time_to_state[t].push_back(state);
if (clat.Properties(fst::kTopSorted, true) == 0) { // }
CompactLattice clat_copy(clat); //
if (!TopSort(&clat_copy)) // for (int32 t = 0; t < utt_len; t++) {
KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
return LongestSentenceLength(clat_copy); // KALDI_WARN << "Features are too short for lattice: utt-len is "
} // << utt_len << ", " << t << " is last frame";
std::vector<int32> max_length(clat.NumStates(), 0); // return false;
int32 lattice_max_length = 0; // }
for (StateId s = 0; s < clat.NumStates(); s++) { // for (size_t i = 0; i < time_to_state[t].size(); i++) {
int32 this_max_length = max_length[s]; // int32 state = time_to_state[t][i];
for (fst::ArcIterator<CompactLattice> aiter(clat, s); // for (fst::MutableArcIterator<Lattice> aiter(lat, state);
!aiter.Done(); aiter.Next()) { // !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value(); // LatticeArc arc = aiter.Value();
bool arc_has_word = (arc.ilabel != 0); // note: olabel == ilabel. // if (arc.ilabel != 0) {
// also note: for normal CompactLattice, e.g. as produced by // int32 trans_id = arc.ilabel; // Note: it doesn't necessarily
// determinization, all arcs will have nonzero labels, but the user might // // have to be a transition-id, just whatever the Decodable
// decide to remplace some of the labels with zero for some reason, and we // // object is expecting, but it's normally a transition-id.
// want to support this. //
StateId nextstate = arc.nextstate; // BaseFloat log_like = decodable->LogLikelihood(t, trans_id);
KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size()); // arc.weight.SetValue2(-log_like + arc.weight.Value2());
KALDI_ASSERT(nextstate > s && "CompactLattice has cycles"); // aiter.SetValue(arc);
if (arc_has_word) // }
max_length[nextstate] = std::max(max_length[nextstate], // }
this_max_length + 1); // }
else // }
max_length[nextstate] = std::max(max_length[nextstate], // return true;
this_max_length); // }
} //
if (clat.Final(s) != CompactLatticeWeight::Zero()) //
lattice_max_length = std::max(lattice_max_length, max_length[s]); // BaseFloat LatticeForwardBackwardMmi(
} // const TransitionModel &tmodel,
return lattice_max_length; // const Lattice &lat,
} // const std::vector<int32> &num_ali,
// bool drop_frames,
void ComposeCompactLatticeDeterministic( // bool convert_to_pdf_ids,
const CompactLattice& clat, // bool cancel,
fst::DeterministicOnDemandFst<fst::StdArc>* det_fst, // Posterior *post) {
CompactLattice* composed_clat) { // // First compute the MMI posteriors.
// StdFst::Arc and CompactLatticeArc has the same StateId type. //
typedef fst::StdArc::StateId StateId; // Posterior den_post;
typedef fst::StdArc::Weight Weight1; // BaseFloat ans = LatticeForwardBackward(lat,
typedef CompactLatticeArc::Weight Weight2; // &den_post,
typedef std::pair<StateId, StateId> StatePair; // NULL);
typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType; //
typedef MapType::iterator IterType; // Posterior num_post;
// AlignmentToPosterior(num_ali, &num_post);
// Empties the output FST. //
KALDI_ASSERT(composed_clat != NULL); // // Now negate the MMI posteriors and add the numerator
composed_clat->DeleteStates(); // // posteriors.
// ScalePosterior(-1.0, &den_post);
MapType state_map; //
std::queue<StatePair> state_queue; // if (convert_to_pdf_ids) {
// Posterior num_tmp;
// Sets start state in <composed_clat>. // ConvertPosteriorToPdfs(tmodel, num_post, &num_tmp);
StateId start_state = composed_clat->AddState(); // num_tmp.swap(num_post);
StatePair start_pair(clat.Start(), det_fst->Start()); // Posterior den_tmp;
composed_clat->SetStart(start_state); // ConvertPosteriorToPdfs(tmodel, den_post, &den_tmp);
state_queue.push(start_pair); // den_tmp.swap(den_post);
std::pair<IterType, bool> result = // }
state_map.insert(std::make_pair(start_pair, start_state)); //
KALDI_ASSERT(result.second == true); // MergePosteriors(num_post, den_post,
// cancel, drop_frames, post);
// Starts composition here. //
while (!state_queue.empty()) { // return ans;
// Gets the first state in the queue. // }
StatePair s = state_queue.front(); //
StateId s1 = s.first; //
StateId s2 = s.second; // int32 LongestSentenceLength(const Lattice &lat) {
state_queue.pop(); // typedef Lattice::Arc Arc;
// typedef Arc::Label Label;
// typedef Arc::StateId StateId;
Weight2 clat_final = clat.Final(s1); //
if (clat_final.Weight().Value1() != // if (lat.Properties(fst::kTopSorted, true) == 0) {
std::numeric_limits<BaseFloat>::infinity()) { // Lattice lat_copy(lat);
// Test for whether the final-prob of state s1 was zero. // if (!TopSort(&lat_copy))
Weight1 det_fst_final = det_fst->Final(s2); // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
if (det_fst_final.Value() != // return LongestSentenceLength(lat_copy);
std::numeric_limits<BaseFloat>::infinity()) { // }
// Test for whether the final-prob of state s2 was zero. If neither // std::vector<int32> max_length(lat.NumStates(), 0);
// source-state final prob was zero, then we should create final state // int32 lattice_max_length = 0;
// in fst_composed. We compute the product manually since this is more // for (StateId s = 0; s < lat.NumStates(); s++) {
// efficient. // int32 this_max_length = max_length[s];
Weight2 final_weight(LatticeWeight(clat_final.Weight().Value1() + // for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
det_fst_final.Value(), // const Arc &arc = aiter.Value();
clat_final.Weight().Value2()), // bool arc_has_word = (arc.olabel != 0);
clat_final.String()); // StateId nextstate = arc.nextstate;
// we can assume final_weight is not Zero(), since neither of // KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
// the sources was zero. // if (arc_has_word) {
KALDI_ASSERT(state_map.find(s) != state_map.end()); // // A lattice should ideally not have cycles anyway; a cycle with a word
composed_clat->SetFinal(state_map[s], final_weight); // // on is something very bad.
} // KALDI_ASSERT(nextstate > s && "Lattice has cycles with words on.");
} // max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length + 1);
// Loops over pair of edges at s1 and s2. // } else {
for (fst::ArcIterator<CompactLattice> aiter(clat, s1); // max_length[nextstate] = std::max(max_length[nextstate],
!aiter.Done(); aiter.Next()) { // this_max_length);
const CompactLatticeArc& arc1 = aiter.Value(); // }
fst::StdArc arc2; // }
StateId next_state1 = arc1.nextstate, next_state2; // if (lat.Final(s) != LatticeWeight::Zero())
bool matched = false; // lattice_max_length = std::max(lattice_max_length, max_length[s]);
// }
if (arc1.olabel == 0) { // return lattice_max_length;
// If the symbol on <arc1> is <epsilon>, we transit to the next state // }
// for <clat>, but keep <det_fst> at the current state. //
matched = true; // int32 LongestSentenceLength(const CompactLattice &clat) {
next_state2 = s2; // typedef CompactLattice::Arc Arc;
} else { // typedef Arc::Label Label;
// Otherwise try to find the matched arc in <det_fst>. // typedef Arc::StateId StateId;
matched = det_fst->GetArc(s2, arc1.olabel, &arc2); //
if (matched) { // if (clat.Properties(fst::kTopSorted, true) == 0) {
next_state2 = arc2.nextstate; // CompactLattice clat_copy(clat);
} // if (!TopSort(&clat_copy))
} // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// return LongestSentenceLength(clat_copy);
// If matched arc is found in <det_fst>, then we have to add new arcs to // }
// <composed_clat>. // std::vector<int32> max_length(clat.NumStates(), 0);
if (matched) { // int32 lattice_max_length = 0;
StatePair next_state_pair(next_state1, next_state2); // for (StateId s = 0; s < clat.NumStates(); s++) {
IterType siter = state_map.find(next_state_pair); // int32 this_max_length = max_length[s];
StateId next_state; // for (fst::ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// Adds composed state to <state_map>. // const Arc &arc = aiter.Value();
if (siter == state_map.end()) { // bool arc_has_word = (arc.ilabel != 0); // note: olabel == ilabel.
// If the composed state has not been created yet, create it. // // also note: for normal CompactLattice, e.g. as produced by
next_state = composed_clat->AddState(); // // determinization, all arcs will have nonzero labels, but the user might
std::pair<const StatePair, StateId> next_state_map(next_state_pair, // // decide to remplace some of the labels with zero for some reason, and we
next_state); // // want to support this.
std::pair<IterType, bool> result = state_map.insert(next_state_map); // StateId nextstate = arc.nextstate;
KALDI_ASSERT(result.second); // KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
state_queue.push(next_state_pair); // KALDI_ASSERT(nextstate > s && "CompactLattice has cycles");
} else { // if (arc_has_word)
// If the composed state is already in <state_map>, we can directly // max_length[nextstate] = std::max(max_length[nextstate],
// use that. // this_max_length + 1);
next_state = siter->second; // else
} // max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length);
// Adds arc to <composed_clat>. // }
if (arc1.olabel == 0) { // if (clat.Final(s) != CompactLatticeWeight::Zero())
composed_clat->AddArc(state_map[s], // lattice_max_length = std::max(lattice_max_length, max_length[s]);
CompactLatticeArc(arc1.ilabel, 0, // }
arc1.weight, next_state)); // return lattice_max_length;
} else { // }
Weight2 composed_weight( //
LatticeWeight(arc1.weight.Weight().Value1() + // void ComposeCompactLatticeDeterministic(
arc2.weight.Value(), // const CompactLattice& clat,
arc1.weight.Weight().Value2()), // fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
arc1.weight.String()); // CompactLattice* composed_clat) {
composed_clat->AddArc(state_map[s], // // StdFst::Arc and CompactLatticeArc has the same StateId type.
CompactLatticeArc(arc1.ilabel, arc2.olabel, // typedef fst::StdArc::StateId StateId;
composed_weight, next_state)); // typedef fst::StdArc::Weight Weight1;
} // typedef CompactLatticeArc::Weight Weight2;
} // typedef std::pair<StateId, StateId> StatePair;
} // typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
} // typedef MapType::iterator IterType;
fst::Connect(composed_clat); //
} // // Empties the output FST.
// KALDI_ASSERT(composed_clat != NULL);
// composed_clat->DeleteStates();
void ComputeAcousticScoresMap( //
const Lattice &lat, // MapType state_map;
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>, // std::queue<StatePair> state_queue;
PairHasher<int32> > *acoustic_scores) { //
// typedef the arc, weight types // // Sets start state in <composed_clat>.
typedef Lattice::Arc Arc; // StateId start_state = composed_clat->AddState();
typedef Arc::Weight LatticeWeight; // StatePair start_pair(clat.Start(), det_fst->Start());
typedef Arc::StateId StateId; // composed_clat->SetStart(start_state);
// state_queue.push(start_pair);
acoustic_scores->clear(); // std::pair<IterType, bool> result =
// state_map.insert(std::make_pair(start_pair, start_state));
std::vector<int32> state_times; // KALDI_ASSERT(result.second == true);
LatticeStateTimes(lat, &state_times); // Assumes the input is top sorted //
// // Starts composition here.
KALDI_ASSERT(lat.Start() == 0); // while (!state_queue.empty()) {
// // Gets the first state in the queue.
for (StateId s = 0; s < lat.NumStates(); s++) { // StatePair s = state_queue.front();
int32 t = state_times[s]; // StateId s1 = s.first;
for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); // StateId s2 = s.second;
aiter.Next()) { // state_queue.pop();
const Arc &arc = aiter.Value(); //
const LatticeWeight &weight = arc.weight; //
// Weight2 clat_final = clat.Final(s1);
int32 tid = arc.ilabel; // if (clat_final.Weight().Value1() !=
// std::numeric_limits<BaseFloat>::infinity()) {
if (tid != 0) { // // Test for whether the final-prob of state s1 was zero.
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>, // Weight1 det_fst_final = det_fst->Final(s2);
PairHasher<int32> >::iterator it = acoustic_scores->find(std::make_pair(t, tid)); // if (det_fst_final.Value() !=
if (it == acoustic_scores->end()) { // std::numeric_limits<BaseFloat>::infinity()) {
acoustic_scores->insert(std::make_pair(std::make_pair(t, tid), // // Test for whether the final-prob of state s2 was zero. If neither
std::make_pair(weight.Value2(), 1))); // // source-state final prob was zero, then we should create final state
} else { // // in fst_composed. We compute the product manually since this is more
if (it->second.second == 2 // // efficient.
&& it->second.first / it->second.second != weight.Value2()) { // Weight2 final_weight(LatticeWeight(clat_final.Weight().Value1() +
KALDI_VLOG(2) << "Transitions on the same frame have different " // det_fst_final.Value(),
<< "acoustic costs for tid " << tid << "; " // clat_final.Weight().Value2()),
<< it->second.first / it->second.second // clat_final.String());
<< " vs " << weight.Value2(); // // we can assume final_weight is not Zero(), since neither of
} // // the sources was zero.
it->second.first += weight.Value2(); // KALDI_ASSERT(state_map.find(s) != state_map.end());
it->second.second++; // composed_clat->SetFinal(state_map[s], final_weight);
} // }
} else { // }
// Arcs with epsilon input label (tid) must have 0 acoustic cost //
KALDI_ASSERT(weight.Value2() == 0); // // Loops over pair of edges at s1 and s2.
} // for (fst::ArcIterator<CompactLattice> aiter(clat, s1);
} // !aiter.Done(); aiter.Next()) {
// const CompactLatticeArc& arc1 = aiter.Value();
LatticeWeight f = lat.Final(s); // fst::StdArc arc2;
if (f != LatticeWeight::Zero()) { // StateId next_state1 = arc1.nextstate, next_state2;
// Final acoustic cost must be 0 as we are reading from // bool matched = false;
// non-determinized, non-compact lattice //
KALDI_ASSERT(f.Value2() == 0.0); // if (arc1.olabel == 0) {
} // // If the symbol on <arc1> is <epsilon>, we transit to the next state
} // // for <clat>, but keep <det_fst> at the current state.
} // matched = true;
// next_state2 = s2;
void ReplaceAcousticScoresFromMap( // } else {
const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>, // // Otherwise try to find the matched arc in <det_fst>.
PairHasher<int32> > &acoustic_scores, // matched = det_fst->GetArc(s2, arc1.olabel, &arc2);
Lattice *lat) { // if (matched) {
// typedef the arc, weight types // next_state2 = arc2.nextstate;
typedef Lattice::Arc Arc; // }
typedef Arc::Weight LatticeWeight; // }
typedef Arc::StateId StateId; //
// // If matched arc is found in <det_fst>, then we have to add new arcs to
TopSortLatticeIfNeeded(lat); // // <composed_clat>.
// if (matched) {
std::vector<int32> state_times; // StatePair next_state_pair(next_state1, next_state2);
LatticeStateTimes(*lat, &state_times); // IterType siter = state_map.find(next_state_pair);
// StateId next_state;
KALDI_ASSERT(lat->Start() == 0); //
// // Adds composed state to <state_map>.
for (StateId s = 0; s < lat->NumStates(); s++) { // if (siter == state_map.end()) {
int32 t = state_times[s]; // // If the composed state has not been created yet, create it.
for (fst::MutableArcIterator<Lattice> aiter(lat, s); // next_state = composed_clat->AddState();
!aiter.Done(); aiter.Next()) { // std::pair<const StatePair, StateId> next_state_map(next_state_pair,
Arc arc(aiter.Value()); // next_state);
// std::pair<IterType, bool> result = state_map.insert(next_state_map);
int32 tid = arc.ilabel; // KALDI_ASSERT(result.second);
if (tid != 0) { // state_queue.push(next_state_pair);
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>, // } else {
PairHasher<int32> >::const_iterator it = acoustic_scores.find(std::make_pair(t, tid)); // // If the composed state is already in <state_map>, we can directly
if (it == acoustic_scores.end()) { // // use that.
KALDI_ERR << "Could not find tid " << tid << " at time " << t // next_state = siter->second;
<< " in the acoustic scores map."; // }
} else { //
arc.weight.SetValue2(it->second.first / it->second.second); // // Adds arc to <composed_clat>.
} // if (arc1.olabel == 0) {
} else { // composed_clat->AddArc(state_map[s],
// For epsilon arcs, set acoustic cost to 0.0 // CompactLatticeArc(arc1.ilabel, 0,
arc.weight.SetValue2(0.0); // arc1.weight, next_state));
} // } else {
aiter.SetValue(arc); // Weight2 composed_weight(
} // LatticeWeight(arc1.weight.Weight().Value1() +
// arc2.weight.Value(),
LatticeWeight f = lat->Final(s); // arc1.weight.Weight().Value2()),
if (f != LatticeWeight::Zero()) { // arc1.weight.String());
// Set final acoustic cost to 0.0 // composed_clat->AddArc(state_map[s],
f.SetValue2(0.0); // CompactLatticeArc(arc1.ilabel, arc2.olabel,
lat->SetFinal(s, f); // composed_weight, next_state));
} // }
} // }
} // }
// }
// fst::Connect(composed_clat);
// }
//
//
// void ComputeAcousticScoresMap(
// const Lattice &lat,
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > *acoustic_scores) {
// // typedef the arc, weight types
// typedef Lattice::Arc Arc;
// typedef Arc::Weight LatticeWeight;
// typedef Arc::StateId StateId;
//
// acoustic_scores->clear();
//
// std::vector<int32> state_times;
// LatticeStateTimes(lat, &state_times); // Assumes the input is top sorted
//
// KALDI_ASSERT(lat.Start() == 0);
//
// for (StateId s = 0; s < lat.NumStates(); s++) {
// int32 t = state_times[s];
// for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// const LatticeWeight &weight = arc.weight;
//
// int32 tid = arc.ilabel;
//
// if (tid != 0) {
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> >::iterator it = acoustic_scores->find(std::make_pair(t, tid));
// if (it == acoustic_scores->end()) {
// acoustic_scores->insert(std::make_pair(std::make_pair(t, tid),
// std::make_pair(weight.Value2(), 1)));
// } else {
// if (it->second.second == 2
// && it->second.first / it->second.second != weight.Value2()) {
// KALDI_VLOG(2) << "Transitions on the same frame have different "
// << "acoustic costs for tid " << tid << "; "
// << it->second.first / it->second.second
// << " vs " << weight.Value2();
// }
// it->second.first += weight.Value2();
// it->second.second++;
// }
// } else {
// // Arcs with epsilon input label (tid) must have 0 acoustic cost
// KALDI_ASSERT(weight.Value2() == 0);
// }
// }
//
// LatticeWeight f = lat.Final(s);
// if (f != LatticeWeight::Zero()) {
// // Final acoustic cost must be 0 as we are reading from
// // non-determinized, non-compact lattice
// KALDI_ASSERT(f.Value2() == 0.0);
// }
// }
// }
//
// void ReplaceAcousticScoresFromMap(
// const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > &acoustic_scores,
// Lattice *lat) {
// // typedef the arc, weight types
// typedef Lattice::Arc Arc;
// typedef Arc::Weight LatticeWeight;
// typedef Arc::StateId StateId;
//
// TopSortLatticeIfNeeded(lat);
//
// std::vector<int32> state_times;
// LatticeStateTimes(*lat, &state_times);
//
// KALDI_ASSERT(lat->Start() == 0);
//
// for (StateId s = 0; s < lat->NumStates(); s++) {
// int32 t = state_times[s];
// for (fst::MutableArcIterator<Lattice> aiter(lat, s);
// !aiter.Done(); aiter.Next()) {
// Arc arc(aiter.Value());
//
// int32 tid = arc.ilabel;
// if (tid != 0) {
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> >::const_iterator it = acoustic_scores.find(std::make_pair(t, tid));
// if (it == acoustic_scores.end()) {
// KALDI_ERR << "Could not find tid " << tid << " at time " << t
// << " in the acoustic scores map.";
// } else {
// arc.weight.SetValue2(it->second.first / it->second.second);
// }
// } else {
// // For epsilon arcs, set acoustic cost to 0.0
// arc.weight.SetValue2(0.0);
// }
// aiter.SetValue(arc);
// }
//
// LatticeWeight f = lat->Final(s);
// if (f != LatticeWeight::Zero()) {
// // Set final acoustic cost to 0.0
// f.SetValue2(0.0);
// lat->SetFinal(s, f);
// }
// }
// }
} // namespace kaldi } // namespace kaldi
...@@ -28,374 +28,427 @@ ...@@ -28,374 +28,427 @@
#include <map> #include <map>
#include "base/kaldi-common.h" #include "base/kaldi-common.h"
// #include "hmm/posterior.h"
#include "fstext/fstext-lib.h" #include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h" // #include "hmm/transition-model.h"
#include "itf/transition-information.h"
#include "lat/kaldi-lattice.h" #include "lat/kaldi-lattice.h"
// #include "itf/decodable-itf.h"
namespace kaldi { namespace kaldi {
// Redundant with the typedef in hmm/posterior.h. We want functions // /**
// using the Posterior type to be usable without a dependency on the // This function extracts the per-frame log likelihoods from a linear
// hmm library. // lattice (which we refer to as an 'nbest' lattice elsewhere in Kaldi code).
typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior; // The dimension of *per_frame_loglikes will be set to the
// number of input symbols in 'nbest'. The elements of
/** // '*per_frame_loglikes' will be set to the .Value2() elements of the lattice
This function extracts the per-frame log likelihoods from a linear // weights, which represent the acoustic costs; you may want to scale this
lattice (which we refer to as an 'nbest' lattice elsewhere in Kaldi code). // vector afterward by -1/acoustic_scale to get the original loglikes.
The dimension of *per_frame_loglikes will be set to the // If there are acoustic costs on input-epsilon arcs or the final-prob in 'nbest'
number of input symbols in 'nbest'. The elements of // (and this should not normally be the case in situations where it makes
'*per_frame_loglikes' will be set to the .Value2() elements of the lattice // sense to call this function), they will be included to the cost of the
weights, which represent the acoustic costs; you may want to scale this // preceding input symbol, or the following input symbol for input-epsilons
vector afterward by -1/acoustic_scale to get the original loglikes. // encountered prior to any input symbol. If 'nbest' has no input symbols,
If there are acoustic costs on input-epsilon arcs or the final-prob in 'nbest' // 'per_frame_loglikes' will be set to the empty vector.
(and this should not normally be the case in situations where it makes // **/
sense to call this function), they will be included to the cost of the // void GetPerFrameAcousticCosts(const Lattice &nbest,
preceding input symbol, or the following input symbol for input-epsilons // Vector<BaseFloat> *per_frame_loglikes);
encountered prior to any input symbol. If 'nbest' has no input symbols, //
'per_frame_loglikes' will be set to the empty vector. // /// This function iterates over the states of a topologically sorted lattice and
**/ // /// counts the time instance corresponding to each state. The times are returned
void GetPerFrameAcousticCosts(const Lattice &nbest, // /// in a vector of integers 'times' which is resized to have a size equal to the
Vector<BaseFloat> *per_frame_loglikes); // /// number of states in the lattice. The function also returns the maximum time
// /// in the lattice (this will equal the number of frames in the file).
/// This function iterates over the states of a topologically sorted lattice and // int32 LatticeStateTimes(const Lattice &lat, std::vector<int32> *times);
/// counts the time instance corresponding to each state. The times are returned //
/// in a vector of integers 'times' which is resized to have a size equal to the // /// As LatticeStateTimes, but in the CompactLattice format. Note: must
/// number of states in the lattice. The function also returns the maximum time // /// be topologically sorted. Returns length of the utterance in frames, which
/// in the lattice (this will equal the number of frames in the file). // /// might not be the same as the maximum time in the lattice, due to frames
int32 LatticeStateTimes(const Lattice &lat, std::vector<int32> *times); // /// in the final-prob.
// int32 CompactLatticeStateTimes(const CompactLattice &clat,
/// As LatticeStateTimes, but in the CompactLattice format. Note: must // std::vector<int32> *times);
/// be topologically sorted. Returns length of the utterance in frames, which //
/// might not be the same as the maximum time in the lattice, due to frames // /// This function does the forward-backward over lattices and computes the
/// in the final-prob. // /// posterior probabilities of the arcs. It returns the total log-probability
int32 CompactLatticeStateTimes(const CompactLattice &clat, // /// of the lattice. The Posterior quantities contain pairs of (transition-id, weight)
std::vector<int32> *times); // /// on each frame.
// /// If the pointer "acoustic_like_sum" is provided, this value is set to
/// This function does the forward-backward over lattices and computes the // /// the sum over the arcs, of the posterior of the arc times the
/// posterior probabilities of the arcs. It returns the total log-probability // /// acoustic likelihood [i.e. negated acoustic score] on that link.
/// of the lattice. The Posterior quantities contain pairs of (transition-id, weight) // /// This is used in combination with other quantities to work out
/// on each frame. // /// the objective function in MMI discriminative training.
/// If the pointer "acoustic_like_sum" is provided, this value is set to // BaseFloat LatticeForwardBackward(const Lattice &lat,
/// the sum over the arcs, of the posterior of the arc times the // Posterior *arc_post,
/// acoustic likelihood [i.e. negated acoustic score] on that link. // double *acoustic_like_sum = NULL);
/// This is used in combination with other quantities to work out //
/// the objective function in MMI discriminative training. // // This function is something similar to LatticeForwardBackward(), but it is on
BaseFloat LatticeForwardBackward(const Lattice &lat, // // the CompactLattice lattice format. Also we only need the alpha in the forward
Posterior *arc_post, // // path, not the posteriors.
double *acoustic_like_sum = NULL); // bool ComputeCompactLatticeAlphas(const CompactLattice &lat,
// std::vector<double> *alpha);
// This function is something similar to LatticeForwardBackward(), but it is on //
// the CompactLattice lattice format. Also we only need the alpha in the forward // // A sibling of the function CompactLatticeAlphas()... We compute the beta from
// path, not the posteriors. // // the backward path here.
bool ComputeCompactLatticeAlphas(const CompactLattice &lat, // bool ComputeCompactLatticeBetas(const CompactLattice &lat,
std::vector<double> *alpha); // std::vector<double> *beta);
//
// A sibling of the function CompactLatticeAlphas()... We compute the beta from //
// the backward path here. // // Computes (normal or Viterbi) alphas and betas; returns (total-prob, or
bool ComputeCompactLatticeBetas(const CompactLattice &lat, // // best-path negated cost) Note: in either case, the alphas and betas are
std::vector<double> *beta); // // negated costs. Requires that lat be topologically sorted. This code
// // will work for either CompactLattice or Latice.
// template<typename LatticeType>
// Computes (normal or Viterbi) alphas and betas; returns (total-prob, or // double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
// best-path negated cost) Note: in either case, the alphas and betas are // bool viterbi,
// negated costs. Requires that lat be topologically sorted. This code // std::vector<double> *alpha,
// will work for either CompactLattice or Lattice. // std::vector<double> *beta);
template<typename LatticeType> //
double ComputeLatticeAlphasAndBetas(const LatticeType &lat, //
bool viterbi, // /// Topologically sort the compact lattice if not already topologically sorted.
std::vector<double> *alpha, // /// Will crash if the lattice cannot be topologically sorted.
std::vector<double> *beta); // void TopSortCompactLatticeIfNeeded(CompactLattice *clat);
//
//
/// Topologically sort the compact lattice if not already topologically sorted. // /// Topologically sort the lattice if not already topologically sorted.
/// Will crash if the lattice cannot be topologically sorted. // /// Will crash if lattice cannot be topologically sorted.
void TopSortCompactLatticeIfNeeded(CompactLattice *clat); // void TopSortLatticeIfNeeded(Lattice *clat);
//
// /// Returns the depth of the lattice, defined as the average number of arcs (or
/// Topologically sort the lattice if not already topologically sorted. // /// final-prob strings) crossing any given frame. Returns 1 for empty lattices.
/// Will crash if lattice cannot be topologically sorted. // /// Requires that clat is topologically sorted!
void TopSortLatticeIfNeeded(Lattice *clat); // BaseFloat CompactLatticeDepth(const CompactLattice &clat,
// int32 *num_frames = NULL);
/// Returns the depth of the lattice, defined as the average number of arcs (or //
/// final-prob strings) crossing any given frame. Returns 1 for empty lattices. // /// This function returns, for each frame, the number of arcs crossing that
/// Requires that clat is topologically sorted! // /// frame.
BaseFloat CompactLatticeDepth(const CompactLattice &clat, // void CompactLatticeDepthPerFrame(const CompactLattice &clat,
int32 *num_frames = NULL); // std::vector<int32> *depth_per_frame);
//
/// This function returns, for each frame, the number of arcs crossing that //
/// frame. // /// This function limits the depth of the lattice, per frame: that means, it
void CompactLatticeDepthPerFrame(const CompactLattice &clat, // /// does not allow more than a specified number of arcs active on any given
std::vector<int32> *depth_per_frame); // /// frame. This can be used to reduce the size of the "very deep" portions of
// /// the lattice.
// void CompactLatticeLimitDepth(int32 max_arcs_per_frame,
/// This function limits the depth of the lattice, per frame: that means, it // CompactLattice *clat);
/// does not allow more than a specified number of arcs active on any given //
/// frame. This can be used to reduce the size of the "very deep" portions of //
/// the lattice. // /// Given a lattice, and a transition model to map pdf-ids to phones,
void CompactLatticeLimitDepth(int32 max_arcs_per_frame, // /// outputs for each frame the set of phones active on that frame. If
CompactLattice *clat); // /// sil_phones (which must be sorted and uniq) is nonempty, it excludes
// /// phones in this list.
// void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
/// Given a lattice, and a transition model to map pdf-ids to phones, // const std::vector<int32> &sil_phones,
/// outputs for each frame the set of phones active on that frame. If // std::vector<std::set<int32> > *active_phones);
/// sil_phones (which must be sorted and uniq) is nonempty, it excludes //
/// phones in this list. // /// Given a lattice, and a transition model to map pdf-ids to phones,
void LatticeActivePhones(const Lattice &lat, const TransitionInformation &trans, // /// replace the output symbols (presumably words), with phones; we
const std::vector<int32> &sil_phones, // /// use the TransitionModel to work out the phone sequence. Note
std::vector<std::set<int32> > *active_phones); // /// that the phone labels are not exactly aligned with the phone
// /// boundaries. We put a phone label to coincide with any transition
/// Given a lattice, and a transition model to map pdf-ids to phones, // /// to the final, nonemitting state of a phone (this state always exists,
/// replace the output symbols (presumably words), with phones; we // /// we ensure this in HmmTopology::Check()). This would be the last
/// use the TransitionModel to work out the phone sequence. Note // /// transition-id in the phone if reordering is not done (but typically
/// that the phone labels are not exactly aligned with the phone // /// we do reorder).
/// boundaries. We put a phone label to coincide with any transition // /// Also see PhoneAlignLattice, in phone-align-lattice.h.
/// to the final, nonemitting state of a phone (this state always exists, // void ConvertLatticeToPhones(const TransitionModel &trans_model,
/// we ensure this in HmmTopology::Check()). This would be the last // Lattice *lat);
/// transition-id in the phone if reordering is not done (but typically
/// we do reorder).
/// Also see PhoneAlignLattice, in phone-align-lattice.h.
void ConvertLatticeToPhones(const TransitionInformation &trans_model,
Lattice *lat);
/// Prunes a lattice or compact lattice. Returns true on success, false if /// Prunes a lattice or compact lattice. Returns true on success, false if
/// there was some kind of failure. /// there was some kind of failure.
template<class LatticeType> template<class LatticeType>
bool PruneLattice(BaseFloat beam, LatticeType *lat); bool PruneLattice(BaseFloat beam, LatticeType *lat);
//
/// Given a lattice, and a transition model to map pdf-ids to phones, // /// Given a lattice, and a transition model to map pdf-ids to phones,
/// replace the sequences of transition-ids with sequences of phones. // /// replace the sequences of transition-ids with sequences of phones.
/// Note that this is different from ConvertLatticeToPhones, in that // /// Note that this is different from ConvertLatticeToPhones, in that
/// we replace the transition-ids not the words. // /// we replace the transition-ids not the words.
void ConvertCompactLatticeToPhones(const TransitionInformation &trans_model, // void ConvertCompactLatticeToPhones(const TransitionModel &trans_model,
CompactLattice *clat); // CompactLattice *clat);
//
/// Boosts LM probabilities by b * [number of frame errors]; equivalently, adds // /// Boosts LM probabilities by b * [number of frame errors]; equivalently, adds
/// -b*[number of frame errors] to the graph-component of the cost of each arc/path. // /// -b*[number of frame errors] to the graph-component of the cost of each arc/path.
/// There is a frame error if a particular transition-id on a particular frame // /// There is a frame error if a particular transition-id on a particular frame
/// corresponds to a phone not matching transcription's alignment for that frame. // /// corresponds to a phone not matching transcription's alignment for that frame.
/// This is used in "margin-inspired" discriminative training, esp. Boosted MMI. // /// This is used in "margin-inspired" discriminative training, esp. Boosted MMI.
/// The TransitionInformation is used to map transition-ids in the lattice // /// The TransitionModel is used to map transition-ids in the lattice
/// input-side to phones; the phones appearing in // /// input-side to phones; the phones appearing in
/// "silence_phones" are treated specially in that we replace the frame error f // /// "silence_phones" are treated specially in that we replace the frame error f
/// (either zero or 1) for a frame, with the minimum of f or max_silence_error. // /// (either zero or 1) for a frame, with the minimum of f or max_silence_error.
/// For the normal recipe, max_silence_error would be zero. // /// For the normal recipe, max_silence_error would be zero.
/// Returns true on success, false if there was some kind of mismatch. // /// Returns true on success, false if there was some kind of mismatch.
/// At input, silence_phones must be sorted and unique. // /// At input, silence_phones must be sorted and unique.
bool LatticeBoost(const TransitionInformation &trans, // bool LatticeBoost(const TransitionModel &trans,
const std::vector<int32> &alignment, // const std::vector<int32> &alignment,
const std::vector<int32> &silence_phones, // const std::vector<int32> &silence_phones,
BaseFloat b, // BaseFloat b,
BaseFloat max_silence_error, // BaseFloat max_silence_error,
Lattice *lat); // Lattice *lat);
//
//
/** // /**
This function implements either the MPFE (minimum phone frame error) or SMBR // This function implements either the MPFE (minimum phone frame error) or SMBR
(state-level minimum bayes risk) forward-backward, depending on whether // (state-level minimum bayes risk) forward-backward, depending on whether
"criterion" is "mpfe" or "smbr". It returns the MPFE // "criterion" is "mpfe" or "smbr". It returns the MPFE
criterion of SMBR criterion for this utterance, and outputs the posteriors (which // criterion of SMBR criterion for this utterance, and outputs the posteriors (which
may be positive or negative) into "post". // may be positive or negative) into "post".
//
@param [in] trans The transition model. Used to map the // @param [in] trans The transition model. Used to map the
transition-ids to phones or pdfs. // transition-ids to phones or pdfs.
@param [in] silence_phones A list of integer ids of silence phones. The // @param [in] silence_phones A list of integer ids of silence phones. The
silence frames i.e. the frames where num_ali // silence frames i.e. the frames where num_ali
corresponds to a silence phones are treated specially. // corresponds to a silence phones are treated specially.
The behavior is determined by 'one_silence_class' // The behavior is determined by 'one_silence_class'
being false (traditional behavior) or true. // being false (traditional behavior) or true.
Usually in our setup, several phones including // Usually in our setup, several phones including
the silence, vocalized noise, non-spoken noise // the silence, vocalized noise, non-spoken noise
and unk are treated as "silence phones" // and unk are treated as "silence phones"
@param [in] lat The denominator lattice // @param [in] lat The denominator lattice
@param [in] num_ali The numerator alignment // @param [in] num_ali The numerator alignment
@param [in] criterion The objective function. Must be "mpfe" or "smbr" // @param [in] criterion The objective function. Must be "mpfe" or "smbr"
for MPFE (minimum phone frame error) or sMBR // for MPFE (minimum phone frame error) or sMBR
(state minimum bayes risk) training. // (state minimum bayes risk) training.
@param [in] one_silence_class Determines how the silence frames are treated. // @param [in] one_silence_class Determines how the silence frames are treated.
Setting this to false gives the old traditional behavior, // Setting this to false gives the old traditional behavior,
where the silence frames (according to num_ali) are // where the silence frames (according to num_ali) are
treated as incorrect. However, this means that the // treated as incorrect. However, this means that the
insertions are not penalized by the objective. // insertions are not penalized by the objective.
Setting this to true gives the new behaviour, where we // Setting this to true gives the new behaviour, where we
treat silence as any other phone, except that all pdfs // treat silence as any other phone, except that all pdfs
of silence phones are collapsed into a single class for // of silence phones are collapsed into a single class for
the frame-error computation. This can possible reduce // the frame-error computation. This can possible reduce
the insertions in the trained model. This is closer to // the insertions in the trained model. This is closer to
the WER metric that we actually care about, since WER is // the WER metric that we actually care about, since WER is
generally computed after filtering out noises, but // generally computed after filtering out noises, but
does penalize insertions. // does penalize insertions.
@param [out] post The "MBR posteriors" i.e. derivatives w.r.t to the // @param [out] post The "MBR posteriors" i.e. derivatives w.r.t to the
pseudo log-likelihoods of states at each frame. // pseudo log-likelihoods of states at each frame.
*/ // */
BaseFloat LatticeForwardBackwardMpeVariants( // BaseFloat LatticeForwardBackwardMpeVariants(
const TransitionInformation &trans, // const TransitionModel &trans,
const std::vector<int32> &silence_phones, // const std::vector<int32> &silence_phones,
const Lattice &lat, // const Lattice &lat,
const std::vector<int32> &num_ali, // const std::vector<int32> &num_ali,
std::string criterion, // std::string criterion,
bool one_silence_class, // bool one_silence_class,
Posterior *post); // Posterior *post);
//
/// This function takes a CompactLattice that should only contain a single // /**
/// linear sequence (e.g. derived from lattice-1best), and that should have been // This function can be used to compute posteriors for MMI, with a positive contribution
/// processed so that the arcs in the CompactLattice align correctly with the // for the numerator and a negative one for the denominator. This function is not actually
/// word boundaries (e.g. by lattice-align-words). It outputs 3 vectors of the // used in our normal MMI training recipes, where it's instead done using various command
/// same size, which give, for each word in the lattice (in sequence), the word // line programs that each do a part of the job. This function was written for use in
/// label and the begin time and length in frames. This is done even for zero // neural-net MMI training.
/// (epsilon) words, generally corresponding to optional silence-- if you don't //
/// want them, just ignore them in the output. // @param [in] trans The transition model. Used to map the
/// This function will print a warning and return false, if the lattice // transition-ids to phones or pdfs.
/// did not have the correct format (e.g. if it is empty or it is not // @param [in] lat The denominator lattice
/// linear). // @param [in] num_ali The numerator alignment
bool CompactLatticeToWordAlignment(const CompactLattice &clat, // @param [in] drop_frames If "drop_frames" is true, it will not compute any
std::vector<int32> *words, // posteriors on frames where the num and den have disjoint
std::vector<int32> *begin_times, // pdf-ids.
std::vector<int32> *lengths); // @param [in] convert_to_pdf_ids If "convert_to_pdfs_ids" is true, it will
// convert the output to be at the level of pdf-ids, not
/// A form of the shortest-path/best-path algorithm that's specially coded for // transition-ids.
/// CompactLattice. Requires that clat be acyclic. // @param [in] cancel If "cancel" is true, it will cancel out any positive and
void CompactLatticeShortestPath(const CompactLattice &clat, // negative parts from the same transition-id (or pdf-id,
CompactLattice *shortest_path); // if convert_to_pdf_ids == true).
// @param [out] arc_post The output MMI posteriors of transition-ids (or
/// This function expands a CompactLattice to ensure high-probability paths // pdf-ids if convert_to_pdf_ids == true) at each frame
/// have unique histories. Arcs with posteriors larger than epsilon get splitted. // i.e. the difference between the numerator
void ExpandCompactLattice(const CompactLattice &clat, // and denominator posteriors.
double epsilon, //
CompactLattice *expand_clat); // It returns the forward-backward likelihood of the lattice. */
// BaseFloat LatticeForwardBackwardMmi(
/// For each state, compute forward and backward best (viterbi) costs and its // const TransitionModel &trans,
/// traceback states (for generating best paths later). The forward best cost // const Lattice &lat,
/// for a state is the cost of the best path from the start state to the state. // const std::vector<int32> &num_ali,
/// The traceback state of this state is its predecessor state in the best path. // bool drop_frames,
/// The backward best cost for a state is the cost of the best path from the // bool convert_to_pdf_ids,
/// state to a final one. Its traceback state is the successor state in the best // bool cancel,
/// path in the forward direction. // Posterior *arc_post);
/// Note: final weights of states are in backward_best_cost_and_pred. //
/// Requires the input CompactLattice clat be acyclic. //
typedef std::vector<std::pair<double, // /// This function takes a CompactLattice that should only contain a single
CompactLatticeArc::StateId> > CostTraceType; // /// linear sequence (e.g. derived from lattice-1best), and that should have been
void CompactLatticeBestCostsAndTracebacks( // /// processed so that the arcs in the CompactLattice align correctly with the
const CompactLattice &clat, // /// word boundaries (e.g. by lattice-align-words). It outputs 3 vectors of the
CostTraceType *forward_best_cost_and_pred, // /// same size, which give, for each word in the lattice (in sequence), the word
CostTraceType *backward_best_cost_and_pred); // /// label and the begin time and length in frames. This is done even for zero
// /// (epsilon) words, generally corresponding to optional silence-- if you don't
/// This function adds estimated neural language model scores of words in a // /// want them, just ignore them in the output.
/// minimal list of hypotheses that covers a lattice, to the graph scores on the // /// This function will print a warning and return false, if the lattice
/// arcs. The list of hypotheses are generated by latbin/lattice-path-cover. // /// did not have the correct format (e.g. if it is empty or it is not
typedef unordered_map<std::pair<int32, int32>, double, PairHasher<int32> > MapT; // /// linear).
void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores, // bool CompactLatticeToWordAlignment(const CompactLattice &clat,
CompactLattice *clat); // std::vector<int32> *words,
// std::vector<int32> *begin_times,
/// This function add the word insertion penalty to graph score of each word // std::vector<int32> *lengths);
/// in the compact lattice //
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty, // /// This function takes a CompactLattice that should only contain a single
CompactLattice *clat); // /// linear sequence (e.g. derived from lattice-1best), and that should have been
// /// processed so that the arcs in the CompactLattice align correctly with the
/// This function *adds* the negated scores obtained from the Decodable object, // /// word boundaries (e.g. by lattice-align-words). It outputs 4 vectors of the
/// to the acoustic scores on the arcs. If you want to replace them, you should // /// same size, which give, for each word in the lattice (in sequence), the word
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns // /// label, the begin time and length in frames, and the pronunciation (sequence
/// true on success, false on error (typically some kind of mismatched inputs). // /// of phones). This is done even for zero words, corresponding to optional
bool RescoreCompactLattice(DecodableInterface *decodable, // /// silences -- if you don't want them, just ignore them in the output.
CompactLattice *clat); // /// This function will print a warning and return false, if the lattice
// /// did not have the correct format (e.g. if it is empty or it is not
// /// linear).
/// This function returns the number of words in the longest sentence in a // bool CompactLatticeToWordProns(
/// CompactLattice (i.e. the the maximum of any path, of the count of // const TransitionModel &tmodel,
/// olabels on that path). // const CompactLattice &clat,
int32 LongestSentenceLength(const Lattice &lat); // std::vector<int32> *words,
// std::vector<int32> *begin_times,
/// This function returns the number of words in the longest sentence in a // std::vector<int32> *lengths,
/// CompactLattice, i.e. the the maximum of any path, of the count of // std::vector<std::vector<int32> > *prons,
/// labels on that path... note, in CompactLattice, the ilabels and olabels // std::vector<std::vector<int32> > *phone_lengths);
/// are identical because it is an acceptor. //
int32 LongestSentenceLength(const CompactLattice &lat); //
// /// A form of the shortest-path/best-path algorithm that's specially coded for
// /// CompactLattice. Requires that clat be acyclic.
/// This function is like RescoreCompactLattice, but it is modified to avoid // void CompactLatticeShortestPath(const CompactLattice &clat,
/// computing probabilities on most frames where all the pdf-ids are the same. // CompactLattice *shortest_path);
/// (it needs the transition-model to work out whether two transition-ids map to //
/// the same pdf-id, and it assumes that the lattice has transition-ids on it). // /// This function expands a CompactLattice to ensure high-probability paths
/// The naive thing would be to just set all probabilities to zero on frames // /// have unique histories. Arcs with posteriors larger than epsilon get splitted.
/// where all the pdf-ids are the same (because this value won't affect the // void ExpandCompactLattice(const CompactLattice &clat,
/// lattice posterior). But this would become confusing when we compute // double epsilon,
/// corpus-level diagnostics such as the MMI objective function. Instead, // CompactLattice *expand_clat);
/// imagine speedup_factor = 100 (it must be >= 1.0)... with probability (1.0 / //
/// speedup_factor) we compute those likelihoods and multiply them by // /// For each state, compute forward and backward best (viterbi) costs and its
/// speedup_factor; otherwise we set them to zero. This gives the right // /// traceback states (for generating best paths later). The forward best cost
/// expected probability so our corpus-level diagnostics will be about right. // /// for a state is the cost of the best path from the start state to the state.
bool RescoreCompactLatticeSpeedup( // /// The traceback state of this state is its predecessor state in the best path.
const TransitionInformation &tmodel, // /// The backward best cost for a state is the cost of the best path from the
BaseFloat speedup_factor, // /// state to a final one. Its traceback state is the successor state in the best
DecodableInterface *decodable, // /// path in the forward direction.
CompactLattice *clat); // /// Note: final weights of states are in backward_best_cost_and_pred.
// /// Requires the input CompactLattice clat be acyclic.
// typedef std::vector<std::pair<double,
/// This function *adds* the negated scores obtained from the Decodable object, // CompactLatticeArc::StateId> > CostTraceType;
/// to the acoustic scores on the arcs. If you want to replace them, you should // void CompactLatticeBestCostsAndTracebacks(
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns // const CompactLattice &clat,
/// true on success, false on error (e.g. some kind of mismatched inputs). // CostTraceType *forward_best_cost_and_pred,
/// The input labels, if nonzero, are interpreted as transition-ids or whatever // CostTraceType *backward_best_cost_and_pred);
/// other index the Decodable object expects. //
bool RescoreLattice(DecodableInterface *decodable, // /// This function adds estimated neural language model scores of words in a
Lattice *lat); // /// minimal list of hypotheses that covers a lattice, to the graph scores on the
// /// arcs. The list of hypotheses are generated by latbin/lattice-path-cover.
/// This function Composes a CompactLattice format lattice with a // typedef unordered_map<std::pair<int32, int32>, double, PairHasher<int32> > MapT;
/// DeterministicOnDemandFst<fst::StdFst> format fst, and outputs another // void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
/// CompactLattice format lattice. The first element (the one that corresponds // CompactLattice *clat);
/// to LM weight) in CompactLatticeWeight is used for composition. //
/// // /// This function add the word insertion penalty to graph score of each word
/// Note that the DeterministicOnDemandFst interface is not "const", therefore // /// in the compact lattice
/// we cannot use "const" for <det_fst>. // void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
void ComposeCompactLatticeDeterministic( // CompactLattice *clat);
const CompactLattice& clat, //
fst::DeterministicOnDemandFst<fst::StdArc>* det_fst, // /// This function *adds* the negated scores obtained from the Decodable object,
CompactLattice* composed_clat); // /// to the acoustic scores on the arcs. If you want to replace them, you should
// /// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
/// This function computes the mapping from the pair // /// true on success, false on error (typically some kind of mismatched inputs).
/// (frame-index, transition-id) to the pair // bool RescoreCompactLattice(DecodableInterface *decodable,
/// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the // CompactLattice *clat);
/// transition-id in that frame. //
/// frame-index in the lattice. //
/// This function is useful for retaining the acoustic scores in a // /// This function returns the number of words in the longest sentence in a
/// non-compact lattice after a process like determinization where the // /// CompactLattice (i.e. the the maximum of any path, of the count of
/// frame-level acoustic scores are typically lost. // /// olabels on that path).
/// The function ReplaceAcousticScoresFromMap is used to restore the // int32 LongestSentenceLength(const Lattice &lat);
/// acoustic scores computed by this function. //
/// // /// This function returns the number of words in the longest sentence in a
/// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the // /// CompactLattice, i.e. the the maximum of any path, of the count of
/// function will crash. // /// labels on that path... note, in CompactLattice, the ilabels and olabels
/// @param [out] acoustic_scores // /// are identical because it is an acceptor.
/// Pointer to a map from the pair (frame-index, // int32 LongestSentenceLength(const CompactLattice &lat);
/// transition-id) to a pair (sum-of-acoustic-scores, //
/// num-of-occurences). //
/// Usually the acoustic scores for a pdf-id (and hence // /// This function is like RescoreCompactLattice, but it is modified to avoid
/// transition-id) on a frame will be the same for all the // /// computing probabilities on most frames where all the pdf-ids are the same.
/// occurences of the pdf-id in that frame. // /// (it needs the transition-model to work out whether two transition-ids map to
/// But if not, we will take the average of the acoustic // /// the same pdf-id, and it assumes that the lattice has transition-ids on it).
/// scores. Hence, we store both the sum-of-acoustic-scores // /// The naive thing would be to just set all probabilities to zero on frames
/// and the num-of-occurences of the transition-id in that // /// where all the pdf-ids are the same (because this value won't affect the
/// frame. // /// lattice posterior). But this would become confusing when we compute
void ComputeAcousticScoresMap( // /// corpus-level diagnostics such as the MMI objective function. Instead,
const Lattice &lat, // /// imagine speedup_factor = 100 (it must be >= 1.0)... with probability (1.0 /
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>, // /// speedup_factor) we compute those likelihoods and multiply them by
PairHasher<int32> > *acoustic_scores); // /// speedup_factor; otherwise we set them to zero. This gives the right
// /// expected probability so our corpus-level diagnostics will be about right.
/// This function restores acoustic scores computed using the function // bool RescoreCompactLatticeSpeedup(
/// ComputeAcousticScoresMap into the lattice. // const TransitionModel &tmodel,
/// // BaseFloat speedup_factor,
/// @param [in] acoustic_scores // DecodableInterface *decodable,
/// A map from the pair (frame-index, transition-id) to a // CompactLattice *clat);
/// pair (sum-of-acoustic-scores, num-of-occurences) of //
/// the occurences of the transition-id in that frame. //
/// See the comments for ComputeAcousticScoresMap for // /// This function *adds* the negated scores obtained from the Decodable object,
/// details. // /// to the acoustic scores on the arcs. If you want to replace them, you should
/// @param [out] lat Pointer to the output lattice. // /// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
void ReplaceAcousticScoresFromMap( // /// true on success, false on error (e.g. some kind of mismatched inputs).
const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>, // /// The input labels, if nonzero, are interpreted as transition-ids or whatever
PairHasher<int32> > &acoustic_scores, // /// other index the Decodable object expects.
Lattice *lat); // bool RescoreLattice(DecodableInterface *decodable,
// Lattice *lat);
//
// /// This function Composes a CompactLattice format lattice with a
// /// DeterministicOnDemandFst<fst::StdFst> format fst, and outputs another
// /// CompactLattice format lattice. The first element (the one that corresponds
// /// to LM weight) in CompactLatticeWeight is used for composition.
// ///
// /// Note that the DeterministicOnDemandFst interface is not "const", therefore
// /// we cannot use "const" for <det_fst>.
// void ComposeCompactLatticeDeterministic(
// const CompactLattice& clat,
// fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
// CompactLattice* composed_clat);
//
// /// This function computes the mapping from the pair
// /// (frame-index, transition-id) to the pair
// /// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the
// /// transition-id in that frame.
// /// frame-index in the lattice.
// /// This function is useful for retaining the acoustic scores in a
// /// non-compact lattice after a process like determinization where the
// /// frame-level acoustic scores are typically lost.
// /// The function ReplaceAcousticScoresFromMap is used to restore the
// /// acoustic scores computed by this function.
// ///
// /// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the
// /// function will crash.
// /// @param [out] acoustic_scores
// /// Pointer to a map from the pair (frame-index,
// /// transition-id) to a pair (sum-of-acoustic-scores,
// /// num-of-occurences).
// /// Usually the acoustic scores for a pdf-id (and hence
// /// transition-id) on a frame will be the same for all the
// /// occurences of the pdf-id in that frame.
// /// But if not, we will take the average of the acoustic
// /// scores. Hence, we store both the sum-of-acoustic-scores
// /// and the num-of-occurences of the transition-id in that
// /// frame.
// void ComputeAcousticScoresMap(
// const Lattice &lat,
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > *acoustic_scores);
//
// /// This function restores acoustic scores computed using the function
// /// ComputeAcousticScoresMap into the lattice.
// ///
// /// @param [in] acoustic_scores
// /// A map from the pair (frame-index, transition-id) to a
// /// pair (sum-of-acoustic-scores, num-of-occurences) of
// /// the occurences of the transition-id in that frame.
// /// See the comments for ComputeAcousticScoresMap for
// /// details.
// /// @param [out] lat Pointer to the output lattice.
// void ReplaceAcousticScoresFromMap(
// const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > &acoustic_scores,
// Lattice *lat);
} // namespace kaldi } // namespace kaldi
......
...@@ -22,8 +22,13 @@ using std::vector; ...@@ -22,8 +22,13 @@ using std::vector;
using kaldi::Vector; using kaldi::Vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet, Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend) const std::shared_ptr<FrontendInterface>& frontend,
: frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {} kaldi::BaseFloat acoustic_scale)
: frontend_(frontend),
nnet_(nnet),
frame_offset_(0),
frames_ready_(0),
acoustic_scale_(acoustic_scale) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_cache_ = likelihood; nnet_cache_ = likelihood;
...@@ -33,16 +38,30 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { ...@@ -33,16 +38,30 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
// Decodable::Init(DecodableConfig config) { // Decodable::Init(DecodableConfig config) {
//} //}
bool Decodable::IsLastFrame(int32 frame) const { // return the size of frame have computed.
CHECK_LE(frame, frames_ready_); int32 Decodable::NumFramesReady() const { return frames_ready_; }
return IsInputFinished() && (frame == frames_ready_ - 1);
// frame idx is from 0 to frame_ready_ -1;
bool Decodable::IsLastFrame(int32 frame) {
bool flag = EnsureFrameHaveComputed(frame);
return frame >= frames_ready_;
} }
int32 Decodable::NumIndices() const { return 0; } int32 Decodable::NumIndices() const { return 0; }
// the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob id.
int32 Decodable::TokenId2NnetId(int32 token_id) {
return token_id - 1;
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(index, nnet_cache_.NumCols());
return 0; CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_;
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
return acoustic_scale_ * std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) +
std::numeric_limits<float>::min());
} }
bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::EnsureFrameHaveComputed(int32 frame) {
...@@ -59,20 +78,23 @@ bool Decodable::AdvanceChunk() { ...@@ -59,20 +78,23 @@ bool Decodable::AdvanceChunk() {
} }
int32 nnet_dim = 0; int32 nnet_dim = 0;
Vector<BaseFloat> inferences; Vector<BaseFloat> inferences;
Matrix<BaseFloat> nnet_cache_tmp;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences); nnet_cache_.CopyRowsFromVec(inferences);
frame_offset_ = frames_ready_; frame_offset_ = frames_ready_;
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
return true; return true;
} }
bool Decodable::FrameLogLikelihood(int32 frame, vector<BaseFloat>* likelihood) { bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
std::vector<BaseFloat> result; std::vector<BaseFloat> result;
if (EnsureFrameHaveComputed(frame) == false) return false; if (EnsureFrameHaveComputed(frame) == false) return false;
likelihood->resize(nnet_cache_.NumCols()); likelihood->resize(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
(*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx); (*likelihood)[idx] =
nnet_cache_(frame - frame_offset_, idx) * acoustic_scale_;
} }
return true; return true;
} }
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "base/common.h" #include "base/common.h"
#include "frontend/audio/frontend_itf.h" #include "frontend/audio/frontend_itf.h"
#include "kaldi/decoder/decodable-itf.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_itf.h" #include "nnet/nnet_itf.h"
namespace ppspeech { namespace ppspeech {
...@@ -25,32 +25,38 @@ struct DecodableOpts; ...@@ -25,32 +25,38 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet, explicit Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend); const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale = 1.0);
// void Init(DecodableOpts config); // void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const; virtual bool IsLastFrame(int32 frame);
virtual int32 NumIndices() const; virtual int32 NumIndices() const;
virtual bool FrameLogLikelihood(int32 frame, // not logprob
virtual bool FrameLikelihood(int32 frame,
std::vector<kaldi::BaseFloat>* likelihood); std::vector<kaldi::BaseFloat>* likelihood);
virtual int32 NumFramesReady() const;
// for offline test // for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
void Reset(); void Reset();
bool IsInputFinished() const { return frontend_->IsFinished(); } bool IsInputFinished() const { return frontend_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame); bool EnsureFrameHaveComputed(int32 frame);
int32 TokenId2NnetId(int32 token_id);
private: private:
bool AdvanceChunk(); bool AdvanceChunk();
std::shared_ptr<FrontendInterface> frontend_; std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_; kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_; // the frame is nnet prob frame rather than audio feature frame
// nnet frame subsample the feature frame
// eg: 35 frame features output 8 frame inferences
int32 frame_offset_; int32 frame_offset_;
int32 frames_ready_; int32 frames_ready_;
// todo: feature frame mismatch with nnet inference frame // todo: feature frame mismatch with nnet inference frame
// eg: 35 frame features output 8 frame inferences
// so use subsampled_frame // so use subsampled_frame
int32 current_log_post_subsampled_offset_; int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_; int32 num_chunk_computed_;
kaldi::BaseFloat acoustic_scale_;
}; };
} // namespace ppspeech } // namespace ppspeech
...@@ -94,7 +94,6 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) { ...@@ -94,7 +94,6 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
void PaddleNnet::Reset() { InitCacheEncouts(opts_); } void PaddleNnet::Reset() { InitCacheEncouts(opts_); }
paddle_infer::Predictor* PaddleNnet::GetPredictor() { paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
paddle_infer::Predictor* predictor = nullptr; paddle_infer::Predictor* predictor = nullptr;
std::lock_guard<std::mutex> guard(pool_mutex); std::lock_guard<std::mutex> guard(pool_mutex);
int pred_id = 0; int pred_id = 0;
...@@ -110,7 +109,6 @@ paddle_infer::Predictor* PaddleNnet::GetPredictor() { ...@@ -110,7 +109,6 @@ paddle_infer::Predictor* PaddleNnet::GetPredictor() {
if (predictor) { if (predictor) {
pool_usages[pred_id] = true; pool_usages[pred_id] = true;
predictor_to_thread_id[predictor] = pred_id; predictor_to_thread_id[predictor] = pred_id;
LOG(INFO) << pred_id << " predictor create success";
} else { } else {
LOG(INFO) << "Failed to get predictor from pool !!!"; LOG(INFO) << "Failed to get predictor from pool !!!";
} }
...@@ -119,7 +117,6 @@ paddle_infer::Predictor* PaddleNnet::GetPredictor() { ...@@ -119,7 +117,6 @@ paddle_infer::Predictor* PaddleNnet::GetPredictor() {
} }
int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) { int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
LOG(INFO) << "attempt to releae a predictor";
std::lock_guard<std::mutex> guard(pool_mutex); std::lock_guard<std::mutex> guard(pool_mutex);
auto iter = predictor_to_thread_id.find(predictor); auto iter = predictor_to_thread_id.find(predictor);
...@@ -128,10 +125,8 @@ int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) { ...@@ -128,10 +125,8 @@ int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
return 0; return 0;
} }
LOG(INFO) << iter->second << " predictor will be release";
pool_usages[iter->second] = false; pool_usages[iter->second] = false;
predictor_to_thread_id.erase(predictor); predictor_to_thread_id.erase(predictor);
LOG(INFO) << "release success";
return 0; return 0;
} }
...@@ -152,7 +147,6 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features, ...@@ -152,7 +147,6 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
int feat_row = features.Dim() / feature_dim; int feat_row = features.Dim() / feature_dim;
std::vector<std::string> input_names = predictor->GetInputNames(); std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames(); std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: rows, cols: " << feat_row << ", " << feature_dim;
std::unique_ptr<paddle_infer::Tensor> input_tensor = std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]); predictor->GetInputHandle(input_names[0]);
...@@ -183,7 +177,6 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features, ...@@ -183,7 +177,6 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
LOG(INFO) << "predictor run occurs error"; LOG(INFO) << "predictor run occurs error";
} }
LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> h_out = std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]); predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape()); assert(h_cache->get_shape() == h_out->shape());
......
// fstbin/fstaddselfloops.cc
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "fst/fstlib.h"
#include "fstext/determinize-star.h"
#include "fstext/fstext-utils.h"
#include "fstext/kaldi-fst-io.h"
#include "util/parse-options.h"
#include "util/simple-io-funcs.h"
/* some test examples:
pushd ~/tmpdir
( echo 3; echo 4) > in.list
( echo 5; echo 6) > out.list
( echo "0 0 0 0"; echo "0 0" ) | fstcompile | fstaddselfloops in.list out.list
| fstprint ( echo "0 1 0 1"; echo " 0 2 1 0"; echo "1 0"; echo "2 0"; ) |
fstcompile | fstaddselfloops in.list out.list | fstprint
*/
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Adds self-loops to states of an FST to propagate disambiguation "
"symbols through it\n"
"They are added on each final state and each state with non-epsilon "
"output symbols\n"
"on at least one arc out of the state. Useful in conjunction with "
"predeterminize\n"
"\n"
"Usage: fstaddselfloops in-disambig-list out-disambig-list [in.fst "
"[out.fst] ]\n"
"E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n"
"in.list and out.list are lists of integers, one per line, of the\n"
"same length.\n";
ParseOptions po(usage);
po.Read(argc, argv);
if (po.NumArgs() < 2 || po.NumArgs() > 4) {
po.PrintUsage();
exit(1);
}
std::string disambig_in_rxfilename = po.GetArg(1),
disambig_out_rxfilename = po.GetArg(2),
fst_in_filename = po.GetOptArg(3),
fst_out_filename = po.GetOptArg(4);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
std::vector<int32> disambig_out;
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
if (disambig_in.size() != disambig_out.size())
KALDI_ERR
<< "fstaddselfloops: mismatch in size of disambiguation symbols";
AddSelfLoops(fst, disambig_in, disambig_out);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
}
// fstbin/fstdeterminizestar.cc
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "fst/fstlib.h"
#include "fstext/determinize-star.h"
#include "fstext/fstext-utils.h"
#include "fstext/kaldi-fst-io.h"
#include "util/parse-options.h"
#if !defined(_MSC_VER) && !defined(__APPLE__)
#include <signal.h> // Comment this line and the call to signal below if
// it causes compilation problems. It is only to enable a debugging procedure
// when determinization does not terminate. We are disabling this code if
// compiling on Windows because signal.h is not available there, and on
// MacOS due to a problem with <signal.h> in the initial release of Sierra.
#endif
/* some test examples:
( echo "0 0 0 0"; echo "0 0" ) | fstcompile | fstdeterminizestar | fstprint
( echo "0 0 1 0"; echo "0 0" ) | fstcompile | fstdeterminizestar | fstprint
( echo "0 0 1 0"; echo "0 1 1 0"; echo "0 0" ) | fstcompile |
fstdeterminizestar | fstprint # this last one fails [correctly]: ( echo "0 0 0
1"; echo "0 0" ) | fstcompile | fstdeterminizestar | fstprint
cd ~/tmpdir
while true; do
fstrand > 1.fst
fstpredeterminize out.lst 1.fst | fstdeterminizestar | fstrmsymbols out.lst
> 2.fst fstequivalent --random=true 1.fst 2.fst || echo "Test failed" echo -n
"." done
Test of debugging [with non-determinizable input]:
( echo " 0 0 1 0 1.0"; echo "0 1 1 0"; echo "1 1 1 0 0"; echo "0 2 2 0"; echo
"2"; echo "1" ) | fstcompile | fstdeterminizestar kill -SIGUSR1 [the process-id
of fstdeterminizestar] # prints out a bunch of debugging output showing the
mess it got itself into.
*/
bool debug_location = false;
void signal_handler(int) { debug_location = true; }
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Removes epsilons and determinizes in one step\n"
"\n"
"Usage: fstdeterminizestar [in.fst [out.fst] ]\n"
"\n"
"See also: fstdeterminizelog, lattice-determinize\n";
float delta = kDelta;
int max_states = -1;
bool use_log = false;
ParseOptions po(usage);
po.Register("use-log", &use_log, "Determinize in log semiring.");
po.Register("delta", &delta,
"Delta value used to determine equivalence of weights.");
po.Register(
"max-states", &max_states,
"Maximum number of states in determinized FST before it will abort.");
po.Read(argc, argv);
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2);
// This enables us to get traceback info from determinization that is
// not seeming to terminate.
#if !defined(_MSC_VER) && !defined(__APPLE__)
signal(SIGUSR1, signal_handler);
#endif
// Normal case: just files.
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str);
ArcSort(fst, ILabelCompare<StdArc>()); // improves speed.
if (use_log) {
DeterminizeStarInLog(fst, delta, &debug_location, max_states);
} else {
VectorFst<StdArc> det_fst;
DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states);
*fst = det_fst; // will do shallow copy and then det_fst goes
// out of scope anyway.
}
WriteFstKaldi(*fst, fst_out_str);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
// fstbin/fstisstochastic.cc
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "fst/fstlib.h"
#include "fstext/fstext-utils.h"
#include "fstext/kaldi-fst-io.h"
#include "util/kaldi-io.h"
#include "util/parse-options.h"
// e.g. of test:
// echo " 0 0" | fstcompile | fstisstochastic
// should return 0 and print "0 0" [meaning, min and
// max weight are one = exp(0)]
// echo " 0 1" | fstcompile | fstisstochastic
// should return 1, not stochastic, and print 1 1
// (echo "0 0 0 0 0.693147 "; echo "0 1 0 0 0.693147 "; echo "1 0" ) |
// fstcompile | fstisstochastic should return 0, stochastic; it prints "0
// -1.78e-07" for me (echo "0 0 0 0 0.693147 "; echo "0 1 0 0 0.693147 "; echo
// "1 0" ) | fstcompile | fstisstochastic --test-in-log=false should return 1,
// not stochastic in tropical; it prints "0 0.693147" for me (echo "0 0 0 0 0 ";
// echo "0 1 0 0 0 "; echo "1 0" ) | fstcompile | fstisstochastic
// --test-in-log=false should return 0, stochastic in tropical; it prints "0 0"
// for me (echo "0 0 0 0 0.693147 "; echo "0 1 0 0 0.693147 "; echo "1 0" ) |
// fstcompile | fstisstochastic --test-in-log=false --delta=1 returns 0 even
// though not stochastic because we gave it an absurdly large delta.
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Checks whether an FST is stochastic and exits with success if so.\n"
"Prints out maximum error (in log units).\n"
"\n"
"Usage: fstisstochastic [ in.fst ]\n";
float delta = 0.01;
bool test_in_log = true;
ParseOptions po(usage);
po.Register("delta", &delta, "Maximum error to accept.");
po.Register("test-in-log", &test_in_log,
"Test stochasticity in log semiring.");
po.Read(argc, argv);
if (po.NumArgs() > 1) {
po.PrintUsage();
exit(1);
}
std::string fst_in_filename = po.GetOptArg(1);
Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename);
bool ans;
StdArc::Weight min, max;
if (test_in_log)
ans = IsStochasticFstInLog(*fst, delta, &min, &max);
else
ans = IsStochasticFst(*fst, delta, &min, &max);
std::cout << min.Value() << " " << max.Value() << '\n';
delete fst;
if (ans)
return 0; // success;
else
return 1;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
// fstbin/fstminimizeencoded.cc
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "fst/fstlib.h"
#include "fstext/determinize-star.h"
#include "fstext/fstext-utils.h"
#include "fstext/kaldi-fst-io.h"
#include "util/kaldi-io.h"
#include "util/parse-options.h"
#include "util/text-utils.h"
/* some test examples:
( echo "0 0 0 0"; echo "0 0" ) | fstcompile | fstminimizeencoded | fstprint
( echo "0 1 0 0"; echo " 0 2 0 0"; echo "1 0"; echo "2 0"; ) | fstcompile |
fstminimizeencoded | fstprint
*/
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Minimizes FST after encoding [similar to fstminimize, but no "
"weight-pushing]\n"
"\n"
"Usage: fstminimizeencoded [in.fst [out.fst] ]\n";
float delta = kDelta;
ParseOptions po(usage);
po.Register("delta", &delta,
"Delta likelihood used for quantization of weights");
po.Read(argc, argv);
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
std::string fst_in_filename = po.GetOptArg(1),
fst_out_filename = po.GetOptArg(2);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
MinimizeEncoded(fst, delta);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
}
// fstbin/fsttablecompose.cc
// Copyright 2009-2011 Microsoft Corporation
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "fst/fstlib.h"
#include "fstext/fstext-utils.h"
#include "fstext/kaldi-fst-io.h"
#include "fstext/table-matcher.h"
#include "util/parse-options.h"
/*
cd ~/tmpdir
while true; do
fstrand | fstarcsort --sort_type=olabel > 1.fst; fstrand | fstarcsort
> 2.fst fstcompose 1.fst 2.fst > 3a.fst fsttablecompose 1.fst 2.fst > 3b.fst
fstequivalent --random=true 3a.fst 3b.fst || echo "Test failed"
echo -n "."
done
*/
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
/*
fsttablecompose should always give equivalent results to compose,
but it is more efficient for certain kinds of inputs.
In particular, it is useful when, say, the left FST has states
that typically either have epsilon olabels, or
one transition out for each of the possible symbols (as the
olabel). The same with the input symbols of the right-hand FST
is possible.
*/
const char *usage =
"Composition algorithm [between two FSTs of standard type, in "
"tropical\n"
"semiring] that is more efficient for certain cases-- in particular,\n"
"where one of the FSTs (the left one, if --match-side=left) has large\n"
"out-degree\n"
"\n"
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) "
"(fst2-rxfilename|fst2-rspecifier) [(out-rxfilename|out-rspecifier)]\n";
ParseOptions po(usage);
TableComposeOptions opts;
std::string match_side = "left";
std::string compose_filter = "sequence";
po.Register("connect", &opts.connect, "If true, trim FST before output.");
po.Register("match-side", &match_side,
"Side of composition to do table "
"match, one of: \"left\" or \"right\".");
po.Register("compose-filter", &compose_filter,
"Composition filter to use, "
"one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\"");
po.Read(argc, argv);
if (match_side == "left") {
opts.table_match_type = MATCH_OUTPUT;
} else if (match_side == "right") {
opts.table_match_type = MATCH_INPUT;
} else {
KALDI_ERR << "Invalid match-side option: " << match_side;
}
if (compose_filter == "alt_sequence") {
opts.filter_type = ALT_SEQUENCE_FILTER;
} else if (compose_filter == "auto") {
opts.filter_type = AUTO_FILTER;
} else if (compose_filter == "match") {
opts.filter_type = MATCH_FILTER;
} else if (compose_filter == "sequence") {
opts.filter_type = SEQUENCE_FILTER;
} else {
KALDI_ERR << "Invalid compose-filter option: " << compose_filter;
}
if (po.NumArgs() < 2 || po.NumArgs() > 3) {
po.PrintUsage();
exit(1);
}
std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2),
fst_out_str = po.GetOptArg(3);
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str);
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str);
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted.
if (fst1->Properties(fst::kOLabelSorted, true) == 0) {
KALDI_WARN << "The first FST is not olabel sorted.";
}
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
KALDI_WARN << "The second FST is not ilabel sorted.";
}
VectorFst<StdArc> composed_fst;
TableCompose(*fst1, *fst2, &composed_fst, opts);
delete fst1;
delete fst2;
WriteFstKaldi(composed_fst, fst_out_str);
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
#!/usr/bin/env bash
current_path=`pwd`
current_dir=`basename "$current_path"`
if [ "tools" != "$current_dir" ]; then
echo "You should run this script in tools/ directory!!"
exit 1
fi
if [ ! -d liblbfgs-1.10 ]; then
echo Installing libLBFGS library to support MaxEnt LMs
bash extras/install_liblbfgs.sh || exit 1
fi
! command -v gawk > /dev/null && \
echo "GNU awk is not installed so SRILM will probably not work correctly: refusing to install" && exit 1;
if [ $# -ne 3 ]; then
echo "SRILM download requires some information about you"
echo
echo "Usage: $0 <name> <organization> <email>"
exit 1
fi
srilm_url="http://www.speech.sri.com/projects/srilm/srilm_download.php"
post_data="WWW_file=srilm-1.7.3.tar.gz&WWW_name=$1&WWW_org=$2&WWW_email=$3"
if ! wget --post-data "$post_data" -O ./srilm.tar.gz "$srilm_url"; then
echo 'There was a problem downloading the file.'
echo 'Check you internet connection and try again.'
exit 1
fi
mkdir -p srilm
cd srilm
if [ -f ../srilm.tgz ]; then
tar -xvzf ../srilm.tgz # Old SRILM format
elif [ -f ../srilm.tar.gz ]; then
tar -xvzf ../srilm.tar.gz # Changed format type from tgz to tar.gz
fi
major=`gawk -F. '{ print $1 }' RELEASE`
minor=`gawk -F. '{ print $2 }' RELEASE`
micro=`gawk -F. '{ print $3 }' RELEASE`
if [ $major -le 1 ] && [ $minor -le 7 ] && [ $micro -le 1 ]; then
echo "Detected version 1.7.1 or earlier. Applying patch."
patch -p0 < ../extras/srilm.patch
fi
# set the SRILM variable in the top-level Makefile to this directory.
cp Makefile tmpf
cat tmpf | gawk -v pwd=`pwd` '/SRILM =/{printf("SRILM = %s\n", pwd); next;} {print;}' \
> Makefile || exit 1
rm tmpf
mtype=`sbin/machine-type`
echo HAVE_LIBLBFGS=1 >> common/Makefile.machine.$mtype
grep ADDITIONAL_INCLUDES common/Makefile.machine.$mtype | \
sed 's|$| -I$(SRILM)/../liblbfgs-1.10/include|' \
>> common/Makefile.machine.$mtype
grep ADDITIONAL_LDFLAGS common/Makefile.machine.$mtype | \
sed 's|$| -L$(SRILM)/../liblbfgs-1.10/lib/ -Wl,-rpath -Wl,$(SRILM)/../liblbfgs-1.10/lib/|' \
>> common/Makefile.machine.$mtype
make || exit
cd ..
(
[ ! -z "${SRILM}" ] && \
echo >&2 "SRILM variable is aleady defined. Undefining..." && \
unset SRILM
[ -f ./env.sh ] && . ./env.sh
[ ! -z "${SRILM}" ] && \
echo >&2 "SRILM config is already in env.sh" && exit
wd=`pwd`
wd=`readlink -f $wd || pwd`
echo "export SRILM=$wd/srilm"
dirs="\${PATH}"
for directory in $(cd srilm && find bin -type d ) ; do
dirs="$dirs:\${SRILM}/$directory"
done
echo "export PATH=$dirs"
) >> env.sh
echo >&2 "Installation of SRILM finished successfully"
echo >&2 "Please source the tools/env.sh in your path.sh to enable it"
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(arpa2fst ${CMAKE_CURRENT_SOURCE_DIR}/arpa2fst.cc)
target_include_directories(arpa2fst PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(arpa2fst )
// bin/arpa2fst.cc
//
// Copyright 2009-2011 Gilles Boulianne.
//
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABILITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "lm/arpa-lm-compiler.h"
#include "util/kaldi-io.h"
#include "util/parse-options.h"
int main(int argc, char *argv[]) {
using namespace kaldi; // NOLINT
try {
const char *usage =
"Convert an ARPA format language model into an FST\n"
"Usage: arpa2fst [opts] <input-arpa> <output-fst>\n"
" e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table="
"data/lang/words.txt lm/input.arpa G.fst\n\n"
"Note: When called without switches, the output G.fst will contain\n"
"an embedded symbol table. This is compatible with the way a previous\n"
"version of arpa2fst worked.\n";
ParseOptions po(usage);
ArpaParseOptions options;
options.Register(&po);
// Option flags.
std::string bos_symbol = "<s>";
std::string eos_symbol = "</s>";
std::string disambig_symbol;
std::string read_syms_filename;
std::string write_syms_filename;
bool keep_symbols = false;
bool ilabel_sort = true;
po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol");
po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
po.Register("disambig-symbol", &disambig_symbol,
"Disambiguator. If provided (e. g. #0), used on input side of "
"backoff links, and <s> and </s> are replaced with epsilons");
po.Register("read-symbol-table", &read_syms_filename,
"Use existing symbol table");
po.Register("write-symbol-table", &write_syms_filename,
"Write generated symbol table to a file");
po.Register("keep-symbols", &keep_symbols,
"Store symbol table with FST. Symbols always saved to FST if "
"symbol tables are neither read or written (otherwise symbols "
"would be lost entirely)");
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST");
po.Read(argc, argv);
if (po.NumArgs() != 1 && po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册