Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7dc9cba3
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7dc9cba3
编写于
10月 13, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ctc prefix beam search for u2, test can run
上级
3c3aa6b5
变更
22
展开全部
隐藏空白更改
内联
并排
Showing
22 changed file
with
763 addition
and
365 deletion
+763
-365
speechx/examples/codelab/u2/.gitignore
speechx/examples/codelab/u2/.gitignore
+1
-0
speechx/examples/codelab/u2/README.md
speechx/examples/codelab/u2/README.md
+1
-0
speechx/examples/codelab/u2/local/decode.sh
speechx/examples/codelab/u2/local/decode.sh
+22
-0
speechx/examples/codelab/u2/local/feat.sh
speechx/examples/codelab/u2/local/feat.sh
+27
-0
speechx/examples/codelab/u2/local/nnet.sh
speechx/examples/codelab/u2/local/nnet.sh
+23
-0
speechx/examples/codelab/u2/path.sh
speechx/examples/codelab/u2/path.sh
+1
-2
speechx/examples/codelab/u2/run.sh
speechx/examples/codelab/u2/run.sh
+3
-24
speechx/examples/codelab/u2nnet/README.md
speechx/examples/codelab/u2nnet/README.md
+0
-3
speechx/examples/codelab/u2nnet/valgrind.sh
speechx/examples/codelab/u2nnet/valgrind.sh
+0
-21
speechx/speechx/decoder/CMakeLists.txt
speechx/speechx/decoder/CMakeLists.txt
+12
-1
speechx/speechx/decoder/ctc_beam_search_decoder.cc
speechx/speechx/decoder/ctc_beam_search_decoder.cc
+6
-4
speechx/speechx/decoder/ctc_beam_search_decoder.h
speechx/speechx/decoder/ctc_beam_search_decoder.h
+5
-8
speechx/speechx/decoder/ctc_beam_search_opt.h
speechx/speechx/decoder/ctc_beam_search_opt.h
+65
-0
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
+281
-238
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
+42
-29
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
...hx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
+188
-0
speechx/speechx/decoder/ctc_prefix_beam_search_result.h
speechx/speechx/decoder/ctc_prefix_beam_search_result.h
+41
-0
speechx/speechx/decoder/ctc_tlg_decoder.cc
speechx/speechx/decoder/ctc_tlg_decoder.cc
+11
-6
speechx/speechx/decoder/ctc_tlg_decoder.h
speechx/speechx/decoder/ctc_tlg_decoder.h
+15
-8
speechx/speechx/decoder/decoder_itf.h
speechx/speechx/decoder/decoder_itf.h
+13
-9
speechx/speechx/nnet/u2_nnet_main.cc
speechx/speechx/nnet/u2_nnet_main.cc
+0
-11
speechx/speechx/utils/math.cc
speechx/speechx/utils/math.cc
+6
-1
未找到文件。
speechx/examples/codelab/u2
nnet
/.gitignore
→
speechx/examples/codelab/u2/.gitignore
浏览文件 @
7dc9cba3
data
data
exp
*log
speechx/examples/codelab/u2/README.md
0 → 100644
浏览文件 @
7dc9cba3
# u2/u2pp Streaming Test
speechx/examples/codelab/u2/local/decode.sh
0 → 100755
浏览文件 @
7dc9cba3
#!/bin/bash
set
-x
set
-e
.
path.sh
data
=
data
exp
=
exp
mkdir
-p
$exp
ckpt_dir
=
./data/model
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
ctc_prefix_beam_search_decoder_main
\
--model_path
=
$model_dir
/export.jit
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--downsampling_rate
=
4
\
--vocab_path
=
$model_dir
/unit.txt
\
--feature_rspecifier
=
ark,t:
$exp
/fbank.ark
\
--result_wspecifier
=
ark,t:
$exp
/result.ark
echo
"u2 ctc prefix beam search decode."
speechx/examples/codelab/u2/local/feat.sh
0 → 100755
浏览文件 @
7dc9cba3
#!/bin/bash
set
-x
set
-e
.
path.sh
data
=
data
exp
=
exp
mkdir
-p
$exp
ckpt_dir
=
./data/model
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
cmvn_json2kaldi_main
\
--json_file
$model_dir
/mean_std.json
\
--cmvn_write_path
$exp
/cmvn.ark
\
--binary
=
false
echo
"convert json cmvn to kaldi ark."
compute_fbank_main
\
--num_bins
80
\
--wav_rspecifier
=
scp:
$data
/wav.scp
\
--cmvn_file
=
$exp
/cmvn.ark
\
--feature_wspecifier
=
ark,t:
$exp
/fbank.ark
echo
"compute fbank feature."
speechx/examples/codelab/u2/local/nnet.sh
0 → 100755
浏览文件 @
7dc9cba3
#!/bin/bash
set
-x
set
-e
.
path.sh
data
=
data
exp
=
exp
mkdir
-p
$exp
ckpt_dir
=
./data/model
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_nnet_main
\
--model_path
=
$model_dir
/export.jit
\
--feature_rspecifier
=
ark,t:
$exp
/fbank.ark
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--downsampling_rate
=
4
\
--acoustic_scale
=
1.0
\
--nnet_encoder_outs_wspecifier
=
ark,t:
$exp
/encoder_outs.ark
\
--nnet_prob_wspecifier
=
ark,t:
$exp
/logprobs.ark
echo
"u2 nnet decode."
speechx/examples/codelab/u2
nnet
/path.sh
→
speechx/examples/codelab/u2/path.sh
浏览文件 @
7dc9cba3
...
@@ -12,8 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
...
@@ -12,8 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export
LC_AL
=
C
export
LC_AL
=
C
SPEECHX_BIN
=
$SPEECHX_BUILD
/nnet
export
PATH
=
$PATH
:
$TOOLS_BIN
:
$SPEECHX_BUILD
/nnet:
$SPEECHX_BUILD
/decoder:
$SPEECHX_BUILD
/frontend/audio
export
PATH
=
$PATH
:
$SPEECHX_BIN
:
$TOOLS_BIN
PADDLE_LIB_PATH
=
$(
python
-c
"import paddle ; print(':'.join(paddle.sysconfig.get_lib()), end='')"
)
PADDLE_LIB_PATH
=
$(
python
-c
"import paddle ; print(':'.join(paddle.sysconfig.get_lib()), end='')"
)
export
LD_LIBRARY_PATH
=
$PADDLE_LIB_PATH
:
$LD_LIBRARY_PATH
export
LD_LIBRARY_PATH
=
$PADDLE_LIB_PATH
:
$LD_LIBRARY_PATH
speechx/examples/codelab/u2
nnet
/run.sh
→
speechx/examples/codelab/u2/run.sh
浏览文件 @
7dc9cba3
...
@@ -36,29 +36,8 @@ ckpt_dir=./data/model
...
@@ -36,29 +36,8 @@ ckpt_dir=./data/model
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
cmvn_json2kaldi_main
\
./local/feat.sh
--json_file
$model_dir
/mean_std.json
\
--cmvn_write_path
$exp
/cmvn.ark
\
--binary
=
false
echo
"convert json cmvn to kaldi ark."
./local/nnet.sh
compute_fbank_main
\
./local/decode.sh
--num_bins
80
\
--wav_rspecifier
=
scp:
$data
/wav.scp
\
--cmvn_file
=
$exp
/cmvn.ark
\
--feature_wspecifier
=
ark,t:
$exp
/fbank.ark
echo
"compute fbank feature."
u2_nnet_main
\
--model_path
=
$model_dir
/export.jit
\
--feature_rspecifier
=
ark,t:
$exp
/fbank.ark
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--downsampling_rate
=
4
\
--acoustic_scale
=
1.0
\
--nnet_encoder_outs_wspecifier
=
ark,t:
$exp
/encoder_outs.ark
\
--nnet_prob_wspecifier
=
ark,t:
$exp
/logprobs.ark
echo
"u2 nnet decode."
speechx/examples/codelab/u2nnet/README.md
已删除
100644 → 0
浏览文件 @
3c3aa6b5
# Deepspeech2 Streaming NNet Test
Using for ds2 streaming nnet inference test.
speechx/examples/codelab/u2nnet/valgrind.sh
已删除
100755 → 0
浏览文件 @
3c3aa6b5
#!/bin/bash
# this script is for memory check, so please run ./run.sh first.
set
+x
set
-e
.
./path.sh
if
[
!
-d
${
SPEECHX_TOOLS
}
/valgrind/install
]
;
then
echo
"please install valgrind in the speechx tools dir.
\n
"
exit
1
fi
ckpt_dir
=
./data/model
model_dir
=
$ckpt_dir
/exp/deepspeech2_online/checkpoints/
valgrind
--tool
=
memcheck
--track-origins
=
yes
--leak-check
=
full
--show-leak-kinds
=
all
\
ds2_model_test_main
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
--param_path
=
$model_dir
/avg_1.jit.pdparams
speechx/speechx/decoder/CMakeLists.txt
浏览文件 @
7dc9cba3
...
@@ -10,8 +10,9 @@ add_library(decoder STATIC
...
@@ -10,8 +10,9 @@ add_library(decoder STATIC
ctc_tlg_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
recognizer.cc
)
)
target_link_libraries
(
decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder
)
target_link_libraries
(
decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder
absl::strings
)
# test
set
(
BINS
set
(
BINS
ctc_beam_search_decoder_main
ctc_beam_search_decoder_main
nnet_logprob_decoder_main
nnet_logprob_decoder_main
...
@@ -24,3 +25,13 @@ foreach(bin_name IN LISTS BINS)
...
@@ -24,3 +25,13 @@ foreach(bin_name IN LISTS BINS)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
${
DEPS
}
)
target_link_libraries
(
${
bin_name
}
PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
${
DEPS
}
)
endforeach
()
endforeach
()
# u2
set
(
bin_name ctc_prefix_beam_search_decoder_main
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
)
\ No newline at end of file
speechx/speechx/decoder/ctc_beam_search_decoder.cc
浏览文件 @
7dc9cba3
...
@@ -82,8 +82,6 @@ void CTCBeamSearch::Decode(
...
@@ -82,8 +82,6 @@ void CTCBeamSearch::Decode(
return
;
return
;
}
}
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
+
1
;
}
// todo rename, refactor
// todo rename, refactor
void
CTCBeamSearch
::
AdvanceDecode
(
void
CTCBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
{
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
{
...
@@ -110,15 +108,19 @@ void CTCBeamSearch::ResetPrefixes() {
...
@@ -110,15 +108,19 @@ void CTCBeamSearch::ResetPrefixes() {
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
vector
<
string
>&
nbest_words
)
{
vector
<
string
>&
nbest_words
)
{
kaldi
::
Timer
timer
;
kaldi
::
Timer
timer
;
timer
.
Reset
();
AdvanceDecoding
(
probs
);
AdvanceDecoding
(
probs
);
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
<<
static_cast
<
float
>
(
timer
.
Elapsed
())
/
1000.0
f
;
<<
static_cast
<
float
>
(
timer
.
Elapsed
())
/
1000.0
f
;
return
0
;
return
0
;
}
}
vector
<
std
::
pair
<
double
,
string
>>
CTCBeamSearch
::
GetNBestPath
(
int
n
)
{
int
beam_size
=
n
==
-
1
?
opts_
.
beam_size
:
std
::
min
(
n
,
opts_
.
beam_size
);
return
get_beam_search_result
(
prefixes_
,
vocabulary_
,
beam_size
);
}
vector
<
std
::
pair
<
double
,
string
>>
CTCBeamSearch
::
GetNBestPath
()
{
vector
<
std
::
pair
<
double
,
string
>>
CTCBeamSearch
::
GetNBestPath
()
{
return
get_beam_search_result
(
prefixes_
,
vocabulary_
,
opts_
.
beam_size
);
return
GetNBestPath
(
-
1
);
}
}
string
CTCBeamSearch
::
GetBestPath
()
{
string
CTCBeamSearch
::
GetBestPath
()
{
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.h
浏览文件 @
7dc9cba3
...
@@ -35,6 +35,11 @@ class CTCBeamSearch : public DecoderInterface {
...
@@ -35,6 +35,11 @@ class CTCBeamSearch : public DecoderInterface {
void
AdvanceDecode
(
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
void
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
);
std
::
string
GetBestPath
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
);
std
::
string
GetFinalBestPath
();
std
::
string
GetFinalBestPath
();
std
::
string
GetPartialResult
()
{
std
::
string
GetPartialResult
()
{
...
@@ -42,14 +47,6 @@ class CTCBeamSearch : public DecoderInterface {
...
@@ -42,14 +47,6 @@ class CTCBeamSearch : public DecoderInterface {
return
{};
return
{};
}
}
void
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
);
std
::
string
GetBestPath
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
();
int
NumFrameDecoded
();
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
std
::
vector
<
std
::
string
>&
nbest_words
);
...
...
speechx/speechx/decoder/ctc_beam_search_opt.h
浏览文件 @
7dc9cba3
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
namespace
ppspeech
{
namespace
ppspeech
{
struct
CTCBeamSearchOptions
{
struct
CTCBeamSearchOptions
{
// common
// common
int
blank
;
int
blank
;
...
@@ -75,4 +76,68 @@ struct CTCBeamSearchOptions {
...
@@ -75,4 +76,68 @@ struct CTCBeamSearchOptions {
}
}
};
};
// used by u2 model
struct
CTCBeamSearchDecoderOptions
{
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 67=16*4 + 3, stride is 64=16*4
int
chunk_size
;
int
num_left_chunks
;
// final_score = rescoring_weight * rescoring_score + ctc_weight *
// ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores
// in the following two search methods are different. For
// CtcPrefixBeamSerch,
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
// max(viterbi) path score + context score So we should carefully set
// ctc_weight accroding to the search methods.
float
ctc_weight
;
float
rescoring_weight
;
float
reverse_weight
;
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions
ctc_prefix_search_opts
;
CTCBeamSearchDecoderOptions
()
:
chunk_size
(
16
),
num_left_chunks
(
-
1
),
ctc_weight
(
0.5
),
rescoring_weight
(
1.0
),
reverse_weight
(
0.0
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
std
::
string
module
=
"DecoderConfig: "
;
opts
->
Register
(
"chunk-size"
,
&
chunk_size
,
module
+
"the frame number of one chunk after subsampling."
);
opts
->
Register
(
"num-left-chunks"
,
&
num_left_chunks
,
module
+
"the left history chunks number."
);
opts
->
Register
(
"ctc-weight"
,
&
ctc_weight
,
module
+
"ctc weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score."
);
opts
->
Register
(
"rescoring-weight"
,
&
rescoring_weight
,
module
+
"attention score weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score."
);
opts
->
Register
(
"reverse-weight"
,
&
reverse_weight
,
module
+
"reverse decoder weight. rescoring_score = "
"left_to_right_score * (1 - reverse_weight) + "
"right_to_left_score * reverse_weight."
);
}
};
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
浏览文件 @
7dc9cba3
此差异已折叠。
点击以展开。
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
浏览文件 @
7dc9cba3
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_result.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h"
#include "decoder/decoder_itf.h"
...
@@ -25,48 +26,37 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -25,48 +26,37 @@ class CTCPrefixBeamSearch : public DecoderInterface {
explicit
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
);
explicit
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
);
~
CTCPrefixBeamSearch
()
{}
~
CTCPrefixBeamSearch
()
{}
void
InitDecoder
();
void
InitDecoder
()
override
;
void
Reset
();
void
Reset
()
override
;
void
AdvanceDecode
(
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
override
;
std
::
string
GetFinalBestPath
();
std
::
string
GetFinalBestPath
()
override
;
std
::
string
GetPartialResult
()
override
;
std
::
string
GetPartialResult
()
{
void
FinalizeSearch
();
CHECK
(
false
)
<<
"Not implement."
;
return
{};
}
void
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
);
std
::
string
GetBestPath
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
();
int
NumFrameDecoded
();
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
const
std
::
vector
<
float
>&
ViterbiLikelihood
()
const
{
protected:
return
viterbi_likelihood_
;
std
::
string
GetBestPath
()
override
;
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
override
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
)
override
;
const
std
::
vector
<
std
::
vector
<
int
>>&
Inputs
()
const
{
return
hypotheses_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Inputs
()
const
{
return
hypotheses_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Outputs
()
const
{
return
outputs_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Outputs
()
const
{
return
outputs_
;
}
const
std
::
vector
<
float
>&
Likelihood
()
const
{
return
likelihood_
;
}
const
std
::
vector
<
float
>&
Likelihood
()
const
{
return
likelihood_
;
}
const
std
::
vector
<
float
>&
ViterbiLikelihood
()
const
{
return
viterbi_likelihood_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
{
return
times_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
{
return
times_
;
}
private:
private:
void
AdvanceDecoding
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
logp
);
std
::
string
GetBestPath
(
int
index
);
void
FinalizeSearch
();
void
AdvanceDecoding
(
const
std
::
vector
<
std
::
vector
<
kaldi
::
BaseFloat
>>&
logp
);
void
UpdateOutputs
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
prefix
);
void
UpdateOutputs
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
prefix
);
void
UpdateHypotheses
(
void
UpdateHypotheses
(
...
@@ -77,8 +67,6 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -77,8 +67,6 @@ class CTCPrefixBeamSearch : public DecoderInterface {
private:
private:
CTCBeamSearchOptions
opts_
;
CTCBeamSearchOptions
opts_
;
int
abs_time_step_
=
0
;
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixScoreHash
>
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixScoreHash
>
cur_hyps_
;
cur_hyps_
;
...
@@ -97,4 +85,29 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -97,4 +85,29 @@ class CTCPrefixBeamSearch : public DecoderInterface {
DISALLOW_COPY_AND_ASSIGN
(
CTCPrefixBeamSearch
);
DISALLOW_COPY_AND_ASSIGN
(
CTCPrefixBeamSearch
);
};
};
class
CTCPrefixBeamSearchDecoder
:
public
CTCPrefixBeamSearch
{
public:
explicit
CTCPrefixBeamSearchDecoder
(
const
CTCBeamSearchDecoderOptions
&
opts
)
:
CTCPrefixBeamSearch
(
opts
.
ctc_prefix_search_opts
),
opts_
(
opts
)
{}
~
CTCPrefixBeamSearchDecoder
()
{}
private:
CTCBeamSearchDecoderOptions
opts_
;
// cache feature
bool
start_
=
false
;
// false, this is first frame.
// for continues decoding
int
num_frames_
=
0
;
int
global_frame_offset_
=
0
;
const
int
time_stamp_gap_
=
100
;
// timestamp gap between words in a sentence
// std::unique_ptr<CtcEndpoint> ctc_endpointer_;
int
num_frames_in_current_chunk_
=
0
;
std
::
vector
<
DecodeResult
>
result_
;
};
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
0 → 100644
浏览文件 @
7dc9cba3
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/u2_nnet.h"
#include "absl/strings/str_split.h"
#include "fst/symbol-table.h"
DEFINE_string
(
feature_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_string
(
vocab_path
,
""
,
"vocab path"
);
DEFINE_string
(
model_path
,
""
,
"paddle nnet model"
);
DEFINE_int32
(
receptive_field_length
,
7
,
"receptive field of two CNN(kernel=3) downsampling module."
);
DEFINE_int32
(
downsampling_rate
,
4
,
"two CNN(kernel=3) module downsampling rate."
);
DEFINE_int32
(
nnet_decoder_chunk
,
16
,
"paddle nnet forward chunk"
);
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
// test ds2 online decoder by feeding speech feature
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int32
num_done
=
0
,
num_err
=
0
;
CHECK
(
FLAGS_result_wspecifier
!=
""
);
CHECK
(
FLAGS_feature_rspecifier
!=
""
);
CHECK
(
FLAGS_vocab_path
!=
""
);
CHECK
(
FLAGS_model_path
!=
""
);
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_rspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
LOG
(
INFO
)
<<
"Reading vocab table "
<<
FLAGS_vocab_path
;
fst
::
SymbolTable
*
unit_table
=
fst
::
SymbolTable
::
ReadText
(
FLAGS_vocab_path
);
// nnet
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
FLAGS_model_path
;
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet
(
new
ppspeech
::
U2Nnet
(
model_opts
));
// decodeable
std
::
shared_ptr
<
ppspeech
::
DataCache
>
raw_data
(
new
ppspeech
::
DataCache
());
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
,
raw_data
));
// decoder
ppspeech
::
CTCBeamSearchDecoderOptions
opts
;
opts
.
chunk_size
=
16
;
opts
.
num_left_chunks
=
-
1
;
opts
.
ctc_weight
=
0.5
;
opts
.
rescoring_weight
=
1.0
;
opts
.
reverse_weight
=
0.3
;
opts
.
ctc_prefix_search_opts
.
blank
=
0
;
opts
.
ctc_prefix_search_opts
.
first_beam_size
=
10
;
opts
.
ctc_prefix_search_opts
.
second_beam_size
=
10
;
ppspeech
::
CTCPrefixBeamSearchDecoder
decoder
(
opts
);
int32
chunk_size
=
FLAGS_receptive_field_length
+
(
FLAGS_nnet_decoder_chunk
-
1
)
*
FLAGS_downsampling_rate
;
int32
chunk_stride
=
FLAGS_downsampling_rate
*
FLAGS_nnet_decoder_chunk
;
int32
receptive_field_length
=
FLAGS_receptive_field_length
;
LOG
(
INFO
)
<<
"chunk size (frame): "
<<
chunk_size
;
LOG
(
INFO
)
<<
"chunk stride (frame): "
<<
chunk_stride
;
LOG
(
INFO
)
<<
"receptive field (frame): "
<<
receptive_field_length
;
decoder
.
InitDecoder
();
kaldi
::
Timer
timer
;
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
string
utt
=
feature_reader
.
Key
();
kaldi
::
Matrix
<
BaseFloat
>
feature
=
feature_reader
.
Value
();
int
nframes
=
feature
.
NumRows
();
int
feat_dim
=
feature
.
NumCols
();
raw_data
->
SetDim
(
feat_dim
);
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"feat shape: "
<<
nframes
<<
", "
<<
feat_dim
;
raw_data
->
SetDim
(
feat_dim
);
int32
ori_feature_len
=
feature
.
NumRows
();
int32
num_chunks
=
feature
.
NumRows
()
/
chunk_stride
+
1
;
LOG
(
INFO
)
<<
"num_chunks: "
<<
num_chunks
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_chunks
;
++
chunk_idx
)
{
int32
this_chunk_size
=
0
;
if
(
ori_feature_len
>
chunk_idx
*
chunk_stride
)
{
this_chunk_size
=
std
::
min
(
ori_feature_len
-
chunk_idx
*
chunk_stride
,
chunk_size
);
}
if
(
this_chunk_size
<
receptive_field_length
)
{
LOG
(
WARNING
)
<<
"utt: "
<<
utt
<<
" skip last "
<<
this_chunk_size
<<
" frames, expect is "
<<
receptive_field_length
;
break
;
}
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
feature_chunk
(
this_chunk_size
*
feat_dim
);
int32
start
=
chunk_idx
*
chunk_stride
;
for
(
int
row_id
=
0
;
row_id
<
this_chunk_size
;
++
row_id
)
{
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
feat_row
(
feature
,
start
);
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
feature_chunk_row
(
feature_chunk
.
Data
()
+
row_id
*
feat_dim
,
feat_dim
);
feature_chunk_row
.
CopyFromVec
(
feat_row
);
++
start
;
}
// feat to frontend pipeline cache
raw_data
->
Accept
(
feature_chunk
);
// send data finish signal
if
(
chunk_idx
==
num_chunks
-
1
)
{
raw_data
->
SetFinished
();
}
// forward nnet
decoder
.
AdvanceDecode
(
decodable
);
}
decoder
.
FinalizeSearch
();
// get 1-best result
std
::
string
result_ints
=
decoder
.
GetFinalBestPath
();
std
::
vector
<
std
::
string
>
tokenids
=
absl
::
StrSplit
(
result_ints
,
ppspeech
::
kSpaceSymbol
);
std
::
string
result
;
for
(
int
i
=
0
;
i
<
tokenids
.
size
();
i
++
){
result
+=
unit_table
->
Find
(
std
::
stoi
(
tokenids
[
i
]));
}
// after process one utt, then reset state.
decodable
->
Reset
();
decoder
.
Reset
();
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
++
num_err
;
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is empty"
;
continue
;
}
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is "
<<
result
;
result_writer
.
Write
(
utt
,
result
);
++
num_done
;
}
double
elapsed
=
timer
.
Elapsed
();
LOG
(
INFO
)
<<
"Program cost:"
<<
elapsed
<<
" sec"
;
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" utterances, "
<<
num_err
<<
" with errors."
;
return
(
num_done
!=
0
?
0
:
1
);
}
speechx/speechx/decoder/ctc_prefix_beam_search_result.h
0 → 100644
浏览文件 @
7dc9cba3
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
namespace
ppspeech
{
struct
WordPiece
{
std
::
string
word
;
int
start
=
-
1
;
int
end
=
-
1
;
WordPiece
(
std
::
string
word
,
int
start
,
int
end
)
:
word
(
std
::
move
(
word
)),
start
(
start
),
end
(
end
)
{}
};
struct
DecodeResult
{
float
score
=
-
kBaseFloatMax
;
std
::
string
sentence
;
std
::
vector
<
WordPiece
>
word_pieces
;
static
bool
CompareFunc
(
const
DecodeResult
&
a
,
const
DecodeResult
&
b
)
{
return
a
.
score
>
b
.
score
;
}
};
}
// namespace ppspeech
speechx/speechx/decoder/ctc_tlg_decoder.cc
浏览文件 @
7dc9cba3
...
@@ -18,16 +18,23 @@ namespace ppspeech {
...
@@ -18,16 +18,23 @@ namespace ppspeech {
TLGDecoder
::
TLGDecoder
(
TLGDecoderOptions
opts
)
{
TLGDecoder
::
TLGDecoder
(
TLGDecoderOptions
opts
)
{
fst_
.
reset
(
fst
::
Fst
<
fst
::
StdArc
>::
Read
(
opts
.
fst_path
));
fst_
.
reset
(
fst
::
Fst
<
fst
::
StdArc
>::
Read
(
opts
.
fst_path
));
CHECK
(
fst_
!=
nullptr
);
CHECK
(
fst_
!=
nullptr
);
word_symbol_table_
.
reset
(
word_symbol_table_
.
reset
(
fst
::
SymbolTable
::
ReadText
(
opts
.
word_symbol_table
));
fst
::
SymbolTable
::
ReadText
(
opts
.
word_symbol_table
));
decoder_
.
reset
(
new
kaldi
::
LatticeFasterOnlineDecoder
(
*
fst_
,
opts
.
opts
));
decoder_
.
reset
(
new
kaldi
::
LatticeFasterOnlineDecoder
(
*
fst_
,
opts
.
opts
));
Reset
();
}
void
TLGDecoder
::
Reset
()
{
decoder_
->
InitDecoding
();
decoder_
->
InitDecoding
();
num_frame_decoded_
=
0
;
num_frame_decoded_
=
0
;
return
;
}
}
void
TLGDecoder
::
InitDecoder
()
{
void
TLGDecoder
::
InitDecoder
()
{
decoder_
->
InitDecoding
();
Reset
();
num_frame_decoded_
=
0
;
}
}
void
TLGDecoder
::
AdvanceDecode
(
void
TLGDecoder
::
AdvanceDecode
(
...
@@ -42,10 +49,7 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
...
@@ -42,10 +49,7 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
num_frame_decoded_
++
;
num_frame_decoded_
++
;
}
}
void
TLGDecoder
::
Reset
()
{
InitDecoder
();
return
;
}
std
::
string
TLGDecoder
::
GetPartialResult
()
{
std
::
string
TLGDecoder
::
GetPartialResult
()
{
if
(
num_frame_decoded_
==
0
)
{
if
(
num_frame_decoded_
==
0
)
{
...
@@ -88,4 +92,5 @@ std::string TLGDecoder::GetFinalBestPath() {
...
@@ -88,4 +92,5 @@ std::string TLGDecoder::GetFinalBestPath() {
}
}
return
words
;
return
words
;
}
}
}
}
speechx/speechx/decoder/ctc_tlg_decoder.h
浏览文件 @
7dc9cba3
...
@@ -42,20 +42,27 @@ class TLGDecoder : public DecoderInterface {
...
@@ -42,20 +42,27 @@ class TLGDecoder : public DecoderInterface {
void
AdvanceDecode
(
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
std
::
string
GetFinalBestPath
();
std
::
string
GetPartialResult
();
void
Decode
();
void
Decode
();
std
::
string
Get
BestPath
()
;
std
::
string
Get
FinalBestPath
()
override
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
;
std
::
string
GetPartialResult
()
override
;
int
NumFrameDecoded
();
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
std
::
vector
<
std
::
string
>&
nbest_words
);
protected:
std
::
string
GetBestPath
()
override
{
CHECK
(
false
);
return
{};
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
override
{
CHECK
(
false
);
return
{};
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
)
override
{
CHECK
(
false
);
return
{};
}
private:
private:
void
AdvanceDecoding
(
kaldi
::
DecodableInterface
*
decodable
);
void
AdvanceDecoding
(
kaldi
::
DecodableInterface
*
decodable
);
...
...
speechx/speechx/decoder/decoder_itf.h
浏览文件 @
7dc9cba3
...
@@ -28,27 +28,31 @@ class DecoderInterface {
...
@@ -28,27 +28,31 @@ class DecoderInterface {
virtual
void
Reset
()
=
0
;
virtual
void
Reset
()
=
0
;
// call AdvanceDecoding
virtual
void
AdvanceDecode
(
virtual
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
=
0
;
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
=
0
;
// call GetBestPath
virtual
std
::
string
GetFinalBestPath
()
=
0
;
virtual
std
::
string
GetFinalBestPath
()
=
0
;
virtual
std
::
string
GetPartialResult
()
=
0
;
virtual
std
::
string
GetPartialResult
()
=
0
;
// void Decode();
protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
// std::string GetBestPath();
// virtual void Decode() = 0;
// std::vector<std::pair<double, std::string>> GetNBestPath();
// int NumFrameDecoded();
virtual
std
::
string
GetBestPath
()
=
0
;
// int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
// std::vector<std::string>& nbest_words);
virtual
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
=
0
;
protected:
virtual
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
)
=
0
;
// void AdvanceDecoding(kaldi::DecodableInterface* decodable);
// current decoding frame number
// start from one
int
NumFrameDecoded
()
{
return
num_frame_decoded_
+
1
;
}
protected:
// current decoding frame number, abs_time_step_
int32
num_frame_decoded_
;
int32
num_frame_decoded_
;
};
};
...
...
speechx/speechx/nnet/u2_nnet_main.cc
浏览文件 @
7dc9cba3
...
@@ -86,17 +86,6 @@ int main(int argc, char* argv[]) {
...
@@ -86,17 +86,6 @@ int main(int argc, char* argv[]) {
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"feat shape: "
<<
nframes
<<
", "
<<
feat_dim
;
LOG
(
INFO
)
<<
"feat shape: "
<<
nframes
<<
", "
<<
feat_dim
;
// // pad feats
// int32 padding_len = 0;
// if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
// padding_len =
// chunk_stride - (feature.NumRows() - chunk_size) %
// chunk_stride;
// feature.Resize(feature.NumRows() + padding_len,
// feature.NumCols(),
// kaldi::kCopyData);
// }
int32
frame_idx
=
0
;
int32
frame_idx
=
0
;
int
vocab_dim
=
0
;
int
vocab_dim
=
0
;
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>
prob_vec
;
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>
prob_vec
;
...
...
speechx/speechx/utils/math.cc
浏览文件 @
7dc9cba3
...
@@ -68,7 +68,7 @@ void TopK(const std::vector<T>& data,
...
@@ -68,7 +68,7 @@ void TopK(const std::vector<T>& data,
for
(
int
i
=
k
;
i
<
n
;
i
++
)
{
for
(
int
i
=
k
;
i
<
n
;
i
++
)
{
if
(
pq
.
top
().
first
<
data
[
i
])
{
if
(
pq
.
top
().
first
<
data
[
i
])
{
pq
.
pop
();
pq
.
pop
();
pq
.
emplace
_back
(
data
[
i
],
i
);
pq
.
emplace
(
data
[
i
],
i
);
}
}
}
}
...
@@ -88,4 +88,9 @@ void TopK(const std::vector<T>& data,
...
@@ -88,4 +88,9 @@ void TopK(const std::vector<T>& data,
}
}
}
}
template
void
TopK
<
float
>(
const
std
::
vector
<
float
>&
data
,
int32_t
k
,
std
::
vector
<
float
>*
values
,
std
::
vector
<
int
>*
indices
)
;
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录