Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
21183d48
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
21183d48
编写于
2月 07, 2023
作者:
Y
YangZhou
提交者:
GitHub
2月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add wfst decoder (#2886)
上级
5042a168
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
154 addition
and
115 deletion
+154
-115
speechx/CMakeLists.txt
speechx/CMakeLists.txt
+1
-1
speechx/speechx/asr/decoder/CMakeLists.txt
speechx/speechx/asr/decoder/CMakeLists.txt
+2
-0
speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h
speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h
+1
-2
speechx/speechx/asr/decoder/ctc_tlg_decoder.cc
speechx/speechx/asr/decoder/ctc_tlg_decoder.cc
+41
-3
speechx/speechx/asr/decoder/ctc_tlg_decoder.h
speechx/speechx/asr/decoder/ctc_tlg_decoder.h
+32
-5
speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc
speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc
+15
-62
speechx/speechx/asr/decoder/decoder_itf.h
speechx/speechx/asr/decoder/decoder_itf.h
+9
-1
speechx/speechx/asr/decoder/param.h
speechx/speechx/asr/decoder/param.h
+2
-2
speechx/speechx/asr/nnet/decodable.h
speechx/speechx/asr/nnet/decodable.h
+0
-2
speechx/speechx/asr/nnet/nnet_producer.cc
speechx/speechx/asr/nnet/nnet_producer.cc
+8
-8
speechx/speechx/asr/nnet/nnet_producer.h
speechx/speechx/asr/nnet/nnet_producer.h
+5
-5
speechx/speechx/asr/recognizer/u2_recognizer.cc
speechx/speechx/asr/recognizer/u2_recognizer.cc
+25
-21
speechx/speechx/asr/recognizer/u2_recognizer.h
speechx/speechx/asr/recognizer/u2_recognizer.h
+13
-3
未找到文件。
speechx/CMakeLists.txt
浏览文件 @
21183d48
...
@@ -33,7 +33,7 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch})
...
@@ -33,7 +33,7 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch})
# compiler option
# compiler option
# Keep the same with openfst, -fPIC or -fpic
# Keep the same with openfst, -fPIC or -fpic
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
--std=c++14 -pthread -fPIC -O0 -Wall -g"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
--std=c++14 -pthread -fPIC -O0 -Wall -g
-ldl
"
)
SET
(
CMAKE_CXX_FLAGS_DEBUG
"$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb"
)
SET
(
CMAKE_CXX_FLAGS_DEBUG
"$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb"
)
SET
(
CMAKE_CXX_FLAGS_RELEASE
"$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall"
)
SET
(
CMAKE_CXX_FLAGS_RELEASE
"$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall"
)
...
...
speechx/speechx/asr/decoder/CMakeLists.txt
浏览文件 @
21183d48
set
(
srcs
)
set
(
srcs
)
list
(
APPEND srcs
list
(
APPEND srcs
ctc_prefix_beam_search_decoder.cc
ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
)
)
add_library
(
decoder STATIC
${
srcs
}
)
add_library
(
decoder STATIC
${
srcs
}
)
...
@@ -9,6 +10,7 @@ target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
...
@@ -9,6 +10,7 @@ target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
# test
# test
set
(
TEST_BINS
set
(
TEST_BINS
ctc_prefix_beam_search_decoder_main
ctc_prefix_beam_search_decoder_main
ctc_tlg_decoder_main
)
)
foreach
(
bin_name IN LISTS TEST_BINS
)
foreach
(
bin_name IN LISTS TEST_BINS
)
...
...
speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h
浏览文件 @
21183d48
...
@@ -45,7 +45,7 @@ class CTCPrefixBeamSearch : public DecoderBase {
...
@@ -45,7 +45,7 @@ class CTCPrefixBeamSearch : public DecoderBase {
void
FinalizeSearch
();
void
FinalizeSearch
();
const
std
::
shared_ptr
<
fst
::
SymbolTable
>
VocabTable
()
const
{
const
std
::
shared_ptr
<
fst
::
SymbolTable
>
WordSymbolTable
()
const
override
{
return
unit_table_
;
return
unit_table_
;
}
}
...
@@ -57,7 +57,6 @@ class CTCPrefixBeamSearch : public DecoderBase {
...
@@ -57,7 +57,6 @@ class CTCPrefixBeamSearch : public DecoderBase {
}
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
{
return
times_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
{
return
times_
;
}
protected:
protected:
std
::
string
GetBestPath
()
override
;
std
::
string
GetBestPath
()
override
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
override
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
override
;
...
...
speechx/speechx/asr/decoder/ctc_tlg_decoder.cc
浏览文件 @
21183d48
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
namespace
ppspeech
{
namespace
ppspeech
{
TLGDecoder
::
TLGDecoder
(
TLGDecoderOptions
opts
)
{
TLGDecoder
::
TLGDecoder
(
TLGDecoderOptions
opts
)
:
opts_
(
opts
)
{
fst_
.
reset
(
fst
::
Fst
<
fst
::
StdArc
>::
Read
(
opts
.
fst_path
));
fst_
.
reset
(
fst
::
Fst
<
fst
::
StdArc
>::
Read
(
opts
.
fst_path
));
CHECK
(
fst_
!=
nullptr
);
CHECK
(
fst_
!=
nullptr
);
...
@@ -68,14 +68,52 @@ std::string TLGDecoder::GetPartialResult() {
...
@@ -68,14 +68,52 @@ std::string TLGDecoder::GetPartialResult() {
return
words
;
return
words
;
}
}
void
TLGDecoder
::
FinalizeSearch
()
{
decoder_
->
FinalizeDecoding
();
kaldi
::
CompactLattice
clat
;
decoder_
->
GetLattice
(
&
clat
,
true
);
kaldi
::
Lattice
lat
,
nbest_lat
;
fst
::
ConvertLattice
(
clat
,
&
lat
);
fst
::
ShortestPath
(
lat
,
&
nbest_lat
,
opts_
.
nbest
);
std
::
vector
<
kaldi
::
Lattice
>
nbest_lats
;
fst
::
ConvertNbestToVector
(
nbest_lat
,
&
nbest_lats
);
hypotheses_
.
clear
();
hypotheses_
.
reserve
(
nbest_lats
.
size
());
likelihood_
.
clear
();
likelihood_
.
reserve
(
nbest_lats
.
size
());
times_
.
clear
();
times_
.
reserve
(
nbest_lats
.
size
());
for
(
auto
lat
:
nbest_lats
)
{
kaldi
::
LatticeWeight
weight
;
std
::
vector
<
int
>
hypothese
;
std
::
vector
<
int
>
time
;
std
::
vector
<
int
>
alignment
;
std
::
vector
<
int
>
words_id
;
fst
::
GetLinearSymbolSequence
(
lat
,
&
alignment
,
&
words_id
,
&
weight
);
int
idx
=
0
;
for
(;
idx
<
alignment
.
size
()
-
1
;
++
idx
)
{
if
(
alignment
[
idx
]
==
0
)
continue
;
if
(
alignment
[
idx
]
!=
alignment
[
idx
+
1
])
{
hypothese
.
push_back
(
alignment
[
idx
]
-
1
);
time
.
push_back
(
idx
);
// fake time, todo later
}
}
hypothese
.
push_back
(
alignment
[
idx
]
-
1
);
time
.
push_back
(
idx
);
// fake time, todo later
hypotheses_
.
push_back
(
hypothese
);
times_
.
push_back
(
time
);
olabels
.
push_back
(
words_id
);
likelihood_
.
push_back
(
-
(
weight
.
Value2
()
+
weight
.
Value1
()));
}
}
std
::
string
TLGDecoder
::
GetFinalBestPath
()
{
std
::
string
TLGDecoder
::
GetFinalBestPath
()
{
if
(
num_frame_decoded_
==
0
)
{
if
(
num_frame_decoded_
==
0
)
{
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
// BestPathEnd if no frames were decoded.")
return
std
::
string
(
""
);
return
std
::
string
(
""
);
}
}
decoder_
->
FinalizeDecoding
();
kaldi
::
Lattice
lat
;
kaldi
::
Lattice
lat
;
kaldi
::
LatticeWeight
weight
;
kaldi
::
LatticeWeight
weight
;
std
::
vector
<
int
>
alignment
;
std
::
vector
<
int
>
alignment
;
...
...
speechx/speechx/asr/decoder/ctc_tlg_decoder.h
浏览文件 @
21183d48
...
@@ -19,9 +19,8 @@
...
@@ -19,9 +19,8 @@
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
#include "util/parse-options.h"
DECLARE_string
(
graph_path
);
DECLARE_string
(
word_symbol_table
);
DECLARE_string
(
word_symbol_table
);
DECLARE_string
(
graph_path
);
DECLARE_int32
(
max_active
);
DECLARE_int32
(
max_active
);
DECLARE_double
(
beam
);
DECLARE_double
(
beam
);
DECLARE_double
(
lattice_beam
);
DECLARE_double
(
lattice_beam
);
...
@@ -33,6 +32,9 @@ struct TLGDecoderOptions {
...
@@ -33,6 +32,9 @@ struct TLGDecoderOptions {
// todo remove later, add into decode resource
// todo remove later, add into decode resource
std
::
string
word_symbol_table
;
std
::
string
word_symbol_table
;
std
::
string
fst_path
;
std
::
string
fst_path
;
int
nbest
;
TLGDecoderOptions
()
:
word_symbol_table
(
""
),
fst_path
(
""
),
nbest
(
10
)
{}
static
TLGDecoderOptions
InitFromFlags
()
{
static
TLGDecoderOptions
InitFromFlags
()
{
TLGDecoderOptions
decoder_opts
;
TLGDecoderOptions
decoder_opts
;
...
@@ -44,6 +46,7 @@ struct TLGDecoderOptions {
...
@@ -44,6 +46,7 @@ struct TLGDecoderOptions {
decoder_opts
.
opts
.
max_active
=
FLAGS_max_active
;
decoder_opts
.
opts
.
max_active
=
FLAGS_max_active
;
decoder_opts
.
opts
.
beam
=
FLAGS_beam
;
decoder_opts
.
opts
.
beam
=
FLAGS_beam
;
decoder_opts
.
opts
.
lattice_beam
=
FLAGS_lattice_beam
;
decoder_opts
.
opts
.
lattice_beam
=
FLAGS_lattice_beam
;
// decoder_opts.nbest = FLAGS_lattice_nbest;
LOG
(
INFO
)
<<
"LatticeFasterDecoder max active: "
LOG
(
INFO
)
<<
"LatticeFasterDecoder max active: "
<<
decoder_opts
.
opts
.
max_active
;
<<
decoder_opts
.
opts
.
max_active
;
LOG
(
INFO
)
<<
"LatticeFasterDecoder beam: "
<<
decoder_opts
.
opts
.
beam
;
LOG
(
INFO
)
<<
"LatticeFasterDecoder beam: "
<<
decoder_opts
.
opts
.
beam
;
...
@@ -59,20 +62,38 @@ class TLGDecoder : public DecoderBase {
...
@@ -59,20 +62,38 @@ class TLGDecoder : public DecoderBase {
explicit
TLGDecoder
(
TLGDecoderOptions
opts
);
explicit
TLGDecoder
(
TLGDecoderOptions
opts
);
~
TLGDecoder
()
=
default
;
~
TLGDecoder
()
=
default
;
void
InitDecoder
();
void
InitDecoder
()
override
;
void
Reset
();
void
Reset
()
override
;
void
AdvanceDecode
(
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
override
;
void
Decode
();
void
Decode
();
std
::
string
GetFinalBestPath
()
override
;
std
::
string
GetFinalBestPath
()
override
;
std
::
string
GetPartialResult
()
override
;
std
::
string
GetPartialResult
()
override
;
const
std
::
shared_ptr
<
fst
::
SymbolTable
>
WordSymbolTable
()
const
override
{
return
word_symbol_table_
;
}
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
const
std
::
vector
<
std
::
string
>&
nbest_words
);
const
std
::
vector
<
std
::
string
>&
nbest_words
);
void
FinalizeSearch
()
override
;
const
std
::
vector
<
std
::
vector
<
int
>>&
Inputs
()
const
override
{
return
hypotheses_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Outputs
()
const
override
{
return
olabels
;
}
// outputs_; }
const
std
::
vector
<
float
>&
Likelihood
()
const
override
{
return
likelihood_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
override
{
return
times_
;
}
protected:
protected:
std
::
string
GetBestPath
()
override
{
std
::
string
GetBestPath
()
override
{
CHECK
(
false
);
CHECK
(
false
);
...
@@ -90,9 +111,15 @@ class TLGDecoder : public DecoderBase {
...
@@ -90,9 +111,15 @@ class TLGDecoder : public DecoderBase {
private:
private:
void
AdvanceDecoding
(
kaldi
::
DecodableInterface
*
decodable
);
void
AdvanceDecoding
(
kaldi
::
DecodableInterface
*
decodable
);
std
::
vector
<
std
::
vector
<
int
>>
hypotheses_
;
std
::
vector
<
std
::
vector
<
int
>>
olabels
;
std
::
vector
<
float
>
likelihood_
;
std
::
vector
<
std
::
vector
<
int
>>
times_
;
std
::
shared_ptr
<
kaldi
::
LatticeFasterOnlineDecoder
>
decoder_
;
std
::
shared_ptr
<
kaldi
::
LatticeFasterOnlineDecoder
>
decoder_
;
std
::
shared_ptr
<
fst
::
Fst
<
fst
::
StdArc
>>
fst_
;
std
::
shared_ptr
<
fst
::
Fst
<
fst
::
StdArc
>>
fst_
;
std
::
shared_ptr
<
fst
::
SymbolTable
>
word_symbol_table_
;
std
::
shared_ptr
<
fst
::
SymbolTable
>
word_symbol_table_
;
TLGDecoderOptions
opts_
;
};
};
...
...
speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc
浏览文件 @
21183d48
...
@@ -14,16 +14,16 @@
...
@@ -14,16 +14,16 @@
// todo refactor, repalce with gtest
// todo refactor, repalce with gtest
#include "base/common.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "base/common.h"
#include "decoder/param.h"
#include "decoder/param.h"
#include "frontend/
audio/
data_cache.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/decodable.h"
#include "nnet/
ds2_nnet
.h"
#include "nnet/
nnet_producer
.h"
DEFINE_string
(
feature
_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
nnet_prob
_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
...
@@ -39,8 +39,8 @@ int main(int argc, char* argv[]) {
...
@@ -39,8 +39,8 @@ int main(int argc, char* argv[]) {
google
::
InstallFailureSignalHandler
();
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
FLAGS_logtostderr
=
1
;
kaldi
::
SequentialBaseFloatMatrixReader
feature
_reader
(
kaldi
::
SequentialBaseFloatMatrixReader
nnet_prob
_reader
(
FLAGS_
feature
_rspecifier
);
FLAGS_
nnet_prob
_rspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
int32
num_done
=
0
,
num_err
=
0
;
int32
num_done
=
0
,
num_err
=
0
;
...
@@ -53,66 +53,19 @@ int main(int argc, char* argv[]) {
...
@@ -53,66 +53,19 @@ int main(int argc, char* argv[]) {
ppspeech
::
ModelOptions
model_opts
=
ppspeech
::
ModelOptions
::
InitFromFlags
();
ppspeech
::
ModelOptions
model_opts
=
ppspeech
::
ModelOptions
::
InitFromFlags
();
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
std
::
shared_ptr
<
ppspeech
::
NnetProducer
>
nnet_producer
=
new
ppspeech
::
PaddleNnet
(
model_opts
));
std
::
make_shared
<
ppspeech
::
NnetProducer
>
(
nullptr
);
std
::
shared_ptr
<
ppspeech
::
DataCache
>
raw_data
(
new
ppspeech
::
DataCache
());
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
,
raw_data
,
FLAGS_acoustic_scale
));
new
ppspeech
::
Decodable
(
nnet_producer
,
FLAGS_acoustic_scale
));
int32
chunk_size
=
FLAGS_receptive_field_length
+
(
FLAGS_nnet_decoder_chunk
-
1
)
*
FLAGS_subsampling_rate
;
int32
chunk_stride
=
FLAGS_subsampling_rate
*
FLAGS_nnet_decoder_chunk
;
int32
receptive_field_length
=
FLAGS_receptive_field_length
;
LOG
(
INFO
)
<<
"chunk size (frame): "
<<
chunk_size
;
LOG
(
INFO
)
<<
"chunk stride (frame): "
<<
chunk_stride
;
LOG
(
INFO
)
<<
"receptive field (frame): "
<<
receptive_field_length
;
decoder
.
InitDecoder
();
decoder
.
InitDecoder
();
kaldi
::
Timer
timer
;
kaldi
::
Timer
timer
;
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
string
utt
=
feature_reader
.
Key
();
for
(;
!
nnet_prob_reader
.
Done
();
nnet_prob_reader
.
Next
())
{
kaldi
::
Matrix
<
BaseFloat
>
feature
=
feature_reader
.
Value
();
string
utt
=
nnet_prob_reader
.
Key
();
raw_data
->
SetDim
(
feature
.
NumCols
());
kaldi
::
Matrix
<
BaseFloat
>
prob
=
nnet_prob_reader
.
Value
();
LOG
(
INFO
)
<<
"process utt: "
<<
utt
;
decodable
->
Acceptlikelihood
(
prob
);
LOG
(
INFO
)
<<
"rows: "
<<
feature
.
NumRows
();
LOG
(
INFO
)
<<
"cols: "
<<
feature
.
NumCols
();
int32
row_idx
=
0
;
int32
padding_len
=
0
;
int32
ori_feature_len
=
feature
.
NumRows
();
if
((
feature
.
NumRows
()
-
chunk_size
)
%
chunk_stride
!=
0
)
{
padding_len
=
chunk_stride
-
(
feature
.
NumRows
()
-
chunk_size
)
%
chunk_stride
;
feature
.
Resize
(
feature
.
NumRows
()
+
padding_len
,
feature
.
NumCols
(),
kaldi
::
kCopyData
);
}
int32
num_chunks
=
(
feature
.
NumRows
()
-
chunk_size
)
/
chunk_stride
+
1
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_chunks
;
++
chunk_idx
)
{
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
feature_chunk
(
chunk_size
*
feature
.
NumCols
());
int32
feature_chunk_size
=
0
;
if
(
ori_feature_len
>
chunk_idx
*
chunk_stride
)
{
feature_chunk_size
=
std
::
min
(
ori_feature_len
-
chunk_idx
*
chunk_stride
,
chunk_size
);
}
if
(
feature_chunk_size
<
receptive_field_length
)
break
;
int32
start
=
chunk_idx
*
chunk_stride
;
for
(
int
row_id
=
0
;
row_id
<
chunk_size
;
++
row_id
)
{
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
tmp
(
feature
,
start
);
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
f_chunk_tmp
(
feature_chunk
.
Data
()
+
row_id
*
feature
.
NumCols
(),
feature
.
NumCols
());
f_chunk_tmp
.
CopyFromVec
(
tmp
);
++
start
;
}
raw_data
->
Accept
(
feature_chunk
);
if
(
chunk_idx
==
num_chunks
-
1
)
{
raw_data
->
SetFinished
();
}
decoder
.
AdvanceDecode
(
decodable
);
decoder
.
AdvanceDecode
(
decodable
);
}
std
::
string
result
;
std
::
string
result
;
result
=
decoder
.
GetFinalBestPath
();
result
=
decoder
.
GetFinalBestPath
();
decodable
->
Reset
();
decodable
->
Reset
();
...
...
speechx/speechx/asr/decoder/decoder_itf.h
浏览文件 @
21183d48
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -16,6 +15,7 @@
...
@@ -16,6 +15,7 @@
#pragma once
#pragma once
#include "base/common.h"
#include "base/common.h"
#include "fst/symbol-table.h"
#include "kaldi/decoder/decodable-itf.h"
#include "kaldi/decoder/decodable-itf.h"
namespace
ppspeech
{
namespace
ppspeech
{
...
@@ -41,6 +41,14 @@ class DecoderInterface {
...
@@ -41,6 +41,14 @@ class DecoderInterface {
virtual
std
::
string
GetPartialResult
()
=
0
;
virtual
std
::
string
GetPartialResult
()
=
0
;
virtual
const
std
::
shared_ptr
<
fst
::
SymbolTable
>
WordSymbolTable
()
const
=
0
;
virtual
void
FinalizeSearch
()
=
0
;
virtual
const
std
::
vector
<
std
::
vector
<
int
>>&
Inputs
()
const
=
0
;
virtual
const
std
::
vector
<
std
::
vector
<
int
>>&
Outputs
()
const
=
0
;
virtual
const
std
::
vector
<
float
>&
Likelihood
()
const
=
0
;
virtual
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
=
0
;
protected:
protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
...
...
speechx/speechx/asr/decoder/param.h
浏览文件 @
21183d48
...
@@ -57,8 +57,8 @@ DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
...
@@ -57,8 +57,8 @@ DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
// decoder
// decoder
DEFINE_double
(
acoustic_scale
,
1
.
0
,
"acoustic scale"
);
DEFINE_double
(
acoustic_scale
,
1
.
0
,
"acoustic scale"
);
DEFINE_string
(
graph_path
,
"
TLG
"
,
"decoder graph"
);
DEFINE_string
(
graph_path
,
""
,
"decoder graph"
);
DEFINE_string
(
word_symbol_table
,
"
words.txt
"
,
"word symbol table"
);
DEFINE_string
(
word_symbol_table
,
""
,
"word symbol table"
);
DEFINE_int32
(
max_active
,
7500
,
"max active"
);
DEFINE_int32
(
max_active
,
7500
,
"max active"
);
DEFINE_double
(
beam
,
15
.
0
,
"decoder beam"
);
DEFINE_double
(
beam
,
15
.
0
,
"decoder beam"
);
DEFINE_double
(
lattice_beam
,
7
.
5
,
"decoder beam"
);
DEFINE_double
(
lattice_beam
,
7
.
5
,
"decoder beam"
);
...
...
speechx/speechx/asr/nnet/decodable.h
浏览文件 @
21183d48
...
@@ -27,8 +27,6 @@ class Decodable : public kaldi::DecodableInterface {
...
@@ -27,8 +27,6 @@ class Decodable : public kaldi::DecodableInterface {
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetProducer
>&
nnet_producer
,
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetProducer
>&
nnet_producer
,
kaldi
::
BaseFloat
acoustic_scale
=
1.0
);
kaldi
::
BaseFloat
acoustic_scale
=
1.0
);
// void Init(DecodableOpts config);
// nnet logprob output, used by wfst
// nnet logprob output, used by wfst
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
...
...
speechx/speechx/asr/nnet/nnet_producer.cc
浏览文件 @
21183d48
...
@@ -25,15 +25,15 @@ NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
...
@@ -25,15 +25,15 @@ NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
:
nnet_
(
nnet
),
frontend_
(
frontend
)
{
:
nnet_
(
nnet
),
frontend_
(
frontend
)
{
abort_
=
false
;
abort_
=
false
;
Reset
();
Reset
();
thread_
=
std
::
thread
(
RunNnetEvaluation
,
this
);
if
(
nnet_
!=
nullptr
)
thread_
=
std
::
thread
(
RunNnetEvaluation
,
this
);
}
}
void
NnetProducer
::
Accept
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
inputs
)
{
void
NnetProducer
::
Accept
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
inputs
)
{
frontend_
->
Accept
(
inputs
);
frontend_
->
Accept
(
inputs
);
condition_variable_
.
notify_one
();
condition_variable_
.
notify_one
();
}
}
void
NnetProducer
::
UnLock
()
{
void
NnetProducer
::
WaitProduce
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
read_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
read_mutex_
);
while
(
frontend_
->
IsFinished
()
==
false
&&
cache_
.
empty
())
{
while
(
frontend_
->
IsFinished
()
==
false
&&
cache_
.
empty
())
{
condition_read_ready_
.
wait
(
lock
);
condition_read_ready_
.
wait
(
lock
);
...
@@ -41,7 +41,7 @@ void NnetProducer::UnLock() {
...
@@ -41,7 +41,7 @@ void NnetProducer::UnLock() {
return
;
return
;
}
}
void
NnetProducer
::
RunNnetEvaluation
(
NnetProducer
*
me
)
{
void
NnetProducer
::
RunNnetEvaluation
(
NnetProducer
*
me
)
{
me
->
RunNnetEvaluationInteral
();
me
->
RunNnetEvaluationInteral
();
}
}
...
...
speechx/speechx/asr/nnet/nnet_producer.h
浏览文件 @
21183d48
...
@@ -34,9 +34,9 @@ class NnetProducer {
...
@@ -34,9 +34,9 @@ class NnetProducer {
// nnet
// nnet
bool
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
);
bool
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
);
bool
ReadandCompute
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
);
bool
ReadandCompute
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
);
static
void
RunNnetEvaluation
(
NnetProducer
*
me
);
static
void
RunNnetEvaluation
(
NnetProducer
*
me
);
void
RunNnetEvaluationInteral
();
void
RunNnetEvaluationInteral
();
void
UnLock
();
void
WaitProduce
();
void
Wait
()
{
void
Wait
()
{
abort_
=
true
;
abort_
=
true
;
...
@@ -60,8 +60,8 @@ class NnetProducer {
...
@@ -60,8 +60,8 @@ class NnetProducer {
}
}
void
Reset
()
{
void
Reset
()
{
frontend_
->
Reset
();
if
(
frontend_
!=
NULL
)
frontend_
->
Reset
();
nnet_
->
Reset
();
if
(
nnet_
!=
NULL
)
nnet_
->
Reset
();
VLOG
(
3
)
<<
"feature cache reset: cache size: "
<<
cache_
.
size
();
VLOG
(
3
)
<<
"feature cache reset: cache size: "
<<
cache_
.
size
();
cache_
.
clear
();
cache_
.
clear
();
finished_
=
false
;
finished_
=
false
;
...
...
speechx/speechx/asr/recognizer/u2_recognizer.cc
浏览文件 @
21183d48
...
@@ -33,11 +33,15 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
...
@@ -33,11 +33,15 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
CHECK_NE
(
resource
.
vocab_path
,
""
);
CHECK_NE
(
resource
.
vocab_path
,
""
);
if
(
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
==
""
)
{
LOG
(
INFO
)
<<
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
;
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
));
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
));
}
else
{
decoder_
.
reset
(
new
TLGDecoder
(
resource
.
decoder_opts
.
tlg_decoder_opts
));
}
unit_table_
=
decoder_
->
VocabTable
();
symbol_table_
=
decoder_
->
WordSymbolTable
();
symbol_table_
=
unit_table_
;
global_frame_offset_
=
0
;
global_frame_offset_
=
0
;
input_finished_
=
false
;
input_finished_
=
false
;
...
@@ -56,11 +60,14 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource,
...
@@ -56,11 +60,14 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource,
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
CHECK_NE
(
resource
.
vocab_path
,
""
);
CHECK_NE
(
resource
.
vocab_path
,
""
);
if
(
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
==
""
)
{
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
));
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
));
}
else
{
decoder_
.
reset
(
new
TLGDecoder
(
resource
.
decoder_opts
.
tlg_decoder_opts
));
}
unit_table_
=
decoder_
->
VocabTable
();
symbol_table_
=
decoder_
->
WordSymbolTable
();
symbol_table_
=
unit_table_
;
global_frame_offset_
=
0
;
global_frame_offset_
=
0
;
input_finished_
=
false
;
input_finished_
=
false
;
...
@@ -109,10 +116,11 @@ void U2Recognizer::RunDecoderSearch(U2Recognizer* me) {
...
@@ -109,10 +116,11 @@ void U2Recognizer::RunDecoderSearch(U2Recognizer* me) {
void
U2Recognizer
::
RunDecoderSearchInternal
()
{
void
U2Recognizer
::
RunDecoderSearchInternal
()
{
LOG
(
INFO
)
<<
"DecoderSearchInteral begin"
;
LOG
(
INFO
)
<<
"DecoderSearchInteral begin"
;
while
(
!
nnet_producer_
->
IsFinished
())
{
while
(
!
nnet_producer_
->
IsFinished
())
{
nnet_producer_
->
UnLock
();
nnet_producer_
->
WaitProduce
();
decoder_
->
AdvanceDecode
(
decodable_
);
decoder_
->
AdvanceDecode
(
decodable_
);
}
}
Decode
();
decoder_
->
AdvanceDecode
(
decodable_
);
UpdateResult
(
false
);
LOG
(
INFO
)
<<
"DecoderSearchInteral exit"
;
LOG
(
INFO
)
<<
"DecoderSearchInteral exit"
;
}
}
...
@@ -140,7 +148,7 @@ void U2Recognizer::UpdateResult(bool finish) {
...
@@ -140,7 +148,7 @@ void U2Recognizer::UpdateResult(bool finish) {
const
auto
&
times
=
decoder_
->
Times
();
const
auto
&
times
=
decoder_
->
Times
();
result_
.
clear
();
result_
.
clear
();
CHECK_EQ
(
hypothese
s
.
size
(),
likelihood
.
size
());
CHECK_EQ
(
input
s
.
size
(),
likelihood
.
size
());
for
(
size_t
i
=
0
;
i
<
hypotheses
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
hypotheses
.
size
();
i
++
)
{
const
std
::
vector
<
int
>&
hypothesis
=
hypotheses
[
i
];
const
std
::
vector
<
int
>&
hypothesis
=
hypotheses
[
i
];
...
@@ -148,13 +156,9 @@ void U2Recognizer::UpdateResult(bool finish) {
...
@@ -148,13 +156,9 @@ void U2Recognizer::UpdateResult(bool finish) {
path
.
score
=
likelihood
[
i
];
path
.
score
=
likelihood
[
i
];
for
(
size_t
j
=
0
;
j
<
hypothesis
.
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
hypothesis
.
size
();
j
++
)
{
std
::
string
word
=
symbol_table_
->
Find
(
hypothesis
[
j
]);
std
::
string
word
=
symbol_table_
->
Find
(
hypothesis
[
j
]);
// A detailed explanation of this if-else branch can be found in
// path.sentence += (" " + word); // todo SmileGoat: add blank
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
// processor
if
(
decoder_
->
Type
()
==
kWfstBeamSearch
)
{
path
.
sentence
+=
word
;
// todo SmileGoat: add blank processor
path
.
sentence
+=
(
" "
+
word
);
}
else
{
path
.
sentence
+=
(
word
);
}
}
}
// TimeStamp is only supported in final result
// TimeStamp is only supported in final result
...
@@ -162,7 +166,7 @@ void U2Recognizer::UpdateResult(bool finish) {
...
@@ -162,7 +166,7 @@ void U2Recognizer::UpdateResult(bool finish) {
// various FST operations when building the decoding graph. So here we
// various FST operations when building the decoding graph. So here we
// use time stamp of the input(e2e model unit), which is more accurate,
// use time stamp of the input(e2e model unit), which is more accurate,
// and it requires the symbol table of the e2e model used in training.
// and it requires the symbol table of the e2e model used in training.
if
(
unit
_table_
!=
nullptr
&&
finish
)
{
if
(
symbol
_table_
!=
nullptr
&&
finish
)
{
int
offset
=
global_frame_offset_
*
FrameShiftInMs
();
int
offset
=
global_frame_offset_
*
FrameShiftInMs
();
const
std
::
vector
<
int
>&
input
=
inputs
[
i
];
const
std
::
vector
<
int
>&
input
=
inputs
[
i
];
...
@@ -170,7 +174,7 @@ void U2Recognizer::UpdateResult(bool finish) {
...
@@ -170,7 +174,7 @@ void U2Recognizer::UpdateResult(bool finish) {
CHECK_EQ
(
input
.
size
(),
time_stamp
.
size
());
CHECK_EQ
(
input
.
size
(),
time_stamp
.
size
());
for
(
size_t
j
=
0
;
j
<
input
.
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
input
.
size
();
j
++
)
{
std
::
string
word
=
unit
_table_
->
Find
(
input
[
j
]);
std
::
string
word
=
symbol
_table_
->
Find
(
input
[
j
]);
int
start
=
int
start
=
time_stamp
[
j
]
*
FrameShiftInMs
()
-
time_stamp_gap_
>
0
time_stamp
[
j
]
*
FrameShiftInMs
()
-
time_stamp_gap_
>
0
...
@@ -214,7 +218,7 @@ void U2Recognizer::UpdateResult(bool finish) {
...
@@ -214,7 +218,7 @@ void U2Recognizer::UpdateResult(bool finish) {
void
U2Recognizer
::
AttentionRescoring
()
{
void
U2Recognizer
::
AttentionRescoring
()
{
decoder_
->
FinalizeSearch
();
decoder_
->
FinalizeSearch
();
UpdateResult
(
tru
e
);
UpdateResult
(
fals
e
);
// No need to do rescoring
// No need to do rescoring
if
(
0.0
==
opts_
.
decoder_opts
.
rescoring_weight
)
{
if
(
0.0
==
opts_
.
decoder_opts
.
rescoring_weight
)
{
...
...
speechx/speechx/asr/recognizer/u2_recognizer.h
浏览文件 @
21183d48
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "decoder/common.h"
#include "decoder/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/decoder_itf.h"
#include "decoder/decoder_itf.h"
#include "frontend/feature_pipeline.h"
#include "frontend/feature_pipeline.h"
#include "fst/fstlib.h"
#include "fst/fstlib.h"
...
@@ -33,6 +34,8 @@ DECLARE_int32(blank);
...
@@ -33,6 +34,8 @@ DECLARE_int32(blank);
DECLARE_double
(
acoustic_scale
);
DECLARE_double
(
acoustic_scale
);
DECLARE_string
(
vocab_path
);
DECLARE_string
(
vocab_path
);
DECLARE_string
(
word_symbol_table
);
// DECLARE_string(fst_path);
namespace
ppspeech
{
namespace
ppspeech
{
...
@@ -59,6 +62,7 @@ struct DecodeOptions {
...
@@ -59,6 +62,7 @@ struct DecodeOptions {
// CtcEndpointConfig ctc_endpoint_opts;
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions
ctc_prefix_search_opts
{};
CTCBeamSearchOptions
ctc_prefix_search_opts
{};
TLGDecoderOptions
tlg_decoder_opts
{};
static
DecodeOptions
InitFromFlags
()
{
static
DecodeOptions
InitFromFlags
()
{
DecodeOptions
decoder_opts
;
DecodeOptions
decoder_opts
;
...
@@ -70,6 +74,13 @@ struct DecodeOptions {
...
@@ -70,6 +74,13 @@ struct DecodeOptions {
decoder_opts
.
ctc_prefix_search_opts
.
blank
=
FLAGS_blank
;
decoder_opts
.
ctc_prefix_search_opts
.
blank
=
FLAGS_blank
;
decoder_opts
.
ctc_prefix_search_opts
.
first_beam_size
=
FLAGS_nbest
;
decoder_opts
.
ctc_prefix_search_opts
.
first_beam_size
=
FLAGS_nbest
;
decoder_opts
.
ctc_prefix_search_opts
.
second_beam_size
=
FLAGS_nbest
;
decoder_opts
.
ctc_prefix_search_opts
.
second_beam_size
=
FLAGS_nbest
;
// decoder_opts.tlg_decoder_opts.fst_path = "";//FLAGS_fst_path;
// decoder_opts.tlg_decoder_opts.word_symbol_table =
// FLAGS_word_symbol_table;
// decoder_opts.tlg_decoder_opts.nbest = FLAGS_nbest;
decoder_opts
.
tlg_decoder_opts
=
ppspeech
::
TLGDecoderOptions
::
InitFromFlags
();
LOG
(
INFO
)
<<
"chunk_size: "
<<
decoder_opts
.
chunk_size
;
LOG
(
INFO
)
<<
"chunk_size: "
<<
decoder_opts
.
chunk_size
;
LOG
(
INFO
)
<<
"num_left_chunks: "
<<
decoder_opts
.
num_left_chunks
;
LOG
(
INFO
)
<<
"num_left_chunks: "
<<
decoder_opts
.
num_left_chunks
;
LOG
(
INFO
)
<<
"ctc_weight: "
<<
decoder_opts
.
ctc_weight
;
LOG
(
INFO
)
<<
"ctc_weight: "
<<
decoder_opts
.
ctc_weight
;
...
@@ -154,10 +165,9 @@ class U2Recognizer {
...
@@ -154,10 +165,9 @@ class U2Recognizer {
std
::
shared_ptr
<
NnetProducer
>
nnet_producer_
;
std
::
shared_ptr
<
NnetProducer
>
nnet_producer_
;
std
::
shared_ptr
<
Decodable
>
decodable_
;
std
::
shared_ptr
<
Decodable
>
decodable_
;
std
::
unique_ptr
<
CTCPrefixBeamSearch
>
decoder_
;
std
::
unique_ptr
<
DecoderBase
>
decoder_
;
// e2e unit symbol table
// e2e unit symbol table
std
::
shared_ptr
<
fst
::
SymbolTable
>
unit_table_
=
nullptr
;
std
::
shared_ptr
<
fst
::
SymbolTable
>
symbol_table_
=
nullptr
;
std
::
shared_ptr
<
fst
::
SymbolTable
>
symbol_table_
=
nullptr
;
std
::
vector
<
DecodeResult
>
result_
;
std
::
vector
<
DecodeResult
>
result_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录