diff --git a/fluid/DeepASR/decoder/post_decode_faster.cc b/fluid/DeepASR/decoder/post_decode_faster.cc index 5c0027c4ba1e8262c7c5b89d621a01ad1f0a8a7e..957af550e333edef26112e16702e0c6c94a0326a 100644 --- a/fluid/DeepASR/decoder/post_decode_faster.cc +++ b/fluid/DeepASR/decoder/post_decode_faster.cc @@ -22,7 +22,7 @@ using fst::StdArc; Decoder::Decoder(std::string word_syms_filename, std::string fst_in_filename, std::string logprior_rxfilename) { - const char *usage = + const char* usage = "Decode, reading log-likelihoods (of transition-ids or whatever symbol " "is on the graph) as matrices."; @@ -68,6 +68,23 @@ Decoder::~Decoder() { delete decoder; } +std::string Decoder::decode( + std::string key, std::vector>& log_probs) { + size_t num_frames = log_probs.size(); + size_t dim_label = log_probs[0].size(); + + kaldi::Matrix loglikes( + num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols); + for (size_t i = 0; i < num_frames; ++i) { + memcpy(loglikes.Data() + i * dim_label, + log_probs[i].data(), + sizeof(kaldi::BaseFloat) * dim_label); + } + + return decode(key, loglikes); +} + + std::vector Decoder::decode(std::string posterior_rspecifier) { kaldi::SequentialBaseFloatMatrixReader posterior_reader(posterior_rspecifier); std::vector decoding_results; @@ -139,3 +156,46 @@ std::vector Decoder::decode(std::string posterior_rspecifier) { << " frames."; return decoding_results; } + + +std::string Decoder::decode(std::string key, + kaldi::Matrix& loglikes) { + std::string decoding_result; + + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << key; + } + KALDI_ASSERT(loglikes.NumCols() == logprior.Dim()); + + loglikes.ApplyLog(); + loglikes.AddVecToRows(-1.0, logprior); + + kaldi::DecodableMatrixScaled decodable(loglikes, acoustic_scale); + decoder->Decode(&decodable); + + VectorFst decoded; // linear FST. + + if ((allow_partial || decoder->ReachedFinal()) && + decoder->GetBestPath(&decoded)) { + if (!decoder->ReachedFinal()) + KALDI_WARN << "Decoder did not reach end-state, outputting partial " + "traceback."; + + std::vector alignment; + std::vector words; + kaldi::LatticeWeight weight; + + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + + if (word_syms != NULL) { + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + decoding_result += s; + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + } + } + } + + return decoding_result; +} diff --git a/fluid/DeepASR/decoder/post_decode_faster.h b/fluid/DeepASR/decoder/post_decode_faster.h index 49d680c58a1443ccd6572fff1ae1226855fbae86..6a5830e296634d1e64dbc544526e4f6191899bf8 100644 --- a/fluid/DeepASR/decoder/post_decode_faster.h +++ b/fluid/DeepASR/decoder/post_decode_faster.h @@ -34,7 +34,13 @@ public: std::vector decode(std::string posterior_rspecifier); + std::string decode(std::string key, + std::vector> &log_probs); + private: + std::string decode(std::string key, + kaldi::Matrix &loglikes); + fst::SymbolTable *word_syms; fst::VectorFst *decode_fst; kaldi::FasterDecoder *decoder; diff --git a/fluid/DeepASR/decoder/pybind.cc b/fluid/DeepASR/decoder/pybind.cc index efa37d5d51feaa7df6402ccae98a7217cf7fdec1..1b91f02b89ffd259f7a24a005ec96bd780287e25 100644 --- a/fluid/DeepASR/decoder/pybind.cc +++ b/fluid/DeepASR/decoder/pybind.cc @@ -25,7 +25,14 @@ PYBIND11_MODULE(post_decode_faster, m) { py::class_(m, "Decoder") .def(py::init()) .def("decode", - &Decoder::decode, + (std::vector (Decoder::*)(std::string)) & + Decoder::decode, + "Decode one input probability matrix " + "and return the transcription") + .def("decode", + (std::string (Decoder::*)( + std::string, std::vector>&)) & + Decoder::decode, "Decode one input probability matrix " "and return the transcription"); }