提交 7dc9cba3 编写于 作者: H Hui Zhang

ctc prefix beam search for u2, test can run

上级 3c3aa6b5
#!/bin/bash
set -x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
ctc_prefix_beam_search_decoder_main \
--model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--downsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--result_wspecifier=ark,t:$exp/result.ark
echo "u2 ctc prefix beam search decode."
#!/bin/bash
set -x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
compute_fbank_main \
--num_bins 80 \
--wav_rspecifier=scp:$data/wav.scp \
--cmvn_file=$exp/cmvn.ark \
--feature_wspecifier=ark,t:$exp/fbank.ark
echo "compute fbank feature."
#!/bin/bash
set -x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_nnet_main \
--model_path=$model_dir/export.jit \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--downsampling_rate=4 \
--acoustic_scale=1.0 \
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
echo "u2 nnet decode."
...@@ -12,8 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -12,8 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_BUILD/nnet export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
PADDLE_LIB_PATH=$(python -c "import paddle ; print(':'.join(paddle.sysconfig.get_lib()), end='')") PADDLE_LIB_PATH=$(python -c "import paddle ; print(':'.join(paddle.sysconfig.get_lib()), end='')")
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH
...@@ -36,29 +36,8 @@ ckpt_dir=./data/model ...@@ -36,29 +36,8 @@ ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/ model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
cmvn_json2kaldi_main \ ./local/feat.sh
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark." ./local/nnet.sh
compute_fbank_main \ ./local/decode.sh
--num_bins 80 \
--wav_rspecifier=scp:$data/wav.scp \
--cmvn_file=$exp/cmvn.ark \
--feature_wspecifier=ark,t:$exp/fbank.ark
echo "compute fbank feature."
u2_nnet_main \
--model_path=$model_dir/export.jit \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--downsampling_rate=4 \
--acoustic_scale=1.0 \
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
echo "u2 nnet decode."
# Deepspeech2 Streaming NNet Test
Using for ds2 streaming nnet inference test.
#!/bin/bash
# this script is for memory check, so please run ./run.sh first.
set +x
set -e
. ./path.sh
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
echo "please install valgrind in the speechx tools dir.\n"
exit 1
fi
ckpt_dir=./data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
ds2_model_test_main \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams
...@@ -10,8 +10,9 @@ add_library(decoder STATIC ...@@ -10,8 +10,9 @@ add_library(decoder STATIC
ctc_tlg_decoder.cc ctc_tlg_decoder.cc
recognizer.cc recognizer.cc
) )
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings)
# test
set(BINS set(BINS
ctc_beam_search_decoder_main ctc_beam_search_decoder_main
nnet_logprob_decoder_main nnet_logprob_decoder_main
...@@ -24,3 +25,13 @@ foreach(bin_name IN LISTS BINS) ...@@ -24,3 +25,13 @@ foreach(bin_name IN LISTS BINS)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach() endforeach()
# u2
set(bin_name ctc_prefix_beam_search_decoder_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
\ No newline at end of file
...@@ -82,8 +82,6 @@ void CTCBeamSearch::Decode( ...@@ -82,8 +82,6 @@ void CTCBeamSearch::Decode(
return; return;
} }
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; }
// todo rename, refactor // todo rename, refactor
void CTCBeamSearch::AdvanceDecode( void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
...@@ -110,15 +108,19 @@ void CTCBeamSearch::ResetPrefixes() { ...@@ -110,15 +108,19 @@ void CTCBeamSearch::ResetPrefixes() {
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs, int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) { vector<string>& nbest_words) {
kaldi::Timer timer; kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(probs); AdvanceDecoding(probs);
LOG(INFO) << "ctc decoding elapsed time(s) " LOG(INFO) << "ctc decoding elapsed time(s) "
<< static_cast<float>(timer.Elapsed()) / 1000.0f; << static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0; return 0;
} }
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath(int n) {
int beam_size = n == -1 ? opts_.beam_size: std::min(n, opts_.beam_size);
return get_beam_search_result(prefixes_, vocabulary_, beam_size);
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() { vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); return GetNBestPath(-1);
} }
string CTCBeamSearch::GetBestPath() { string CTCBeamSearch::GetBestPath() {
......
...@@ -35,6 +35,11 @@ class CTCBeamSearch : public DecoderInterface { ...@@ -35,6 +35,11 @@ class CTCBeamSearch : public DecoderInterface {
void AdvanceDecode( void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable); const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(int n);
std::string GetFinalBestPath(); std::string GetFinalBestPath();
std::string GetPartialResult() { std::string GetPartialResult() {
...@@ -42,14 +47,6 @@ class CTCBeamSearch : public DecoderInterface { ...@@ -42,14 +47,6 @@ class CTCBeamSearch : public DecoderInterface {
return {}; return {};
} }
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
namespace ppspeech { namespace ppspeech {
struct CTCBeamSearchOptions { struct CTCBeamSearchOptions {
// common // common
int blank; int blank;
...@@ -75,4 +76,68 @@ struct CTCBeamSearchOptions { ...@@ -75,4 +76,68 @@ struct CTCBeamSearchOptions {
} }
}; };
// used by u2 model
struct CTCBeamSearchDecoderOptions {
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 67=16*4 + 3, stride is 64=16*4
int chunk_size;
int num_left_chunks;
// final_score = rescoring_weight * rescoring_score + ctc_weight *
// ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores
// in the following two search methods are different. For
// CtcPrefixBeamSerch,
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
// max(viterbi) path score + context score So we should carefully set
// ctc_weight accroding to the search methods.
float ctc_weight;
float rescoring_weight;
float reverse_weight;
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions ctc_prefix_search_opts;
CTCBeamSearchDecoderOptions()
: chunk_size(16),
num_left_chunks(-1),
ctc_weight(0.5),
rescoring_weight(1.0),
reverse_weight(0.0) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "DecoderConfig: ";
opts->Register(
"chunk-size",
&chunk_size,
module + "the frame number of one chunk after subsampling.");
opts->Register("num-left-chunks",
&num_left_chunks,
module + "the left history chunks number.");
opts->Register("ctc-weight",
&ctc_weight,
module +
"ctc weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("rescoring-weight",
&rescoring_weight,
module +
"attention score weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("reverse-weight",
&reverse_weight,
module +
"reverse decoder weight. rescoring_score = "
"left_to_right_score * (1 - reverse_weight) + "
"right_to_left_score * reverse_weight.");
}
};
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_result.h"
#include "decoder/ctc_prefix_beam_search_score.h" #include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h" #include "decoder/decoder_itf.h"
...@@ -25,48 +26,37 @@ class CTCPrefixBeamSearch : public DecoderInterface { ...@@ -25,48 +26,37 @@ class CTCPrefixBeamSearch : public DecoderInterface {
explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts); explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {} ~CTCPrefixBeamSearch() {}
void InitDecoder(); void InitDecoder() override;
void Reset(); void Reset() override;
void AdvanceDecode( void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable); const std::shared_ptr<kaldi::DecodableInterface>& decodable) override;
std::string GetFinalBestPath(); std::string GetFinalBestPath() override;
std::string GetPartialResult() override;
std::string GetPartialResult() { void FinalizeSearch();
CHECK(false) << "Not implement.";
return {};
}
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
const std::vector<float>& ViterbiLikelihood() const { protected:
return viterbi_likelihood_; std::string GetBestPath() override;
} std::vector<std::pair<double, std::string>> GetNBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath(int n) override;
const std::vector<std::vector<int>>& Inputs() const { return hypotheses_; } const std::vector<std::vector<int>>& Inputs() const { return hypotheses_; }
const std::vector<std::vector<int>>& Outputs() const { return outputs_; } const std::vector<std::vector<int>>& Outputs() const { return outputs_; }
const std::vector<float>& Likelihood() const { return likelihood_; } const std::vector<float>& Likelihood() const { return likelihood_; }
const std::vector<float>& ViterbiLikelihood() const {
return viterbi_likelihood_;
}
const std::vector<std::vector<int>>& Times() const { return times_; } const std::vector<std::vector<int>>& Times() const { return times_; }
private: private:
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& logp); std::string GetBestPath(int index);
void FinalizeSearch(); void AdvanceDecoding(
const std::vector<std::vector<kaldi::BaseFloat>>& logp);
void UpdateOutputs(const std::pair<std::vector<int>, PrefixScore>& prefix); void UpdateOutputs(const std::pair<std::vector<int>, PrefixScore>& prefix);
void UpdateHypotheses( void UpdateHypotheses(
...@@ -77,8 +67,6 @@ class CTCPrefixBeamSearch : public DecoderInterface { ...@@ -77,8 +67,6 @@ class CTCPrefixBeamSearch : public DecoderInterface {
private: private:
CTCBeamSearchOptions opts_; CTCBeamSearchOptions opts_;
int abs_time_step_ = 0;
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash> std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
cur_hyps_; cur_hyps_;
...@@ -97,4 +85,29 @@ class CTCPrefixBeamSearch : public DecoderInterface { ...@@ -97,4 +85,29 @@ class CTCPrefixBeamSearch : public DecoderInterface {
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch); DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
}; };
class CTCPrefixBeamSearchDecoder : public CTCPrefixBeamSearch {
public:
explicit CTCPrefixBeamSearchDecoder(const CTCBeamSearchDecoderOptions& opts)
: CTCPrefixBeamSearch(opts.ctc_prefix_search_opts), opts_(opts) {}
~CTCPrefixBeamSearchDecoder() {}
private:
CTCBeamSearchDecoderOptions opts_;
// cache feature
bool start_ = false; // false, this is first frame.
// for continues decoding
int num_frames_ = 0;
int global_frame_offset_ = 0;
const int time_stamp_gap_ =
100; // timestamp gap between words in a sentence
// std::unique_ptr<CtcEndpoint> ctc_endpointer_;
int num_frames_in_current_chunk_ = 0;
std::vector<DecodeResult> result_;
};
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/u2_nnet.h"
#include "absl/strings/str_split.h"
#include "fst/symbol-table.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(vocab_path, "", "vocab path");
DEFINE_string(model_path, "", "paddle nnet model");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 16, "paddle nnet forward chunk");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int32 num_done = 0, num_err = 0;
CHECK(FLAGS_result_wspecifier != "");
CHECK(FLAGS_feature_rspecifier != "");
CHECK(FLAGS_vocab_path != "");
CHECK(FLAGS_model_path != "");
LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path;
fst::SymbolTable* unit_table = fst::SymbolTable::ReadText(FLAGS_vocab_path);
// nnet
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
std::shared_ptr<ppspeech::U2Nnet> nnet(
new ppspeech::U2Nnet(model_opts));
// decodeable
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
// decoder
ppspeech::CTCBeamSearchDecoderOptions opts;
opts.chunk_size = 16;
opts.num_left_chunks = -1;
opts.ctc_weight = 0.5;
opts.rescoring_weight = 1.0;
opts.reverse_weight = 0.3;
opts.ctc_prefix_search_opts.blank = 0;
opts.ctc_prefix_search_opts.first_beam_size = 10;
opts.ctc_prefix_search_opts.second_beam_size = 10;
ppspeech::CTCPrefixBeamSearchDecoder decoder(opts);
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
int nframes = feature.NumRows();
int feat_dim = feature.NumCols();
raw_data->SetDim(feat_dim);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim;
raw_data->SetDim(feat_dim);
int32 ori_feature_len = feature.NumRows();
int32 num_chunks = feature.NumRows() / chunk_stride + 1;
LOG(INFO) << "num_chunks: " << num_chunks;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
int32 this_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
this_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (this_chunk_size < receptive_field_length) {
LOG(WARNING) << "utt: " << utt << " skip last "
<< this_chunk_size << " frames, expect is "
<< receptive_field_length;
break;
}
kaldi::Vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim);
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < this_chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row(
feature_chunk.Data() + row_id * feat_dim, feat_dim);
feature_chunk_row.CopyFromVec(feat_row);
++start;
}
// feat to frontend pipeline cache
raw_data->Accept(feature_chunk);
// send data finish signal
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
// forward nnet
decoder.AdvanceDecode(decodable);
}
decoder.FinalizeSearch();
// get 1-best result
std::string result_ints = decoder.GetFinalBestPath();
std::vector<std::string> tokenids = absl::StrSplit(result_ints, ppspeech::kSpaceSymbol);
std::string result;
for (int i = 0; i < tokenids.size(); i++){
result += unit_table->Find(std::stoi(tokenids[i]));
}
// after process one utt, then reset state.
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
continue;
}
LOG(INFO) << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
LOG(INFO) << "Program cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
namespace ppspeech {
struct WordPiece {
std::string word;
int start = -1;
int end = -1;
WordPiece(std::string word, int start, int end)
: word(std::move(word)), start(start), end(end) {}
};
struct DecodeResult {
float score = -kBaseFloatMax;
std::string sentence;
std::vector<WordPiece> word_pieces;
static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
return a.score > b.score;
}
};
} // namespace ppspeech
...@@ -18,16 +18,23 @@ namespace ppspeech { ...@@ -18,16 +18,23 @@ namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path)); fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr); CHECK(fst_ != nullptr);
word_symbol_table_.reset( word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table)); fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
Reset();
}
void TLGDecoder::Reset() {
decoder_->InitDecoding(); decoder_->InitDecoding();
num_frame_decoded_ = 0; num_frame_decoded_ = 0;
return;
} }
void TLGDecoder::InitDecoder() { void TLGDecoder::InitDecoder() {
decoder_->InitDecoding(); Reset();
num_frame_decoded_ = 0;
} }
void TLGDecoder::AdvanceDecode( void TLGDecoder::AdvanceDecode(
...@@ -42,10 +49,7 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { ...@@ -42,10 +49,7 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
num_frame_decoded_++; num_frame_decoded_++;
} }
void TLGDecoder::Reset() {
InitDecoder();
return;
}
std::string TLGDecoder::GetPartialResult() { std::string TLGDecoder::GetPartialResult() {
if (num_frame_decoded_ == 0) { if (num_frame_decoded_ == 0) {
...@@ -88,4 +92,5 @@ std::string TLGDecoder::GetFinalBestPath() { ...@@ -88,4 +92,5 @@ std::string TLGDecoder::GetFinalBestPath() {
} }
return words; return words;
} }
} }
...@@ -42,20 +42,27 @@ class TLGDecoder : public DecoderInterface { ...@@ -42,20 +42,27 @@ class TLGDecoder : public DecoderInterface {
void AdvanceDecode( void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable); const std::shared_ptr<kaldi::DecodableInterface>& decodable);
std::string GetFinalBestPath();
std::string GetPartialResult();
void Decode(); void Decode();
std::string GetBestPath(); std::string GetFinalBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath(); std::string GetPartialResult() override;
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
protected:
std::string GetBestPath() override {
CHECK(false);
return {};
}
std::vector<std::pair<double, std::string>> GetNBestPath() override {
CHECK(false);
return {};
}
std::vector<std::pair<double, std::string>> GetNBestPath(int n) override {
CHECK(false);
return {};
}
private: private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable); void AdvanceDecoding(kaldi::DecodableInterface* decodable);
......
...@@ -28,27 +28,31 @@ class DecoderInterface { ...@@ -28,27 +28,31 @@ class DecoderInterface {
virtual void Reset() = 0; virtual void Reset() = 0;
// call AdvanceDecoding
virtual void AdvanceDecode( virtual void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0; const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0;
// call GetBestPath
virtual std::string GetFinalBestPath() = 0; virtual std::string GetFinalBestPath() = 0;
virtual std::string GetPartialResult() = 0; virtual std::string GetPartialResult() = 0;
// void Decode(); protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
// std::string GetBestPath(); // virtual void Decode() = 0;
// std::vector<std::pair<double, std::string>> GetNBestPath();
// int NumFrameDecoded(); virtual std::string GetBestPath() = 0;
// int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
// std::vector<std::string>& nbest_words);
virtual std::vector<std::pair<double, std::string>> GetNBestPath() = 0;
protected: virtual std::vector<std::pair<double, std::string>> GetNBestPath(int n) = 0;
// void AdvanceDecoding(kaldi::DecodableInterface* decodable);
// current decoding frame number // start from one
int NumFrameDecoded() { return num_frame_decoded_ + 1; }
protected:
// current decoding frame number, abs_time_step_
int32 num_frame_decoded_; int32 num_frame_decoded_;
}; };
......
...@@ -86,17 +86,6 @@ int main(int argc, char* argv[]) { ...@@ -86,17 +86,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "utt: " << utt; LOG(INFO) << "utt: " << utt;
LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim; LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim;
// // pad feats
// int32 padding_len = 0;
// if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
// padding_len =
// chunk_stride - (feature.NumRows() - chunk_size) %
// chunk_stride;
// feature.Resize(feature.NumRows() + padding_len,
// feature.NumCols(),
// kaldi::kCopyData);
// }
int32 frame_idx = 0; int32 frame_idx = 0;
int vocab_dim = 0; int vocab_dim = 0;
std::vector<kaldi::Vector<kaldi::BaseFloat>> prob_vec; std::vector<kaldi::Vector<kaldi::BaseFloat>> prob_vec;
......
...@@ -68,7 +68,7 @@ void TopK(const std::vector<T>& data, ...@@ -68,7 +68,7 @@ void TopK(const std::vector<T>& data,
for (int i = k; i < n; i++) { for (int i = k; i < n; i++) {
if (pq.top().first < data[i]) { if (pq.top().first < data[i]) {
pq.pop(); pq.pop();
pq.emplace_back(data[i], i); pq.emplace(data[i], i);
} }
} }
...@@ -88,4 +88,9 @@ void TopK(const std::vector<T>& data, ...@@ -88,4 +88,9 @@ void TopK(const std::vector<T>& data,
} }
} }
template void TopK<float>(const std::vector<float>& data,
int32_t k,
std::vector<float>* values,
std::vector<int>* indices) ;
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册