未验证 提交 fdc189a3 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1599 from SmileGoat/add_tlg

[Speechx] add tlg decoder
...@@ -48,7 +48,7 @@ wer=./aishell_wer ...@@ -48,7 +48,7 @@ wer=./aishell_wer
nj=40 nj=40
export GLOG_logtostderr=1 export GLOG_logtostderr=1
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj #./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
data=$PWD/data data=$PWD/data
# 3. gen linear feat # 3. gen linear feat
...@@ -72,10 +72,42 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \ ...@@ -72,10 +72,42 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \ --param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \ --dict_file=$lm_model_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result --result_wspecifier=ark,t:$data/split${nj}/JOB/result
cat $data/split${nj}/*/result > $label_file cat $data/split${nj}/*/result > ${label_file}
local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
# 4. decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \
offline_decoder_sliding_chunk_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
cat $data/split${nj}/*/result_lm > ${label_file}_lm
local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
graph_dir=./aishell_graph
if [ ! -d $ ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip -d aishell_graph.zip
fi
# 5. test TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \
offline_wfst_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
local/compute-wer.py --char=1 --v=1 $label_file $text > $wer cat $data/split${nj}/*/result_tlg > ${label_file}_tlg
tail $wer local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg
\ No newline at end of file
...@@ -8,6 +8,10 @@ add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_ ...@@ -8,6 +8,10 @@ add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc) add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
......
...@@ -27,7 +27,7 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); ...@@ -27,7 +27,7 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model"); DEFINE_string(lm_path, "", "language model");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=5) downsampling module."); "receptive field of two CNN(kernel=5) downsampling module.");
...@@ -45,7 +45,6 @@ using kaldi::BaseFloat; ...@@ -45,7 +45,6 @@ using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
// test ds2 online decoder by feeding speech feature // test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
...@@ -63,7 +62,6 @@ int main(int argc, char* argv[]) { ...@@ -63,7 +62,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "dict path: " << dict_file; LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path; LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts; ppspeech::CTCBeamSearchOptions opts;
...@@ -138,10 +136,16 @@ int main(int argc, char* argv[]) { ...@@ -138,10 +136,16 @@ int main(int argc, char* argv[]) {
} }
std::string result; std::string result;
result = decoder.GetFinalBestPath(); result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
decodable->Reset(); decodable->Reset();
decoder.Reset(); decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done; ++num_done;
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test TLG decoder by feeding speech feature.
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string word_symbol_table = FLAGS_word_symbol_table;
std::string graph_path = FLAGS_graph_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "word symbol path: " << word_symbol_table;
LOG(INFO) << "graph path: " << graph_path;
int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts;
opts.word_symbol_table = word_symbol_table;
opts.fst_path = graph_path;
opts.opts.max_active = FLAGS_max_active;
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_downsampling_rate;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
...@@ -6,5 +6,6 @@ add_library(decoder STATIC ...@@ -6,5 +6,6 @@ add_library(decoder STATIC
ctc_decoders/decoder_utils.cpp ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp ctc_decoders/scorer.cpp
ctc_tlg_decoder.cc
) )
target_link_libraries(decoder PUBLIC kenlm utils fst) target_link_libraries(decoder PUBLIC kenlm utils fst)
\ No newline at end of file
...@@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode( ...@@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode(
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob; vector<BaseFloat> frame_prob;
bool flag = bool flag =
decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break; if (flag == false) break;
likelihood.push_back(frame_prob); likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood); AdvanceDecoding(likelihood);
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h" #include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h" #include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h" #include "kaldi/decoder/decodable-itf.h"
#include "util/parse-options.h" #include "util/parse-options.h"
#pragma once #pragma once
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr);
word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding();
frame_decoded_size_ = 0;
}
void TLGDecoder::InitDecoder() {
decoder_->InitDecoding();
frame_decoded_size_ = 0;
}
void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) {
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get());
}
}
void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
decoder_->AdvanceDecoding(decodable, 1);
frame_decoded_size_++;
}
void TLGDecoder::Reset() {
InitDecoder();
return;
}
std::string TLGDecoder::GetFinalBestPath() {
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
std::vector<int> words_id;
decoder_->GetBestPath(&lat, true);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
}
}
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/basic_types.h"
#include "kaldi/decoder/decodable-itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
namespace ppspeech {
struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts;
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {}
};
class TLGDecoder {
public:
explicit TLGDecoder(TLGDecoderOptions opts);
void InitDecoder();
void Decode();
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable);
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
// the frame size which have decoded starts from 0.
int32 frame_decoded_size_;
};
} // namespace ppspeech
\ No newline at end of file
...@@ -120,4 +120,4 @@ void CMVN::ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats) { ...@@ -120,4 +120,4 @@ void CMVN::ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm_, feats); ApplyCmvn(stats_, var_norm_, feats);
} }
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
...@@ -4,3 +4,6 @@ add_subdirectory(base) ...@@ -4,3 +4,6 @@ add_subdirectory(base)
add_subdirectory(util) add_subdirectory(util)
add_subdirectory(feat) add_subdirectory(feat)
add_subdirectory(matrix) add_subdirectory(matrix)
add_subdirectory(lat)
add_subdirectory(fstext)
add_subdirectory(decoder)
add_library(kaldi-decoder
lattice-faster-decoder.cc
lattice-faster-online-decoder.cc
)
target_link_libraries(kaldi-decoder PUBLIC kaldi-lat)
...@@ -121,7 +121,7 @@ class DecodableInterface { ...@@ -121,7 +121,7 @@ class DecodableInterface {
/// decoding-from-matrix setting where we want to allow the last delta or /// decoding-from-matrix setting where we want to allow the last delta or
/// LDA /// LDA
/// features to be flushed out for compatibility with the baseline setup. /// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0; virtual bool IsLastFrame(int32 frame) = 0;
/// The call NumFramesReady() will return the number of frames currently /// The call NumFramesReady() will return the number of frames currently
/// available /// available
...@@ -143,7 +143,7 @@ class DecodableInterface { ...@@ -143,7 +143,7 @@ class DecodableInterface {
/// this is for compatibility with OpenFst). /// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0; virtual int32 NumIndices() const = 0;
virtual bool FrameLogLikelihood( virtual bool FrameLikelihood(
int32 frame, std::vector<kaldi::BaseFloat>* likelihood) = 0; int32 frame, std::vector<kaldi::BaseFloat>* likelihood) = 0;
......
...@@ -1007,14 +1007,10 @@ template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc>, decoder::StdToken> ...@@ -1007,14 +1007,10 @@ template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc>, decoder::StdToken>
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::StdToken >; template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::StdToken >;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::StdToken >; template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::StdToken >;
template class LatticeFasterDecoderTpl<fst::ConstGrammarFst, decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc> , decoder::BackpointerToken>; template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc> , decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::BackpointerToken >; template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::BackpointerToken >;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::BackpointerToken >; template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::BackpointerToken >;
template class LatticeFasterDecoderTpl<fst::ConstGrammarFst, decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::BackpointerToken>;
} // end namespace kaldi. } // end namespace kaldi.
...@@ -23,11 +23,10 @@ ...@@ -23,11 +23,10 @@
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_ #ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_ #define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#include "decoder/grammar-fst.h"
#include "fst/fstlib.h" #include "fst/fstlib.h"
#include "fst/memory.h" #include "fst/memory.h"
#include "fstext/fstext-lib.h" #include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h" #include "decoder/decodable-itf.h"
#include "lat/determinize-lattice-pruned.h" #include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h" #include "lat/kaldi-lattice.h"
#include "util/hash-list.h" #include "util/hash-list.h"
......
...@@ -278,8 +278,8 @@ bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned( ...@@ -278,8 +278,8 @@ bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >; template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >; template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >; template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >; //template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >; //template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;
} // end namespace kaldi. } // end namespace kaldi.
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include "util/stl-utils.h" #include "util/stl-utils.h"
#include "util/hash-list.h" #include "util/hash-list.h"
#include "fst/fstlib.h" #include "fst/fstlib.h"
#include "itf/decodable-itf.h" #include "decoder/decodable-itf.h"
#include "fstext/fstext-lib.h" #include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h" #include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h" #include "lat/kaldi-lattice.h"
......
add_library(kaldi-fstext
kaldi-fst-io.cc
)
target_link_libraries(kaldi-fstext PUBLIC kaldi-util)
此差异已折叠。
// fstext/determinize-lattice.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
#define KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "fstext/lattice-weight.h"
namespace fst {
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-lattice.cc
/*
DeterminizeLattice implements a special form of determinization
with epsilon removal, optimized for a phase of lattice generation.
Its input is an FST with weight-type BaseWeightType (usually a pair of
floats, with a lexicographical type of order, such as
LatticeWeightTpl<float>). Typically this would be a state-level lattice, with
input symbols equal to words, and output-symbols equal to p.d.f's (so like
the inverse of HCLG). Imagine representing this as an acceptor of type
CompactLatticeWeightTpl<float>, in which the input/output symbols are words,
and the weights contain the original weights together with strings (with zero
or one symbol in them) containing the original output labels (the p.d.f.'s).
We determinize this using acceptor determinization with epsilon removal.
Remember (from lattice-weight.h) that CompactLatticeWeightTpl has a special
kind of semiring where we always take the string corresponding to the best
cost (of type BaseWeightType), and discard the other. This corresponds to
taking the best output-label sequence (of p.d.f.'s) for each input-label
sequence (of words). We couldn't use the Gallic weight for this, or it would
die as soon as it detected that the input FST was non-functional. In our
case, any acyclic FST (and many cyclic ones) can be determinized. We assume
that there is a function Compare(const BaseWeightType &a, const
BaseWeightType &b) that returns (-1, 0, 1) according to whether (a < b, a ==
b, a > b) in the total order on the BaseWeightType... this information should
be the same as NaturalLess would give, but it's more efficient to do it this
way. You can define this for things like TropicalWeight if you need to
instantiate this class for that weight type.
We implement this determinization in a special way to make it efficient for
the types of FSTs that we will apply it to. One issue is that if we
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
type vector<IntType>, the algorithm takes time quadratic in the length of
words (in states), because propagating each arc involves copying a whole
vector (of integers representing p.d.f.'s). Instead we use a hash structure
where each string is a pointer (Entry*), and uses a hash from (Entry*,
IntType), to the successor string (and a way to get the latest IntType and
the ancestor Entry*). [this is the class LatticeStringRepository].
Another issue is that rather than representing a determinized-state as a
collection of (state, weight), we represent it in a couple of reduced forms.
Suppose a determinized-state is a collection of (state, weight) pairs; call
this the "canonical representation". Note: these collections are always
normalized to remove any common weight and string part. Define end-states as
the subset of states that have an arc out of them with a label on, or are
final. If we represent a determinized-state a the set of just its
(end-state, weight) pairs, this will be a valid and more compact
representation, and will lead to a smaller set of determinized states (like
early minimization). Call this collection of (end-state, weight) pairs the
"minimal representation". As a mechanism to reduce compute, we can also
consider another representation. In the determinization algorithm, we start
off with a set of (begin-state, weight) pairs (where the "begin-states" are
initial or have a label on the transition into them), and the "canonical
representation" consists of the epsilon-closure of this set (i.e. follow
epsilons). Call this set of (begin-state, weight) pairs, appropriately
normalized, the "initial representation". If two initial representations are
the same, the "canonical representation" and hence the "minimal
representation" will be the same. We can use this to reduce compute. Note
that if two initial representations are different, this does not preclude the
other representations from being the same.
*/
struct DeterminizeLatticeOptions {
float delta; // A small offset used to measure equality of weights.
int max_mem; // If >0, determinization will fail and return false
// when the algorithm's (approximate) memory consumption crosses this
// threshold.
int max_loop; // If >0, can be used to detect non-determinizable input
// (a case that wouldn't be caught by max_mem).
DeterminizeLatticeOptions() : delta(kDelta), max_mem(-1), max_loop(-1) {}
};
/**
This function implements the normal version of DeterminizeLattice, in which
the output strings are represented using sequences of arcs, where all but
the first one has an epsilon on the input side. The debug_ptr argument is
an optional pointer to a bool that, if it becomes true while the algorithm
is executing, the algorithm will print a traceback and terminate (used in
fstdeterminizestar.cc debug non-terminating determinization). More
efficient if ifst is arc-sorted on input label. If the number of arcs gets
more than max_states, it will throw std::runtime_error (otherwise this code
does not use exceptions). This is mainly useful for debug. */
template <class Weight, class IntType>
bool DeterminizeLattice(
const Fst<ArcTpl<Weight> > &ifst, MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticeOptions opts = DeterminizeLatticeOptions(),
bool *debug_ptr = NULL);
/* This is a version of DeterminizeLattice with a slightly more "natural"
output format, where the output sequences are encoded using the
CompactLatticeArcTpl template (i.e. the sequences of output symbols are
represented directly as strings) More efficient if ifst is arc-sorted on
input label. If the #arcs gets more than max_arcs, it will throw
std::runtime_error (otherwise this code does not use exceptions). This is
mainly useful for debug.
*/
template <class Weight, class IntType>
bool DeterminizeLattice(
const Fst<ArcTpl<Weight> > &ifst,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticeOptions opts = DeterminizeLatticeOptions(),
bool *debug_ptr = NULL);
/// @} end "addtogroup fst_extensions"
} // end namespace fst
#include "fstext/determinize-lattice-inl.h"
#endif // KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
此差异已折叠。
// fstext/determinize-star.h
// Copyright 2009-2011 Microsoft Corporation
// 2014 Guoguo Chen
// 2015 Hainan Xu
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_DETERMINIZE_STAR_H_
#define KALDI_FSTEXT_DETERMINIZE_STAR_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
#include <algorithm>
#include <map>
#include <set>
#include <stdexcept> // this algorithm uses exceptions
#include <vector>
namespace fst {
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-star.cc
/*
DeterminizeStar implements determinization with epsilon removal, which we
distinguish with a star.
We define a determinized* FST as one in which no state has more than one
transition with the same input-label. Epsilon input labels are not allowed
except starting from states that have exactly one arc exiting them (and are
not final). [In the normal definition of determinized, epsilon-input labels
are not allowed at all, whereas in Mohri's definition, epsilons are treated
as ordinary symbols]. The determinized* definition is intended to simulate
the effect of allowing strings of output symbols at each state.
The algorithm implemented here takes an Fst<Arc>, and a pointer to a
MutableFst<Arc> where it puts its output. The weight type is assumed to be a
float-weight. It does epsilon removal and determinization.
This algorithm may fail if the input has epsilon cycles under
certain circumstances (i.e. the semiring is non-idempotent, e.g. the log
semiring, or there are negative cost epsilon cycles).
This implementation is much less fancy than the one in fst/determinize.h, and
does not have an "on-demand" version.
The algorithm is a fairly normal determinization algorithm. We keep in
memory the subsets of states, together with their leftover strings and their
weights. The only difference is we detect input epsilon transitions and
treat them "specially".
*/
// This algorithm will be slightly faster if you sort the input fst on input
// label.
/**
This function implements the normal version of DeterminizeStar, in which the
output strings are represented using sequences of arcs, where all but the
first one has an epsilon on the input side. The debug_ptr argument is an
optional pointer to a bool that, if it becomes true while the algorithm is
executing, the algorithm will print a traceback and terminate (used in
fstdeterminizestar.cc debug non-terminating determinization).
If max_states is positive, it will stop determinization and throw an
exception as soon as the max-states is reached. This can be useful in test.
If allow_partial is true, the algorithm will output partial results when the
specified max_states is reached (when larger than zero), instead of throwing
out an error.
Caution, the return status is un-intuitive: this function will return false
if determinization completed normally, and true if it was stopped early by
reaching the 'max-states' limit, and a partial FST was generated.
*/
template <class F>
bool DeterminizeStar(F &ifst, MutableFst<typename F::Arc> *ofst, // NOLINT
float delta = kDelta, bool *debug_ptr = NULL,
int max_states = -1, bool allow_partial = false);
/* This is a version of DeterminizeStar with a slightly more "natural" output
format, where the output sequences are encoded using the GallicArc (i.e. the
output symbols are strings. If max_states is positive, it will stop
determinization and throw an exception as soon as the max-states is reached.
This can be useful in test. If allow_partial is true, the algorithm will
output partial results when the specified max_states is reached (when larger
than zero), instead of throwing out an error.
Caution, the return status is un-intuitive: this function will return false
if determinization completed normally, and true if it was stopped early by
reaching the 'max-states' limit, and a partial FST was generated.
*/
template <class F>
bool DeterminizeStar(F &ifst, // NOLINT
MutableFst<GallicArc<typename F::Arc> > *ofst,
float delta = kDelta, bool *debug_ptr = NULL,
int max_states = -1, bool allow_partial = false);
/// @} end "addtogroup fst_extensions"
} // end namespace fst
#include "fstext/determinize-star-inl.h"
#endif // KALDI_FSTEXT_DETERMINIZE_STAR_H_
// fstext/fstext-lib.h
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (author:
// Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_FSTEXT_LIB_H_
#define KALDI_FSTEXT_FSTEXT_LIB_H_
#include "fst/fstlib.h"
#include "fstext/determinize-lattice.h"
#include "fstext/determinize-star.h"
#include "fstext/fstext-utils.h"
#include "fstext/kaldi-fst-io.h"
#include "fstext/lattice-utils.h"
#include "fstext/lattice-weight.h"
#include "fstext/pre-determinize.h"
#include "fstext/table-matcher.h"
#endif // KALDI_FSTEXT_FSTEXT_LIB_H_
此差异已折叠。
此差异已折叠。
// fstext/kaldi-fst-io-inl.h
// Copyright 2009-2011 Microsoft Corporation
// 2012-2015 Johns Hopkins University (Author: Daniel Povey)
// 2013 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_KALDI_FST_IO_INL_H_
#define KALDI_FSTEXT_KALDI_FST_IO_INL_H_
#include <string>
#include <vector>
#include "util/text-utils.h"
namespace fst {
template <class Arc>
void WriteFstKaldi(std::ostream &os, bool binary, const VectorFst<Arc> &t) {
bool ok;
if (binary) {
// Binary-mode writing.
ok = t.Write(os, FstWriteOptions());
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached. Write a
// newline to start the FST; in a table, the first line of the FST will
// appear on its own line.
os << '\n';
bool acceptor = false, write_one = false;
FstPrinter<Arc> printer(t, t.InputSymbols(), t.OutputSymbols(), NULL,
acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail()) KALDI_ERR << "Stream failure detected writing FST to stream";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not something in the original
// OpenFst code].
os << '\n';
ok = os.good();
}
if (!ok) {
KALDI_ERR << "Error writing FST to stream";
}
}
// Utility function used in ReadFstKaldi
template <class W>
inline bool StrToWeight(const std::string &s, bool allow_zero, W *w) {
std::istringstream strm(s);
strm >> *w;
if (strm.fail() || (!allow_zero && *w == W::Zero())) {
return false;
}
return true;
}
template <class Arc>
void ReadFstKaldi(std::istream &is, bool binary, VectorFst<Arc> *fst) {
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
if (binary) {
// We don't have access to the filename here, so write [unknown].
VectorFst<Arc> *ans =
VectorFst<Arc>::Read(is, fst::FstReadOptions(std::string("[unknown]")));
if (ans == NULL) {
KALDI_ERR << "Error reading FST from stream.";
}
*fst = *ans; // shallow copy.
delete ans;
} else {
// Consume the \r on Windows, the \n that the text-form FST format starts
// with, and any extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') {
is.get(); // consume the newline.
} else { // saw spaces but no newline.. this is not expected.
KALDI_ERR << "Reading FST: unexpected sequence of spaces "
<< " at file position " << is.tellg();
}
using kaldi::ConvertStringToInteger;
using kaldi::SplitStringToIntegers;
using std::string;
using std::vector;
fst->DeleteStates();
string line;
size_t nline = 0;
string separator = FLAGS_fst_field_separator + "\r\n";
while (std::getline(is, line)) {
nline++;
vector<string> col;
// on Windows we'll write in text and read in binary mode.
kaldi::SplitStringToVector(line, separator.c_str(), true, &col);
if (col.size() == 0) break; // Empty line is a signal to stop, in our
// archive format.
if (col.size() > 5) {
KALDI_ERR << "Bad line in FST: " << line;
}
StateId s;
if (!ConvertStringToInteger(col[0], &s)) {
KALDI_ERR << "Bad line in FST: " << line;
}
while (s >= fst->NumStates()) fst->AddState();
if (nline == 1) fst->SetStart(s);
bool ok = true;
Arc arc;
Weight w;
StateId d = s;
switch (col.size()) {
case 1:
fst->SetFinal(s, Weight::One());
break;
case 2:
if (!StrToWeight(col[1], true, &w))
ok = false;
else
fst->SetFinal(s, w);
break;
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
ok = false;
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel);
if (ok) {
d = arc.nextstate;
arc.weight = Weight::One();
fst->AddArc(s, arc);
}
break;
case 5:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel) &&
StrToWeight(col[4], false, &arc.weight);
if (ok) {
d = arc.nextstate;
fst->AddArc(s, arc);
}
break;
default:
ok = false;
}
while (d >= fst->NumStates()) fst->AddState();
if (!ok) KALDI_ERR << "Bad line in FST: " << line;
}
}
}
template <class Arc> // static
bool VectorFstTplHolder<Arc>::Write(std::ostream &os, bool binary, const T &t) {
try {
WriteFstKaldi(os, binary, t);
return true;
} catch (...) {
return false;
}
}
template <class Arc> // static
bool VectorFstTplHolder<Arc>::Read(std::istream &is) {
Clear();
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading Fst";
return false;
} else if (isspace(c)) { // The text form of the FST begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is
// not space).
try {
t_ = new VectorFst<Arc>();
ReadFstKaldi(is, false, t_);
} catch (...) {
Clear();
return false;
}
} else { // reading a binary FST.
try {
t_ = new VectorFst<Arc>();
ReadFstKaldi(is, true, t_);
} catch (...) {
Clear();
return false;
}
}
return true;
}
} // namespace fst.
#endif // KALDI_FSTEXT_KALDI_FST_IO_INL_H_
// fstext/kaldi-fst-io.cc
// Copyright 2009-2011 Microsoft Corporation
// 2012-2015 Johns Hopkins University (Author: Daniel Povey)
// 2013 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "fstext/kaldi-fst-io.h"
#include <string>
#include "base/kaldi-error.h"
#include "base/kaldi-math.h"
#include "util/kaldi-io.h"
namespace fst {
VectorFst<StdArc> *ReadFstKaldi(std::string rxfilename) {
if (rxfilename == "") rxfilename = "-"; // interpret "" as stdin,
// for compatibility with OpenFst conventions.
kaldi::Input ki(rxfilename);
fst::FstHeader hdr;
if (!hdr.Read(ki.Stream(), rxfilename))
KALDI_ERR << "Reading FST: error reading FST header from "
<< kaldi::PrintableRxfilename(rxfilename);
FstReadOptions ropts("<unspecified>", &hdr);
VectorFst<StdArc> *fst = VectorFst<StdArc>::Read(ki.Stream(), ropts);
if (!fst)
KALDI_ERR << "Could not read fst from "
<< kaldi::PrintableRxfilename(rxfilename);
return fst;
}
// Register const fst to load it automatically. Other types like
// olabel_lookahead or ngram or compact_fst should be registered
// through OpenFst registration API.
static fst::FstRegisterer<VectorFst<StdArc>> VectorFst_StdArc_registerer;
static fst::FstRegisterer<ConstFst<StdArc>> ConstFst_StdArc_registerer;
Fst<StdArc> *ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err) {
if (rxfilename == "") rxfilename = "-"; // interpret "" as stdin,
// for compatibility with OpenFst conventions.
kaldi::Input ki(rxfilename);
fst::FstHeader hdr;
// Read FstHeader which contains the type of FST
if (!hdr.Read(ki.Stream(), rxfilename)) {
if (throw_on_err) {
KALDI_ERR << "Reading FST: error reading FST header from "
<< kaldi::PrintableRxfilename(rxfilename);
} else {
KALDI_WARN << "We fail to read FST header from "
<< kaldi::PrintableRxfilename(rxfilename)
<< ". A NULL pointer is returned.";
return NULL;
}
}
// Check the type of Arc
if (hdr.ArcType() != fst::StdArc::Type()) {
if (throw_on_err) {
KALDI_ERR << "FST with arc type " << hdr.ArcType()
<< " is not supported.";
} else {
KALDI_WARN << "Fst with arc type" << hdr.ArcType()
<< " is not supported. A NULL pointer is returned.";
return NULL;
}
}
// Read the FST
FstReadOptions ropts("<unspecified>", &hdr);
Fst<StdArc> *fst = Fst<StdArc>::Read(ki.Stream(), ropts);
if (!fst) {
if (throw_on_err) {
KALDI_ERR << "Could not read fst from "
<< kaldi::PrintableRxfilename(rxfilename);
} else {
KALDI_WARN << "Could not read fst from "
<< kaldi::PrintableRxfilename(rxfilename)
<< ". A NULL pointer is returned.";
return NULL;
}
}
return fst;
}
VectorFst<StdArc> *CastOrConvertToVectorFst(Fst<StdArc> *fst) {
// This version currently supports ConstFst<StdArc> or VectorFst<StdArc>
std::string real_type = fst->Type();
KALDI_ASSERT(real_type == "vector" || real_type == "const");
if (real_type == "vector") {
return dynamic_cast<VectorFst<StdArc> *>(fst);
} else {
// As the 'fst' can't cast to VectorFst, we create a new
// VectorFst<StdArc> initialized by 'fst', and delete 'fst'.
VectorFst<StdArc> *new_fst = new VectorFst<StdArc>(*fst);
delete fst;
return new_fst;
}
}
void ReadFstKaldi(std::string rxfilename, fst::StdVectorFst *ofst) {
fst::StdVectorFst *fst = ReadFstKaldi(rxfilename);
*ofst = *fst;
delete fst;
}
void WriteFstKaldi(const VectorFst<StdArc> &fst, std::string wxfilename) {
if (wxfilename == "") wxfilename = "-"; // interpret "" as stdout,
// for compatibility with OpenFst conventions.
bool write_binary = true, write_header = false;
kaldi::Output ko(wxfilename, write_binary, write_header);
FstWriteOptions wopts(kaldi::PrintableWxfilename(wxfilename));
fst.Write(ko.Stream(), wopts);
}
fst::VectorFst<fst::StdArc> *ReadAndPrepareLmFst(std::string rxfilename) {
// ReadFstKaldi() will die with exception on failure.
fst::VectorFst<fst::StdArc> *ans = fst::ReadFstKaldi(rxfilename);
if (ans->Properties(fst::kAcceptor, true) == 0) {
// If it's not already an acceptor, project on the output, i.e. copy olabels
// to ilabels. Generally the G.fst's on disk will have the disambiguation
// symbol #0 on the input symbols of the backoff arc, and projection will
// replace them with epsilons which is what is on the output symbols of
// those arcs.
fst::Project(ans, fst::PROJECT_OUTPUT);
}
if (ans->Properties(fst::kILabelSorted, true) == 0) {
// Make sure LM is sorted on ilabel.
fst::ILabelCompare<fst::StdArc> ilabel_comp;
fst::ArcSort(ans, ilabel_comp);
}
return ans;
}
} // end namespace fst
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
add_library(kaldi-lat
determinize-lattice-pruned.cc
lattice-functions.cc
)
target_link_libraries(kaldi-lat PUBLIC kaldi-util)
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(arpa2fst ${CMAKE_CURRENT_SOURCE_DIR}/arpa2fst.cc)
target_include_directories(arpa2fst PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(arpa2fst )
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册