Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7dc9cba3
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7dc9cba3
编写于
10月 13, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ctc prefix beam search for u2, test can run
上级
3c3aa6b5
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
763 addition
and
365 deletion
+763
-365
speechx/examples/codelab/u2/.gitignore
speechx/examples/codelab/u2/.gitignore
+1
-0
speechx/examples/codelab/u2/README.md
speechx/examples/codelab/u2/README.md
+1
-0
speechx/examples/codelab/u2/local/decode.sh
speechx/examples/codelab/u2/local/decode.sh
+22
-0
speechx/examples/codelab/u2/local/feat.sh
speechx/examples/codelab/u2/local/feat.sh
+27
-0
speechx/examples/codelab/u2/local/nnet.sh
speechx/examples/codelab/u2/local/nnet.sh
+23
-0
speechx/examples/codelab/u2/path.sh
speechx/examples/codelab/u2/path.sh
+1
-2
speechx/examples/codelab/u2/run.sh
speechx/examples/codelab/u2/run.sh
+3
-24
speechx/examples/codelab/u2nnet/README.md
speechx/examples/codelab/u2nnet/README.md
+0
-3
speechx/examples/codelab/u2nnet/valgrind.sh
speechx/examples/codelab/u2nnet/valgrind.sh
+0
-21
speechx/speechx/decoder/CMakeLists.txt
speechx/speechx/decoder/CMakeLists.txt
+12
-1
speechx/speechx/decoder/ctc_beam_search_decoder.cc
speechx/speechx/decoder/ctc_beam_search_decoder.cc
+6
-4
speechx/speechx/decoder/ctc_beam_search_decoder.h
speechx/speechx/decoder/ctc_beam_search_decoder.h
+5
-8
speechx/speechx/decoder/ctc_beam_search_opt.h
speechx/speechx/decoder/ctc_beam_search_opt.h
+65
-0
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
+281
-238
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
+42
-29
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
...hx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
+188
-0
speechx/speechx/decoder/ctc_prefix_beam_search_result.h
speechx/speechx/decoder/ctc_prefix_beam_search_result.h
+41
-0
speechx/speechx/decoder/ctc_tlg_decoder.cc
speechx/speechx/decoder/ctc_tlg_decoder.cc
+11
-6
speechx/speechx/decoder/ctc_tlg_decoder.h
speechx/speechx/decoder/ctc_tlg_decoder.h
+15
-8
speechx/speechx/decoder/decoder_itf.h
speechx/speechx/decoder/decoder_itf.h
+13
-9
speechx/speechx/nnet/u2_nnet_main.cc
speechx/speechx/nnet/u2_nnet_main.cc
+0
-11
speechx/speechx/utils/math.cc
speechx/speechx/utils/math.cc
+6
-1
未找到文件。
speechx/examples/codelab/u2
nnet
/.gitignore
→
speechx/examples/codelab/u2/.gitignore
浏览文件 @
7dc9cba3
data
data
exp
*log
speechx/examples/codelab/u2/README.md
0 → 100644
浏览文件 @
7dc9cba3
# u2/u2pp Streaming Test
speechx/examples/codelab/u2/local/decode.sh
0 → 100755
浏览文件 @
7dc9cba3
#!/bin/bash
set
-x
set
-e
.
path.sh
data
=
data
exp
=
exp
mkdir
-p
$exp
ckpt_dir
=
./data/model
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
ctc_prefix_beam_search_decoder_main
\
--model_path
=
$model_dir
/export.jit
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--downsampling_rate
=
4
\
--vocab_path
=
$model_dir
/unit.txt
\
--feature_rspecifier
=
ark,t:
$exp
/fbank.ark
\
--result_wspecifier
=
ark,t:
$exp
/result.ark
echo
"u2 ctc prefix beam search decode."
speechx/examples/codelab/u2/local/feat.sh
0 → 100755
浏览文件 @
7dc9cba3
#!/bin/bash
set
-x
set
-e
.
path.sh
data
=
data
exp
=
exp
mkdir
-p
$exp
ckpt_dir
=
./data/model
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
cmvn_json2kaldi_main
\
--json_file
$model_dir
/mean_std.json
\
--cmvn_write_path
$exp
/cmvn.ark
\
--binary
=
false
echo
"convert json cmvn to kaldi ark."
compute_fbank_main
\
--num_bins
80
\
--wav_rspecifier
=
scp:
$data
/wav.scp
\
--cmvn_file
=
$exp
/cmvn.ark
\
--feature_wspecifier
=
ark,t:
$exp
/fbank.ark
echo
"compute fbank feature."
speechx/examples/codelab/u2/local/nnet.sh
0 → 100755
浏览文件 @
7dc9cba3
#!/bin/bash
set
-x
set
-e
.
path.sh
data
=
data
exp
=
exp
mkdir
-p
$exp
ckpt_dir
=
./data/model
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_nnet_main
\
--model_path
=
$model_dir
/export.jit
\
--feature_rspecifier
=
ark,t:
$exp
/fbank.ark
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--downsampling_rate
=
4
\
--acoustic_scale
=
1.0
\
--nnet_encoder_outs_wspecifier
=
ark,t:
$exp
/encoder_outs.ark
\
--nnet_prob_wspecifier
=
ark,t:
$exp
/logprobs.ark
echo
"u2 nnet decode."
speechx/examples/codelab/u2
nnet
/path.sh
→
speechx/examples/codelab/u2/path.sh
浏览文件 @
7dc9cba3
...
@@ -12,8 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
...
@@ -12,8 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export
LC_AL
=
C
export
LC_AL
=
C
SPEECHX_BIN
=
$SPEECHX_BUILD
/nnet
export
PATH
=
$PATH
:
$TOOLS_BIN
:
$SPEECHX_BUILD
/nnet:
$SPEECHX_BUILD
/decoder:
$SPEECHX_BUILD
/frontend/audio
export
PATH
=
$PATH
:
$SPEECHX_BIN
:
$TOOLS_BIN
PADDLE_LIB_PATH
=
$(
python
-c
"import paddle ; print(':'.join(paddle.sysconfig.get_lib()), end='')"
)
PADDLE_LIB_PATH
=
$(
python
-c
"import paddle ; print(':'.join(paddle.sysconfig.get_lib()), end='')"
)
export
LD_LIBRARY_PATH
=
$PADDLE_LIB_PATH
:
$LD_LIBRARY_PATH
export
LD_LIBRARY_PATH
=
$PADDLE_LIB_PATH
:
$LD_LIBRARY_PATH
speechx/examples/codelab/u2
nnet
/run.sh
→
speechx/examples/codelab/u2/run.sh
浏览文件 @
7dc9cba3
...
@@ -36,29 +36,8 @@ ckpt_dir=./data/model
...
@@ -36,29 +36,8 @@ ckpt_dir=./data/model
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
cmvn_json2kaldi_main
\
./local/feat.sh
--json_file
$model_dir
/mean_std.json
\
--cmvn_write_path
$exp
/cmvn.ark
\
--binary
=
false
echo
"convert json cmvn to kaldi ark."
./local/nnet.sh
compute_fbank_main
\
./local/decode.sh
--num_bins
80
\
--wav_rspecifier
=
scp:
$data
/wav.scp
\
--cmvn_file
=
$exp
/cmvn.ark
\
--feature_wspecifier
=
ark,t:
$exp
/fbank.ark
echo
"compute fbank feature."
u2_nnet_main
\
--model_path
=
$model_dir
/export.jit
\
--feature_rspecifier
=
ark,t:
$exp
/fbank.ark
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--downsampling_rate
=
4
\
--acoustic_scale
=
1.0
\
--nnet_encoder_outs_wspecifier
=
ark,t:
$exp
/encoder_outs.ark
\
--nnet_prob_wspecifier
=
ark,t:
$exp
/logprobs.ark
echo
"u2 nnet decode."
speechx/examples/codelab/u2nnet/README.md
已删除
100644 → 0
浏览文件 @
3c3aa6b5
# Deepspeech2 Streaming NNet Test
Using for ds2 streaming nnet inference test.
speechx/examples/codelab/u2nnet/valgrind.sh
已删除
100755 → 0
浏览文件 @
3c3aa6b5
#!/bin/bash
# this script is for memory check, so please run ./run.sh first.
set
+x
set
-e
.
./path.sh
if
[
!
-d
${
SPEECHX_TOOLS
}
/valgrind/install
]
;
then
echo
"please install valgrind in the speechx tools dir.
\n
"
exit
1
fi
ckpt_dir
=
./data/model
model_dir
=
$ckpt_dir
/exp/deepspeech2_online/checkpoints/
valgrind
--tool
=
memcheck
--track-origins
=
yes
--leak-check
=
full
--show-leak-kinds
=
all
\
ds2_model_test_main
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
--param_path
=
$model_dir
/avg_1.jit.pdparams
speechx/speechx/decoder/CMakeLists.txt
浏览文件 @
7dc9cba3
...
@@ -10,8 +10,9 @@ add_library(decoder STATIC
...
@@ -10,8 +10,9 @@ add_library(decoder STATIC
ctc_tlg_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
recognizer.cc
)
)
target_link_libraries
(
decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder
)
target_link_libraries
(
decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder
absl::strings
)
# test
set
(
BINS
set
(
BINS
ctc_beam_search_decoder_main
ctc_beam_search_decoder_main
nnet_logprob_decoder_main
nnet_logprob_decoder_main
...
@@ -24,3 +25,13 @@ foreach(bin_name IN LISTS BINS)
...
@@ -24,3 +25,13 @@ foreach(bin_name IN LISTS BINS)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
${
DEPS
}
)
target_link_libraries
(
${
bin_name
}
PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
${
DEPS
}
)
endforeach
()
endforeach
()
# u2
set
(
bin_name ctc_prefix_beam_search_decoder_main
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
)
\ No newline at end of file
speechx/speechx/decoder/ctc_beam_search_decoder.cc
浏览文件 @
7dc9cba3
...
@@ -82,8 +82,6 @@ void CTCBeamSearch::Decode(
...
@@ -82,8 +82,6 @@ void CTCBeamSearch::Decode(
return
;
return
;
}
}
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
+
1
;
}
// todo rename, refactor
// todo rename, refactor
void
CTCBeamSearch
::
AdvanceDecode
(
void
CTCBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
{
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
{
...
@@ -110,15 +108,19 @@ void CTCBeamSearch::ResetPrefixes() {
...
@@ -110,15 +108,19 @@ void CTCBeamSearch::ResetPrefixes() {
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
vector
<
string
>&
nbest_words
)
{
vector
<
string
>&
nbest_words
)
{
kaldi
::
Timer
timer
;
kaldi
::
Timer
timer
;
timer
.
Reset
();
AdvanceDecoding
(
probs
);
AdvanceDecoding
(
probs
);
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
<<
static_cast
<
float
>
(
timer
.
Elapsed
())
/
1000.0
f
;
<<
static_cast
<
float
>
(
timer
.
Elapsed
())
/
1000.0
f
;
return
0
;
return
0
;
}
}
vector
<
std
::
pair
<
double
,
string
>>
CTCBeamSearch
::
GetNBestPath
(
int
n
)
{
int
beam_size
=
n
==
-
1
?
opts_
.
beam_size
:
std
::
min
(
n
,
opts_
.
beam_size
);
return
get_beam_search_result
(
prefixes_
,
vocabulary_
,
beam_size
);
}
vector
<
std
::
pair
<
double
,
string
>>
CTCBeamSearch
::
GetNBestPath
()
{
vector
<
std
::
pair
<
double
,
string
>>
CTCBeamSearch
::
GetNBestPath
()
{
return
get_beam_search_result
(
prefixes_
,
vocabulary_
,
opts_
.
beam_size
);
return
GetNBestPath
(
-
1
);
}
}
string
CTCBeamSearch
::
GetBestPath
()
{
string
CTCBeamSearch
::
GetBestPath
()
{
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.h
浏览文件 @
7dc9cba3
...
@@ -35,6 +35,11 @@ class CTCBeamSearch : public DecoderInterface {
...
@@ -35,6 +35,11 @@ class CTCBeamSearch : public DecoderInterface {
void
AdvanceDecode
(
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
void
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
);
std
::
string
GetBestPath
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
);
std
::
string
GetFinalBestPath
();
std
::
string
GetFinalBestPath
();
std
::
string
GetPartialResult
()
{
std
::
string
GetPartialResult
()
{
...
@@ -42,14 +47,6 @@ class CTCBeamSearch : public DecoderInterface {
...
@@ -42,14 +47,6 @@ class CTCBeamSearch : public DecoderInterface {
return
{};
return
{};
}
}
void
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
);
std
::
string
GetBestPath
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
();
int
NumFrameDecoded
();
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
std
::
vector
<
std
::
string
>&
nbest_words
);
...
...
speechx/speechx/decoder/ctc_beam_search_opt.h
浏览文件 @
7dc9cba3
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
namespace
ppspeech
{
namespace
ppspeech
{
struct
CTCBeamSearchOptions
{
struct
CTCBeamSearchOptions
{
// common
// common
int
blank
;
int
blank
;
...
@@ -75,4 +76,68 @@ struct CTCBeamSearchOptions {
...
@@ -75,4 +76,68 @@ struct CTCBeamSearchOptions {
}
}
};
};
// used by u2 model
struct
CTCBeamSearchDecoderOptions
{
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 67=16*4 + 3, stride is 64=16*4
int
chunk_size
;
int
num_left_chunks
;
// final_score = rescoring_weight * rescoring_score + ctc_weight *
// ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores
// in the following two search methods are different. For
// CtcPrefixBeamSerch,
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
// max(viterbi) path score + context score So we should carefully set
// ctc_weight accroding to the search methods.
float
ctc_weight
;
float
rescoring_weight
;
float
reverse_weight
;
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions
ctc_prefix_search_opts
;
CTCBeamSearchDecoderOptions
()
:
chunk_size
(
16
),
num_left_chunks
(
-
1
),
ctc_weight
(
0.5
),
rescoring_weight
(
1.0
),
reverse_weight
(
0.0
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
std
::
string
module
=
"DecoderConfig: "
;
opts
->
Register
(
"chunk-size"
,
&
chunk_size
,
module
+
"the frame number of one chunk after subsampling."
);
opts
->
Register
(
"num-left-chunks"
,
&
num_left_chunks
,
module
+
"the left history chunks number."
);
opts
->
Register
(
"ctc-weight"
,
&
ctc_weight
,
module
+
"ctc weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score."
);
opts
->
Register
(
"rescoring-weight"
,
&
rescoring_weight
,
module
+
"attention score weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score."
);
opts
->
Register
(
"reverse-weight"
,
&
reverse_weight
,
module
+
"reverse decoder weight. rescoring_score = "
"left_to_right_score * (1 - reverse_weight) + "
"right_to_left_score * reverse_weight."
);
}
};
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
浏览文件 @
7dc9cba3
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -13,11 +15,12 @@
...
@@ -13,11 +15,12 @@
// limitations under the License.
// limitations under the License.
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "base/common.h"
#include "base/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "utils/math.h"
#include "utils/math.h"
#include "absl/strings/str_join.h"
#ifdef USE_PROFILING
#ifdef USE_PROFILING
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
@@ -29,85 +32,47 @@ namespace ppspeech {
...
@@ -29,85 +32,47 @@ namespace ppspeech {
CTCPrefixBeamSearch
::
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
)
CTCPrefixBeamSearch
::
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
)
:
opts_
(
opts
)
{
:
opts_
(
opts
)
{
InitDecoder
();
Reset
();
}
}
void
CTCPrefixBeamSearch
::
InitDecoder
()
{
void
CTCPrefixBeamSearch
::
Reset
()
{
num_frame_decoded_
=
0
;
num_frame_decoded_
=
0
;
cur_hyps_
.
clear
();
cur_hyps_
.
clear
();
hypotheses_
.
clear
();
hypotheses_
.
clear
();
likelihood_
.
clear
();
likelihood_
.
clear
();
viterbi_likelihood_
.
clear
();
viterbi_likelihood_
.
clear
();
times_
.
clear
();
times_
.
clear
();
outputs_
.
clear
();
outputs_
.
clear
();
abs_time_step_
=
0
;
// empty hyp with Score
std
::
vector
<
int
>
empty
;
PrefixScore
prefix_score
;
prefix_score
.
b
=
0.0
f
;
// log(1)
prefix_score
.
nb
=
-
kBaseFloatMax
;
// log(0)
prefix_score
.
v_b
=
0.0
f
;
// log(1)
prefix_score
.
v_nb
=
0.0
f
;
// log(1)
cur_hyps_
[
empty
]
=
prefix_score
;
// empty hyp with Score
outputs_
.
emplace_back
(
empty
);
std
::
vector
<
int
>
empty
;
hypotheses_
.
emplace_back
(
empty
);
PrefixScore
prefix_score
;
likelihood_
.
emplace_back
(
prefix_score
.
TotalScore
());
prefix_score
.
b
=
0.0
f
;
// log(1)
times_
.
emplace_back
(
empty
);
prefix_score
.
nb
=
-
kBaseFloatMax
;
// log(0)
}
prefix_score
.
v_b
=
0.0
f
;
// log(1)
prefix_score
.
v_nb
=
0.0
f
;
// log(1)
cur_hyps_
[
empty
]
=
prefix_score
;
outputs_
.
emplace_back
(
empty
);
void
CTCPrefixBeamSearch
::
InitDecoder
()
{
Reset
();
}
hypotheses_
.
emplace_back
(
empty
);
likelihood_
.
emplace_back
(
prefix_score
.
TotalScore
());
times_
.
emplace_back
(
empty
);
}
void
CTCPrefixBeamSearch
::
Reset
()
{
InitDecoder
();
}
void
CTCPrefixBeamSearch
::
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
)
{
return
;
}
int32
CTCPrefixBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
+
1
;
}
void
CTCPrefixBeamSearch
::
UpdateOutputs
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
prefix
)
{
const
std
::
vector
<
int
>&
input
=
prefix
.
first
;
// const std::vector<int>& start_boundaries = prefix.second.start_boundaries;
// const std::vector<int>& end_boundaries = prefix.second.end_boundaries;
std
::
vector
<
int
>
output
;
int
s
=
0
;
int
e
=
0
;
for
(
int
i
=
0
;
i
<
input
.
size
();
++
i
)
{
// if (s < start_boundaries.size() && i == start_boundaries[s]){
// // <context>
// output.emplace_back(context_graph_->start_tag_id());
// ++s;
// }
output
.
emplace_back
(
input
[
i
]);
// if (e < end_boundaries.size() && i == end_boundaries[e]){
// // </context>
// output.emplace_back(context_graph_->end_tag_id());
// ++e;
// }
}
outputs_
.
emplace_back
(
output
);
}
void
CTCPrefixBeamSearch
::
AdvanceDecode
(
void
CTCPrefixBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
{
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
{
while
(
1
)
{
while
(
1
)
{
// forward frame by frame
std
::
vector
<
kaldi
::
BaseFloat
>
frame_prob
;
std
::
vector
<
kaldi
::
BaseFloat
>
frame_prob
;
bool
flag
=
decodable
->
FrameLikelihood
(
num_frame_decoded_
,
&
frame_prob
);
bool
flag
=
decodable
->
FrameLikelihood
(
num_frame_decoded_
,
&
frame_prob
);
if
(
flag
==
false
)
break
;
if
(
flag
==
false
)
break
;
std
::
vector
<
std
::
vector
<
kaldi
::
BaseFloat
>>
likelihood
;
std
::
vector
<
std
::
vector
<
kaldi
::
BaseFloat
>>
likelihood
;
likelihood
.
push_back
(
frame_prob
);
likelihood
.
push_back
(
frame_prob
);
AdvanceDecoding
(
likelihood
);
AdvanceDecoding
(
likelihood
);
...
@@ -117,201 +82,279 @@ void CTCPrefixBeamSearch::AdvanceDecode(
...
@@ -117,201 +82,279 @@ void CTCPrefixBeamSearch::AdvanceDecode(
static
bool
PrefixScoreCompare
(
static
bool
PrefixScoreCompare
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
a
,
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
a
,
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
b
)
{
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
b
)
{
// log domain
// log domain
return
a
.
second
.
TotalScore
()
>
b
.
second
.
TotalScore
();
return
a
.
second
.
TotalScore
()
>
b
.
second
.
TotalScore
();
}
}
void
CTCPrefixBeamSearch
::
AdvanceDecoding
(
const
std
::
vector
<
std
::
vector
<
float
>>&
logp
)
{
void
CTCPrefixBeamSearch
::
AdvanceDecoding
(
const
std
::
vector
<
std
::
vector
<
kaldi
::
BaseFloat
>>&
logp
)
{
#ifdef USE_PROFILING
#ifdef USE_PROFILING
RecordEvent
event
(
RecordEvent
event
(
"CtcPrefixBeamSearch::AdvanceDecoding"
,
"CtcPrefixBeamSearch::AdvanceDecoding"
,
TracerEventType
::
UserDefined
,
1
);
TracerEventType
::
UserDefined
,
1
);
#endif
#endif
if
(
logp
.
size
()
==
0
)
return
;
if
(
logp
.
size
()
==
0
)
return
;
int
first_beam_size
=
int
first_beam_size
=
std
::
min
(
static_cast
<
int
>
(
logp
[
0
].
size
()),
opts_
.
first_beam_size
);
std
::
min
(
static_cast
<
int
>
(
logp
[
0
].
size
()),
opts_
.
first_beam_size
);
for
(
int
t
=
0
;
t
<
logp
.
size
();
++
t
,
++
abs_time_step_
)
{
for
(
int
t
=
0
;
t
<
logp
.
size
();
++
t
,
++
num_frame_decoded_
)
{
const
std
::
vector
<
float
>&
logp_t
=
logp
[
t
];
const
std
::
vector
<
float
>&
logp_t
=
logp
[
t
];
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixScoreHash
>
next_hyps
;
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixScoreHash
>
next_hyps
;
// 1. first beam prune, only select topk candidates
std
::
vector
<
float
>
topk_score
;
// 1. first beam prune, only select topk candidates
std
::
vector
<
int32_t
>
topk_index
;
std
::
vector
<
float
>
topk_score
;
TopK
(
logp_t
,
first_beam_size
,
&
topk_score
,
&
topk_index
);
std
::
vector
<
int32_t
>
topk_index
;
TopK
(
logp_t
,
first_beam_size
,
&
topk_score
,
&
topk_index
);
// 2. token passing
for
(
int
i
=
0
;
i
<
topk_index
.
size
();
++
i
)
{
// 2. token passing
int
id
=
topk_index
[
i
];
for
(
int
i
=
0
;
i
<
topk_index
.
size
();
++
i
)
{
auto
prob
=
topk_score
[
i
];
int
id
=
topk_index
[
i
];
auto
prob
=
topk_score
[
i
];
for
(
const
auto
&
it
:
cur_hyps_
)
{
const
std
::
vector
<
int
>&
prefix
=
it
.
first
;
for
(
const
auto
&
it
:
cur_hyps_
)
{
const
PrefixScore
&
prefix_score
=
it
.
second
;
const
std
::
vector
<
int
>&
prefix
=
it
.
first
;
const
PrefixScore
&
prefix_score
=
it
.
second
;
// If prefix doesn't exist in next_hyps, next_hyps[prefix] will insert
// PrefixScore(-inf, -inf) by default, since the default constructor
// If prefix doesn't exist in next_hyps, next_hyps[prefix] will
// of PrefixScore will set fields b(blank ending Score) and
// insert
// nb(none blank ending Score) to -inf, respectively.
// PrefixScore(-inf, -inf) by default, since the default
// constructor
if
(
id
==
opts_
.
blank
)
{
// of PrefixScore will set fields b(blank ending Score) and
// case 0: *a + <blank> => *a, *a<blank> + <blank> => *a, prefix not
// nb(none blank ending Score) to -inf, respectively.
// change
PrefixScore
&
next_score
=
next_hyps
[
prefix
];
if
(
id
==
opts_
.
blank
)
{
next_score
.
b
=
LogSumExp
(
next_score
.
b
,
prefix_score
.
Score
()
+
prob
);
// case 0: *a + <blank> => *a, *a<blank> + <blank> => *a,
// prefix not
// timestamp, blank is slince, not effact timestamp
// change
next_score
.
v_b
=
prefix_score
.
ViterbiScore
()
+
prob
;
PrefixScore
&
next_score
=
next_hyps
[
prefix
];
next_score
.
times_b
=
prefix_score
.
Times
();
next_score
.
b
=
LogSumExp
(
next_score
.
b
,
prefix_score
.
Score
()
+
prob
);
// Prefix not changed, copy the context from pefix
if
(
context_graph_
&&
!
next_score
.
has_context
)
{
// timestamp, blank is slince, not effact timestamp
next_score
.
CopyContext
(
prefix_score
);
next_score
.
v_b
=
prefix_score
.
ViterbiScore
()
+
prob
;
next_score
.
has_context
=
true
;
next_score
.
times_b
=
prefix_score
.
Times
();
}
// Prefix not changed, copy the context from pefix
}
else
if
(
!
prefix
.
empty
()
&&
id
==
prefix
.
back
())
{
if
(
context_graph_
&&
!
next_score
.
has_context
)
{
// case 1: *a + a => *a, prefix not changed
next_score
.
CopyContext
(
prefix_score
);
PrefixScore
&
next_score1
=
next_hyps
[
prefix
];
next_score
.
has_context
=
true
;
next_score1
.
nb
=
LogSumExp
(
next_score1
.
nb
,
prefix_score
.
nb
+
prob
);
}
// timestamp, non-blank symbol effact timestamp
}
else
if
(
!
prefix
.
empty
()
&&
id
==
prefix
.
back
())
{
if
(
next_score1
.
v_nb
<
prefix_score
.
v_nb
+
prob
)
{
// case 1: *a + a => *a, prefix not changed
// compute viterbi Score
PrefixScore
&
next_score1
=
next_hyps
[
prefix
];
next_score1
.
v_nb
=
prefix_score
.
v_nb
+
prob
;
next_score1
.
nb
=
if
(
next_score1
.
cur_token_prob
<
prob
)
{
LogSumExp
(
next_score1
.
nb
,
prefix_score
.
nb
+
prob
);
// store max token prob
next_score1
.
cur_token_prob
=
prob
;
// timestamp, non-blank symbol effact timestamp
// update this timestamp as token appeared here.
if
(
next_score1
.
v_nb
<
prefix_score
.
v_nb
+
prob
)
{
next_score1
.
times_nb
=
prefix_score
.
times_nb
;
// compute viterbi Score
assert
(
next_score1
.
times_nb
.
size
()
>
0
);
next_score1
.
v_nb
=
prefix_score
.
v_nb
+
prob
;
next_score1
.
times_nb
.
back
()
=
abs_time_step_
;
if
(
next_score1
.
cur_token_prob
<
prob
)
{
}
// store max token prob
}
next_score1
.
cur_token_prob
=
prob
;
// update this timestamp as token appeared here.
// Prefix not changed, copy the context from pefix
next_score1
.
times_nb
=
prefix_score
.
times_nb
;
if
(
context_graph_
&&
!
next_score1
.
has_context
)
{
assert
(
next_score1
.
times_nb
.
size
()
>
0
);
next_score1
.
CopyContext
(
prefix_score
);
next_score1
.
times_nb
.
back
()
=
num_frame_decoded_
;
next_score1
.
has_context
=
true
;
}
}
}
// case 2: *a<blank> + a => *aa, prefix changed.
// Prefix not changed, copy the context from pefix
std
::
vector
<
int
>
new_prefix
(
prefix
);
if
(
context_graph_
&&
!
next_score1
.
has_context
)
{
new_prefix
.
emplace_back
(
id
);
next_score1
.
CopyContext
(
prefix_score
);
PrefixScore
&
next_score2
=
next_hyps
[
new_prefix
];
next_score1
.
has_context
=
true
;
next_score2
.
nb
=
LogSumExp
(
next_score2
.
nb
,
prefix_score
.
b
+
prob
);
}
// timestamp, non-blank symbol effact timestamp
// case 2: *a<blank> + a => *aa, prefix changed.
if
(
next_score2
.
v_nb
<
prefix_score
.
v_b
+
prob
)
{
std
::
vector
<
int
>
new_prefix
(
prefix
);
// compute viterbi Score
new_prefix
.
emplace_back
(
id
);
next_score2
.
v_nb
=
prefix_score
.
v_b
+
prob
;
PrefixScore
&
next_score2
=
next_hyps
[
new_prefix
];
// new token added
next_score2
.
nb
=
next_score2
.
cur_token_prob
=
prob
;
LogSumExp
(
next_score2
.
nb
,
prefix_score
.
b
+
prob
);
next_score2
.
times_nb
=
prefix_score
.
times_b
;
next_score2
.
times_nb
.
emplace_back
(
abs_time_step_
);
// timestamp, non-blank symbol effact timestamp
}
if
(
next_score2
.
v_nb
<
prefix_score
.
v_b
+
prob
)
{
// compute viterbi Score
// Prefix changed, calculate the context Score.
next_score2
.
v_nb
=
prefix_score
.
v_b
+
prob
;
if
(
context_graph_
&&
!
next_score2
.
has_context
)
{
// new token added
next_score2
.
UpdateContext
(
next_score2
.
cur_token_prob
=
prob
;
context_graph_
,
prefix_score
,
id
,
prefix
.
size
());
next_score2
.
times_nb
=
prefix_score
.
times_b
;
next_score2
.
has_context
=
true
;
next_score2
.
times_nb
.
emplace_back
(
num_frame_decoded_
);
}
}
}
else
{
// Prefix changed, calculate the context Score.
// id != prefix.back()
if
(
context_graph_
&&
!
next_score2
.
has_context
)
{
// case 3: *a + b => *ab, *a<blank> +b => *ab
next_score2
.
UpdateContext
(
std
::
vector
<
int
>
new_prefix
(
prefix
);
context_graph_
,
prefix_score
,
id
,
prefix
.
size
());
new_prefix
.
emplace_back
(
id
);
next_score2
.
has_context
=
true
;
PrefixScore
&
next_score
=
next_hyps
[
new_prefix
];
}
next_score
.
nb
=
LogSumExp
(
next_score
.
nb
,
prefix_score
.
Score
()
+
prob
);
}
else
{
// timetamp, non-blank symbol effact timestamp
// id != prefix.back()
if
(
next_score
.
v_nb
<
prefix_score
.
ViterbiScore
()
+
prob
)
{
// case 3: *a + b => *ab, *a<blank> +b => *ab
next_score
.
v_nb
=
prefix_score
.
ViterbiScore
()
+
prob
;
std
::
vector
<
int
>
new_prefix
(
prefix
);
new_prefix
.
emplace_back
(
id
);
next_score
.
cur_token_prob
=
prob
;
PrefixScore
&
next_score
=
next_hyps
[
new_prefix
];
next_score
.
times_nb
=
prefix_score
.
Times
();
next_score
.
nb
=
next_score
.
times_nb
.
emplace_back
(
abs_time_step_
);
LogSumExp
(
next_score
.
nb
,
prefix_score
.
Score
()
+
prob
);
}
// timetamp, non-blank symbol effact timestamp
// Prefix changed, calculate the context Score.
if
(
next_score
.
v_nb
<
prefix_score
.
ViterbiScore
()
+
prob
)
{
if
(
context_graph_
&&
!
next_score
.
has_context
)
{
next_score
.
v_nb
=
prefix_score
.
ViterbiScore
()
+
prob
;
next_score
.
UpdateContext
(
context_graph_
,
prefix_score
,
id
,
prefix
.
size
());
next_score
.
cur_token_prob
=
prob
;
next_score
.
has_context
=
true
;
next_score
.
times_nb
=
prefix_score
.
Times
();
}
next_score
.
times_nb
.
emplace_back
(
num_frame_decoded_
);
}
}
}
// end for (const auto& it : cur_hyps_)
}
// end for (int i = 0; i < topk_index.size(); ++i)
// Prefix changed, calculate the context Score.
if
(
context_graph_
&&
!
next_score
.
has_context
)
{
// 3. second beam prune, only keep top n best paths
next_score
.
UpdateContext
(
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>
arr
(
next_hyps
.
begin
(),
context_graph_
,
prefix_score
,
id
,
prefix
.
size
());
next_hyps
.
end
());
next_score
.
has_context
=
true
;
int
second_beam_size
=
}
std
::
min
(
static_cast
<
int
>
(
arr
.
size
()),
opts_
.
second_beam_size
);
}
std
::
nth_element
(
arr
.
begin
(),
}
// end for (const auto& it : cur_hyps_)
arr
.
begin
()
+
second_beam_size
,
}
// end for (int i = 0; i < topk_index.size(); ++i)
arr
.
end
(),
PrefixScoreCompare
);
// 3. second beam prune, only keep top n best paths
arr
.
resize
(
second_beam_size
);
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>
arr
(
std
::
sort
(
arr
.
begin
(),
arr
.
end
(),
PrefixScoreCompare
);
next_hyps
.
begin
(),
next_hyps
.
end
());
int
second_beam_size
=
// 4. update cur_hyps by next_hyps, and get new result
std
::
min
(
static_cast
<
int
>
(
arr
.
size
()),
opts_
.
second_beam_size
);
UpdateHypotheses
(
arr
);
std
::
nth_element
(
arr
.
begin
(),
arr
.
begin
()
+
second_beam_size
,
num_frame_decoded_
++
;
arr
.
end
(),
}
// end for (int t = 0; t < logp.size(); ++t, ++abs_time_step_)
PrefixScoreCompare
);
arr
.
resize
(
second_beam_size
);
std
::
sort
(
arr
.
begin
(),
arr
.
end
(),
PrefixScoreCompare
);
// 4. update cur_hyps by next_hyps, and get new result
UpdateHypotheses
(
arr
);
}
// end for (int t = 0; t < logp.size(); ++t, ++num_frame_decoded_)
}
}
void
CTCPrefixBeamSearch
::
UpdateHypotheses
(
void
CTCPrefixBeamSearch
::
UpdateHypotheses
(
const
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>&
hyps
)
{
const
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>&
hyps
)
{
cur_hyps_
.
clear
();
cur_hyps_
.
clear
();
outputs_
.
clear
();
outputs_
.
clear
();
hypotheses_
.
clear
();
hypotheses_
.
clear
();
likelihood_
.
clear
();
likelihood_
.
clear
();
viterbi_likelihood_
.
clear
();
viterbi_likelihood_
.
clear
();
times_
.
clear
();
times_
.
clear
();
for
(
auto
&
item
:
hyps
)
{
for
(
auto
&
item
:
hyps
)
{
cur_hyps_
[
item
.
first
]
=
item
.
second
;
cur_hyps_
[
item
.
first
]
=
item
.
second
;
UpdateOutputs
(
item
);
UpdateOutputs
(
item
);
hypotheses_
.
emplace_back
(
std
::
move
(
item
.
first
));
hypotheses_
.
emplace_back
(
std
::
move
(
item
.
first
));
likelihood_
.
emplace_back
(
item
.
second
.
TotalScore
());
likelihood_
.
emplace_back
(
item
.
second
.
TotalScore
());
viterbi_likelihood_
.
emplace_back
(
item
.
second
.
ViterbiScore
());
viterbi_likelihood_
.
emplace_back
(
item
.
second
.
ViterbiScore
());
times_
.
emplace_back
(
item
.
second
.
Times
());
times_
.
emplace_back
(
item
.
second
.
Times
());
}
}
}
}
void
CTCPrefixBeamSearch
::
FinalizeSearch
()
{
UpdateFinalContext
();
}
void
CTCPrefixBeamSearch
::
UpdateOutputs
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
prefix
)
{
const
std
::
vector
<
int
>&
input
=
prefix
.
first
;
const
std
::
vector
<
int
>&
start_boundaries
=
prefix
.
second
.
start_boundaries
;
const
std
::
vector
<
int
>&
end_boundaries
=
prefix
.
second
.
end_boundaries
;
// add <context> </context> tag
std
::
vector
<
int
>
output
;
int
s
=
0
;
int
e
=
0
;
for
(
int
i
=
0
;
i
<
input
.
size
();
++
i
)
{
// if (s < start_boundaries.size() && i == start_boundaries[s]){
// // <context>
// output.emplace_back(context_graph_->start_tag_id());
// ++s;
// }
output
.
emplace_back
(
input
[
i
]);
// if (e < end_boundaries.size() && i == end_boundaries[e]){
// // </context>
// output.emplace_back(context_graph_->end_tag_id());
// ++e;
// }
}
outputs_
.
emplace_back
(
output
);
}
void
CTCPrefixBeamSearch
::
FinalizeSearch
()
{
UpdateFinalContext
();
}
void
CTCPrefixBeamSearch
::
UpdateFinalContext
()
{
void
CTCPrefixBeamSearch
::
UpdateFinalContext
()
{
if
(
context_graph_
==
nullptr
)
return
;
if
(
context_graph_
==
nullptr
)
return
;
assert
(
hypotheses_
.
size
()
==
cur_hyps_
.
size
());
assert
(
hypotheses_
.
size
()
==
likelihood_
.
size
());
CHECK
(
hypotheses_
.
size
()
==
cur_hyps_
.
size
());
CHECK
(
hypotheses_
.
size
()
==
likelihood_
.
size
());
// We should backoff the context Score/state when the context is
// not fully matched at the last time.
// We should backoff the context Score/state when the context is
for
(
const
auto
&
prefix
:
hypotheses_
)
{
// not fully matched at the last time.
PrefixScore
&
prefix_score
=
cur_hyps_
[
prefix
];
for
(
const
auto
&
prefix
:
hypotheses_
)
{
if
(
prefix_score
.
context_score
!=
0
)
{
PrefixScore
&
prefix_score
=
cur_hyps_
[
prefix
];
// prefix_score.UpdateContext(context_graph_, prefix_score, 0,
if
(
prefix_score
.
context_score
!=
0
)
{
// prefix.size());
prefix_score
.
UpdateContext
(
context_graph_
,
prefix_score
,
0
,
prefix
.
size
());
}
}
}
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>
arr
(
cur_hyps_
.
begin
(),
cur_hyps_
.
end
());
std
::
sort
(
arr
.
begin
(),
arr
.
end
(),
PrefixScoreCompare
);
// Update cur_hyps_ and get new result
UpdateHypotheses
(
arr
);
}
std
::
string
CTCPrefixBeamSearch
::
GetBestPath
(
int
index
)
{
int
n_hyps
=
Outputs
().
size
();
CHECK
(
n_hyps
>
0
);
CHECK
(
index
<
n_hyps
);
std
::
vector
<
int
>
one
=
Outputs
()[
index
];
return
std
::
string
(
absl
::
StrJoin
(
one
,
kSpaceSymbol
));
}
std
::
string
CTCPrefixBeamSearch
::
GetBestPath
()
{
return
GetBestPath
(
0
);
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
CTCPrefixBeamSearch
::
GetNBestPath
(
int
n
)
{
int
hyps_size
=
hypotheses_
.
size
();
CHECK
(
hyps_size
>
0
);
int
min_n
=
n
==
-
1
?
hypotheses_
.
size
()
:
std
::
min
(
n
,
hyps_size
);
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
n_best
;
n_best
.
reserve
(
min_n
);
for
(
int
i
=
0
;
i
<
min_n
;
i
++
){
n_best
.
emplace_back
(
Likelihood
()[
i
],
GetBestPath
(
i
)
);
}
return
n_best
;
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
CTCPrefixBeamSearch
::
GetNBestPath
()
{
return
GetNBestPath
(
-
1
);
}
}
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>
arr
(
cur_hyps_
.
begin
(),
cur_hyps_
.
end
());
std
::
sort
(
arr
.
begin
(),
arr
.
end
(),
PrefixScoreCompare
);
// Update cur_hyps_ and get new result
std
::
string
CTCPrefixBeamSearch
::
GetFinalBestPath
()
{
UpdateHypotheses
(
arr
);
return
GetBestPath
();
}
std
::
string
CTCPrefixBeamSearch
::
GetPartialResult
()
{
return
GetBestPath
();
}
}
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
\ No newline at end of file
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
浏览文件 @
7dc9cba3
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_result.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h"
#include "decoder/decoder_itf.h"
...
@@ -25,48 +26,37 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -25,48 +26,37 @@ class CTCPrefixBeamSearch : public DecoderInterface {
explicit
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
);
explicit
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
);
~
CTCPrefixBeamSearch
()
{}
~
CTCPrefixBeamSearch
()
{}
void
InitDecoder
();
void
InitDecoder
()
override
;
void
Reset
();
void
Reset
()
override
;
void
AdvanceDecode
(
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
override
;
std
::
string
GetFinalBestPath
();
std
::
string
GetFinalBestPath
()
override
;
std
::
string
GetPartialResult
()
override
;
std
::
string
GetPartialResult
()
{
void
FinalizeSearch
();
CHECK
(
false
)
<<
"Not implement."
;
return
{};
}
void
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
);
std
::
string
GetBestPath
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
();
int
NumFrameDecoded
();
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
const
std
::
vector
<
float
>&
ViterbiLikelihood
()
const
{
protected:
return
viterbi_likelihood_
;
std
::
string
GetBestPath
()
override
;
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
override
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
)
override
;
const
std
::
vector
<
std
::
vector
<
int
>>&
Inputs
()
const
{
return
hypotheses_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Inputs
()
const
{
return
hypotheses_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Outputs
()
const
{
return
outputs_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Outputs
()
const
{
return
outputs_
;
}
const
std
::
vector
<
float
>&
Likelihood
()
const
{
return
likelihood_
;
}
const
std
::
vector
<
float
>&
Likelihood
()
const
{
return
likelihood_
;
}
const
std
::
vector
<
float
>&
ViterbiLikelihood
()
const
{
return
viterbi_likelihood_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
{
return
times_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
{
return
times_
;
}
private:
private:
void
AdvanceDecoding
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
logp
);
std
::
string
GetBestPath
(
int
index
);
void
FinalizeSearch
();
void
AdvanceDecoding
(
const
std
::
vector
<
std
::
vector
<
kaldi
::
BaseFloat
>>&
logp
);
void
UpdateOutputs
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
prefix
);
void
UpdateOutputs
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
prefix
);
void
UpdateHypotheses
(
void
UpdateHypotheses
(
...
@@ -77,8 +67,6 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -77,8 +67,6 @@ class CTCPrefixBeamSearch : public DecoderInterface {
private:
private:
CTCBeamSearchOptions
opts_
;
CTCBeamSearchOptions
opts_
;
int
abs_time_step_
=
0
;
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixScoreHash
>
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixScoreHash
>
cur_hyps_
;
cur_hyps_
;
...
@@ -97,4 +85,29 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -97,4 +85,29 @@ class CTCPrefixBeamSearch : public DecoderInterface {
DISALLOW_COPY_AND_ASSIGN
(
CTCPrefixBeamSearch
);
DISALLOW_COPY_AND_ASSIGN
(
CTCPrefixBeamSearch
);
};
};
class
CTCPrefixBeamSearchDecoder
:
public
CTCPrefixBeamSearch
{
public:
explicit
CTCPrefixBeamSearchDecoder
(
const
CTCBeamSearchDecoderOptions
&
opts
)
:
CTCPrefixBeamSearch
(
opts
.
ctc_prefix_search_opts
),
opts_
(
opts
)
{}
~
CTCPrefixBeamSearchDecoder
()
{}
private:
CTCBeamSearchDecoderOptions
opts_
;
// cache feature
bool
start_
=
false
;
// false, this is first frame.
// for continues decoding
int
num_frames_
=
0
;
int
global_frame_offset_
=
0
;
const
int
time_stamp_gap_
=
100
;
// timestamp gap between words in a sentence
// std::unique_ptr<CtcEndpoint> ctc_endpointer_;
int
num_frames_in_current_chunk_
=
0
;
std
::
vector
<
DecodeResult
>
result_
;
};
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
0 → 100644
浏览文件 @
7dc9cba3
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/u2_nnet.h"
#include "absl/strings/str_split.h"
#include "fst/symbol-table.h"
DEFINE_string
(
feature_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_string
(
vocab_path
,
""
,
"vocab path"
);
DEFINE_string
(
model_path
,
""
,
"paddle nnet model"
);
DEFINE_int32
(
receptive_field_length
,
7
,
"receptive field of two CNN(kernel=3) downsampling module."
);
DEFINE_int32
(
downsampling_rate
,
4
,
"two CNN(kernel=3) module downsampling rate."
);
DEFINE_int32
(
nnet_decoder_chunk
,
16
,
"paddle nnet forward chunk"
);
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
// test ds2 online decoder by feeding speech feature
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int32
num_done
=
0
,
num_err
=
0
;
CHECK
(
FLAGS_result_wspecifier
!=
""
);
CHECK
(
FLAGS_feature_rspecifier
!=
""
);
CHECK
(
FLAGS_vocab_path
!=
""
);
CHECK
(
FLAGS_model_path
!=
""
);
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_rspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
LOG
(
INFO
)
<<
"Reading vocab table "
<<
FLAGS_vocab_path
;
fst
::
SymbolTable
*
unit_table
=
fst
::
SymbolTable
::
ReadText
(
FLAGS_vocab_path
);
// nnet
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
FLAGS_model_path
;
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet
(
new
ppspeech
::
U2Nnet
(
model_opts
));
// decodeable
std
::
shared_ptr
<
ppspeech
::
DataCache
>
raw_data
(
new
ppspeech
::
DataCache
());
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
,
raw_data
));
// decoder
ppspeech
::
CTCBeamSearchDecoderOptions
opts
;
opts
.
chunk_size
=
16
;
opts
.
num_left_chunks
=
-
1
;
opts
.
ctc_weight
=
0.5
;
opts
.
rescoring_weight
=
1.0
;
opts
.
reverse_weight
=
0.3
;
opts
.
ctc_prefix_search_opts
.
blank
=
0
;
opts
.
ctc_prefix_search_opts
.
first_beam_size
=
10
;
opts
.
ctc_prefix_search_opts
.
second_beam_size
=
10
;
ppspeech
::
CTCPrefixBeamSearchDecoder
decoder
(
opts
);
int32
chunk_size
=
FLAGS_receptive_field_length
+
(
FLAGS_nnet_decoder_chunk
-
1
)
*
FLAGS_downsampling_rate
;
int32
chunk_stride
=
FLAGS_downsampling_rate
*
FLAGS_nnet_decoder_chunk
;
int32
receptive_field_length
=
FLAGS_receptive_field_length
;
LOG
(
INFO
)
<<
"chunk size (frame): "
<<
chunk_size
;
LOG
(
INFO
)
<<
"chunk stride (frame): "
<<
chunk_stride
;
LOG
(
INFO
)
<<
"receptive field (frame): "
<<
receptive_field_length
;
decoder
.
InitDecoder
();
kaldi
::
Timer
timer
;
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
string
utt
=
feature_reader
.
Key
();
kaldi
::
Matrix
<
BaseFloat
>
feature
=
feature_reader
.
Value
();
int
nframes
=
feature
.
NumRows
();
int
feat_dim
=
feature
.
NumCols
();
raw_data
->
SetDim
(
feat_dim
);
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"feat shape: "
<<
nframes
<<
", "
<<
feat_dim
;
raw_data
->
SetDim
(
feat_dim
);
int32
ori_feature_len
=
feature
.
NumRows
();
int32
num_chunks
=
feature
.
NumRows
()
/
chunk_stride
+
1
;
LOG
(
INFO
)
<<
"num_chunks: "
<<
num_chunks
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_chunks
;
++
chunk_idx
)
{
int32
this_chunk_size
=
0
;
if
(
ori_feature_len
>
chunk_idx
*
chunk_stride
)
{
this_chunk_size
=
std
::
min
(
ori_feature_len
-
chunk_idx
*
chunk_stride
,
chunk_size
);
}
if
(
this_chunk_size
<
receptive_field_length
)
{
LOG
(
WARNING
)
<<
"utt: "
<<
utt
<<
" skip last "
<<
this_chunk_size
<<
" frames, expect is "
<<
receptive_field_length
;
break
;
}
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
feature_chunk
(
this_chunk_size
*
feat_dim
);
int32
start
=
chunk_idx
*
chunk_stride
;
for
(
int
row_id
=
0
;
row_id
<
this_chunk_size
;
++
row_id
)
{
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
feat_row
(
feature
,
start
);
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
feature_chunk_row
(
feature_chunk
.
Data
()
+
row_id
*
feat_dim
,
feat_dim
);
feature_chunk_row
.
CopyFromVec
(
feat_row
);
++
start
;
}
// feat to frontend pipeline cache
raw_data
->
Accept
(
feature_chunk
);
// send data finish signal
if
(
chunk_idx
==
num_chunks
-
1
)
{
raw_data
->
SetFinished
();
}
// forward nnet
decoder
.
AdvanceDecode
(
decodable
);
}
decoder
.
FinalizeSearch
();
// get 1-best result
std
::
string
result_ints
=
decoder
.
GetFinalBestPath
();
std
::
vector
<
std
::
string
>
tokenids
=
absl
::
StrSplit
(
result_ints
,
ppspeech
::
kSpaceSymbol
);
std
::
string
result
;
for
(
int
i
=
0
;
i
<
tokenids
.
size
();
i
++
){
result
+=
unit_table
->
Find
(
std
::
stoi
(
tokenids
[
i
]));
}
// after process one utt, then reset state.
decodable
->
Reset
();
decoder
.
Reset
();
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
++
num_err
;
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is empty"
;
continue
;
}
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is "
<<
result
;
result_writer
.
Write
(
utt
,
result
);
++
num_done
;
}
double
elapsed
=
timer
.
Elapsed
();
LOG
(
INFO
)
<<
"Program cost:"
<<
elapsed
<<
" sec"
;
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" utterances, "
<<
num_err
<<
" with errors."
;
return
(
num_done
!=
0
?
0
:
1
);
}
speechx/speechx/decoder/ctc_prefix_beam_search_result.h
0 → 100644
浏览文件 @
7dc9cba3
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
namespace
ppspeech
{
struct
WordPiece
{
std
::
string
word
;
int
start
=
-
1
;
int
end
=
-
1
;
WordPiece
(
std
::
string
word
,
int
start
,
int
end
)
:
word
(
std
::
move
(
word
)),
start
(
start
),
end
(
end
)
{}
};
struct
DecodeResult
{
float
score
=
-
kBaseFloatMax
;
std
::
string
sentence
;
std
::
vector
<
WordPiece
>
word_pieces
;
static
bool
CompareFunc
(
const
DecodeResult
&
a
,
const
DecodeResult
&
b
)
{
return
a
.
score
>
b
.
score
;
}
};
}
// namespace ppspeech
speechx/speechx/decoder/ctc_tlg_decoder.cc
浏览文件 @
7dc9cba3
...
@@ -18,16 +18,23 @@ namespace ppspeech {
...
@@ -18,16 +18,23 @@ namespace ppspeech {
TLGDecoder
::
TLGDecoder
(
TLGDecoderOptions
opts
)
{
TLGDecoder
::
TLGDecoder
(
TLGDecoderOptions
opts
)
{
fst_
.
reset
(
fst
::
Fst
<
fst
::
StdArc
>::
Read
(
opts
.
fst_path
));
fst_
.
reset
(
fst
::
Fst
<
fst
::
StdArc
>::
Read
(
opts
.
fst_path
));
CHECK
(
fst_
!=
nullptr
);
CHECK
(
fst_
!=
nullptr
);
word_symbol_table_
.
reset
(
word_symbol_table_
.
reset
(
fst
::
SymbolTable
::
ReadText
(
opts
.
word_symbol_table
));
fst
::
SymbolTable
::
ReadText
(
opts
.
word_symbol_table
));
decoder_
.
reset
(
new
kaldi
::
LatticeFasterOnlineDecoder
(
*
fst_
,
opts
.
opts
));
decoder_
.
reset
(
new
kaldi
::
LatticeFasterOnlineDecoder
(
*
fst_
,
opts
.
opts
));
Reset
();
}
void
TLGDecoder
::
Reset
()
{
decoder_
->
InitDecoding
();
decoder_
->
InitDecoding
();
num_frame_decoded_
=
0
;
num_frame_decoded_
=
0
;
return
;
}
}
void
TLGDecoder
::
InitDecoder
()
{
void
TLGDecoder
::
InitDecoder
()
{
decoder_
->
InitDecoding
();
Reset
();
num_frame_decoded_
=
0
;
}
}
void
TLGDecoder
::
AdvanceDecode
(
void
TLGDecoder
::
AdvanceDecode
(
...
@@ -42,10 +49,7 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
...
@@ -42,10 +49,7 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
num_frame_decoded_
++
;
num_frame_decoded_
++
;
}
}
void
TLGDecoder
::
Reset
()
{
InitDecoder
();
return
;
}
std
::
string
TLGDecoder
::
GetPartialResult
()
{
std
::
string
TLGDecoder
::
GetPartialResult
()
{
if
(
num_frame_decoded_
==
0
)
{
if
(
num_frame_decoded_
==
0
)
{
...
@@ -88,4 +92,5 @@ std::string TLGDecoder::GetFinalBestPath() {
...
@@ -88,4 +92,5 @@ std::string TLGDecoder::GetFinalBestPath() {
}
}
return
words
;
return
words
;
}
}
}
}
speechx/speechx/decoder/ctc_tlg_decoder.h
浏览文件 @
7dc9cba3
...
@@ -42,20 +42,27 @@ class TLGDecoder : public DecoderInterface {
...
@@ -42,20 +42,27 @@ class TLGDecoder : public DecoderInterface {
void
AdvanceDecode
(
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
std
::
string
GetFinalBestPath
();
std
::
string
GetPartialResult
();
void
Decode
();
void
Decode
();
std
::
string
Get
BestPath
()
;
std
::
string
Get
FinalBestPath
()
override
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
;
std
::
string
GetPartialResult
()
override
;
int
NumFrameDecoded
();
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
std
::
vector
<
std
::
string
>&
nbest_words
);
protected:
std
::
string
GetBestPath
()
override
{
CHECK
(
false
);
return
{};
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
override
{
CHECK
(
false
);
return
{};
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
)
override
{
CHECK
(
false
);
return
{};
}
private:
private:
void
AdvanceDecoding
(
kaldi
::
DecodableInterface
*
decodable
);
void
AdvanceDecoding
(
kaldi
::
DecodableInterface
*
decodable
);
...
...
speechx/speechx/decoder/decoder_itf.h
浏览文件 @
7dc9cba3
...
@@ -28,27 +28,31 @@ class DecoderInterface {
...
@@ -28,27 +28,31 @@ class DecoderInterface {
virtual
void
Reset
()
=
0
;
virtual
void
Reset
()
=
0
;
// call AdvanceDecoding
virtual
void
AdvanceDecode
(
virtual
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
=
0
;
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
=
0
;
// call GetBestPath
virtual
std
::
string
GetFinalBestPath
()
=
0
;
virtual
std
::
string
GetFinalBestPath
()
=
0
;
virtual
std
::
string
GetPartialResult
()
=
0
;
virtual
std
::
string
GetPartialResult
()
=
0
;
// void Decode();
protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
// std::string GetBestPath();
// virtual void Decode() = 0;
// std::vector<std::pair<double, std::string>> GetNBestPath();
// int NumFrameDecoded();
virtual
std
::
string
GetBestPath
()
=
0
;
// int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
// std::vector<std::string>& nbest_words);
virtual
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
=
0
;
protected:
virtual
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
)
=
0
;
// void AdvanceDecoding(kaldi::DecodableInterface* decodable);
// current decoding frame number
// start from one
int
NumFrameDecoded
()
{
return
num_frame_decoded_
+
1
;
}
protected:
// current decoding frame number, abs_time_step_
int32
num_frame_decoded_
;
int32
num_frame_decoded_
;
};
};
...
...
speechx/speechx/nnet/u2_nnet_main.cc
浏览文件 @
7dc9cba3
...
@@ -86,17 +86,6 @@ int main(int argc, char* argv[]) {
...
@@ -86,17 +86,6 @@ int main(int argc, char* argv[]) {
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"feat shape: "
<<
nframes
<<
", "
<<
feat_dim
;
LOG
(
INFO
)
<<
"feat shape: "
<<
nframes
<<
", "
<<
feat_dim
;
// // pad feats
// int32 padding_len = 0;
// if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
// padding_len =
// chunk_stride - (feature.NumRows() - chunk_size) %
// chunk_stride;
// feature.Resize(feature.NumRows() + padding_len,
// feature.NumCols(),
// kaldi::kCopyData);
// }
int32
frame_idx
=
0
;
int32
frame_idx
=
0
;
int
vocab_dim
=
0
;
int
vocab_dim
=
0
;
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>
prob_vec
;
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>
prob_vec
;
...
...
speechx/speechx/utils/math.cc
浏览文件 @
7dc9cba3
...
@@ -68,7 +68,7 @@ void TopK(const std::vector<T>& data,
...
@@ -68,7 +68,7 @@ void TopK(const std::vector<T>& data,
for
(
int
i
=
k
;
i
<
n
;
i
++
)
{
for
(
int
i
=
k
;
i
<
n
;
i
++
)
{
if
(
pq
.
top
().
first
<
data
[
i
])
{
if
(
pq
.
top
().
first
<
data
[
i
])
{
pq
.
pop
();
pq
.
pop
();
pq
.
emplace
_back
(
data
[
i
],
i
);
pq
.
emplace
(
data
[
i
],
i
);
}
}
}
}
...
@@ -88,4 +88,9 @@ void TopK(const std::vector<T>& data,
...
@@ -88,4 +88,9 @@ void TopK(const std::vector<T>& data,
}
}
}
}
template
void
TopK
<
float
>(
const
std
::
vector
<
float
>&
data
,
int32_t
k
,
std
::
vector
<
float
>*
values
,
std
::
vector
<
int
>*
indices
)
;
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录