// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "nnet/decodable.h" namespace ppspeech { using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; using kaldi::Vector; Decodable::Decodable(const std::shared_ptr& nnet, const std::shared_ptr& frontend) : frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {} void Decodable::Acceptlikelihood(const Matrix& likelihood) { frames_ready_ += likelihood.NumRows(); } // Decodable::Init(DecodableConfig config) { //} int32 Decodable::NumFramesReady() const { return frames_ready_; } bool Decodable::IsLastFrame(int32 frame) { CHECK_LE(frame, frames_ready_); bool flag = EnsureFrameHaveComputed(frame); return (flag == false) && (frame == frames_ready_ - 1); } int32 Decodable::NumIndices() const { return 0; } BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(frame, frames_ready_); int32 frame_idx = frame - frame_offset_; return nnet_cache_(frame_idx, index); } bool Decodable::EnsureFrameHaveComputed(int32 frame) { if (frame >= frames_ready_) { return AdvanceChunk(); } return true; } bool Decodable::AdvanceChunk() { Vector features; if (frontend_->Read(&features) == false) { return false; } int32 nnet_dim = 0; Vector inferences; nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); nnet_cache_.CopyRowsFromVec(inferences); frame_offset_ = frames_ready_; frames_ready_ += nnet_cache_.NumRows(); return true; } bool Decodable::FrameLogLikelihood(int32 frame, vector* likelihood) { std::vector result; if (EnsureFrameHaveComputed(frame) == false) return false; likelihood->resize(nnet_cache_.NumCols()); for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { (*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx); } return true; } void Decodable::Reset() { frontend_->Reset(); nnet_->Reset(); frame_offset_ = 0; frames_ready_ = 0; } } // namespace ppspeech