提交 e9043828 编写于 作者: S SmileGoat

add offline_deocder_main

上级 d14ee800
// todo refactor, repalce with gtest
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
int32 chunk_size,
std::vector<kaldi::Matrix<BaseFloat>> feature_chunks) {
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
CTCBeamSearchOptions opts;
CTCBeamSearch decoder(opts);
ModelOptions model_opts;
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(model_opts));
Decodable decodable();
decodable.SetNnet(nnet);
int32 chunk_size = 0;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
vector<Matrix<BaseFloat>> feature_chunks;
SplitFeature(feature, chunk_size, &feature_chunks);
for (auto feature_chunk : feature_chunks) {
decodable.FeedFeatures(feature_chunk);
decoder.InitDecoder();
decoder.AdvanceDecode(decodable, chunk_size);
}
decodable.InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
\ No newline at end of file
......@@ -14,22 +14,28 @@ CTCBeamSearch::CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts) :
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
num_frame_decoded(0),
root(nullptr) {
LOG(INFO) << "dict path: " << _opts.dict_file;
LOG(INFO) << "dict path: " << opts_.dict_file;
vocabulary_ = std::make_shared<vector<string>>();
if (!basr::ReadDictToVector(_opts.dict_file, *vocabulary_)) {
if (!basr::ReadDictToVector(opts_.dict_file, *vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size();
LOG(INFO) << "language model path: " << _opts.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(_opts.alpha,
_opts.beta,
_opts.lm_path,
LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
opts_.beta,
opts_.lm_path,
*vocabulary_);
}
void CTCBeamSearch::Reset() {
num_frame_decoded_ = 0;
ResetPrefixes();
}
void CTCBeamSearch::InitDecoder() {
blank_id = 0;
......@@ -41,7 +47,7 @@ void CTCBeamSearch::InitDecoder() {
space_id = -2;
}
clear_prefixes();
ResetPrefixes();
root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0;
......@@ -57,6 +63,23 @@ void CTCBeamSearch::InitDecoder() {
}
}
int32 CTCBeamSearch::NumFrameDecoded() {
return num_frame_decoded_;
}
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, int max_frames) {
while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break;
}
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(result);
max_frames--;
}
}
void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) {
......@@ -81,19 +104,32 @@ int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
}
timer.Reset();
vector<std::pair<double, string>> results = AdvanceDecoding(double_probs);
AdvanceDecoding(double_probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
for (const auto& item : results) {
nbest_words.push_back(item.second);
}
return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size);
}
string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size);
return result[0]->second;
}
string CTCBeamSearch::GetFinalBestPath() {
CalculateApproxScore();
LMRescore();
return GetBestPath();
}
void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
size_t num_time_steps = probs_seq.size();
size_t beam_size = _opts.beam_size;
double cutoff_prob = _opts.cutoff_prob;
size_t cutoff_top_n = _opts.cutoff_top_n;
size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n;
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
......@@ -137,18 +173,14 @@ vector<std::pair<double, string>> CTCBeamSearch::AdvanceDecoding(const vector<ve
prefixes[i]->remove();
}
} // if
num_frame_decoded_++;
} // for probs_seq
// score the last word of each prefix that doesn't end with space
LMRescore();
CalculateApproxScore();
return get_beam_search_result(prefixes, *vocabulary_, beam_size);
}
int CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, float>& log_prob_idx,
const float& min_cutoff) {
size_t beam_size = _opts.beam_size;
size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size);
......@@ -219,7 +251,7 @@ int CTCBeamSearch::SearchOneChar(const bool& full_beam,
}
void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = _opts.beam_size;
size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(),
......@@ -246,7 +278,7 @@ void CTCBeamSearch::CalculateApproxScore() {
}
void CTCBeamSearch::LMRescore() {
size_t beam_size = _opts.beam_size;
size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
......
#include "base/basic_types.h"
#include "nnet/decodable-itf.h"
#pragma once
......@@ -44,12 +45,14 @@ public:
~CTCBeamSearch() {
}
bool InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
std::vector<std::string>& nbest_words);
std::vector<DecodeResult>& GetDecodeResult() {
return decoder_results_;
}
void Reset();
private:
void ResetPrefixes();
......@@ -58,17 +61,18 @@ private:
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
std::vector<std::pair<double, std::string>>
AdvanceDecoding(const std::vector<std::vector<double>>& probs_seq);
void AdvanceDecoding(const std::vector<std::vector<double>>& probs_seq);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<DecodeResult> decoder_results_;
//std::vector<DecodeResult> decoder_results_;
std::vector<std::vector<std::string>> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
int num_frame_decoded_;
};
} // namespace basr
\ No newline at end of file
#include "nnet/decodable.h"
namespace ppspeech {
Decodable::Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
}
Decodable::Init(DecodableConfig config) {
}
Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1);
}
int32 Decodable::NumIndices() const {
return 0;
}
void Decodable::LogLikelihood(int32 frame, int32 index) {
return ;
}
void Decodable::FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& features) {
// skip frame ???
nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows();
return ;
}
void Decodable::Reset() {
// frontend_.Reset();
nnet_->Reset();
}
} // namespace ppspeech
......@@ -2,17 +2,27 @@
#include "base/common.h"
namespace ppsepeech {
struct DecodeableConfig;
namespace ppspeech {
class Decodeable : public kaldi::DecodableInterface {
struct DecodableConfig;
class Decodable : public kaldi::DecodableInterface {
public:
virtual Init(Decodeable config) = 0;
virtual Acceptlikeihood() = 0;
virtual void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const;
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later
std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Reset();
void InputFinished() { finished_ = true; }
private:
std::share_ptr<FeatureExtractorInterface> frontend_;
std::share_ptr<NnetInterface> nnet_;
//Cache nnet_cache_;
}
std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
bool finished_;
int32 frames_ready_;
};
} // namespace ppspeech
\ No newline at end of file
......@@ -10,7 +10,8 @@ class NnetInterface {
public:
virtual ~NnetInterface() {}
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences) const = 0;
kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset();
};
......
......@@ -3,7 +3,7 @@
namespace ppspeech {
void PaddleNnet::init_cache_encouts(const ModelOptions& opts) {
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ", ");
std::vector<std::string> cache_shapes;
......@@ -66,7 +66,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) {
}
release_predictor(predictor);
init_cache_encouts(opts);
InitCacheEncouts(opts);
}
paddle_infer::Predictor* PaddleNnet::get_predictor() {
......
......@@ -94,7 +94,7 @@ class PaddleNnet : public NnetInterface {
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences) const;
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
void init_cache_encouts(const ModelOptions& opts);
void InitCacheEncouts(const ModelOptions& opts);
private:
std::unique_ptr<paddle_infer::services::PredictorPool> pool;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册