提交 e9043828 编写于 作者: S SmileGoat

add offline_deocder_main

上级 d14ee800
// todo refactor, repalce with gtest
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
int32 chunk_size,
std::vector<kaldi::Matrix<BaseFloat>> feature_chunks) {
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
CTCBeamSearchOptions opts;
CTCBeamSearch decoder(opts);
ModelOptions model_opts;
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(model_opts));
Decodable decodable();
decodable.SetNnet(nnet);
int32 chunk_size = 0;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
vector<Matrix<BaseFloat>> feature_chunks;
SplitFeature(feature, chunk_size, &feature_chunks);
for (auto feature_chunk : feature_chunks) {
decodable.FeedFeatures(feature_chunk);
decoder.InitDecoder();
decoder.AdvanceDecode(decodable, chunk_size);
}
decodable.InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
\ No newline at end of file
...@@ -14,22 +14,28 @@ CTCBeamSearch::CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts) : ...@@ -14,22 +14,28 @@ CTCBeamSearch::CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts) :
init_ext_scorer_(nullptr), init_ext_scorer_(nullptr),
blank_id(-1), blank_id(-1),
space_id(-1), space_id(-1),
num_frame_decoded(0),
root(nullptr) { root(nullptr) {
LOG(INFO) << "dict path: " << _opts.dict_file; LOG(INFO) << "dict path: " << opts_.dict_file;
vocabulary_ = std::make_shared<vector<string>>(); vocabulary_ = std::make_shared<vector<string>>();
if (!basr::ReadDictToVector(_opts.dict_file, *vocabulary_)) { if (!basr::ReadDictToVector(opts_.dict_file, *vocabulary_)) {
LOG(INFO) << "load the dict failed"; LOG(INFO) << "load the dict failed";
} }
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size(); LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size();
LOG(INFO) << "language model path: " << _opts.lm_path; LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(_opts.alpha, init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
_opts.beta, opts_.beta,
_opts.lm_path, opts_.lm_path,
*vocabulary_); *vocabulary_);
} }
void CTCBeamSearch::Reset() {
num_frame_decoded_ = 0;
ResetPrefixes();
}
void CTCBeamSearch::InitDecoder() { void CTCBeamSearch::InitDecoder() {
blank_id = 0; blank_id = 0;
...@@ -41,7 +47,7 @@ void CTCBeamSearch::InitDecoder() { ...@@ -41,7 +47,7 @@ void CTCBeamSearch::InitDecoder() {
space_id = -2; space_id = -2;
} }
clear_prefixes(); ResetPrefixes();
root = std::make_shared<PathTrie>(); root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0; root->score = root->log_prob_b_prev = 0.0;
...@@ -57,6 +63,23 @@ void CTCBeamSearch::InitDecoder() { ...@@ -57,6 +63,23 @@ void CTCBeamSearch::InitDecoder() {
} }
} }
int32 CTCBeamSearch::NumFrameDecoded() {
return num_frame_decoded_;
}
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, int max_frames) {
while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break;
}
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(result);
max_frames--;
}
}
void CTCBeamSearch::ResetPrefixes() { void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) { for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) { if (prefixes[i] != nullptr) {
...@@ -81,19 +104,32 @@ int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs, ...@@ -81,19 +104,32 @@ int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
} }
timer.Reset(); timer.Reset();
vector<std::pair<double, string>> results = AdvanceDecoding(double_probs); AdvanceDecoding(double_probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f; LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
for (const auto& item : results) {
nbest_words.push_back(item.second);
}
return 0; return 0;
} }
vector<std::pair<double, string>> CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) { vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size);
}
string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size);
return result[0]->second;
}
string CTCBeamSearch::GetFinalBestPath() {
CalculateApproxScore();
LMRescore();
return GetBestPath();
}
void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
size_t beam_size = _opts.beam_size; size_t beam_size = opts_.beam_size;
double cutoff_prob = _opts.cutoff_prob; double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = _opts.cutoff_top_n; size_t cutoff_top_n = opts_.cutoff_top_n;
for (size_t time_step = 0; time_step < num_time_steps; time_step++) { for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step]; const auto& prob = probs_seq[time_step];
...@@ -137,18 +173,14 @@ vector<std::pair<double, string>> CTCBeamSearch::AdvanceDecoding(const vector<ve ...@@ -137,18 +173,14 @@ vector<std::pair<double, string>> CTCBeamSearch::AdvanceDecoding(const vector<ve
prefixes[i]->remove(); prefixes[i]->remove();
} }
} // if } // if
num_frame_decoded_++;
} // for probs_seq } // for probs_seq
// score the last word of each prefix that doesn't end with space
LMRescore();
CalculateApproxScore();
return get_beam_search_result(prefixes, *vocabulary_, beam_size);
} }
int CTCBeamSearch::SearchOneChar(const bool& full_beam, int CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, float>& log_prob_idx, const std::pair<size_t, float>& log_prob_idx,
const float& min_cutoff) { const float& min_cutoff) {
size_t beam_size = _opts.beam_size; size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first; const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second; const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size); size_t prefixes_len = std::min(prefixes.size(), beam_size);
...@@ -219,7 +251,7 @@ int CTCBeamSearch::SearchOneChar(const bool& full_beam, ...@@ -219,7 +251,7 @@ int CTCBeamSearch::SearchOneChar(const bool& full_beam,
} }
void CTCBeamSearch::CalculateApproxScore() { void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = _opts.beam_size; size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort( std::sort(
prefixes.begin(), prefixes.begin(),
...@@ -246,7 +278,7 @@ void CTCBeamSearch::CalculateApproxScore() { ...@@ -246,7 +278,7 @@ void CTCBeamSearch::CalculateApproxScore() {
} }
void CTCBeamSearch::LMRescore() { void CTCBeamSearch::LMRescore() {
size_t beam_size = _opts.beam_size; size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i]; auto prefix = prefixes[i];
......
#include "base/basic_types.h" #include "base/basic_types.h"
#include "nnet/decodable-itf.h"
#pragma once #pragma once
...@@ -44,12 +45,14 @@ public: ...@@ -44,12 +45,14 @@ public:
~CTCBeamSearch() { ~CTCBeamSearch() {
} }
bool InitDecoder(); bool InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
void Reset();
std::vector<DecodeResult>& GetDecodeResult() {
return decoder_results_;
}
private: private:
void ResetPrefixes(); void ResetPrefixes();
...@@ -58,17 +61,18 @@ private: ...@@ -58,17 +61,18 @@ private:
const BaseFloat& min_cutoff); const BaseFloat& min_cutoff);
void CalculateApproxScore(); void CalculateApproxScore();
void LMRescore(); void LMRescore();
std::vector<std::pair<double, std::string>> void AdvanceDecoding(const std::vector<std::vector<double>>& probs_seq);
AdvanceDecoding(const std::vector<std::vector<double>>& probs_seq);
CTCBeamSearchOptions opts_; CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<DecodeResult> decoder_results_; //std::vector<DecodeResult> decoder_results_;
std::vector<std::vector<std::string>> vocabulary_; // todo remove later std::vector<std::vector<std::string>> vocabulary_; // todo remove later
size_t blank_id; size_t blank_id;
int space_id; int space_id;
std::shared_ptr<PathTrie> root; std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes; std::vector<PathTrie*> prefixes;
int num_frame_decoded_;
}; };
} // namespace basr } // namespace basr
\ No newline at end of file
#include "nnet/decodable.h"
namespace ppspeech {
Decodable::Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
}
Decodable::Init(DecodableConfig config) {
}
Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1);
}
int32 Decodable::NumIndices() const {
return 0;
}
void Decodable::LogLikelihood(int32 frame, int32 index) {
return ;
}
void Decodable::FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& features) {
// skip frame ???
nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows();
return ;
}
void Decodable::Reset() {
// frontend_.Reset();
nnet_->Reset();
}
} // namespace ppspeech
...@@ -2,17 +2,27 @@ ...@@ -2,17 +2,27 @@
#include "base/common.h" #include "base/common.h"
namespace ppsepeech { namespace ppspeech {
struct DecodeableConfig;
class Decodeable : public kaldi::DecodableInterface { struct DecodableConfig;
class Decodable : public kaldi::DecodableInterface {
public: public:
virtual Init(Decodeable config) = 0; virtual void Init(DecodableOpts config);
virtual Acceptlikeihood() = 0; virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const;
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later
std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Reset();
void InputFinished() { finished_ = true; }
private: private:
std::share_ptr<FeatureExtractorInterface> frontend_; std::shared_ptr<FeatureExtractorInterface> frontend_;
std::share_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;
//Cache nnet_cache_; kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
} bool finished_;
int32 frames_ready_;
};
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
...@@ -10,7 +10,8 @@ class NnetInterface { ...@@ -10,7 +10,8 @@ class NnetInterface {
public: public:
virtual ~NnetInterface() {} virtual ~NnetInterface() {}
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences) const = 0; kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset();
}; };
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace ppspeech { namespace ppspeech {
void PaddleNnet::init_cache_encouts(const ModelOptions& opts) { void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names; std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ", "); cache_names = absl::StrSplit(opts.cache_names, ", ");
std::vector<std::string> cache_shapes; std::vector<std::string> cache_shapes;
...@@ -66,7 +66,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) { ...@@ -66,7 +66,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) {
} }
release_predictor(predictor); release_predictor(predictor);
init_cache_encouts(opts); InitCacheEncouts(opts);
} }
paddle_infer::Predictor* PaddleNnet::get_predictor() { paddle_infer::Predictor* PaddleNnet::get_predictor() {
......
...@@ -94,7 +94,7 @@ class PaddleNnet : public NnetInterface { ...@@ -94,7 +94,7 @@ class PaddleNnet : public NnetInterface {
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences) const; kaldi::Matrix<kaldi::BaseFloat>* inferences) const;
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name); std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
void init_cache_encouts(const ModelOptions& opts); void InitCacheEncouts(const ModelOptions& opts);
private: private:
std::unique_ptr<paddle_infer::services::PredictorPool> pool; std::unique_ptr<paddle_infer::services::PredictorPool> pool;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册