From 7dc9cba3be0706cb024f1d998c69b97a5d6816f3 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 13 Oct 2022 11:51:54 +0000 Subject: [PATCH] ctc prefix beam search for u2, test can run --- speechx/examples/codelab/u2/.gitignore | 1 + speechx/examples/codelab/u2/README.md | 1 + speechx/examples/codelab/u2/local/decode.sh | 22 + speechx/examples/codelab/u2/local/feat.sh | 27 + speechx/examples/codelab/u2/local/nnet.sh | 23 + .../examples/codelab/{u2nnet => u2}/path.sh | 3 +- .../examples/codelab/{u2nnet => u2}/run.sh | 27 +- speechx/examples/codelab/u2nnet/.gitignore | 3 - speechx/examples/codelab/u2nnet/README.md | 3 - speechx/examples/codelab/u2nnet/valgrind.sh | 21 - speechx/speechx/decoder/CMakeLists.txt | 13 +- .../decoder/ctc_beam_search_decoder.cc | 10 +- .../speechx/decoder/ctc_beam_search_decoder.h | 13 +- speechx/speechx/decoder/ctc_beam_search_opt.h | 65 +++ .../decoder/ctc_prefix_beam_search_decoder.cc | 519 ++++++++++-------- .../decoder/ctc_prefix_beam_search_decoder.h | 71 ++- .../ctc_prefix_beam_search_decoder_main.cc | 188 +++++++ .../decoder/ctc_prefix_beam_search_result.h | 41 ++ speechx/speechx/decoder/ctc_tlg_decoder.cc | 17 +- speechx/speechx/decoder/ctc_tlg_decoder.h | 23 +- speechx/speechx/decoder/decoder_itf.h | 22 +- speechx/speechx/nnet/u2_nnet_main.cc | 11 - speechx/speechx/utils/math.cc | 7 +- 23 files changed, 763 insertions(+), 368 deletions(-) create mode 100644 speechx/examples/codelab/u2/.gitignore create mode 100644 speechx/examples/codelab/u2/README.md create mode 100755 speechx/examples/codelab/u2/local/decode.sh create mode 100755 speechx/examples/codelab/u2/local/feat.sh create mode 100755 speechx/examples/codelab/u2/local/nnet.sh rename speechx/examples/codelab/{u2nnet => u2}/path.sh (84%) rename speechx/examples/codelab/{u2nnet => u2}/run.sh (54%) delete mode 100644 speechx/examples/codelab/u2nnet/.gitignore delete mode 100644 speechx/examples/codelab/u2nnet/README.md delete mode 100755 speechx/examples/codelab/u2nnet/valgrind.sh create mode 100644 speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc create mode 100644 speechx/speechx/decoder/ctc_prefix_beam_search_result.h diff --git a/speechx/examples/codelab/u2/.gitignore b/speechx/examples/codelab/u2/.gitignore new file mode 100644 index 00000000..1269488f --- /dev/null +++ b/speechx/examples/codelab/u2/.gitignore @@ -0,0 +1 @@ +data diff --git a/speechx/examples/codelab/u2/README.md b/speechx/examples/codelab/u2/README.md new file mode 100644 index 00000000..3c85dc91 --- /dev/null +++ b/speechx/examples/codelab/u2/README.md @@ -0,0 +1 @@ +# u2/u2pp Streaming Test diff --git a/speechx/examples/codelab/u2/local/decode.sh b/speechx/examples/codelab/u2/local/decode.sh new file mode 100755 index 00000000..12297661 --- /dev/null +++ b/speechx/examples/codelab/u2/local/decode.sh @@ -0,0 +1,22 @@ +#!/bin/bash +set -x +set -e + +. path.sh + +data=data +exp=exp +mkdir -p $exp +ckpt_dir=./data/model +model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/ + +ctc_prefix_beam_search_decoder_main \ + --model_path=$model_dir/export.jit \ + --nnet_decoder_chunk=16 \ + --receptive_field_length=7 \ + --downsampling_rate=4 \ + --vocab_path=$model_dir/unit.txt \ + --feature_rspecifier=ark,t:$exp/fbank.ark \ + --result_wspecifier=ark,t:$exp/result.ark + +echo "u2 ctc prefix beam search decode." diff --git a/speechx/examples/codelab/u2/local/feat.sh b/speechx/examples/codelab/u2/local/feat.sh new file mode 100755 index 00000000..1eec3aae --- /dev/null +++ b/speechx/examples/codelab/u2/local/feat.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -x +set -e + +. path.sh + +data=data +exp=exp +mkdir -p $exp +ckpt_dir=./data/model +model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/ + + +cmvn_json2kaldi_main \ + --json_file $model_dir/mean_std.json \ + --cmvn_write_path $exp/cmvn.ark \ + --binary=false + +echo "convert json cmvn to kaldi ark." + +compute_fbank_main \ + --num_bins 80 \ + --wav_rspecifier=scp:$data/wav.scp \ + --cmvn_file=$exp/cmvn.ark \ + --feature_wspecifier=ark,t:$exp/fbank.ark + +echo "compute fbank feature." diff --git a/speechx/examples/codelab/u2/local/nnet.sh b/speechx/examples/codelab/u2/local/nnet.sh new file mode 100755 index 00000000..78663e9c --- /dev/null +++ b/speechx/examples/codelab/u2/local/nnet.sh @@ -0,0 +1,23 @@ +#!/bin/bash +set -x +set -e + +. path.sh + +data=data +exp=exp +mkdir -p $exp +ckpt_dir=./data/model +model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/ + +u2_nnet_main \ + --model_path=$model_dir/export.jit \ + --feature_rspecifier=ark,t:$exp/fbank.ark \ + --nnet_decoder_chunk=16 \ + --receptive_field_length=7 \ + --downsampling_rate=4 \ + --acoustic_scale=1.0 \ + --nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \ + --nnet_prob_wspecifier=ark,t:$exp/logprobs.ark +echo "u2 nnet decode." + diff --git a/speechx/examples/codelab/u2nnet/path.sh b/speechx/examples/codelab/u2/path.sh similarity index 84% rename from speechx/examples/codelab/u2nnet/path.sh rename to speechx/examples/codelab/u2/path.sh index 564e9fed..7f32fbce 100644 --- a/speechx/examples/codelab/u2nnet/path.sh +++ b/speechx/examples/codelab/u2/path.sh @@ -12,8 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C -SPEECHX_BIN=$SPEECHX_BUILD/nnet -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN +export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio PADDLE_LIB_PATH=$(python -c "import paddle ; print(':'.join(paddle.sysconfig.get_lib()), end='')") export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH diff --git a/speechx/examples/codelab/u2nnet/run.sh b/speechx/examples/codelab/u2/run.sh similarity index 54% rename from speechx/examples/codelab/u2nnet/run.sh rename to speechx/examples/codelab/u2/run.sh index 704653e7..d314262b 100755 --- a/speechx/examples/codelab/u2nnet/run.sh +++ b/speechx/examples/codelab/u2/run.sh @@ -36,29 +36,8 @@ ckpt_dir=./data/model model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/ -cmvn_json2kaldi_main \ - --json_file $model_dir/mean_std.json \ - --cmvn_write_path $exp/cmvn.ark \ - --binary=false +./local/feat.sh -echo "convert json cmvn to kaldi ark." +./local/nnet.sh -compute_fbank_main \ - --num_bins 80 \ - --wav_rspecifier=scp:$data/wav.scp \ - --cmvn_file=$exp/cmvn.ark \ - --feature_wspecifier=ark,t:$exp/fbank.ark - -echo "compute fbank feature." - -u2_nnet_main \ - --model_path=$model_dir/export.jit \ - --feature_rspecifier=ark,t:$exp/fbank.ark \ - --nnet_decoder_chunk=16 \ - --receptive_field_length=7 \ - --downsampling_rate=4 \ - --acoustic_scale=1.0 \ - --nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \ - --nnet_prob_wspecifier=ark,t:$exp/logprobs.ark - -echo "u2 nnet decode." +./local/decode.sh diff --git a/speechx/examples/codelab/u2nnet/.gitignore b/speechx/examples/codelab/u2nnet/.gitignore deleted file mode 100644 index d6fe69bc..00000000 --- a/speechx/examples/codelab/u2nnet/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -data -exp -*log diff --git a/speechx/examples/codelab/u2nnet/README.md b/speechx/examples/codelab/u2nnet/README.md deleted file mode 100644 index 772a58f0..00000000 --- a/speechx/examples/codelab/u2nnet/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Deepspeech2 Streaming NNet Test - -Using for ds2 streaming nnet inference test. diff --git a/speechx/examples/codelab/u2nnet/valgrind.sh b/speechx/examples/codelab/u2nnet/valgrind.sh deleted file mode 100755 index a5aab663..00000000 --- a/speechx/examples/codelab/u2nnet/valgrind.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# this script is for memory check, so please run ./run.sh first. - -set +x -set -e - -. ./path.sh - -if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then - echo "please install valgrind in the speechx tools dir.\n" - exit 1 -fi - -ckpt_dir=./data/model -model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ - -valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \ - ds2_model_test_main \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdparams diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index b08aaba5..8cf94a10 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -10,8 +10,9 @@ add_library(decoder STATIC ctc_tlg_decoder.cc recognizer.cc ) -target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) +target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings) +# test set(BINS ctc_beam_search_decoder_main nnet_logprob_decoder_main @@ -24,3 +25,13 @@ foreach(bin_name IN LISTS BINS) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) endforeach() + + +# u2 +set(bin_name ctc_prefix_beam_search_decoder_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) +target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 76342b87..3f00ee35 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -82,8 +82,6 @@ void CTCBeamSearch::Decode( return; } -int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; } - // todo rename, refactor void CTCBeamSearch::AdvanceDecode( const std::shared_ptr& decodable) { @@ -110,15 +108,19 @@ void CTCBeamSearch::ResetPrefixes() { int CTCBeamSearch::DecodeLikelihoods(const vector>& probs, vector& nbest_words) { kaldi::Timer timer; - timer.Reset(); AdvanceDecoding(probs); LOG(INFO) << "ctc decoding elapsed time(s) " << static_cast(timer.Elapsed()) / 1000.0f; return 0; } +vector> CTCBeamSearch::GetNBestPath(int n) { + int beam_size = n == -1 ? opts_.beam_size: std::min(n, opts_.beam_size); + return get_beam_search_result(prefixes_, vocabulary_, beam_size); +} + vector> CTCBeamSearch::GetNBestPath() { - return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); + return GetNBestPath(-1); } string CTCBeamSearch::GetBestPath() { diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 516f8b2c..479754c3 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -35,6 +35,11 @@ class CTCBeamSearch : public DecoderInterface { void AdvanceDecode( const std::shared_ptr& decodable); + void Decode(std::shared_ptr decodable); + + std::string GetBestPath(); + std::vector> GetNBestPath(); + std::vector> GetNBestPath(int n); std::string GetFinalBestPath(); std::string GetPartialResult() { @@ -42,14 +47,6 @@ class CTCBeamSearch : public DecoderInterface { return {}; } - void Decode(std::shared_ptr decodable); - - std::string GetBestPath(); - std::vector> GetNBestPath(); - - - int NumFrameDecoded(); - int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); diff --git a/speechx/speechx/decoder/ctc_beam_search_opt.h b/speechx/speechx/decoder/ctc_beam_search_opt.h index dcb62258..af92fad0 100644 --- a/speechx/speechx/decoder/ctc_beam_search_opt.h +++ b/speechx/speechx/decoder/ctc_beam_search_opt.h @@ -19,6 +19,7 @@ namespace ppspeech { + struct CTCBeamSearchOptions { // common int blank; @@ -75,4 +76,68 @@ struct CTCBeamSearchOptions { } }; + +// used by u2 model +struct CTCBeamSearchDecoderOptions { + // chunk_size is the frame number of one chunk after subsampling. + // e.g. if subsample rate is 4 and chunk_size = 16, the frames in + // one chunk are 67=16*4 + 3, stride is 64=16*4 + int chunk_size; + int num_left_chunks; + + // final_score = rescoring_weight * rescoring_score + ctc_weight * + // ctc_score; + // rescoring_score = left_to_right_score * (1 - reverse_weight) + + // right_to_left_score * reverse_weight + // Please note the concept of ctc_scores + // in the following two search methods are different. For + // CtcPrefixBeamSerch, + // it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a + // max(viterbi) path score + context score So we should carefully set + // ctc_weight accroding to the search methods. + float ctc_weight; + float rescoring_weight; + float reverse_weight; + + // CtcEndpointConfig ctc_endpoint_opts; + + CTCBeamSearchOptions ctc_prefix_search_opts; + + CTCBeamSearchDecoderOptions() + : chunk_size(16), + num_left_chunks(-1), + ctc_weight(0.5), + rescoring_weight(1.0), + reverse_weight(0.0) {} + + void Register(kaldi::OptionsItf* opts) { + std::string module = "DecoderConfig: "; + opts->Register( + "chunk-size", + &chunk_size, + module + "the frame number of one chunk after subsampling."); + opts->Register("num-left-chunks", + &num_left_chunks, + module + "the left history chunks number."); + opts->Register("ctc-weight", + &ctc_weight, + module + + "ctc weight for rescore. final_score = " + "rescoring_weight * rescoring_score + ctc_weight * " + "ctc_score."); + opts->Register("rescoring-weight", + &rescoring_weight, + module + + "attention score weight for rescore. final_score = " + "rescoring_weight * rescoring_score + ctc_weight * " + "ctc_score."); + opts->Register("reverse-weight", + &reverse_weight, + module + + "reverse decoder weight. rescoring_score = " + "left_to_right_score * (1 - reverse_weight) + " + "right_to_left_score * reverse_weight."); + } +}; + } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc index fd689023..f22bfea2 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc @@ -1,3 +1,5 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 Binbin Zhang (binbzha@qq.com) // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,11 +15,12 @@ // limitations under the License. +#include "decoder/ctc_prefix_beam_search_decoder.h" #include "base/common.h" #include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_prefix_beam_search_score.h" -#include "decoder/ctc_prefix_beam_search_decoder.h" #include "utils/math.h" +#include "absl/strings/str_join.h" #ifdef USE_PROFILING #include "paddle/fluid/platform/profiler.h" @@ -29,85 +32,47 @@ namespace ppspeech { CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts) : opts_(opts) { - InitDecoder(); + Reset(); } -void CTCPrefixBeamSearch::InitDecoder() { +void CTCPrefixBeamSearch::Reset() { num_frame_decoded_ = 0; cur_hyps_.clear(); - hypotheses_.clear(); - likelihood_.clear(); - viterbi_likelihood_.clear(); - times_.clear(); - outputs_.clear(); + hypotheses_.clear(); + likelihood_.clear(); + viterbi_likelihood_.clear(); + times_.clear(); + outputs_.clear(); - abs_time_step_ = 0; + // empty hyp with Score + std::vector empty; + PrefixScore prefix_score; + prefix_score.b = 0.0f; // log(1) + prefix_score.nb = -kBaseFloatMax; // log(0) + prefix_score.v_b = 0.0f; // log(1) + prefix_score.v_nb = 0.0f; // log(1) + cur_hyps_[empty] = prefix_score; - // empty hyp with Score - std::vector empty; - PrefixScore prefix_score; - prefix_score.b = 0.0f; // log(1) - prefix_score.nb = -kBaseFloatMax; // log(0) - prefix_score.v_b = 0.0f; // log(1) - prefix_score.v_nb = 0.0f; // log(1) - cur_hyps_[empty] = prefix_score; + outputs_.emplace_back(empty); + hypotheses_.emplace_back(empty); + likelihood_.emplace_back(prefix_score.TotalScore()); + times_.emplace_back(empty); + } - outputs_.emplace_back(empty); - hypotheses_.emplace_back(empty); - likelihood_.emplace_back(prefix_score.TotalScore()); - times_.emplace_back(empty); - -} +void CTCPrefixBeamSearch::InitDecoder() { Reset(); } -void CTCPrefixBeamSearch::Reset() { - InitDecoder(); -} - -void CTCPrefixBeamSearch::Decode( - std::shared_ptr decodable) { - return; -} - -int32 CTCPrefixBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; } - - -void CTCPrefixBeamSearch::UpdateOutputs( - const std::pair, PrefixScore>& prefix) { - const std::vector& input = prefix.first; - // const std::vector& start_boundaries = prefix.second.start_boundaries; - // const std::vector& end_boundaries = prefix.second.end_boundaries; - - std::vector output; - int s = 0; - int e = 0; - for (int i = 0; i < input.size(); ++i) { - // if (s < start_boundaries.size() && i == start_boundaries[s]){ - // // - // output.emplace_back(context_graph_->start_tag_id()); - // ++s; - // } - - output.emplace_back(input[i]); - - // if (e < end_boundaries.size() && i == end_boundaries[e]){ - // // - // output.emplace_back(context_graph_->end_tag_id()); - // ++e; - // } - } - - outputs_.emplace_back(output); -} void CTCPrefixBeamSearch::AdvanceDecode( const std::shared_ptr& decodable) { while (1) { + // forward frame by frame std::vector frame_prob; bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); if (flag == false) break; + std::vector> likelihood; likelihood.push_back(frame_prob); AdvanceDecoding(likelihood); @@ -117,201 +82,279 @@ void CTCPrefixBeamSearch::AdvanceDecode( static bool PrefixScoreCompare( const std::pair, PrefixScore>& a, const std::pair, PrefixScore>& b) { - // log domain - return a.second.TotalScore() > b.second.TotalScore(); + // log domain + return a.second.TotalScore() > b.second.TotalScore(); } -void CTCPrefixBeamSearch::AdvanceDecoding(const std::vector>& logp) { +void CTCPrefixBeamSearch::AdvanceDecoding( + const std::vector>& logp) { #ifdef USE_PROFILING - RecordEvent event( - "CtcPrefixBeamSearch::AdvanceDecoding", TracerEventType::UserDefined, 1); + RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding", + TracerEventType::UserDefined, + 1); #endif - if (logp.size() == 0) return; - - int first_beam_size = - std::min(static_cast(logp[0].size()), opts_.first_beam_size); - - for (int t = 0; t < logp.size(); ++t, ++abs_time_step_) { - const std::vector& logp_t = logp[t]; - std::unordered_map, PrefixScore, PrefixScoreHash> next_hyps; - - // 1. first beam prune, only select topk candidates - std::vector topk_score; - std::vector topk_index; - TopK(logp_t, first_beam_size, &topk_score, &topk_index); - - // 2. token passing - for (int i = 0; i < topk_index.size(); ++i) { - int id = topk_index[i]; - auto prob = topk_score[i]; - - for (const auto& it : cur_hyps_) { - const std::vector& prefix = it.first; - const PrefixScore& prefix_score = it.second; - - // If prefix doesn't exist in next_hyps, next_hyps[prefix] will insert - // PrefixScore(-inf, -inf) by default, since the default constructor - // of PrefixScore will set fields b(blank ending Score) and - // nb(none blank ending Score) to -inf, respectively. - - if (id == opts_.blank) { - // case 0: *a + => *a, *a + => *a, prefix not - // change - PrefixScore& next_score = next_hyps[prefix]; - next_score.b = LogSumExp(next_score.b, prefix_score.Score() + prob); - - // timestamp, blank is slince, not effact timestamp - next_score.v_b = prefix_score.ViterbiScore() + prob; - next_score.times_b = prefix_score.Times(); - - // Prefix not changed, copy the context from pefix - if (context_graph_ && !next_score.has_context) { - next_score.CopyContext(prefix_score); - next_score.has_context = true; - } - - } else if (!prefix.empty() && id == prefix.back()) { - // case 1: *a + a => *a, prefix not changed - PrefixScore& next_score1 = next_hyps[prefix]; - next_score1.nb = LogSumExp(next_score1.nb, prefix_score.nb + prob); - - // timestamp, non-blank symbol effact timestamp - if (next_score1.v_nb < prefix_score.v_nb + prob) { - // compute viterbi Score - next_score1.v_nb = prefix_score.v_nb + prob; - if (next_score1.cur_token_prob < prob) { - // store max token prob - next_score1.cur_token_prob = prob; - // update this timestamp as token appeared here. - next_score1.times_nb = prefix_score.times_nb; - assert(next_score1.times_nb.size() > 0); - next_score1.times_nb.back() = abs_time_step_; - } - } - - // Prefix not changed, copy the context from pefix - if (context_graph_ && !next_score1.has_context) { - next_score1.CopyContext(prefix_score); - next_score1.has_context = true; - } - - // case 2: *a + a => *aa, prefix changed. - std::vector new_prefix(prefix); - new_prefix.emplace_back(id); - PrefixScore& next_score2 = next_hyps[new_prefix]; - next_score2.nb = LogSumExp(next_score2.nb, prefix_score.b + prob); - - // timestamp, non-blank symbol effact timestamp - if (next_score2.v_nb < prefix_score.v_b + prob) { - // compute viterbi Score - next_score2.v_nb = prefix_score.v_b + prob; - // new token added - next_score2.cur_token_prob = prob; - next_score2.times_nb = prefix_score.times_b; - next_score2.times_nb.emplace_back(abs_time_step_); - } - - // Prefix changed, calculate the context Score. - if (context_graph_ && !next_score2.has_context) { - next_score2.UpdateContext( - context_graph_, prefix_score, id, prefix.size()); - next_score2.has_context = true; - } - - } else { - // id != prefix.back() - // case 3: *a + b => *ab, *a +b => *ab - std::vector new_prefix(prefix); - new_prefix.emplace_back(id); - PrefixScore& next_score = next_hyps[new_prefix]; - next_score.nb = LogSumExp(next_score.nb, prefix_score.Score() + prob); - - // timetamp, non-blank symbol effact timestamp - if (next_score.v_nb < prefix_score.ViterbiScore() + prob) { - next_score.v_nb = prefix_score.ViterbiScore() + prob; - - next_score.cur_token_prob = prob; - next_score.times_nb = prefix_score.Times(); - next_score.times_nb.emplace_back(abs_time_step_); - } - - // Prefix changed, calculate the context Score. - if (context_graph_ && !next_score.has_context) { - next_score.UpdateContext( - context_graph_, prefix_score, id, prefix.size()); - next_score.has_context = true; - } - } - } // end for (const auto& it : cur_hyps_) - } // end for (int i = 0; i < topk_index.size(); ++i) - - // 3. second beam prune, only keep top n best paths - std::vector, PrefixScore>> arr(next_hyps.begin(), - next_hyps.end()); - int second_beam_size = - std::min(static_cast(arr.size()), opts_.second_beam_size); - std::nth_element(arr.begin(), - arr.begin() + second_beam_size, - arr.end(), - PrefixScoreCompare); - arr.resize(second_beam_size); - std::sort(arr.begin(), arr.end(), PrefixScoreCompare); - - // 4. update cur_hyps by next_hyps, and get new result - UpdateHypotheses(arr); - - num_frame_decoded_++; - } // end for (int t = 0; t < logp.size(); ++t, ++abs_time_step_) + if (logp.size() == 0) return; + + int first_beam_size = + std::min(static_cast(logp[0].size()), opts_.first_beam_size); + + for (int t = 0; t < logp.size(); ++t, ++num_frame_decoded_) { + const std::vector& logp_t = logp[t]; + std::unordered_map, PrefixScore, PrefixScoreHash> + next_hyps; + + // 1. first beam prune, only select topk candidates + std::vector topk_score; + std::vector topk_index; + TopK(logp_t, first_beam_size, &topk_score, &topk_index); + + // 2. token passing + for (int i = 0; i < topk_index.size(); ++i) { + int id = topk_index[i]; + auto prob = topk_score[i]; + + for (const auto& it : cur_hyps_) { + const std::vector& prefix = it.first; + const PrefixScore& prefix_score = it.second; + + // If prefix doesn't exist in next_hyps, next_hyps[prefix] will + // insert + // PrefixScore(-inf, -inf) by default, since the default + // constructor + // of PrefixScore will set fields b(blank ending Score) and + // nb(none blank ending Score) to -inf, respectively. + + if (id == opts_.blank) { + // case 0: *a + => *a, *a + => *a, + // prefix not + // change + PrefixScore& next_score = next_hyps[prefix]; + next_score.b = + LogSumExp(next_score.b, prefix_score.Score() + prob); + + // timestamp, blank is slince, not effact timestamp + next_score.v_b = prefix_score.ViterbiScore() + prob; + next_score.times_b = prefix_score.Times(); + + // Prefix not changed, copy the context from pefix + if (context_graph_ && !next_score.has_context) { + next_score.CopyContext(prefix_score); + next_score.has_context = true; + } + + } else if (!prefix.empty() && id == prefix.back()) { + // case 1: *a + a => *a, prefix not changed + PrefixScore& next_score1 = next_hyps[prefix]; + next_score1.nb = + LogSumExp(next_score1.nb, prefix_score.nb + prob); + + // timestamp, non-blank symbol effact timestamp + if (next_score1.v_nb < prefix_score.v_nb + prob) { + // compute viterbi Score + next_score1.v_nb = prefix_score.v_nb + prob; + if (next_score1.cur_token_prob < prob) { + // store max token prob + next_score1.cur_token_prob = prob; + // update this timestamp as token appeared here. + next_score1.times_nb = prefix_score.times_nb; + assert(next_score1.times_nb.size() > 0); + next_score1.times_nb.back() = num_frame_decoded_; + } + } + + // Prefix not changed, copy the context from pefix + if (context_graph_ && !next_score1.has_context) { + next_score1.CopyContext(prefix_score); + next_score1.has_context = true; + } + + // case 2: *a + a => *aa, prefix changed. + std::vector new_prefix(prefix); + new_prefix.emplace_back(id); + PrefixScore& next_score2 = next_hyps[new_prefix]; + next_score2.nb = + LogSumExp(next_score2.nb, prefix_score.b + prob); + + // timestamp, non-blank symbol effact timestamp + if (next_score2.v_nb < prefix_score.v_b + prob) { + // compute viterbi Score + next_score2.v_nb = prefix_score.v_b + prob; + // new token added + next_score2.cur_token_prob = prob; + next_score2.times_nb = prefix_score.times_b; + next_score2.times_nb.emplace_back(num_frame_decoded_); + } + + // Prefix changed, calculate the context Score. + if (context_graph_ && !next_score2.has_context) { + next_score2.UpdateContext( + context_graph_, prefix_score, id, prefix.size()); + next_score2.has_context = true; + } + + } else { + // id != prefix.back() + // case 3: *a + b => *ab, *a +b => *ab + std::vector new_prefix(prefix); + new_prefix.emplace_back(id); + PrefixScore& next_score = next_hyps[new_prefix]; + next_score.nb = + LogSumExp(next_score.nb, prefix_score.Score() + prob); + + // timetamp, non-blank symbol effact timestamp + if (next_score.v_nb < prefix_score.ViterbiScore() + prob) { + next_score.v_nb = prefix_score.ViterbiScore() + prob; + + next_score.cur_token_prob = prob; + next_score.times_nb = prefix_score.Times(); + next_score.times_nb.emplace_back(num_frame_decoded_); + } + + // Prefix changed, calculate the context Score. + if (context_graph_ && !next_score.has_context) { + next_score.UpdateContext( + context_graph_, prefix_score, id, prefix.size()); + next_score.has_context = true; + } + } + } // end for (const auto& it : cur_hyps_) + } // end for (int i = 0; i < topk_index.size(); ++i) + + // 3. second beam prune, only keep top n best paths + std::vector, PrefixScore>> arr( + next_hyps.begin(), next_hyps.end()); + int second_beam_size = + std::min(static_cast(arr.size()), opts_.second_beam_size); + std::nth_element(arr.begin(), + arr.begin() + second_beam_size, + arr.end(), + PrefixScoreCompare); + arr.resize(second_beam_size); + std::sort(arr.begin(), arr.end(), PrefixScoreCompare); + + // 4. update cur_hyps by next_hyps, and get new result + UpdateHypotheses(arr); + } // end for (int t = 0; t < logp.size(); ++t, ++num_frame_decoded_) } void CTCPrefixBeamSearch::UpdateHypotheses( const std::vector, PrefixScore>>& hyps) { - cur_hyps_.clear(); - - outputs_.clear(); - hypotheses_.clear(); - likelihood_.clear(); - viterbi_likelihood_.clear(); - times_.clear(); - - for (auto& item : hyps) { - cur_hyps_[item.first] = item.second; - - UpdateOutputs(item); - hypotheses_.emplace_back(std::move(item.first)); - likelihood_.emplace_back(item.second.TotalScore()); - viterbi_likelihood_.emplace_back(item.second.ViterbiScore()); - times_.emplace_back(item.second.Times()); - } + cur_hyps_.clear(); + + outputs_.clear(); + hypotheses_.clear(); + likelihood_.clear(); + viterbi_likelihood_.clear(); + times_.clear(); + + for (auto& item : hyps) { + cur_hyps_[item.first] = item.second; + + UpdateOutputs(item); + hypotheses_.emplace_back(std::move(item.first)); + likelihood_.emplace_back(item.second.TotalScore()); + viterbi_likelihood_.emplace_back(item.second.ViterbiScore()); + times_.emplace_back(item.second.Times()); + } } -void CTCPrefixBeamSearch::FinalizeSearch() { UpdateFinalContext(); } +void CTCPrefixBeamSearch::UpdateOutputs( + const std::pair, PrefixScore>& prefix) { + const std::vector& input = prefix.first; + const std::vector& start_boundaries = prefix.second.start_boundaries; + const std::vector& end_boundaries = prefix.second.end_boundaries; + + // add tag + std::vector output; + int s = 0; + int e = 0; + for (int i = 0; i < input.size(); ++i) { + // if (s < start_boundaries.size() && i == start_boundaries[s]){ + // // + // output.emplace_back(context_graph_->start_tag_id()); + // ++s; + // } + + output.emplace_back(input[i]); + + // if (e < end_boundaries.size() && i == end_boundaries[e]){ + // // + // output.emplace_back(context_graph_->end_tag_id()); + // ++e; + // } + } + outputs_.emplace_back(output); +} + +void CTCPrefixBeamSearch::FinalizeSearch() { + UpdateFinalContext(); +} void CTCPrefixBeamSearch::UpdateFinalContext() { - if (context_graph_ == nullptr) return; - assert(hypotheses_.size() == cur_hyps_.size()); - assert(hypotheses_.size() == likelihood_.size()); - - // We should backoff the context Score/state when the context is - // not fully matched at the last time. - for (const auto& prefix : hypotheses_) { - PrefixScore& prefix_score = cur_hyps_[prefix]; - if (prefix_score.context_score != 0) { - // prefix_score.UpdateContext(context_graph_, prefix_score, 0, - // prefix.size()); + if (context_graph_ == nullptr) return; + + CHECK(hypotheses_.size() == cur_hyps_.size()); + CHECK(hypotheses_.size() == likelihood_.size()); + + // We should backoff the context Score/state when the context is + // not fully matched at the last time. + for (const auto& prefix : hypotheses_) { + PrefixScore& prefix_score = cur_hyps_[prefix]; + if (prefix_score.context_score != 0) { + prefix_score.UpdateContext(context_graph_, prefix_score, 0, + prefix.size()); + } } + std::vector, PrefixScore>> arr(cur_hyps_.begin(), + cur_hyps_.end()); + std::sort(arr.begin(), arr.end(), PrefixScoreCompare); + + // Update cur_hyps_ and get new result + UpdateHypotheses(arr); +} + + std::string CTCPrefixBeamSearch::GetBestPath(int index) { + int n_hyps = Outputs().size(); + CHECK(n_hyps > 0); + CHECK(index < n_hyps); + std::vector one = Outputs()[index]; + return std::string(absl::StrJoin(one, kSpaceSymbol)); + } + + std::string CTCPrefixBeamSearch::GetBestPath() { + return GetBestPath(0); + } + + std::vector> CTCPrefixBeamSearch::GetNBestPath(int n) { + int hyps_size = hypotheses_.size(); + CHECK(hyps_size > 0); + + int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size); + + std::vector> n_best; + n_best.reserve(min_n); + + for (int i = 0; i < min_n; i++){ + n_best.emplace_back(Likelihood()[i], GetBestPath(i) ); + } + return n_best; + } + + std::vector> CTCPrefixBeamSearch::GetNBestPath() { + return GetNBestPath(-1); } - std::vector, PrefixScore>> arr(cur_hyps_.begin(), - cur_hyps_.end()); - std::sort(arr.begin(), arr.end(), PrefixScoreCompare); - // Update cur_hyps_ and get new result - UpdateHypotheses(arr); +std::string CTCPrefixBeamSearch::GetFinalBestPath() { + return GetBestPath(); +} + +std::string CTCPrefixBeamSearch::GetPartialResult() { + return GetBestPath(); } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h index b67733e8..ba44b0a2 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h @@ -15,6 +15,7 @@ #pragma once #include "decoder/ctc_beam_search_opt.h" +#include "decoder/ctc_prefix_beam_search_result.h" #include "decoder/ctc_prefix_beam_search_score.h" #include "decoder/decoder_itf.h" @@ -25,48 +26,37 @@ class CTCPrefixBeamSearch : public DecoderInterface { explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts); ~CTCPrefixBeamSearch() {} - void InitDecoder(); + void InitDecoder() override; - void Reset(); + void Reset() override; void AdvanceDecode( - const std::shared_ptr& decodable); + const std::shared_ptr& decodable) override; - std::string GetFinalBestPath(); + std::string GetFinalBestPath() override; + std::string GetPartialResult() override; - std::string GetPartialResult() { - CHECK(false) << "Not implement."; - return {}; - } - - void Decode(std::shared_ptr decodable); - - std::string GetBestPath(); - - std::vector> GetNBestPath(); - - - int NumFrameDecoded(); - - int DecodeLikelihoods(const std::vector>& probs, - std::vector& nbest_words); + void FinalizeSearch(); - const std::vector& ViterbiLikelihood() const { - return viterbi_likelihood_; - } + protected: + std::string GetBestPath() override; + std::vector> GetNBestPath() override; + std::vector> GetNBestPath(int n) override; const std::vector>& Inputs() const { return hypotheses_; } - const std::vector>& Outputs() const { return outputs_; } - const std::vector& Likelihood() const { return likelihood_; } + const std::vector& ViterbiLikelihood() const { + return viterbi_likelihood_; + } const std::vector>& Times() const { return times_; } private: - void AdvanceDecoding(const std::vector>& logp); + std::string GetBestPath(int index); - void FinalizeSearch(); + void AdvanceDecoding( + const std::vector>& logp); void UpdateOutputs(const std::pair, PrefixScore>& prefix); void UpdateHypotheses( @@ -77,8 +67,6 @@ class CTCPrefixBeamSearch : public DecoderInterface { private: CTCBeamSearchOptions opts_; - int abs_time_step_ = 0; - std::unordered_map, PrefixScore, PrefixScoreHash> cur_hyps_; @@ -97,4 +85,29 @@ class CTCPrefixBeamSearch : public DecoderInterface { DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch); }; + +class CTCPrefixBeamSearchDecoder : public CTCPrefixBeamSearch { + public: + explicit CTCPrefixBeamSearchDecoder(const CTCBeamSearchDecoderOptions& opts) + : CTCPrefixBeamSearch(opts.ctc_prefix_search_opts), opts_(opts) {} + + ~CTCPrefixBeamSearchDecoder() {} + + private: + CTCBeamSearchDecoderOptions opts_; + + // cache feature + bool start_ = false; // false, this is first frame. + // for continues decoding + int num_frames_ = 0; + int global_frame_offset_ = 0; + const int time_stamp_gap_ = + 100; // timestamp gap between words in a sentence + + // std::unique_ptr ctc_endpointer_; + + int num_frames_in_current_chunk_ = 0; + std::vector result_; +}; + } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc new file mode 100644 index 00000000..8927a5f4 --- /dev/null +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/common.h" +#include "decoder/ctc_prefix_beam_search_decoder.h" +#include "frontend/audio/data_cache.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/u2_nnet.h" +#include "absl/strings/str_split.h" +#include "fst/symbol-table.h" + +DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_string(vocab_path, "", "vocab path"); + +DEFINE_string(model_path, "", "paddle nnet model"); + +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=3) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=3) module downsampling rate."); + +DEFINE_int32(nnet_decoder_chunk, 16, "paddle nnet forward chunk"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +// test ds2 online decoder by feeding speech feature +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + + CHECK(FLAGS_result_wspecifier != ""); + CHECK(FLAGS_feature_rspecifier != ""); + CHECK(FLAGS_vocab_path != ""); + CHECK(FLAGS_model_path != ""); + LOG(INFO) << "model path: " << FLAGS_model_path; + + kaldi::SequentialBaseFloatMatrixReader feature_reader( + FLAGS_feature_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + + LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path; + fst::SymbolTable* unit_table = fst::SymbolTable::ReadText(FLAGS_vocab_path); + + // nnet + ppspeech::ModelOptions model_opts; + model_opts.model_path = FLAGS_model_path; + std::shared_ptr nnet( + new ppspeech::U2Nnet(model_opts)); + + // decodeable + std::shared_ptr raw_data(new ppspeech::DataCache()); + std::shared_ptr decodable( + new ppspeech::Decodable(nnet, raw_data)); + + // decoder + ppspeech::CTCBeamSearchDecoderOptions opts; + opts.chunk_size = 16; + opts.num_left_chunks = -1; + opts.ctc_weight = 0.5; + opts.rescoring_weight = 1.0; + opts.reverse_weight = 0.3; + opts.ctc_prefix_search_opts.blank = 0; + opts.ctc_prefix_search_opts.first_beam_size = 10; + opts.ctc_prefix_search_opts.second_beam_size = 10; + ppspeech::CTCPrefixBeamSearchDecoder decoder(opts); + + + int32 chunk_size = FLAGS_receptive_field_length + + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; + int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; + int32 receptive_field_length = FLAGS_receptive_field_length; + LOG(INFO) << "chunk size (frame): " << chunk_size; + LOG(INFO) << "chunk stride (frame): " << chunk_stride; + LOG(INFO) << "receptive field (frame): " << receptive_field_length; + + decoder.InitDecoder(); + + kaldi::Timer timer; + for (; !feature_reader.Done(); feature_reader.Next()) { + string utt = feature_reader.Key(); + kaldi::Matrix feature = feature_reader.Value(); + + int nframes = feature.NumRows(); + int feat_dim = feature.NumCols(); + raw_data->SetDim(feat_dim); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim; + + raw_data->SetDim(feat_dim); + + int32 ori_feature_len = feature.NumRows(); + int32 num_chunks = feature.NumRows() / chunk_stride + 1; + LOG(INFO) << "num_chunks: " << num_chunks; + + for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + int32 this_chunk_size = 0; + if (ori_feature_len > chunk_idx * chunk_stride) { + this_chunk_size = std::min( + ori_feature_len - chunk_idx * chunk_stride, chunk_size); + } + if (this_chunk_size < receptive_field_length) { + LOG(WARNING) << "utt: " << utt << " skip last " + << this_chunk_size << " frames, expect is " + << receptive_field_length; + break; + } + + + kaldi::Vector feature_chunk(this_chunk_size * + feat_dim); + int32 start = chunk_idx * chunk_stride; + for (int row_id = 0; row_id < this_chunk_size; ++row_id) { + kaldi::SubVector feat_row(feature, start); + kaldi::SubVector feature_chunk_row( + feature_chunk.Data() + row_id * feat_dim, feat_dim); + + feature_chunk_row.CopyFromVec(feat_row); + ++start; + } + + // feat to frontend pipeline cache + raw_data->Accept(feature_chunk); + + // send data finish signal + if (chunk_idx == num_chunks - 1) { + raw_data->SetFinished(); + } + + // forward nnet + decoder.AdvanceDecode(decodable); + } + + decoder.FinalizeSearch(); + + // get 1-best result + std::string result_ints = decoder.GetFinalBestPath(); + std::vector tokenids = absl::StrSplit(result_ints, ppspeech::kSpaceSymbol); + std::string result; + for (int i = 0; i < tokenids.size(); i++){ + result += unit_table->Find(std::stoi(tokenids[i])); + } + + // after process one utt, then reset state. + decodable->Reset(); + decoder.Reset(); + + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + continue; + } + + LOG(INFO) << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); + + ++num_done; + } + + double elapsed = timer.Elapsed(); + LOG(INFO) << "Program cost:" << elapsed << " sec"; + + LOG(INFO) << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_result.h b/speechx/speechx/decoder/ctc_prefix_beam_search_result.h new file mode 100644 index 00000000..caa3e37e --- /dev/null +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_result.h @@ -0,0 +1,41 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" + +namespace ppspeech { + +struct WordPiece { + std::string word; + int start = -1; + int end = -1; + + WordPiece(std::string word, int start, int end) + : word(std::move(word)), start(start), end(end) {} +}; + +struct DecodeResult { + float score = -kBaseFloatMax; + std::string sentence; + std::vector word_pieces; + + static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) { + return a.score > b.score; + } +}; + +} // namespace ppspeech diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index de97f6ad..4d0a21d5 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -18,16 +18,23 @@ namespace ppspeech { TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { fst_.reset(fst::Fst::Read(opts.fst_path)); CHECK(fst_ != nullptr); + word_symbol_table_.reset( fst::SymbolTable::ReadText(opts.word_symbol_table)); + decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); + + Reset(); +} + +void TLGDecoder::Reset() { decoder_->InitDecoding(); num_frame_decoded_ = 0; + return; } void TLGDecoder::InitDecoder() { - decoder_->InitDecoding(); - num_frame_decoded_ = 0; + Reset(); } void TLGDecoder::AdvanceDecode( @@ -42,10 +49,7 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { num_frame_decoded_++; } -void TLGDecoder::Reset() { - InitDecoder(); - return; -} + std::string TLGDecoder::GetPartialResult() { if (num_frame_decoded_ == 0) { @@ -88,4 +92,5 @@ std::string TLGDecoder::GetFinalBestPath() { } return words; } + } diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h index f3ecde73..2f1d6c10 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -42,20 +42,27 @@ class TLGDecoder : public DecoderInterface { void AdvanceDecode( const std::shared_ptr& decodable); - - std::string GetFinalBestPath(); - std::string GetPartialResult(); - - void Decode(); - std::string GetBestPath(); - std::vector> GetNBestPath(); + std::string GetFinalBestPath() override; + std::string GetPartialResult() override; - int NumFrameDecoded(); int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); + protected: + std::string GetBestPath() override { + CHECK(false); + return {}; + } + std::vector> GetNBestPath() override { + CHECK(false); + return {}; + } + std::vector> GetNBestPath(int n) override { + CHECK(false); + return {}; + } private: void AdvanceDecoding(kaldi::DecodableInterface* decodable); diff --git a/speechx/speechx/decoder/decoder_itf.h b/speechx/speechx/decoder/decoder_itf.h index 1bbc6b11..fe4e7408 100644 --- a/speechx/speechx/decoder/decoder_itf.h +++ b/speechx/speechx/decoder/decoder_itf.h @@ -28,27 +28,31 @@ class DecoderInterface { virtual void Reset() = 0; + // call AdvanceDecoding virtual void AdvanceDecode( const std::shared_ptr& decodable) = 0; + // call GetBestPath virtual std::string GetFinalBestPath() = 0; virtual std::string GetPartialResult() = 0; - // void Decode(); + protected: + // virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0; - // std::string GetBestPath(); - // std::vector> GetNBestPath(); + // virtual void Decode() = 0; - // int NumFrameDecoded(); - // int DecodeLikelihoods(const std::vector>& probs, - // std::vector& nbest_words); + virtual std::string GetBestPath() = 0; + virtual std::vector> GetNBestPath() = 0; - protected: - // void AdvanceDecoding(kaldi::DecodableInterface* decodable); + virtual std::vector> GetNBestPath(int n) = 0; - // current decoding frame number + // start from one + int NumFrameDecoded() { return num_frame_decoded_ + 1; } + + protected: + // current decoding frame number, abs_time_step_ int32 num_frame_decoded_; }; diff --git a/speechx/speechx/nnet/u2_nnet_main.cc b/speechx/speechx/nnet/u2_nnet_main.cc index 2dd1fa0d..4b30f6b4 100644 --- a/speechx/speechx/nnet/u2_nnet_main.cc +++ b/speechx/speechx/nnet/u2_nnet_main.cc @@ -86,17 +86,6 @@ int main(int argc, char* argv[]) { LOG(INFO) << "utt: " << utt; LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim; - // // pad feats - // int32 padding_len = 0; - // if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { - // padding_len = - // chunk_stride - (feature.NumRows() - chunk_size) % - // chunk_stride; - // feature.Resize(feature.NumRows() + padding_len, - // feature.NumCols(), - // kaldi::kCopyData); - // } - int32 frame_idx = 0; int vocab_dim = 0; std::vector> prob_vec; diff --git a/speechx/speechx/utils/math.cc b/speechx/speechx/utils/math.cc index 6a13f69b..c218990a 100644 --- a/speechx/speechx/utils/math.cc +++ b/speechx/speechx/utils/math.cc @@ -68,7 +68,7 @@ void TopK(const std::vector& data, for (int i = k; i < n; i++) { if (pq.top().first < data[i]) { pq.pop(); - pq.emplace_back(data[i], i); + pq.emplace(data[i], i); } } @@ -88,4 +88,9 @@ void TopK(const std::vector& data, } } +template void TopK(const std::vector& data, + int32_t k, + std::vector* values, + std::vector* indices) ; + } // namespace ppspeech \ No newline at end of file -- GitLab