未验证 提交 8ea289a2 编写于 作者: H Hui Zhang 提交者: GitHub

[speechx] fix compile and add more doc (#2591)

* update u2pp latest static graph

* add requirement

* add utils

* update doc

* update result

* update result

* update result

* fix cmake

* imporove scripts

* update result

* add quant model and script

* add profiling timer
上级 79879451
......@@ -134,7 +134,7 @@ string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
# for LD_LIBRARY_PATH
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
execute_process(
COMMAND python -c " \
COMMAND python -c "\
import os; \
import paddle; \
include_dir=paddle.sysconfig.get_include(); \
......
......@@ -70,3 +70,46 @@ popd
### Deepspeech2 with linear feature
* DecibelNormalizer: there is a small difference between the offline and online db norm. The computation of online db norm reads features chunk by chunk, which causes the feature size to be different different with offline db norm. In `normalizer.cc:73`, the `samples.size()` is different, which causes the different result.
## FAQ
1. No moudle named `paddle`.
```
CMake Error at CMakeLists.txt:119 (string):
string sub-command STRIP requires two arguments.
Traceback (most recent call last):
File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'paddle'
-- PADDLE_COMPILE_FLAGS=
CMake Error at CMakeLists.txt:131 (string):
string sub-command STRIP requires two arguments.
File "<string>", line 1
import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);
^
```
please install paddlepaddle >= 2.4rc
2. `u2_recognizer_main: error while loading shared libraries: liblibpaddle.so: cannot open shared object file: No such file or directory`
```
cd $YOUR_ENV_PATH/lib/python3.7/site-packages/paddle/fluid
patchelf --set-soname libpaddle.so libpaddle.so
```
3. `u2_recognizer_main: error while loading shared libraries: libgfortran.so.5: cannot open shared object file: No such file or directory`
```
# my gcc version is 8.2
apt-get install gfortran-8
```
4. `Undefined reference to '_gfortran_concat_string'`
using gcc 8.2, gfortran 8.2.
......@@ -20,4 +20,4 @@ fi
mkdir -p build
cmake -B build -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
cmake --build build -j
cmake --build build
# aishell test
7176 utts, duration 36108.9 sec.
## Attention Rescore
### u2++ FP32
#### CER
```
Overall -> 5.75 % N=104765 C=99035 S=5587 D=143 I=294
Mandarin -> 5.75 % N=104762 C=99035 S=5584 D=143 I=294
English -> 0.00 % N=0 C=0 S=0 D=0 I=0
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
```
#### RTF
> RTF with feature and decoder which is more end to end.
* Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz, support `avx512_vnni`
```
I1027 10:52:38.662868 51665 u2_recognizer_main.cc:122] total wav duration is: 36108.9 sec
I1027 10:52:38.662858 51665 u2_recognizer_main.cc:121] total cost:11169.1 sec
I1027 10:52:38.662876 51665 u2_recognizer_main.cc:123] RTF is: 0.309318
```
* Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz, not support `avx512_vnni`
```
I1026 16:13:26.247121 48038 u2_recognizer_main.cc:123] total wav duration is: 36108.9 sec
I1026 16:13:26.247130 48038 u2_recognizer_main.cc:124] total decode cost:13656.7 sec
I1026 16:13:26.247138 48038 u2_recognizer_main.cc:125] RTF is: 0.378208
```
......@@ -8,7 +8,7 @@ exp=exp
nj=20
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/decoder.fbank.wolm.log \
ctc_prefix_beam_search_decoder_main \
......
......@@ -3,29 +3,40 @@ set -e
. path.sh
nj=20
stage=-1
stop_stage=100
. utils/parse_options.sh
data=data
exp=exp
nj=20
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
aishell_wav_scp=aishell_test.scp
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
compute_fbank_main \
--num_bins 80 \
--cmvn_file=$exp/cmvn.ark \
--streaming_chunk=36 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
fi
echo "compute fbank feature."
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
compute_fbank_main \
--num_bins 80 \
--cmvn_file=$exp/cmvn.ark \
--streaming_chunk=36 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp
echo "compute fbank feature."
fi
......@@ -8,7 +8,7 @@ data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
u2_nnet_main \
--model_path=$model_dir/export.jit \
......
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
nj=20
. utils/parse_options.sh
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
aishell_wav_scp=aishell_test.scp
text=$data/test/text
......
#!/bin/bash
set -e
data=data
exp=exp
nj=20
. utils/parse_options.sh
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model/
aishell_wav_scp=aishell_test.scp
text=$data/test/text
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.quant.log \
u2_recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$exp/cmvn.ark \
--model_path=$model_dir/export \
--vocab_path=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--result_wspecifier=ark,t:$data/split${nj}/JOB/recognizer.quant.rsl.ark
cat $data/split${nj}/*/recognizer.quant.rsl.ark > $exp/aishell.recognizer.quant.rsl
utils/compute-wer.py --char=1 --v=1 $text $exp/aishell.recognizer.quant.rsl > $exp/aishell.recognizer.quant.err
echo "recognizer quant test have finished!!!"
echo "please checkout in $exp/aishell.recognizer.quant.err"
tail -n 7 $exp/aishell.recognizer.quant.err
#!/bin/bash
set +x
set -e
. path.sh
nj=40
stage=0
stop_stage=5
stage=-1
stop_stage=100
. utils/parse_options.sh
......@@ -14,7 +13,7 @@ stop_stage=5
data=data
exp=exp
mkdir -p $exp $data
aishell_wav_scp=aishell_test.scp
# 1. compile
if [ ! -d ${SPEECHX_BUILD} ]; then
......@@ -25,17 +24,28 @@ fi
ckpt_dir=$data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
# download model
if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
# download u2pp model
if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model.tar.gz ]; then
mkdir -p $ckpt_dir
pushd $ckpt_dir
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model.tar.gz
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model.tar.gz
popd
fi
# download u2pp quant model
if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model.tar.gz ]; then
mkdir -p $ckpt_dir
pushd $ckpt_dir
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model.tar.gz
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model.tar.gz
popd
fi
......@@ -73,4 +83,4 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./loca/recognizer.sh
fi
\ No newline at end of file
fi
../../../../utils/
\ No newline at end of file
paddlepaddle>=2.4rc
......@@ -69,20 +69,30 @@ void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
double search_cost = 0.0;
double feat_nnet_cost = 0.0;
while (1) {
// forward frame by frame
kaldi::Timer timer;
std::vector<kaldi::BaseFloat> frame_prob;
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
feat_nnet_cost += timer.Elapsed();
if (flag == false) {
VLOG(1) << "decoder advance decode exit." << frame_prob.size();
VLOG(3) << "decoder advance decode exit." << frame_prob.size();
break;
}
timer.Reset();
std::vector<std::vector<kaldi::BaseFloat>> likelihood;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
search_cost += timer.Elapsed();
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
}
VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
<< " sec.";
VLOG(1) << "AdvanceDecode search cost: " << search_cost << " sec.";
}
static bool PrefixScoreCompare(
......
......@@ -40,7 +40,9 @@ void Assembler::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
// pop feature chunk
bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
kaldi::Timer timer;
bool result = Compute(feats);
VLOG(1) << "Assembler::Read cost: " << timer.Elapsed() << " sec.";
return result;
}
......@@ -51,14 +53,14 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
Vector<BaseFloat> feature;
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) {
VLOG(1) << "result: " << result
VLOG(3) << "result: " << result
<< " feature dim: " << feature.Dim();
if (IsFinished() == false) {
VLOG(1) << "finished reading feature. cache size: "
VLOG(3) << "finished reading feature. cache size: "
<< feature_cache_.size();
return false;
} else {
VLOG(1) << "break";
VLOG(3) << "break";
break;
}
}
......@@ -67,11 +69,11 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
feature_cache_.push(feature);
nframes_ += 1;
VLOG(1) << "nframes: " << nframes_;
VLOG(3) << "nframes: " << nframes_;
}
if (feature_cache_.size() < receptive_filed_length_) {
VLOG(1) << "feature_cache less than receptive_filed_lenght. "
VLOG(3) << "feature_cache less than receptive_filed_lenght. "
<< feature_cache_.size() << ": " << receptive_filed_length_;
return false;
}
......@@ -87,7 +89,7 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
int32 this_chunk_size =
std::min(static_cast<int32>(feature_cache_.size()), frame_chunk_size_);
feats->Resize(dim_ * this_chunk_size);
VLOG(1) << "read " << this_chunk_size << " feat.";
VLOG(3) << "read " << this_chunk_size << " feat.";
int32 counter = 0;
while (counter < this_chunk_size) {
......
......@@ -38,6 +38,7 @@ BaseFloat AudioCache::Convert2PCM32(BaseFloat val) {
}
void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (size_ + waves.Dim() > ring_buffer_.size()) {
ready_feed_condition_.wait(lock);
......@@ -48,11 +49,13 @@ void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx));
}
size_ += waves.Dim();
VLOG(1) << "AudioCache::Accept cost: " << timer.Elapsed() << " sec. "
<< waves.Dim() << " samples.";
}
bool AudioCache::Read(Vector<BaseFloat>* waves) {
size_t chunk_size = waves->Dim();
kaldi::Timer timer;
size_t chunk_size = waves->Dim();
std::unique_lock<std::mutex> lock(mutex_);
while (chunk_size > size_) {
// when audio is empty and no more data feed
......@@ -86,9 +89,11 @@ bool AudioCache::Read(Vector<BaseFloat>* waves) {
offset_ = (offset_ + chunk_size) % ring_buffer_.size();
nsamples_ += chunk_size;
VLOG(1) << "nsamples readed: " << nsamples_;
VLOG(3) << "nsamples readed: " << nsamples_;
ready_feed_condition_.notify_one();
VLOG(1) << "AudioCache::Read cost: " << timer.Elapsed() << " sec. "
<< chunk_size << " samples.";
return true;
}
......
......@@ -50,8 +50,11 @@ bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
if (base_extractor_->Read(feats) == false || feats->Dim() == 0) {
return false;
}
// appply cmvn
kaldi::Timer timer;
Compute(feats);
VLOG(1) << "CMVN::Read cost: " << timer.Elapsed() << " sec.";
return true;
}
......
......@@ -27,27 +27,32 @@ namespace ppspeech {
// pre-recorded audio/feature
class DataCache : public FrontendInterface {
public:
DataCache() { finished_ = false; }
DataCache() : finished_{false}, dim_{0} {}
// accept waves/feats
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) override {
data_ = inputs;
SetDim(data_.Dim());
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) override {
if (data_.Dim() == 0) {
return false;
}
(*feats) = data_;
data_.Resize(0);
SetDim(data_.Dim());
return true;
}
virtual void SetFinished() { finished_ = true; }
virtual bool IsFinished() const { return finished_; }
virtual size_t Dim() const { return dim_; }
void SetFinished() override { finished_ = true; }
bool IsFinished() const override { return finished_; }
size_t Dim() const override { return dim_; }
void SetDim(int32 dim) { dim_ = dim; }
virtual void Reset() { finished_ = true; }
void Reset() override {
finished_ = true;
dim_ = 0;
}
private:
kaldi::Vector<kaldi::BaseFloat> data_;
......
......@@ -34,6 +34,7 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts,
void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
// read inputs
base_extractor_->Accept(inputs);
// feed current data
bool result = false;
do {
......@@ -62,6 +63,7 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
feats->CopyFromVec(cache_.front());
cache_.pop();
ready_feed_condition_.notify_one();
VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec.";
return true;
}
......@@ -72,9 +74,11 @@ bool FeatureCache::Compute() {
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
kaldi::Timer timer;
int32 num_chunk = feature.Dim() / dim_;
nframe_ += num_chunk;
VLOG(1) << "nframe computed: " << nframe_;
VLOG(3) << "nframe computed: " << nframe_;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * dim_;
......@@ -92,7 +96,10 @@ bool FeatureCache::Compute() {
cache_.push(feature_chunk);
ready_read_condition_.notify_one();
}
return result;
VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. "
<< num_chunk << " feats.";
return true;
}
} // namespace ppspeech
\ No newline at end of file
......@@ -58,7 +58,7 @@ class FeatureCache : public FrontendInterface {
std::swap(cache_, empty);
nframe_ = 0;
base_extractor_->Reset();
VLOG(1) << "feature cache reset: cache size: " << cache_.size();
VLOG(3) << "feature cache reset: cache size: " << cache_.size();
}
private:
......
......@@ -34,6 +34,7 @@ bool StreamingFeatureTpl<F>::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
bool flag = base_extractor_->Read(&wav);
if (flag == false || wav.Dim() == 0) return false;
kaldi::Timer timer;
// append remaned waves
int32 wav_len = wav.Dim();
int32 left_len = remained_wav_.Dim();
......@@ -52,6 +53,8 @@ bool StreamingFeatureTpl<F>::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
remained_wav_.Resize(left_samples);
remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples));
VLOG(1) << "StreamingFeatureTpl<F>::Read cost: " << timer.Elapsed()
<< " sec.";
return true;
}
......
......@@ -68,9 +68,10 @@ bool Decodable::AdvanceChunk() {
Vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
VLOG(1) << "decodable exit;";
VLOG(3) << "decodable exit;";
return false;
}
VLOG(1) << "AdvanceChunk feat cost: " << timer.Elapsed() << " sec.";
VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats.";
// forward feats
......@@ -88,7 +89,8 @@ bool Decodable::AdvanceChunk() {
// update state, decoding frame.
frame_offset_ = frames_ready_;
frames_ready_ += nnet_out_cache_.NumRows();
VLOG(2) << "Forward feat chunk cost: " << timer.Elapsed() << " sec.";
VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed()
<< " sec.";
return true;
}
......@@ -115,7 +117,7 @@ bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
// read one frame likelihood
bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
if (EnsureFrameHaveComputed(frame) == false) {
VLOG(1) << "framelikehood exit.";
VLOG(3) << "framelikehood exit.";
return false;
}
......@@ -168,7 +170,9 @@ void Decodable::Reset() {
void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
kaldi::Timer timer;
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec.";
}
} // namespace ppspeech
\ No newline at end of file
......@@ -154,7 +154,7 @@ void U2Nnet::Reset() {
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
encoder_outs_.clear();
VLOG(1) << "u2nnet reset";
VLOG(3) << "u2nnet reset";
}
// Debug API
......@@ -168,6 +168,7 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) {
void U2Nnet::FeedForward(const kaldi::Vector<BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) {
kaldi::Timer timer;
std::vector<kaldi::BaseFloat> chunk_feats(features.Data(),
features.Data() + features.Dim());
......@@ -179,6 +180,8 @@ void U2Nnet::FeedForward(const kaldi::Vector<BaseFloat>& features,
std::memcpy(out->logprobs.Data(),
ctc_probs.data(),
ctc_probs.size() * sizeof(kaldi::BaseFloat));
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
<< chunk_feats.size() / feature_dim << " frames.";
}
......@@ -415,7 +418,6 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
#ifdef USE_PROFILING
RecordEvent event("AttentionRescoring", TracerEventType::UserDefined, 1);
#endif
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
......@@ -627,7 +629,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
// combinded left-to-right and right-to-lfet score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
VLOG(1) << "hyp " << i << " " << hyp.size() << " score: " << score
VLOG(3) << "hyp " << i << " " << hyp.size() << " score: " << score
<< " r_score: " << r_score
<< " reverse_weight: " << reverse_weight
<< " final score: " << (*rescoring_score)[i];
......@@ -639,7 +641,7 @@ void U2Nnet::EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const {
// list of (B=1,T,D)
int size = encoder_outs_.size();
VLOG(1) << "encoder_outs_ size: " << size;
VLOG(3) << "encoder_outs_ size: " << size;
for (int i = 0; i < size; i++) {
const paddle::Tensor& item = encoder_outs_[i];
......@@ -649,7 +651,7 @@ void U2Nnet::EncoderOuts(
const int& T = shape[1];
const int& D = shape[2];
CHECK(B == 1) << "Only support batch one.";
VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << ","
VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << ","
<< D << ")";
const float* this_tensor_ptr = item.data<float>();
......
......@@ -67,7 +67,10 @@ void U2Recognizer::ResetContinuousDecoding() {
void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) {
kaldi::Timer timer;
feature_pipeline_->Accept(waves);
VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim()
<< " samples.";
}
......@@ -78,9 +81,7 @@ void U2Recognizer::Decode() {
void U2Recognizer::Rescoring() {
// Do attention Rescoring
kaldi::Timer timer;
AttentionRescoring();
VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec.";
}
void U2Recognizer::UpdateResult(bool finish) {
......@@ -181,15 +182,13 @@ void U2Recognizer::AttentionRescoring() {
return;
}
kaldi::Timer timer;
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec.";
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
......@@ -197,12 +196,12 @@ void U2Recognizer::AttentionRescoring() {
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
VLOG(1) << "hyp: " << result_[0].sentence
VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(1) << "result: " << result_[0].sentence
VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
......
......@@ -31,6 +31,7 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_decode_time = 0.0;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
......@@ -47,9 +48,7 @@ int main(int argc, char* argv[]) {
ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::U2Recognizer recognizer(resource);
kaldi::Timer timer;
for (; !wav_reader.Done(); wav_reader.Next()) {
kaldi::Timer local_timer;
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
......@@ -65,6 +64,8 @@ int main(int argc, char* argv[]) {
int sample_offset = 0;
int cnt = 0;
kaldi::Timer timer;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
......@@ -95,6 +96,8 @@ int main(int argc, char* argv[]) {
// second pass decoding
recognizer.Rescoring();
tot_decode_time += timer.Elapsed();
std::string result = recognizer.GetFinalResult();
recognizer.Reset();
......@@ -115,10 +118,8 @@ int main(int argc, char* argv[]) {
++num_done;
}
double elapsed = timer.Elapsed();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total cost:" << elapsed << " sec";
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "RTF is: " << elapsed / tot_wav_duration;
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册