diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 5d7a4f77acd69411227bd5985419b749c5c79dd2..b4caa8e7bcc97cccff52cc06da4975f5635d7f97 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode( vector> likelihood; vector frame_prob; bool flag = - decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); + decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); if (flag == false) break; likelihood.push_back(frame_prob); AdvanceDecoding(likelihood); diff --git a/speechx/speechx/kaldi/decoder/decodable-itf.h b/speechx/speechx/kaldi/decoder/decodable-itf.h index 19e074982703c030938fb2c2c1a58deaffb0fd08..b8ce9143e9583aea614c84adf17ad61b2d42c130 100644 --- a/speechx/speechx/kaldi/decoder/decodable-itf.h +++ b/speechx/speechx/kaldi/decoder/decodable-itf.h @@ -143,7 +143,7 @@ class DecodableInterface { /// this is for compatibility with OpenFst). virtual int32 NumIndices() const = 0; - virtual bool FrameLogLikelihood( + virtual bool FrameLikelihood( int32 frame, std::vector* likelihood) = 0; diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index ce26965091ce60776c86340056fa3595e2a866e5..d52b249f7c57693b9c9beb4fbc3cf7b53df0ff5a 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -49,11 +49,18 @@ bool Decodable::IsLastFrame(int32 frame) { int32 Decodable::NumIndices() const { return 0; } +// the ilable(TokenId) of wfst(TLG) insert (id = 0) in front of Nnet prob id. +int32 Decodable::TokenId2NnetId(int32 token_id) { + return token_id - 1; +} + BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(frame, frames_ready_); int32 frame_idx = frame - frame_offset_; - return acoustic_scale_ * std::log(nnet_cache_(frame_idx, index - 1) + + // the nnet output is prob ranther than log prob + // the index - 1, because the ilabel + return acoustic_scale_ * std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) + std::numeric_limits::min()); } @@ -81,7 +88,7 @@ bool Decodable::AdvanceChunk() { return true; } -bool Decodable::FrameLogLikelihood(int32 frame, vector* likelihood) { +bool Decodable::FrameLikelihood(int32 frame, vector* likelihood) { std::vector result; if (EnsureFrameHaveComputed(frame) == false) return false; likelihood->resize(nnet_cache_.NumCols()); diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index b18ef07c2db4ec55bece36ed5ae2480ae32a906e..9555fea792ddb3afed9e9dc0db838c041a9c876b 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -31,24 +31,28 @@ class Decodable : public kaldi::DecodableInterface { virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual bool IsLastFrame(int32 frame); virtual int32 NumIndices() const; - virtual bool FrameLogLikelihood(int32 frame, - std::vector* likelihood); + // not logprob + virtual bool FrameLikelihood(int32 frame, + std::vector* likelihood); virtual int32 NumFramesReady() const; // for offline test void Acceptlikelihood(const kaldi::Matrix& likelihood); void Reset(); bool IsInputFinished() const { return frontend_->IsFinished(); } bool EnsureFrameHaveComputed(int32 frame); + int32 TokenId2NnetId(int32 token_id); private: bool AdvanceChunk(); std::shared_ptr frontend_; std::shared_ptr nnet_; kaldi::Matrix nnet_cache_; + // the frame is nnet prob frame rather than audio feature frame + // nnet frame subsample the feature frame + // eg: 35 frame features output 8 frame inferences int32 frame_offset_; int32 frames_ready_; // todo: feature frame mismatch with nnet inference frame - // eg: 35 frame features output 8 frame inferences // so use subsampled_frame int32 current_log_post_subsampled_offset_; int32 num_chunk_computed_;