diff --git a/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc b/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d7b09df8d26db6d28f2252812f1333fd286fc6c --- /dev/null +++ b/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc @@ -0,0 +1,58 @@ +// todo refactor, repalce with gtest + +#include "decoder/ctc_beam_search_decoder.h" +#include "kaldi/util/table-types.h" +#include "base/log.h" +#include "base/flags.h" + +DEFINE_string(feature_respecifier, "", "test nnet prob"); + +using kaldi::BaseFloat; + +void SplitFeature(kaldi::Matrix feature, + int32 chunk_size, + std::vector> feature_chunks) { + +} + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier); + + // test nnet_output --> decoder result + int32 num_done = 0, num_err = 0; + + CTCBeamSearchOptions opts; + CTCBeamSearch decoder(opts); + + ModelOptions model_opts; + std::shared_ptr nnet(new PaddleNnet(model_opts)); + + Decodable decodable(); + decodable.SetNnet(nnet); + + int32 chunk_size = 0; + for (; !feature_reader.Done(); feature_reader.Next()) { + string utt = feature_reader.Key(); + const kaldi::Matrix feature = feature_reader.Value(); + vector> feature_chunks; + SplitFeature(feature, chunk_size, &feature_chunks); + for (auto feature_chunk : feature_chunks) { + decodable.FeedFeatures(feature_chunk); + decoder.InitDecoder(); + decoder.AdvanceDecode(decodable, chunk_size); + } + decodable.InputFinished(); + std::string result; + result = decoder.GetFinalBestPath(); + KALDI_LOG << " the result of " << utt << " is " << result; + decodable.Reset(); + ++num_done; + } + + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index dc21dcb47951f77502c711deec4701bf84d4d8ab..d4407b535c45b41b107c2d3efd7bbcc14745eee1 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -14,22 +14,28 @@ CTCBeamSearch::CTCBeamSearch(std::shared_ptr opts) : init_ext_scorer_(nullptr), blank_id(-1), space_id(-1), + num_frame_decoded(0), root(nullptr) { - LOG(INFO) << "dict path: " << _opts.dict_file; + LOG(INFO) << "dict path: " << opts_.dict_file; vocabulary_ = std::make_shared>(); - if (!basr::ReadDictToVector(_opts.dict_file, *vocabulary_)) { + if (!basr::ReadDictToVector(opts_.dict_file, *vocabulary_)) { LOG(INFO) << "load the dict failed"; } LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size(); - LOG(INFO) << "language model path: " << _opts.lm_path; - init_ext_scorer_ = std::make_shared(_opts.alpha, - _opts.beta, - _opts.lm_path, + LOG(INFO) << "language model path: " << opts_.lm_path; + init_ext_scorer_ = std::make_shared(opts_.alpha, + opts_.beta, + opts_.lm_path, *vocabulary_); } +void CTCBeamSearch::Reset() { + num_frame_decoded_ = 0; + ResetPrefixes(); +} + void CTCBeamSearch::InitDecoder() { blank_id = 0; @@ -41,7 +47,7 @@ void CTCBeamSearch::InitDecoder() { space_id = -2; } - clear_prefixes(); + ResetPrefixes(); root = std::make_shared(); root->score = root->log_prob_b_prev = 0.0; @@ -57,6 +63,23 @@ void CTCBeamSearch::InitDecoder() { } } +int32 CTCBeamSearch::NumFrameDecoded() { + return num_frame_decoded_; +} + +// todo rename, refactor +void CTCBeamSearch::AdvanceDecode(const std::shared_ptr& decodable, int max_frames) { + while (max_frames > 0) { + vector> likelihood; + if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { + break; + } + likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); + AdvanceDecoding(result); + max_frames--; + } +} + void CTCBeamSearch::ResetPrefixes() { for (size_t i = 0; i < prefixes.size(); i++) { if (prefixes[i] != nullptr) { @@ -81,19 +104,32 @@ int CTCBeamSearch::DecodeLikelihoods(const vector>&probs, } timer.Reset(); - vector> results = AdvanceDecoding(double_probs); + AdvanceDecoding(double_probs); LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast(timer.Elapsed()) / 1000.0f; - for (const auto& item : results) { - nbest_words.push_back(item.second); - } return 0; } -vector> CTCBeamSearch::AdvanceDecoding(const vector>& probs_seq) { +vector> CTCBeamSearch::GetNBestPath() { + return get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size); +} + +string CTCBeamSearch::GetBestPath() { + std::vector> result; + result = get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size); + return result[0]->second; +} + +string CTCBeamSearch::GetFinalBestPath() { + CalculateApproxScore(); + LMRescore(); + return GetBestPath(); +} + +void CTCBeamSearch::AdvanceDecoding(const vector>& probs_seq) { size_t num_time_steps = probs_seq.size(); - size_t beam_size = _opts.beam_size; - double cutoff_prob = _opts.cutoff_prob; - size_t cutoff_top_n = _opts.cutoff_top_n; + size_t beam_size = opts_.beam_size; + double cutoff_prob = opts_.cutoff_prob; + size_t cutoff_top_n = opts_.cutoff_top_n; for (size_t time_step = 0; time_step < num_time_steps; time_step++) { const auto& prob = probs_seq[time_step]; @@ -137,18 +173,14 @@ vector> CTCBeamSearch::AdvanceDecoding(const vectorremove(); } } // if + num_frame_decoded_++; } // for probs_seq - - // score the last word of each prefix that doesn't end with space - LMRescore(); - CalculateApproxScore(); - return get_beam_search_result(prefixes, *vocabulary_, beam_size); } int CTCBeamSearch::SearchOneChar(const bool& full_beam, const std::pair& log_prob_idx, const float& min_cutoff) { - size_t beam_size = _opts.beam_size; + size_t beam_size = opts_.beam_size; const auto& c = log_prob_idx.first; const auto& log_prob_c = log_prob_idx.second; size_t prefixes_len = std::min(prefixes.size(), beam_size); @@ -219,7 +251,7 @@ int CTCBeamSearch::SearchOneChar(const bool& full_beam, } void CTCBeamSearch::CalculateApproxScore() { - size_t beam_size = _opts.beam_size; + size_t beam_size = opts_.beam_size; size_t num_prefixes = std::min(prefixes.size(), beam_size); std::sort( prefixes.begin(), @@ -246,7 +278,7 @@ void CTCBeamSearch::CalculateApproxScore() { } void CTCBeamSearch::LMRescore() { - size_t beam_size = _opts.beam_size; + size_t beam_size = opts_.beam_size; if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { auto prefix = prefixes[i]; diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 5bf388d3a8a19210d22e8301e3aa3d0582447efa..b461db8888bf4722a0287ab7682f7d43b87d2412 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -1,4 +1,5 @@ #include "base/basic_types.h" +#include "nnet/decodable-itf.h" #pragma once @@ -44,12 +45,14 @@ public: ~CTCBeamSearch() { } bool InitDecoder(); + void Decode(std::shared_ptr decodable); + std::string GetBestPath(); + std::vector> GetNBestPath(); + std::string GetFinalBestPath(); + int NumFrameDecoded(); int DecodeLikelihoods(const std::vector>&probs, std::vector& nbest_words); - - std::vector& GetDecodeResult() { - return decoder_results_; - } + void Reset(); private: void ResetPrefixes(); @@ -58,17 +61,18 @@ private: const BaseFloat& min_cutoff); void CalculateApproxScore(); void LMRescore(); - std::vector> - AdvanceDecoding(const std::vector>& probs_seq); + void AdvanceDecoding(const std::vector>& probs_seq); + CTCBeamSearchOptions opts_; std::shared_ptr init_ext_scorer_; // todo separate later - std::vector decoder_results_; + //std::vector decoder_results_; std::vector> vocabulary_; // todo remove later size_t blank_id; int space_id; std::shared_ptr root; std::vector prefixes; + int num_frame_decoded_; }; } // namespace basr \ No newline at end of file diff --git a/speechx/speechx/nnet/ctc_decodable.h b/speechx/speechx/nnet/ctc_decodable.h deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c03b4a4097a786b650ea71e62f1caf214fcb635 --- /dev/null +++ b/speechx/speechx/nnet/decodable.cc @@ -0,0 +1,38 @@ +#include "nnet/decodable.h" + +namespace ppspeech { + +Decodable::Acceptlikelihood(const kaldi::Matrix& likelihood) { + frames_ready_ += likelihood.NumRows(); +} + +Decodable::Init(DecodableConfig config) { + +} + +Decodable::IsLastFrame(int32 frame) const { + CHECK_LE(frame, frames_ready_); + return finished_ && (frame == frames_ready_ - 1); +} + +int32 Decodable::NumIndices() const { + return 0; +} + +void Decodable::LogLikelihood(int32 frame, int32 index) { + return ; +} + +void Decodable::FeedFeatures(const kaldi::Matrix& features) { + // skip frame ??? + nnet_->FeedForward(features, &nnet_cache_); + frames_ready_ += nnet_cache_.NumRows(); + return ; +} + +void Decodable::Reset() { + // frontend_.Reset(); + nnet_->Reset(); +} + +} // namespace ppspeech diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index eb7ac20a205912ab133ad93eb3a30a3688bde4a6..0bf28d9427a49af6f37133fb3e0f7c97bbdfb196 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -2,17 +2,27 @@ #include "base/common.h" -namespace ppsepeech { - struct DecodeableConfig; +namespace ppspeech { - class Decodeable : public kaldi::DecodableInterface { - public: - virtual Init(Decodeable config) = 0; - virtual Acceptlikeihood() = 0; - private: - std::share_ptr frontend_; - std::share_ptr nnet_; - //Cache nnet_cache_; - } +struct DecodableConfig; + +class Decodable : public kaldi::DecodableInterface { + public: + virtual void Init(DecodableOpts config); + virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); + virtual bool IsLastFrame(int32 frame) const; + virtual int32 NumIndices() const; + void Acceptlikelihood(const kaldi::Matrix& likelihood); // remove later + void FeedFeatures(const kaldi::Matrix& feature); // only for test, todo remove later + std::vector FrameLogLikelihood(int32 frame); + void Reset(); + void InputFinished() { finished_ = true; } + private: + std::shared_ptr frontend_; + std::shared_ptr nnet_; + kaldi::Matrix nnet_cache_; + bool finished_; + int32 frames_ready_; +}; } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/dnn_decodable.h b/speechx/speechx/nnet/dnn_decodable.h deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/speechx/speechx/nnet/nnet_interface.h b/speechx/speechx/nnet/nnet_interface.h index c47f38094dcf18e441e1a65a396d56a512305cdf..5965f7e8b4eed7df6bbc7a90f252dde971c8c719 100644 --- a/speechx/speechx/nnet/nnet_interface.h +++ b/speechx/speechx/nnet/nnet_interface.h @@ -9,8 +9,9 @@ namespace ppspeech { class NnetInterface { public: virtual ~NnetInterface() {} - virtual void FeedForward(const kaldi::Matrix& features, - kaldi::Matrix* inferences) const = 0; + virtual void FeedForward(const kaldi::Matrix& features, + kaldi::Matrix* inferences); + virtual void Reset(); }; diff --git a/speechx/speechx/nnet/paddle_nnet.cc b/speechx/speechx/nnet/paddle_nnet.cc index d6f826194061cc7553ad1833674b5f680b131be1..e64850cb0095e103628dbb7ff216b289ddf79d3f 100644 --- a/speechx/speechx/nnet/paddle_nnet.cc +++ b/speechx/speechx/nnet/paddle_nnet.cc @@ -3,7 +3,7 @@ namespace ppspeech { -void PaddleNnet::init_cache_encouts(const ModelOptions& opts) { +void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { std::vector cache_names; cache_names = absl::StrSplit(opts.cache_names, ", "); std::vector cache_shapes; @@ -66,7 +66,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) { } release_predictor(predictor); - init_cache_encouts(opts); + InitCacheEncouts(opts); } paddle_infer::Predictor* PaddleNnet::get_predictor() { diff --git a/speechx/speechx/nnet/paddle_nnet.h b/speechx/speechx/nnet/paddle_nnet.h index 1b3cad978a2055be99c4496e71b035a0423276d9..7f34eeafb1b6ec087216450c8c98c102bb27de88 100644 --- a/speechx/speechx/nnet/paddle_nnet.h +++ b/speechx/speechx/nnet/paddle_nnet.h @@ -94,7 +94,7 @@ class PaddleNnet : public NnetInterface { virtual void FeedForward(const kaldi::Matrix& features, kaldi::Matrix* inferences) const; std::shared_ptr> GetCacheEncoder(const std::string& name); - void init_cache_encouts(const ModelOptions& opts); + void InitCacheEncouts(const ModelOptions& opts); private: std::unique_ptr pool;