未验证 提交 b584b969 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1400 from SmileGoat/feature_dev

[speechx]add linear spectrogram feature extractor
...@@ -35,3 +35,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks ...@@ -35,3 +35,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks
* [librosa](https://github.com/librosa/librosa/blob/main/LICENSE.md) * [librosa](https://github.com/librosa/librosa/blob/main/LICENSE.md)
- ISC License - ISC License
- Audio feature - Audio feature
* [ThreadPool](https://github.com/progschj/ThreadPool/blob/master/COPYING)
- zlib License
- ThreadPool
...@@ -39,15 +39,40 @@ FetchContent_Declare( ...@@ -39,15 +39,40 @@ FetchContent_Declare(
GIT_TAG "20210324.1" GIT_TAG "20210324.1"
) )
FetchContent_MakeAvailable(absl) FetchContent_MakeAvailable(absl)
include_directories(${absl_SOURCE_DIR})
# libsndfile # libsndfile
#include(FetchContent)
#FetchContent_Declare(
# libsndfile
# GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
# GIT_TAG "1.0.31"
#)
#FetchContent_MakeAvailable(libsndfile)
# todo boost build
#include(FetchContent)
#FetchContent_Declare(
# Boost
# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.zip
# URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
#)
#FetchContent_MakeAvailable(Boost)
#include_directories(${Boost_SOURCE_DIR})
set(BOOST_ROOT ${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
include_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
link_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0/stage/lib)
include(FetchContent) include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
libsndfile kenlm
GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git" GIT_REPOSITORY "https://github.com/kpu/kenlm.git"
GIT_TAG "1.0.31" GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a"
) )
FetchContent_MakeAvailable(libsndfile) FetchContent_MakeAvailable(kenlm)
add_dependencies(kenlm Boost)
include_directories(${kenlm_SOURCE_DIR})
# gflags # gflags
FetchContent_Declare( FetchContent_Declare(
...@@ -65,7 +90,7 @@ FetchContent_Declare( ...@@ -65,7 +90,7 @@ FetchContent_Declare(
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
) )
FetchContent_MakeAvailable(glog) FetchContent_MakeAvailable(glog)
include_directories(${glog_BINARY_DIR}) include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
# gtest # gtest
FetchContent_Declare(googletest FetchContent_Declare(googletest
...@@ -93,6 +118,22 @@ add_dependencies(openfst gflags glog) ...@@ -93,6 +118,22 @@ add_dependencies(openfst gflags glog)
link_directories(${openfst_PREFIX_DIR}/lib) link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include) include_directories(${openfst_PREFIX_DIR}/include)
set(PADDLE_LIB ${fc_patch}/paddle-lib/paddle_inference)
include_directories("${PADDLE_LIB}/paddle/include")
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include")
#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib")
#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
link_directories("${PADDLE_LIB}/paddle/lib")
add_subdirectory(speechx) add_subdirectory(speechx)
#openblas #openblas
...@@ -121,4 +162,4 @@ add_subdirectory(speechx) ...@@ -121,4 +162,4 @@ add_subdirectory(speechx)
# if dir do not have CmakeLists.txt # if dir do not have CmakeLists.txt
#add_library(lib_name STATIC file.cc) #add_library(lib_name STATIC file.cc)
#target_link_libraries(lib_name item0 item1) #target_link_libraries(lib_name item0 item1)
#add_dependencies(lib_name depend-target) #add_dependencies(lib_name depend-target)
\ No newline at end of file
...@@ -4,11 +4,43 @@ project(speechx LANGUAGES CXX) ...@@ -4,11 +4,43 @@ project(speechx LANGUAGES CXX)
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas) link_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
include_directories( include_directories(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/kaldi ${CMAKE_CURRENT_SOURCE_DIR}/kaldi
) )
add_subdirectory(kaldi) add_subdirectory(kaldi)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/utils
)
add_subdirectory(utils)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/frontend
)
add_subdirectory(frontend)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/nnet
)
add_subdirectory(nnet)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc) add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc)
target_link_libraries(mfcc-test kaldi-mfcc) target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
#add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
#target_link_libraries(offline_decoder_main nnet decoder gflags glog)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "kaldi/base/kaldi-types.h" #include "kaldi/base/kaldi-types.h"
#include <limits.h> #include <limits>
typedef float BaseFloat; typedef float BaseFloat;
typedef double double64; typedef double double64;
...@@ -35,7 +35,7 @@ typedef unsigned char uint8; ...@@ -35,7 +35,7 @@ typedef unsigned char uint8;
typedef unsigned short uint16; typedef unsigned short uint16;
typedef unsigned int uint32; typedef unsigned int uint32;
if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef unsigned long uint64; typedef unsigned long uint64;
#else #else
typedef unsigned long long uint64; typedef unsigned long long uint64;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <deque>
#include <iostream>
#include <istream>
#include <fstream>
#include <map>
#include <memory>
#include <ostream>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <mutex>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h"
#include "base/macros.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "gflags/gflags.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "glog/logging.h"
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
namespace ppspeech { namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ #define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \ TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete void operator=(const TypeName&) = delete
#endif
} // namespace pp_speech } // namespace pp_speech
\ No newline at end of file
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
// This software is provided 'as-is', without any express or implied
// warranty. In no event will the authors be held liable for any damages
// arising from the use of this software.
// Permission is granted to anyone to use this software for any purpose,
// including commercial applications, and to alter it and redistribute it
// freely, subject to the following restrictions:
// 1. The origin of this software must not be misrepresented; you must not
// claim that you wrote the original software. If you use this software
// in a product, an acknowledgment in the product documentation would be
// appreciated but is not required.
// 2. Altered source versions must be plainly marked as such, and must not be
// misrepresented as being the original software.
// 3. This notice may not be removed or altered from any source
// distribution.
// this code is from https://github.com/progschj/ThreadPool
#ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
}
#endif
// todo refactor, repalce with gtest
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
int32 chunk_size,
std::vector<kaldi::Matrix<BaseFloat>> feature_chunks) {
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
CTCBeamSearchOptions opts;
CTCBeamSearch decoder(opts);
ModelOptions model_opts;
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(model_opts));
Decodable decodable();
decodable.SetNnet(nnet);
int32 chunk_size = 0;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
vector<Matrix<BaseFloat>> feature_chunks;
SplitFeature(feature, chunk_size, &feature_chunks);
for (auto feature_chunk : feature_chunks) {
decodable.FeedFeatures(feature_chunk);
decoder.InitDecoder();
decoder.AdvanceDecode(decodable, chunk_size);
}
decodable.InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
\ No newline at end of file
// todo refactor, repalce with gtest
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "kaldi/feat/wave-reader.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData &wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(waveform);
linear_spectrogram.ReadFeats(&features);
feat_writer.Write(utt, features);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
\ No newline at end of file
aux_source_directory(. DIR_LIB_SRCS) project(decoder)
add_library(decoder STATIC ${DIR_LIB_SRCS})
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder
ctc_beam_search_decoder.cc
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
)
target_link_libraries(decoder kenlm)
\ No newline at end of file
#include "base/basic_types.h"
struct DecoderResult {
BaseFloat acoustic_score;
std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp;
};
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
#include "decoder/ctc_decoders/decoder_utils.h"
#include "utils/file_utils.h"
namespace ppspeech {
using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) :
opts_(opts),
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
num_frame_decoded_(0),
root(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
opts_.beta,
opts_.lm_path,
vocabulary_);
}
void CTCBeamSearch::Reset() {
num_frame_decoded_ = 0;
ResetPrefixes();
}
void CTCBeamSearch::InitDecoder() {
blank_id = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id = it - vocabulary_.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary_.size()) {
space_id = -2;
}
ResetPrefixes();
root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get());
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
}
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
}
int32 CTCBeamSearch::NumFrameDecoded() {
return num_frame_decoded_;
}
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break;
}
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(likelihood);
max_frames--;
}
}
void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) {
delete prefixes[i];
prefixes[i] = nullptr;
}
}
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
vector<string>& nbest_words) {
kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
}
string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return result[0].second;
}
string CTCBeamSearch::GetFinalBestPath() {
CalculateApproxScore();
LMRescore();
return GetBestPath();
}
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
probs_seq[i][j] = static_cast<double>(probs[i][j]);
}
}
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
}
vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} // if
num_frame_decoded_++;
} // for probs_seq
}
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur,
log_prob_c +
prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur,
log_prob_c +
prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur,
log_p);
}
} // end of loop over prefix
return 0;
}
void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (init_ext_scorer_ != nullptr) {
vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = init_ext_scorer_->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight:
approx_ctc -=
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
}
void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha;
score += init_ext_scorer_->beta;
prefix->score += score;
}
}
}
}
} // namespace ppspeech
\ No newline at end of file
#include "base/common.h"
#include "nnet/decodable-itf.h"
#include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once
namespace ppspeech {
struct CTCBeamSearchOptions {
std::string dict_file;
std::string lm_path;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions() :
dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(0) {
}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register("beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch {
public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {}
void InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void Reset();
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
//std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
};
} // namespace basr
\ No newline at end of file
../../../third_party/ctc_decoders
\ No newline at end of file
project(frontend)
add_library(frontend
normalizer.cc
linear_spectrogram.cc
)
target_link_libraries(frontend kaldi-matrix)
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// wrap the fbank feat of kaldi, todo (SmileGoat)
#include "kaldi/feat/feature-mfcc.h"
#incldue "kaldi/matrix/kaldi-vector.h"
namespace ppspeech {
class FbankExtractor : FeatureExtractorInterface {
public:
explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor);
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0;
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
kaldi::Vector<kaldi::BaseFloat>* feat) const;
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/basic_types.h"
#include "kaldi/matrix/kaldi-vector.h"
namespace ppspeech {
class FeatureExtractorInterface {
public:
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0;
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/linear_spectrogram.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/matrix/matrix-functions.h"
namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat;
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector;
//todo remove later
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
}
void CopyStdVector2Vector_(const vector<BaseFloat>& input,
Vector<BaseFloat>* output) {
if (input.empty()) return;
output->Resize(input.size());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
}
LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift();
fft_points_ = window_size;
hanning_window_.resize(window_size);
double a = M_2PI / (window_size - 1);
hanning_window_energy_ = 0;
for (int i = 0; i < window_size; ++i) {
hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
}
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
}
void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) {
base_extractor_->AcceptWaveform(input);
}
void LinearSpectrogram::Hanning(vector<float>* data) const {
CHECK_GE(data->size(), hanning_window_.size());
for (size_t i = 0; i < hanning_window_.size(); ++i) {
data->at(i) *= hanning_window_[i];
}
}
bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
vector<BaseFloat>* real,
vector<BaseFloat>* img) const {
Vector<BaseFloat> v_tmp;
CopyStdVector2Vector_(*v, &v_tmp);
RealFft(&v_tmp, true);
CopyVector2StdVector_(v_tmp, v);
real->push_back(v->at(0));
img->push_back(0);
for (int i = 1; i < v->size() / 2; i++) {
real->push_back(v->at(2 * i));
img->push_back(v->at(2 * i + 1));
}
real->push_back(v->at(1));
img->push_back(0);
return true;
}
// todo remove later
void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
Vector<BaseFloat> tmp;
waveform_.Resize(base_extractor_->Dim());
Compute(tmp, &waveform_);
vector<vector<BaseFloat>> result;
vector<BaseFloat> feats_vec;
CopyVector2StdVector_(waveform_, &feats_vec);
Compute(feats_vec, result);
feats->Resize(result.size(), result[0].size());
for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
for (int col_idx = 0; col_idx < result.size(); ++col_idx) {
(*feats)(row_idx, col_idx) = result[row_idx][col_idx];
}
}
waveform_.Resize(0);
}
void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) {
// todo
return;
}
// only for test, remove later
// todo: compute the feature frame by frame.
void LinearSpectrogram::Compute(const VectorBase<kaldi::BaseFloat>& input,
VectorBase<kaldi::BaseFloat>* feature) {
base_extractor_->Read(feature);
}
// Compute spectrogram feat, only for test, remove later
// todo: refactor later (SmileGoat)
bool LinearSpectrogram::Compute(const vector<float>& wave,
vector<vector<float>>& feat) {
int num_samples = wave.size();
const int& frame_length = opts_.frame_opts.WindowSize();
const int& sample_rate = opts_.frame_opts.samp_freq;
const int& frame_shift = opts_.frame_opts.WindowShift();
const int& fft_points = fft_points_;
const float scale = hanning_window_energy_ * frame_shift;
if (num_samples < frame_length) {
return true;
}
int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
feat.resize(num_frames);
vector<float> fft_real((fft_points_ / 2 + 1), 0);
vector<float> fft_img((fft_points_ / 2 + 1), 0);
vector<float> v(frame_length, 0);
vector<float> power((fft_points / 2 + 1));
for (int i = 0; i < num_frames; ++i) {
vector<float> data(wave.data() + i * frame_shift,
wave.data() + i * frame_shift + frame_length);
Hanning(&data);
fft_img.clear();
fft_real.clear();
v.assign(data.begin(), data.end());
if (NumpyFft(&v, &fft_real, &fft_img)) {
LOG(ERROR)<< i << " fft compute occurs error, please checkout the input data";
return false;
}
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
for (int j = 0; j < (fft_points / 2 + 1); ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
feat[i][j] = power[j];
if (j == 0 || j == feat[0].size() - 1) {
feat[i][j] /= scale;
} else {
feat[i][j] *= (2.0 / scale);
}
// log added eps=1e-14
feat[i][j] = std::log(feat[i][j] + 1e-14);
}
}
return true;
}
} // namespace ppspeech
\ No newline at end of file
#pragma once
#include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h"
#include "base/common.h"
namespace ppspeech {
struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts;
LinearSpectrogramOptions():
frame_opts() {}
void Register(kaldi::OptionsItf* opts) {
frame_opts.Register(opts);
}
};
class LinearSpectrogram : public FeatureExtractorInterface {
public:
explicit LinearSpectrogram(const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; }
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);
private:
void Hanning(std::vector<kaldi::BaseFloat>* data) const;
bool Compute(const std::vector<kaldi::BaseFloat>& wave,
std::vector<std::vector<kaldi::BaseFloat>>& feat);
void Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feature);
bool NumpyFft(std::vector<kaldi::BaseFloat>* v,
std::vector<kaldi::BaseFloat>* real,
std::vector<kaldi::BaseFloat>* img) const;
kaldi::int32 fft_points_;
size_t dim_;
std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat)
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// wrap the mfcc feat of kaldi, todo (SmileGoat)
#include "kaldi/feat/feature-mfcc.h"
\ No newline at end of file
#include "frontend/normalizer.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
opts_ = opts;
dim_ = 0;
}
void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase<BaseFloat>& input) {
dim_ = input.Dim();
waveform_.Resize(input.Dim());
waveform_.CopyFromVec(input);
}
void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
if (waveform_.Dim() == 0) return;
Compute(waveform_, feat);
}
//todo remove later
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
}
void CopyStdVector2Vector(const vector<BaseFloat>& input,
VectorBase<BaseFloat>* output) {
if (input.empty()) return;
assert(input.size() == output->Dim());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
}
bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const {
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples;
samples.resize(input.Dim());
for (int32 i = 0; i < samples.size(); ++i) {
samples[i] = input(i);
}
// square
for (auto &d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
}
mean_square += d * d;
}
// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
// Note that this is an in-place transformation.
for (auto &item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
CopyStdVector2Vector(samples, feat);
return true;
}
/*
PPNormalizer::PPNormalizer(
const PPNormalizerOptions& opts,
const std::unique_ptr<FeatureExtractorInterface>& pre_extractor) {
}
void PPNormalizer::AcceptWavefrom(const Vector<BaseFloat>& input) {
}
void PPNormalizer::Read(Vector<BaseFloat>* feat) {
}
bool PPNormalizer::Compute(const Vector<BaseFloat>& input,
Vector<BaseFloat>>* feat) {
if ((input.Dim() % mean_.Dim()) == 0) {
LOG(ERROR) << "CMVN dimension is wrong!";
return false;
}
try {
int32 size = mean_.Dim();
feat->Resize(input.Dim());
for (int32 row_idx = 0; row_idx < j; ++row_idx) {
int32 base_idx = row_idx * size;
for (int32 idx = 0; idx < mean_.Dim(); ++idx) {
(*feat)(base_idx + idx) = (input(base_dix + idx) - mean_(idx))* variance_(idx);
}
}
} catch(const std::exception& e) {
std::cerr << e.what() << '\n';
return false;
}
return true;
}*/
} // namespace ppspeech
\ No newline at end of file
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech {
struct DecibelNormalizerOptions {
float target_db;
float max_gain_db;
bool convert_int_float;
DecibelNormalizerOptions() :
target_db(-20),
max_gain_db(300.0),
convert_int_float(false) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("target-db", &target_db, "target db for db normalization");
opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float");
}
};
class DecibelNormalizer : public FeatureExtractorInterface {
public:
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return 0; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
private:
DecibelNormalizerOptions opts_;
size_t dim_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> waveform_;
};
/*
struct NormalizerOptions {
std::string mean_std_path;
NormalizerOptions() :
mean_std_path("") {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("mean-std", &mean_std_path, "mean std file");
}
};
// todo refactor later (SmileGoat)
class PPNormalizer : public FeatureExtractorInterface {
public:
explicit PPNormalizer(const NormalizerOptions& opts,
const std::unique_ptr<FeatureExtractorInterface>& pre_extractor);
~PPNormalizer() {}
virtual void AcceptWavefrom(const kaldi::Vector<kaldi::BaseFloat>& input);
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat);
virtual size_t Dim() const;
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& input,
kaldi::Vector<kaldi::BaseFloat>>& feat);
private:
bool _initialized;
kaldi::Vector<float> mean_;
kaldi::Vector<float> variance_;
NormalizerOptions _opts;
};
*/
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// extract the window of kaldi feat.
此差异已折叠。
此差异已折叠。
// decoder/lattice-faster-online-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.cc, about how to maintain this
// file in sync with lattice-faster-decoder.cc
#include "decoder/lattice-faster-online-decoder.h"
#include "lat/lattice-functions.h"
namespace kaldi {
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::TestGetBestPath(
bool use_final_probs) const {
Lattice lat1;
{
Lattice raw_lat;
this->GetRawLattice(&raw_lat, use_final_probs);
ShortestPath(raw_lat, &lat1);
}
Lattice lat2;
GetBestPath(&lat2, use_final_probs);
BaseFloat delta = 0.1;
int32 num_paths = 1;
if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
KALDI_WARN << "Best-path test failed";
return false;
} else {
return true;
}
}
// Outputs an FST corresponding to the single best path through the lattice.
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetBestPath(Lattice *olat,
bool use_final_probs) const {
olat->DeleteStates();
BaseFloat final_graph_cost;
BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
if (iter.Done())
return false; // would have printed warning.
StateId state = olat->AddState();
olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
while (!iter.Done()) {
LatticeArc arc;
iter = TraceBackBestPath(iter, &arc);
arc.nextstate = state;
StateId new_state = olat->AddState();
olat->AddArc(new_state, arc);
state = new_state;
}
olat->SetStart(state);
return true;
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::BestPathEnd(
bool use_final_probs,
BaseFloat *final_cost_out) const {
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "BestPathEnd() with use_final_probs == false";
KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
"You cannot call BestPathEnd if no frames were decoded.");
unordered_map<Token*, BaseFloat> final_costs_local;
const unordered_map<Token*, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ :final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
// Singly linked list of tokens on last frame (access list through "next"
// pointer).
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_final_cost = 0;
Token *best_tok = NULL;
for (Token *tok = this->active_toks_.back().toks;
tok != NULL; tok = tok->next) {
BaseFloat cost = tok->tot_cost, final_cost = 0.0;
if (use_final_probs && !final_costs.empty()) {
// if we are instructed to use final-probs, and any final tokens were
// active on final frame, include the final-prob in the cost of the token.
typename unordered_map<Token*, BaseFloat>::const_iterator
iter = final_costs.find(tok);
if (iter != final_costs.end()) {
final_cost = iter->second;
cost += final_cost;
} else {
cost = std::numeric_limits<BaseFloat>::infinity();
}
}
if (cost < best_cost) {
best_cost = cost;
best_tok = tok;
best_final_cost = final_cost;
}
}
if (best_tok == NULL) { // this should not happen, and is likely a code error or
// caused by infinities in likelihoods, but I'm not making
// it a fatal error for now.
KALDI_WARN << "No final token found.";
}
if (final_cost_out)
*final_cost_out = best_final_cost;
return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::TraceBackBestPath(
BestPathIterator iter, LatticeArc *oarc) const {
KALDI_ASSERT(!iter.Done() && oarc != NULL);
Token *tok = static_cast<Token*>(iter.tok);
int32 cur_t = iter.frame, step_t = 0;
if (tok->backpointer != NULL) {
// retrieve the correct forward link(with the best link cost)
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
ForwardLinkT *link;
for (link = tok->backpointer->links;
link != NULL; link = link->next) {
if (link->next_tok == tok) { // this is a link to "tok"
BaseFloat graph_cost = link->graph_cost,
acoustic_cost = link->acoustic_cost;
BaseFloat cost = graph_cost + acoustic_cost;
if (cost < best_cost) {
oarc->ilabel = link->ilabel;
oarc->olabel = link->olabel;
if (link->ilabel != 0) {
KALDI_ASSERT(static_cast<size_t>(cur_t) < this->cost_offsets_.size());
acoustic_cost -= this->cost_offsets_[cur_t];
step_t = -1;
} else {
step_t = 0;
}
oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
best_cost = cost;
}
}
}
if (link == NULL &&
best_cost == std::numeric_limits<BaseFloat>::infinity()) { // Did not find correct link.
KALDI_ERR << "Error tracing best-path back (likely "
<< "bug in token-pruning algorithm)";
}
} else {
oarc->ilabel = 0;
oarc->olabel = 0;
oarc->weight = LatticeWeight::One(); // zero costs.
}
return BestPathIterator(tok->backpointer, cur_t + step_t);
}
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
Lattice *ofst,
bool use_final_probs,
BaseFloat beam) const {
typedef LatticeArc Arc;
typedef Arc::StateId StateId;
typedef Arc::Weight Weight;
typedef Arc::Label Label;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "GetRawLattice() with use_final_probs == false";
unordered_map<Token*, BaseFloat> final_costs_local;
const unordered_map<Token*, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
ofst->DeleteStates();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32 num_frames = this->active_toks_.size() - 1;
KALDI_ASSERT(num_frames > 0);
for (int32 f = 0; f <= num_frames; f++) {
if (this->active_toks_[f].toks == NULL) {
KALDI_WARN << "No tokens active on frame " << f
<< ": not producing lattice.\n";
return false;
}
}
unordered_map<Token*, StateId> tok_map;
std::queue<std::pair<Token*, int32> > tok_queue;
// First initialize the queue and states. Put the initial state on the queue;
// this is the last token in the list active_toks_[0].toks.
for (Token *tok = this->active_toks_[0].toks;
tok != NULL; tok = tok->next) {
if (tok->next == NULL) {
tok_map[tok] = ofst->AddState();
ofst->SetStart(tok_map[tok]);
std::pair<Token*, int32> tok_pair(tok, 0); // #frame = 0
tok_queue.push(tok_pair);
}
}
// Next create states for "good" tokens
while (!tok_queue.empty()) {
std::pair<Token*, int32> cur_tok_pair = tok_queue.front();
tok_queue.pop();
Token *cur_tok = cur_tok_pair.first;
int32 cur_frame = cur_tok_pair.second;
KALDI_ASSERT(cur_frame >= 0 &&
cur_frame <= this->cost_offsets_.size());
typename unordered_map<Token*, StateId>::const_iterator iter =
tok_map.find(cur_tok);
KALDI_ASSERT(iter != tok_map.end());
StateId cur_state = iter->second;
for (ForwardLinkT *l = cur_tok->links;
l != NULL;
l = l->next) {
Token *next_tok = l->next_tok;
if (next_tok->extra_cost < beam) {
// so both the current and the next token are good; create the arc
int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
StateId nextstate;
if (tok_map.find(next_tok) == tok_map.end()) {
nextstate = tok_map[next_tok] = ofst->AddState();
tok_queue.push(std::pair<Token*, int32>(next_tok, next_frame));
} else {
nextstate = tok_map[next_tok];
}
BaseFloat cost_offset = (l->ilabel != 0 ?
this->cost_offsets_[cur_frame] : 0);
Arc arc(l->ilabel, l->olabel,
Weight(l->graph_cost, l->acoustic_cost - cost_offset),
nextstate);
ofst->AddArc(cur_state, arc);
}
}
if (cur_frame == num_frames) {
if (use_final_probs && !final_costs.empty()) {
typename unordered_map<Token*, BaseFloat>::const_iterator iter =
final_costs.find(cur_tok);
if (iter != final_costs.end())
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
} else {
ofst->SetFinal(cur_state, LatticeWeight::One());
}
}
}
return (ofst->NumStates() != 0);
}
// Instantiate the template for the FST types that we'll need.
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;
} // end namespace kaldi.
// decoder/lattice-faster-online-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.h, about how to maintain this
// file in sync with lattice-faster-decoder.h
#ifndef KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#include "util/stl-utils.h"
#include "util/hash-list.h"
#include "fst/fstlib.h"
#include "itf/decodable-itf.h"
#include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "decoder/lattice-faster-decoder.h"
namespace kaldi {
/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also
supports an efficient way to get the best path (see the function
BestPathEnd()), which is useful in endpointing and in situations where you
might want to frequently access the best path.
This is only templated on the FST type, since the Token type is required to
be BackpointerToken. Actually it only makes sense to instantiate
LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via
this child class.
*/
template <typename FST>
class LatticeFasterOnlineDecoderTpl:
public LatticeFasterDecoderTpl<FST, decoder::BackpointerToken> {
public:
using Arc = typename FST::Arc;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Token = decoder::BackpointerToken;
using ForwardLinkT = decoder::ForwardLink<Token>;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterOnlineDecoderTpl(const FST &fst,
const LatticeFasterDecoderConfig &config):
LatticeFasterDecoderTpl<FST, Token>(fst, config) { }
// This version of the initializer takes ownership of 'fst', and will delete
// it when this object is destroyed.
LatticeFasterOnlineDecoderTpl(const LatticeFasterDecoderConfig &config,
FST *fst):
LatticeFasterDecoderTpl<FST, Token>(config, fst) { }
struct BestPathIterator {
void *tok;
int32 frame;
// note, "frame" is the frame-index of the frame you'll get the
// transition-id for next time, if you call TraceBackBestPath on this
// iterator (assuming it's not an epsilon transition). Note that this
// is one less than you might reasonably expect, e.g. it's -1 for
// the nonemitting transitions before the first frame.
BestPathIterator(void *t, int32 f): tok(t), frame(f) { }
bool Done() const { return tok == NULL; }
};
/// Outputs an FST corresponding to the single best path through the lattice.
/// This is quite efficient because it doesn't get the entire raw lattice and find
/// the best path through it; instead, it uses the BestPathEnd and BestPathIterator
/// so it basically traces it back through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
bool GetBestPath(Lattice *ofst,
bool use_final_probs = true) const;
/// This function does a self-test of GetBestPath(). Returns true on
/// success; returns false and prints a warning on failure.
bool TestGetBestPath(bool use_final_probs = true) const;
/// This function returns an iterator that can be used to trace back
/// the best path. If use_final_probs == true and at least one final state
/// survived till the end, it will use the final-probs in working out the best
/// final Token, and will output the final cost to *final_cost (if non-NULL),
/// else it will use only the forward likelihood, and will put zero in
/// *final_cost (if non-NULL).
/// Requires that NumFramesDecoded() > 0.
BestPathIterator BestPathEnd(bool use_final_probs,
BaseFloat *final_cost = NULL) const;
/// This function can be used in conjunction with BestPathEnd() to trace back
/// the best path one link at a time (e.g. this can be useful in endpoint
/// detection). By "link" we mean a link in the graph; not all links cross
/// frame boundaries, but each time you see a nonzero ilabel you can interpret
/// that as a frame. The return value is the updated iterator. It outputs
/// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer,
/// while leaving its "nextstate" variable unchanged.
BestPathIterator TraceBackBestPath(
BestPathIterator iter, LatticeArc *arc) const;
/// Behaves the same as GetRawLattice but only processes tokens whose
/// extra_cost is smaller than the best-cost plus the specified beam.
/// It is only worthwhile to call this function if beam is less than
/// the lattice_beam specified in the config; otherwise, it would
/// return essentially the same thing as GetRawLattice, but more slowly.
bool GetRawLatticePruned(Lattice *ofst,
bool use_final_probs,
BaseFloat beam) const;
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoderTpl);
};
typedef LatticeFasterOnlineDecoderTpl<fst::StdFst> LatticeFasterOnlineDecoder;
} // end namespace kaldi.
#endif
// lat/determinize-lattice-pruned-test.cc
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/determinize-lattice-pruned.h"
#include "fstext/lattice-utils.h"
#include "fstext/fst-test-utils.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
namespace fst {
// Caution: these tests are not as generic as you might think from all the
// templates in the code. They are basically only valid for LatticeArc.
// This is partly due to the fact that certain templates need to be instantiated
// in other .cc files in this directory.
// test that determinization proceeds correctly on general
// FSTs (not guaranteed determinzable, but we use the
// max-states option to stop it getting out of control).
template<class Arc> void TestDeterminizeLatticePruned() {
typedef kaldi::int32 Int;
typedef typename Arc::Weight Weight;
typedef ArcTpl<CompactLatticeWeightTpl<Weight, Int> > CompactArc;
for(int i = 0; i < 100; i++) {
RandFstOptions opts;
opts.n_states = 4;
opts.n_arcs = 10;
opts.n_final = 2;
opts.allow_empty = false;
opts.weight_multiplier = 0.5; // impt for the randomly generated weights
opts.acyclic = true;
// to be exactly representable in float,
// or this test fails because numerical differences can cause symmetry in
// weights to be broken, which causes the wrong path to be chosen as far
// as the string part is concerned.
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
bool sorted = TopSort(fst);
KALDI_ASSERT(sorted);
ILabelCompare<Arc> ilabel_comp;
if (kaldi::Rand() % 2 == 0)
ArcSort(fst, ilabel_comp);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> det_fst;
try {
DeterminizeLatticePrunedOptions lat_opts;
lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000);
lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20);
lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30);
bool ans = DeterminizeLatticePruned<Weight>(*fst, 10.0, &det_fst, lat_opts);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic);
// OK, now determinize it a different way and check equivalence.
// [note: it's not normal determinization, it's taking the best path
// for any input-symbol sequence....
VectorFst<Arc> pruned_fst(*fst);
if (pruned_fst.NumStates() != 0)
kaldi::PruneLattice(10.0, &pruned_fst);
VectorFst<CompactArc> compact_pruned_fst, compact_pruned_det_fst;
ConvertLattice<Weight, Int>(pruned_fst, &compact_pruned_fst, false);
std::cout << "Compact pruned FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
ConvertLattice<Weight, Int>(det_fst, &compact_pruned_det_fst, false);
std::cout << "Compact version of determinized FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
if (ans)
KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/));
} catch (...) {
std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n";
}
delete fst;
}
}
// test that determinization proceeds without crash on acyclic FSTs
// (guaranteed determinizable in this sense).
template<class Arc> void TestDeterminizeLatticePruned2() {
typedef typename Arc::Weight Weight;
RandFstOptions opts;
opts.acyclic = true;
for(int i = 0; i < 100; i++) {
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> ofst;
DeterminizeLatticePruned<Weight>(*fst, 10.0, &ofst);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(ofst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
delete fst;
}
}
} // end namespace fst
int main() {
using namespace fst;
TestDeterminizeLatticePruned<kaldi::LatticeArc>();
TestDeterminizeLatticePruned2<kaldi::LatticeArc>();
std::cout << "Tests succeeded\n";
}
此差异已折叠。
// lat/determinize-lattice-pruned.h
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#define KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#include <fst/fstlib.h>
#include <fst/fst-decl.h>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "fstext/lattice-weight.h"
#include "itf/transition-information.h"
#include "itf/options-itf.h"
#include "lat/kaldi-lattice.h"
namespace fst {
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-lattice-pruned.cc
/*
DeterminizeLatticePruned implements a special form of determinization with
epsilon removal, optimized for a phase of lattice generation. This algorithm
also does pruning at the same time-- the combination is more efficient as it
somtimes prevents us from creating a lot of states that would later be pruned
away. This allows us to increase the lattice-beam and not have the algorithm
blow up. Also, because our algorithm processes states in order from those
that appear on high-scoring paths down to those that appear on low-scoring
paths, we can easily terminate the algorithm after a certain specified number
of states or arcs.
The input is an FST with weight-type BaseWeightType (usually a pair of floats,
with a lexicographical type of order, such as LatticeWeightTpl<float>).
Typically this would be a state-level lattice, with input symbols equal to
words, and output-symbols equal to p.d.f's (so like the inverse of HCLG). Imagine representing this as an
acceptor of type CompactLatticeWeightTpl<float>, in which the input/output
symbols are words, and the weights contain the original weights together with
strings (with zero or one symbol in them) containing the original output labels
(the p.d.f.'s). We determinize this using acceptor determinization with
epsilon removal. Remember (from lattice-weight.h) that
CompactLatticeWeightTpl has a special kind of semiring where we always take
the string corresponding to the best cost (of type BaseWeightType), and
discard the other. This corresponds to taking the best output-label sequence
(of p.d.f.'s) for each input-label sequence (of words). We couldn't use the
Gallic weight for this, or it would die as soon as it detected that the input
FST was non-functional. In our case, any acyclic FST (and many cyclic ones)
can be determinized.
We assume that there is a function
Compare(const BaseWeightType &a, const BaseWeightType &b)
that returns (-1, 0, 1) according to whether (a < b, a == b, a > b) in the
total order on the BaseWeightType... this information should be the
same as NaturalLess would give, but it's more efficient to do it this way.
You can define this for things like TropicalWeight if you need to instantiate
this class for that weight type.
We implement this determinization in a special way to make it efficient for
the types of FSTs that we will apply it to. One issue is that if we
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
type vector<IntType>, the algorithm takes time quadratic in the length of
words (in states), because propagating each arc involves copying a whole
vector (of integers representing p.d.f.'s). Instead we use a hash structure
where each string is a pointer (Entry*), and uses a hash from (Entry*,
IntType), to the successor string (and a way to get the latest IntType and the
ancestor Entry*). [this is the class LatticeStringRepository].
Another issue is that rather than representing a determinized-state as a
collection of (state, weight), we represent it in a couple of reduced forms.
Suppose a determinized-state is a collection of (state, weight) pairs; call
this the "canonical representation". Note: these collections are always
normalized to remove any common weight and string part. Define end-states as
the subset of states that have an arc out of them with a label on, or are
final. If we represent a determinized-state a the set of just its (end-state,
weight) pairs, this will be a valid and more compact representation, and will
lead to a smaller set of determinized states (like early minimization). Call
this collection of (end-state, weight) pairs the "minimal representation". As
a mechanism to reduce compute, we can also consider another representation.
In the determinization algorithm, we start off with a set of (begin-state,
weight) pairs (where the "begin-states" are initial or have a label on the
transition into them), and the "canonical representation" consists of the
epsilon-closure of this set (i.e. follow epsilons). Call this set of
(begin-state, weight) pairs, appropriately normalized, the "initial
representation". If two initial representations are the same, the "canonical
representation" and hence the "minimal representation" will be the same. We
can use this to reduce compute. Note that if two initial representations are
different, this does not preclude the other representations from being the same.
*/
struct DeterminizeLatticePrunedOptions {
float delta; // A small offset used to measure equality of weights.
int max_mem; // If >0, determinization will fail and return false
// when the algorithm's (approximate) memory consumption crosses this threshold.
int max_loop; // If >0, can be used to detect non-determinizable input
// (a case that wouldn't be caught by max_mem).
int max_states;
int max_arcs;
float retry_cutoff;
DeterminizeLatticePrunedOptions(): delta(kDelta),
max_mem(-1),
max_loop(-1),
max_states(-1),
max_arcs(-1),
retry_cutoff(0.5) { }
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this)");
opts->Register("max-arcs", &max_arcs, "Maximum number of arcs in "
"output FST (total, not per state");
opts->Register("max-states", &max_states, "Maximum number of arcs in output "
"FST (total, not per state");
opts->Register("max-loop", &max_loop, "Option used to detect a particular "
"type of determinization failure, typically due to invalid input "
"(e.g., negative-cost loops)");
opts->Register("retry-cutoff", &retry_cutoff, "Controls pruning un-determinized "
"lattice and retrying determinization: if effective-beam < "
"retry-cutoff * beam, we prune the raw lattice and retry. Avoids "
"ever getting empty output for long segments.");
}
};
struct DeterminizeLatticePhonePrunedOptions {
// delta: a small offset used to measure equality of weights.
float delta;
// max_mem: if > 0, determinization will fail and return false when the
// algorithm's (approximate) memory consumption crosses this threshold.
int max_mem;
// phone_determinize: if true, do a first pass determinization on both phones
// and words.
bool phone_determinize;
// word_determinize: if true, do a second pass determinization on words only.
bool word_determinize;
// minimize: if true, push and minimize after determinization.
bool minimize;
DeterminizeLatticePhonePrunedOptions(): delta(kDelta),
max_mem(50000000),
phone_determinize(true),
word_determinize(true),
minimize(false) {}
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this).");
opts->Register("phone-determinize", &phone_determinize, "If true, do an "
"initial pass of determinization on both phones and words (see"
" also --word-determinize)");
opts->Register("word-determinize", &word_determinize, "If true, do a second "
"pass of determinization on words only (see also "
"--phone-determinize)");
opts->Register("minimize", &minimize, "If true, push and minimize after "
"determinization.");
}
};
/**
This function implements the normal version of DeterminizeLattice, in which the
output strings are represented using sequences of arcs, where all but the
first one has an epsilon on the input side. It also prunes using the beam
in the "prune" parameter. The input FST must be topologically sorted in order
for the algorithm to work. For efficiency it is recommended to sort ilabel as well.
Returns true on success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: you may want to use the version below which outputs to CompactLattice.
*/
template<class Weight>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune,
MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/* This is a version of DeterminizeLattice with a slightly more "natural" output format,
where the output sequences are encoded using the CompactLatticeArcTpl template
(i.e. the sequences of output symbols are represented directly as strings The input
FST must be topologically sorted in order for the algorithm to work. For efficiency
it is recommended to sort the ilabel for the input FST as well.
Returns true on normal success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: if Lattice is the input, you need to Invert() before calling this,
so words are on the input side.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> >&ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/** This function takes in lattices and inserts phones at phone boundaries. It
uses the transition model to work out the transition_id to phone map. The
returning value is the starting index of the phone label. Typically we pick
(maximum_output_label_index + 1) as this value. The inserted phones are then
mapped to (returning_value + original_phone_label) in the new lattice. The
returning value will be used by DeterminizeLatticeDeletePhones() where it
works out the phones according to this value.
*/
template<class Weight>
typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *fst);
/** This function takes in lattices and deletes "phones" from them. The "phones"
here are actually any label that is larger than first_phone_label because
when we insert phones into the lattice, we map the original phone label to
(first_phone_label + original_phone_label). It is supposed to be used
together with DeterminizeLatticeInsertPhones()
*/
template<class Weight>
void DeterminizeLatticeDeletePhones(
typename ArcTpl<Weight>::Label first_phone_label,
MutableFst<ArcTpl<Weight> > *fst);
/** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and
DeterminizeLatticePruned(). If --phone-determinize is set to true, it first
calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of
determinization on the phone + word lattices. If --word-determinize is set
true, it then does a second pass of determinization on the word lattices by
calling DeterminizeLatticePruned(). If both are set to false, then it gives
a warning and copying the lattices without determinization.
Note: the point of doing first a phone-level determinization pass and then
a word-level determinization pass is that it allows us to determinize
deeper lattices without "failing early" and returning a too-small lattice
due to the max-mem constraint. The result should be the same as word-level
determinization in general, but for deeper lattices it is a bit faster,
despite the fact that we now have two passes of determinization by default.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/** "Destructive" version of DeterminizeLatticePhonePruned() where the input
lattice might be changed.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/** This function is a wrapper of DeterminizeLatticePhonePruned() that works for
Lattice type FSTs. It simplifies the calling process by calling
TopSort() Invert() and ArcSort() for you.
Unlike other determinization routines, the function
requires "ifst" to have transition-id's on the input side and words on the
output side.
This function can be used as the top-level interface to all the determinization
code.
*/
bool DeterminizeLatticePhonePrunedWrapper(
const kaldi::TransitionInformation &trans_model,
MutableFst<kaldi::LatticeArc> *ifst,
double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/// @} end "addtogroup fst_extensions"
} // end namespace fst
#endif
// lat/kaldi-lattice.cc
// Copyright 2009-2011 Microsoft Corporation
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/kaldi-lattice.h"
#include "fst/script/print-impl.h"
namespace kaldi {
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
CompactLattice* ConvertToCompactLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
CompactLattice *ofst = new CompactLattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
CompactLattice* ConvertToCompactLattice(CompactLattice *ifst) {
return ifst;
}
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
Lattice* ConvertToLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
Lattice *ofst = new Lattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
Lattice* ConvertToLattice(Lattice *ifst) {
return ifst;
}
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = true, write_one = false;
fst::FstPrinter<CompactLatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
/// LatticeReader provides (static) functions for reading both Lattice
/// and CompactLattice, in text form.
class LatticeReader {
typedef LatticeArc Arc;
typedef LatticeWeight Weight;
typedef CompactLatticeArc CArc;
typedef CompactLatticeWeight CWeight;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
public:
// everything is static in this class.
/** This function reads from the FST text format; it does not know in advance
whether it's a Lattice or CompactLattice in the stream so it tries to
read both formats until it becomes clear which is the correct one.
*/
static std::pair<Lattice*, CompactLattice*> ReadText(
std::istream &is) {
typedef std::pair<Lattice*, CompactLattice*> PairT;
using std::string;
using std::vector;
Lattice *fst = new Lattice();
CompactLattice *cfst = new CompactLattice();
string line;
size_t nline = 0;
string separator = FLAGS_fst_field_separator + "\r\n";
while (std::getline(is, line)) {
nline++;
vector<string> col;
// on Windows we'll write in text and read in binary mode.
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.size() == 0) break; // Empty line is a signal to stop, in our
// archive format.
if (col.size() > 5) {
KALDI_WARN << "Reading lattice: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
StateId s;
if (!ConvertStringToInteger(col[0], &s)) {
KALDI_WARN << "FstCompiler: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
if (fst)
while (s >= fst->NumStates())
fst->AddState();
if (cfst)
while (s >= cfst->NumStates())
cfst->AddState();
if (nline == 1) {
if (fst) fst->SetStart(s);
if (cfst) cfst->SetStart(s);
}
if (fst) { // we still have fst; try to read that arc.
bool ok = true;
Arc arc;
Weight w;
StateId d = s;
switch (col.size()) {
case 1 :
fst->SetFinal(s, Weight::One());
break;
case 2:
if (!StrToWeight(col[1], true, &w)) ok = false;
else fst->SetFinal(s, w);
break;
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
ok = false;
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel);
if (ok) {
d = arc.nextstate;
arc.weight = Weight::One();
fst->AddArc(s, arc);
}
break;
case 5:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel) &&
StrToWeight(col[4], false, &arc.weight);
if (ok) {
d = arc.nextstate;
fst->AddArc(s, arc);
}
break;
default:
ok = false;
}
while (d >= fst->NumStates())
fst->AddState();
if (!ok) {
delete fst;
fst = NULL;
}
}
if (cfst) {
bool ok = true;
CArc arc;
CWeight w;
StateId d = s;
switch (col.size()) {
case 1 :
cfst->SetFinal(s, CWeight::One());
break;
case 2:
if (!StrToCWeight(col[1], true, &w)) ok = false;
else cfst->SetFinal(s, w);
break;
case 3: // compact-lattice is acceptor format: state, next-state, label.
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
arc.weight = CWeight::One();
cfst->AddArc(s, arc);
}
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
StrToCWeight(col[3], false, &arc.weight);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
cfst->AddArc(s, arc);
}
break;
case 5: default:
ok = false;
}
while (d >= cfst->NumStates())
cfst->AddState();
if (!ok) {
delete cfst;
cfst = NULL;
}
}
if (!fst && !cfst) {
KALDI_WARN << "Bad line in lattice text format: " << line;
// read until we get an empty line, so at least we
// have a chance to read the next one (although this might
// be a bit futile since the calling code will get unhappy
// about failing to read this one.
while (std::getline(is, line)) {
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.empty()) break;
}
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
}
return PairT(fst, cfst);
}
static bool StrToWeight(const std::string &s, bool allow_zero, Weight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == Weight::Zero())) {
return false;
}
return true;
}
static bool StrToCWeight(const std::string &s, bool allow_zero, CWeight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == CWeight::Zero())) {
return false;
}
return true;
}
};
CompactLattice *ReadCompactLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.second != NULL) {
delete lat_pair.first;
return lat_pair.second;
} else if (lat_pair.first != NULL) {
// note: ConvertToCompactLattice frees its input.
return ConvertToCompactLattice(lat_pair.first);
} else {
return NULL;
}
}
Lattice *ReadLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.first != NULL) {
delete lat_pair.second;
return lat_pair.first;
} else if (lat_pair.second != NULL) {
// note: ConvertToLattice frees its input.
return ConvertToLattice(lat_pair.second);
} else {
return NULL;
}
}
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat) {
KALDI_ASSERT(*clat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading compact lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading compact lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
CompactLattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToCompactLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToCompactLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToCompactLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToCompactLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to CompactLattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading compact lattice (after reading header).";
return false;
}
*clat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*clat = ReadCompactLatticeText(is); // that routine will warn on error.
return (*clat != NULL);
}
}
bool CompactLatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading CompactLattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadCompactLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadCompactLattice(is, true, &t_);
}
}
bool WriteLattice(std::ostream &os, bool binary, const Lattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to do if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = false, write_one = false;
fst::FstPrinter<LatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat) {
KALDI_ASSERT(*lat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
Lattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to Lattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading lattice (after reading header).";
return false;
}
*lat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*lat = ReadLatticeText(is); // that routine will warn on error.
return (*lat != NULL);
}
}
/* Since we don't write the binary headers for this type of holder,
we use a different method to work out whether we're in binary mode.
*/
bool LatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading Lattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadLattice(is, true, &t_);
}
}
} // end namespace kaldi
// lat/kaldi-lattice.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_KALDI_LATTICE_H_
#define KALDI_LAT_KALDI_LATTICE_H_
#include "fstext/fstext-lib.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
namespace kaldi {
// will import some things above...
typedef fst::LatticeWeightTpl<BaseFloat> LatticeWeight;
// careful: kaldi::int32 is not always the same C type as fst::int32
typedef fst::CompactLatticeWeightTpl<LatticeWeight, int32> CompactLatticeWeight;
typedef fst::CompactLatticeWeightCommonDivisorTpl<LatticeWeight, int32>
CompactLatticeWeightCommonDivisor;
typedef fst::ArcTpl<LatticeWeight> LatticeArc;
typedef fst::ArcTpl<CompactLatticeWeight> CompactLatticeArc;
typedef fst::VectorFst<LatticeArc> Lattice;
typedef fst::VectorFst<CompactLatticeArc> CompactLattice;
// The following functions for writing and reading lattices in binary or text
// form are provided here in case you need to include lattices in larger,
// Kaldi-type objects with their own Read and Write functions. Caution: these
// functions return false on stream failure rather than throwing an exception as
// most similar Kaldi functions would do.
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &clat);
bool WriteLattice(std::ostream &os, bool binary,
const Lattice &lat);
// the following function requires that *clat be
// NULL when called.
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat);
// the following function requires that *lat be
// NULL when called.
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat);
class CompactLatticeHolder {
public:
typedef CompactLattice T;
CompactLatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteCompactLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty CompactLatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(CompactLatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const CompactLatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~CompactLatticeHolder() { Clear(); }
private:
T *t_;
};
class LatticeHolder {
public:
typedef Lattice T;
LatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty LatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(LatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const LatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~LatticeHolder() { Clear(); }
private:
T *t_;
};
typedef TableWriter<LatticeHolder> LatticeWriter;
typedef SequentialTableReader<LatticeHolder> SequentialLatticeReader;
typedef RandomAccessTableReader<LatticeHolder> RandomAccessLatticeReader;
typedef TableWriter<CompactLatticeHolder> CompactLatticeWriter;
typedef SequentialTableReader<CompactLatticeHolder> SequentialCompactLatticeReader;
typedef RandomAccessTableReader<CompactLatticeHolder> RandomAccessCompactLatticeReader;
} // namespace kaldi
#endif // KALDI_LAT_KALDI_LATTICE_H_
此差异已折叠。
此差异已折叠。
aux_source_directory(. DIR_LIB_SRCS)
add_library(nnet STATIC ${DIR_LIB_SRCS})
此差异已折叠。
#include "nnet/decodable.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Matrix;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet):
frontend_(NULL),
nnet_(nnet),
finished_(false),
frames_ready_(0) {
}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
}
//Decodable::Init(DecodableConfig config) {
//}
bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1);
}
int32 Decodable::NumIndices() const {
return 0;
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return 0;
}
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows();
return ;
}
void Decodable::Reset() {
// frontend_.Reset();
nnet_->Reset();
}
} // namespace ppspeech
\ No newline at end of file
#include "nnet/decodable-itf.h"
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h"
#include "nnet/nnet_interface.h"
namespace ppspeech {
struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
//void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const;
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later
std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Reset();
void InputFinished() { finished_ = true; }
private:
std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
bool finished_;
int32 frames_ready_;
};
} // namespace ppspeech
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
#include "base/common.h"
namespace ppspeech {
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册