Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8a225b17
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8a225b17
编写于
1月 18, 2023
作者:
Y
YangZhou
提交者:
GitHub
1月 18, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[speechx] thread decode (#2839)
* fix nnet thread crash && rescore cost time * add nnet thread main
上级
ee7c266f
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
303 addition
and
110 deletion
+303
-110
speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc
...chx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc
+5
-4
speechx/speechx/asr/nnet/CMakeLists.txt
speechx/speechx/asr/nnet/CMakeLists.txt
+17
-10
speechx/speechx/asr/nnet/decodable.cc
speechx/speechx/asr/nnet/decodable.cc
+0
-1
speechx/speechx/asr/nnet/nnet_producer.cc
speechx/speechx/asr/nnet/nnet_producer.cc
+46
-8
speechx/speechx/asr/nnet/nnet_producer.h
speechx/speechx/asr/nnet/nnet_producer.h
+26
-8
speechx/speechx/asr/nnet/u2_nnet_thread_main.cc
speechx/speechx/asr/nnet/u2_nnet_thread_main.cc
+137
-0
speechx/speechx/asr/recognizer/u2_recognizer.cc
speechx/speechx/asr/recognizer/u2_recognizer.cc
+34
-5
speechx/speechx/asr/recognizer/u2_recognizer.h
speechx/speechx/asr/recognizer/u2_recognizer.h
+9
-7
speechx/speechx/asr/recognizer/u2_recognizer_main.cc
speechx/speechx/asr/recognizer/u2_recognizer_main.cc
+2
-2
speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc
speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc
+11
-15
speechx/speechx/common/frontend/compute_fbank_main.cc
speechx/speechx/common/frontend/compute_fbank_main.cc
+1
-2
speechx/speechx/common/frontend/feature_cache.cc
speechx/speechx/common/frontend/feature_cache.cc
+11
-29
speechx/speechx/common/frontend/feature_cache.h
speechx/speechx/common/frontend/feature_cache.h
+3
-17
speechx/speechx/common/frontend/feature_pipeline.cc
speechx/speechx/common/frontend/feature_pipeline.cc
+1
-1
speechx/speechx/common/frontend/feature_pipeline.h
speechx/speechx/common/frontend/feature_pipeline.h
+0
-1
未找到文件。
speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc
浏览文件 @
8a225b17
...
...
@@ -63,8 +63,9 @@ void CTCPrefixBeamSearch::Reset() {
times_
.
emplace_back
(
empty
);
}
void
CTCPrefixBeamSearch
::
InitDecoder
()
{
Reset
();
}
void
CTCPrefixBeamSearch
::
InitDecoder
()
{
Reset
();
}
void
CTCPrefixBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
{
...
...
@@ -77,7 +78,7 @@ void CTCPrefixBeamSearch::AdvanceDecode(
bool
flag
=
decodable
->
FrameLikelihood
(
num_frame_decoded_
,
&
frame_prob
);
feat_nnet_cost
+=
timer
.
Elapsed
();
if
(
flag
==
false
)
{
VLOG
(
3
)
<<
"decoder advance decode exit."
<<
frame_prob
.
size
();
VLOG
(
2
)
<<
"decoder advance decode exit."
<<
frame_prob
.
size
();
break
;
}
...
...
@@ -87,7 +88,7 @@ void CTCPrefixBeamSearch::AdvanceDecode(
AdvanceDecoding
(
likelihood
);
search_cost
+=
timer
.
Elapsed
();
VLOG
(
2
)
<<
"num_frame_decoded_: "
<<
num_frame_decoded_
;
VLOG
(
1
)
<<
"num_frame_decoded_: "
<<
num_frame_decoded_
;
}
VLOG
(
1
)
<<
"AdvanceDecode feat + forward cost: "
<<
feat_nnet_cost
<<
" sec."
;
...
...
speechx/speechx/asr/nnet/CMakeLists.txt
浏览文件 @
8a225b17
...
...
@@ -8,14 +8,21 @@ target_link_libraries(nnet utils)
target_compile_options
(
nnet PUBLIC
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
nnet PUBLIC
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
# test bin
#if(USING_U2)
# set(bin_name u2_nnet_main)
# add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
# target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
# target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
# test bin
#set(bin_name u2_nnet_main)
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
# target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
# target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
# target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
#endif()
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
set
(
bin_name u2_nnet_thread_main
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
utils kaldi-util kaldi-matrix gflags glog nnet frontend
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
)
speechx/speechx/asr/nnet/decodable.cc
浏览文件 @
8a225b17
...
...
@@ -33,7 +33,6 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_producer_
->
Acceptlikelihood
(
likelihood
);
}
// return the size of frame have computed.
int32
Decodable
::
NumFramesReady
()
const
{
return
frames_ready_
;
}
...
...
speechx/speechx/asr/nnet/nnet_producer.cc
浏览文件 @
8a225b17
...
...
@@ -22,14 +22,43 @@ using kaldi::BaseFloat;
NnetProducer
::
NnetProducer
(
std
::
shared_ptr
<
NnetBase
>
nnet
,
std
::
shared_ptr
<
FrontendInterface
>
frontend
)
:
nnet_
(
nnet
),
frontend_
(
frontend
)
{}
:
nnet_
(
nnet
),
frontend_
(
frontend
)
{
abort_
=
false
;
Reset
();
thread_
=
std
::
thread
(
RunNnetEvaluation
,
this
);
}
void
NnetProducer
::
Accept
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
inputs
)
{
frontend_
->
Accept
(
inputs
);
condition_variable_
.
notify_one
();
}
void
NnetProducer
::
UnLock
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
read_mutex_
);
while
(
frontend_
->
IsFinished
()
==
false
&&
cache_
.
empty
())
{
condition_read_ready_
.
wait
(
lock
);
}
return
;
}
void
NnetProducer
::
RunNnetEvaluation
(
NnetProducer
*
me
)
{
me
->
RunNnetEvaluationInteral
();
}
void
NnetProducer
::
RunNnetEvaluationInteral
()
{
bool
result
=
false
;
do
{
result
=
Compute
();
}
while
(
result
);
LOG
(
INFO
)
<<
"NnetEvaluationInteral begin"
;
while
(
!
abort_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
condition_variable_
.
wait
(
lock
);
do
{
result
=
Compute
();
}
while
(
result
);
if
(
frontend_
->
IsFinished
()
==
true
)
{
if
(
cache_
.
empty
())
finished_
=
true
;
}
}
LOG
(
INFO
)
<<
"NnetEvaluationInteral exit"
;
}
void
NnetProducer
::
Acceptlikelihood
(
...
...
@@ -39,12 +68,20 @@ void NnetProducer::Acceptlikelihood(
for
(
size_t
idx
=
0
;
idx
<
likelihood
.
NumRows
();
++
idx
)
{
for
(
size_t
col
=
0
;
col
<
likelihood
.
NumCols
();
++
col
)
{
prob
[
col
]
=
likelihood
(
idx
,
col
);
cache_
.
push_back
(
prob
);
}
cache_
.
push_back
(
prob
);
}
}
bool
NnetProducer
::
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
)
{
bool
flag
=
cache_
.
pop
(
nnet_prob
);
condition_variable_
.
notify_one
();
return
flag
;
}
bool
NnetProducer
::
ReadandCompute
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
)
{
Compute
();
if
(
frontend_
->
IsFinished
()
&&
cache_
.
empty
())
finished_
=
true
;
bool
flag
=
cache_
.
pop
(
nnet_prob
);
return
flag
;
}
...
...
@@ -53,22 +90,23 @@ bool NnetProducer::Compute() {
vector
<
BaseFloat
>
features
;
if
(
frontend_
==
NULL
||
frontend_
->
Read
(
&
features
)
==
false
)
{
// no feat or frontend_ not init.
VLOG
(
3
)
<<
"no feat avalible"
;
VLOG
(
2
)
<<
"no feat avalible"
;
return
false
;
}
CHECK_GE
(
frontend_
->
Dim
(),
0
);
VLOG
(
2
)
<<
"Forward in "
<<
features
.
size
()
/
frontend_
->
Dim
()
<<
" feats."
;
VLOG
(
1
)
<<
"Forward in "
<<
features
.
size
()
/
frontend_
->
Dim
()
<<
" feats."
;
NnetOut
out
;
nnet_
->
FeedForward
(
features
,
frontend_
->
Dim
(),
&
out
);
int32
&
vocab_dim
=
out
.
vocab_dim
;
size_t
nframes
=
out
.
logprobs
.
size
()
/
vocab_dim
;
VLOG
(
2
)
<<
"Forward out "
<<
nframes
<<
" decoder frames."
;
VLOG
(
1
)
<<
"Forward out "
<<
nframes
<<
" decoder frames."
;
for
(
size_t
idx
=
0
;
idx
<
nframes
;
++
idx
)
{
std
::
vector
<
BaseFloat
>
logprob
(
out
.
logprobs
.
data
()
+
idx
*
vocab_dim
,
out
.
logprobs
.
data
()
+
(
idx
+
1
)
*
vocab_dim
);
cache_
.
push_back
(
logprob
);
condition_read_ready_
.
notify_one
();
}
return
true
;
}
...
...
speechx/speechx/asr/nnet/nnet_producer.h
浏览文件 @
8a225b17
...
...
@@ -33,27 +33,38 @@ class NnetProducer {
// nnet
bool
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
);
bool
ReadandCompute
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
);
static
void
RunNnetEvaluation
(
NnetProducer
*
me
);
void
RunNnetEvaluationInteral
();
void
UnLock
();
void
Wait
()
{
abort_
=
true
;
condition_variable_
.
notify_one
();
if
(
thread_
.
joinable
())
thread_
.
join
();
}
bool
Empty
()
const
{
return
cache_
.
empty
();
}
void
SetFinished
()
{
void
Set
Input
Finished
()
{
LOG
(
INFO
)
<<
"set finished"
;
// std::unique_lock<std::mutex> lock(mutex_);
frontend_
->
SetFinished
();
// read the last chunk data
Compute
();
// ready_feed_condition_.notify_one();
LOG
(
INFO
)
<<
"compute last feats done."
;
condition_variable_
.
notify_one
();
}
bool
IsFinished
()
const
{
return
frontend_
->
IsFinished
();
}
// the compute thread exit
bool
IsFinished
()
const
{
return
finished_
;
}
~
NnetProducer
()
{
if
(
thread_
.
joinable
())
thread_
.
join
();
}
void
Reset
()
{
frontend_
->
Reset
();
nnet_
->
Reset
();
VLOG
(
3
)
<<
"feature cache reset: cache size: "
<<
cache_
.
size
();
cache_
.
clear
();
finished_
=
false
;
}
void
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
...
...
@@ -66,6 +77,13 @@ class NnetProducer {
std
::
shared_ptr
<
FrontendInterface
>
frontend_
;
std
::
shared_ptr
<
NnetBase
>
nnet_
;
SafeQueue
<
std
::
vector
<
kaldi
::
BaseFloat
>>
cache_
;
std
::
mutex
mutex_
;
std
::
mutex
read_mutex_
;
std
::
condition_variable
condition_variable_
;
std
::
condition_variable
condition_read_ready_
;
std
::
thread
thread_
;
bool
finished_
;
bool
abort_
;
DISALLOW_COPY_AND_ASSIGN
(
NnetProducer
);
};
...
...
speechx/speechx/asr/nnet/u2_nnet_thread_main.cc
0 → 100644
浏览文件 @
8a225b17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "frontend/feature_pipeline.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/u2_nnet.h"
#include "nnet/nnet_producer.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test wav rspecifier"
);
DEFINE_string
(
nnet_prob_wspecifier
,
""
,
"nnet porb wspecifier"
);
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int32
num_done
=
0
,
num_err
=
0
;
int
sample_rate
=
FLAGS_sample_rate
;
float
streaming_chunk
=
FLAGS_streaming_chunk
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
CHECK_GT
(
FLAGS_wav_rspecifier
.
size
(),
0
);
CHECK_GT
(
FLAGS_nnet_prob_wspecifier
.
size
(),
0
);
CHECK_GT
(
FLAGS_model_path
.
size
(),
0
);
LOG
(
INFO
)
<<
"input rspecifier: "
<<
FLAGS_wav_rspecifier
;
LOG
(
INFO
)
<<
"output wspecifier: "
<<
FLAGS_nnet_prob_wspecifier
;
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
FLAGS_wav_rspecifier
);
kaldi
::
BaseFloatMatrixWriter
nnet_out_writer
(
FLAGS_nnet_prob_wspecifier
);
ppspeech
::
ModelOptions
model_opts
=
ppspeech
::
ModelOptions
::
InitFromFlags
();
ppspeech
::
FeaturePipelineOptions
feature_opts
=
ppspeech
::
FeaturePipelineOptions
::
InitFromFlags
();
feature_opts
.
assembler_opts
.
fill_zero
=
false
;
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet
(
new
ppspeech
::
U2Nnet
(
model_opts
));
std
::
shared_ptr
<
ppspeech
::
FeaturePipeline
>
feature_pipeline
(
new
ppspeech
::
FeaturePipeline
(
feature_opts
));
std
::
shared_ptr
<
ppspeech
::
NnetProducer
>
nnet_producer
(
new
ppspeech
::
NnetProducer
(
nnet
,
feature_pipeline
));
kaldi
::
Timer
timer
;
float
tot_wav_duration
=
0
;
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"wav dur: "
<<
wave_data
.
Duration
()
<<
" sec."
;
double
dur
=
wave_data
.
Duration
();
tot_wav_duration
+=
dur
;
int32
this_channel
=
0
;
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
int
tot_samples
=
waveform
.
Dim
();
LOG
(
INFO
)
<<
"wav len (sample): "
<<
tot_samples
;
int
sample_offset
=
0
;
kaldi
::
Timer
timer
;
while
(
sample_offset
<
tot_samples
)
{
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
std
::
vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
[
i
]
=
waveform
(
sample_offset
+
i
);
}
nnet_producer
->
Accept
(
wav_chunk
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
nnet_producer
->
SetInputFinished
();
}
// no overlap
sample_offset
+=
cur_chunk_size
;
}
CHECK
(
sample_offset
==
tot_samples
);
std
::
vector
<
std
::
vector
<
kaldi
::
BaseFloat
>>
prob_vec
;
while
(
1
)
{
std
::
vector
<
kaldi
::
BaseFloat
>
logprobs
;
bool
isok
=
nnet_producer
->
Read
(
&
logprobs
);
if
(
nnet_producer
->
IsFinished
())
break
;
if
(
isok
==
false
)
continue
;
prob_vec
.
push_back
(
logprobs
);
}
{
// writer nnet output
kaldi
::
MatrixIndexT
nrow
=
prob_vec
.
size
();
kaldi
::
MatrixIndexT
ncol
=
prob_vec
[
0
].
size
();
LOG
(
INFO
)
<<
"nnet out shape: "
<<
nrow
<<
", "
<<
ncol
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
nnet_out
(
nrow
,
ncol
);
for
(
int32
row_idx
=
0
;
row_idx
<
nrow
;
++
row_idx
)
{
for
(
int32
col_idx
=
0
;
col_idx
<
ncol
;
++
col_idx
)
{
nnet_out
(
row_idx
,
col_idx
)
=
prob_vec
[
row_idx
][
col_idx
];
}
}
nnet_out_writer
.
Write
(
utt
,
nnet_out
);
}
nnet_producer
->
Reset
();
}
nnet_producer
->
Wait
();
double
elapsed
=
timer
.
Elapsed
();
LOG
(
INFO
)
<<
"Program cost:"
<<
elapsed
<<
" sec"
;
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" utterances, "
<<
num_err
<<
" with errors."
;
return
(
num_done
!=
0
?
0
:
1
);
}
speechx/speechx/asr/recognizer/u2_recognizer.cc
浏览文件 @
8a225b17
...
...
@@ -39,12 +39,28 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
unit_table_
=
decoder_
->
VocabTable
();
symbol_table_
=
unit_table_
;
global_frame_offset_
=
0
;
input_finished_
=
false
;
num_frames_
=
0
;
result_
.
clear
();
}
U2Recognizer
::~
U2Recognizer
()
{
SetInputFinished
();
WaitDecodeFinished
();
}
Reset
();
void
U2Recognizer
::
WaitDecodeFinished
()
{
if
(
thread_
.
joinable
())
thread_
.
join
();
}
void
U2Recognizer
::
Reset
()
{
void
U2Recognizer
::
WaitFinished
()
{
if
(
thread_
.
joinable
())
thread_
.
join
();
nnet_producer_
->
Wait
();
}
void
U2Recognizer
::
InitDecoder
()
{
global_frame_offset_
=
0
;
input_finished_
=
false
;
num_frames_
=
0
;
...
...
@@ -52,6 +68,7 @@ void U2Recognizer::Reset() {
decodable_
->
Reset
();
decoder_
->
Reset
();
thread_
=
std
::
thread
(
RunDecoderSearch
,
this
);
}
void
U2Recognizer
::
ResetContinuousDecoding
()
{
...
...
@@ -63,6 +80,19 @@ void U2Recognizer::ResetContinuousDecoding() {
decoder_
->
Reset
();
}
void
U2Recognizer
::
RunDecoderSearch
(
U2Recognizer
*
me
)
{
me
->
RunDecoderSearchInternal
();
}
void
U2Recognizer
::
RunDecoderSearchInternal
()
{
LOG
(
INFO
)
<<
"DecoderSearchInteral begin"
;
while
(
!
nnet_producer_
->
IsFinished
())
{
nnet_producer_
->
UnLock
();
decoder_
->
AdvanceDecode
(
decodable_
);
}
Decode
();
LOG
(
INFO
)
<<
"DecoderSearchInteral exit"
;
}
void
U2Recognizer
::
Accept
(
const
vector
<
BaseFloat
>&
waves
)
{
kaldi
::
Timer
timer
;
...
...
@@ -71,7 +101,6 @@ void U2Recognizer::Accept(const vector<BaseFloat>& waves) {
<<
" samples."
;
}
void
U2Recognizer
::
Decode
()
{
decoder_
->
AdvanceDecode
(
decodable_
);
UpdateResult
(
false
);
...
...
@@ -207,8 +236,8 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std
::
string
U2Recognizer
::
GetPartialResult
()
{
return
result_
[
0
].
sentence
;
}
void
U2Recognizer
::
SetFinished
()
{
nnet_producer_
->
SetFinished
();
void
U2Recognizer
::
Set
Input
Finished
()
{
nnet_producer_
->
Set
Input
Finished
();
input_finished_
=
true
;
}
...
...
speechx/speechx/asr/recognizer/u2_recognizer.h
浏览文件 @
8a225b17
...
...
@@ -112,19 +112,21 @@ struct U2RecognizerResource {
class
U2Recognizer
{
public:
explicit
U2Recognizer
(
const
U2RecognizerResource
&
resouce
);
void
Reset
();
~
U2Recognizer
();
void
InitDecoder
();
void
ResetContinuousDecoding
();
void
Accept
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
waves
);
void
Decode
();
void
Rescoring
();
std
::
string
GetFinalResult
();
std
::
string
GetPartialResult
();
void
SetFinished
();
void
Set
Input
Finished
();
bool
IsFinished
()
{
return
input_finished_
;
}
void
WaitDecodeFinished
();
void
WaitFinished
();
bool
DecodedSomething
()
const
{
return
!
result_
.
empty
()
&&
!
result_
[
0
].
sentence
.
empty
();
...
...
@@ -137,18 +139,17 @@ class U2Recognizer {
// feature_pipeline_->FrameShift();
}
const
std
::
vector
<
DecodeResult
>&
Result
()
const
{
return
result_
;
}
void
AttentionRescoring
();
private:
void
AttentionRescoring
();
static
void
RunDecoderSearch
(
U2Recognizer
*
me
);
void
RunDecoderSearchInternal
();
void
UpdateResult
(
bool
finish
=
false
);
private:
U2RecognizerResource
opts_
;
// std::shared_ptr<U2RecognizerResource> resource_;
// U2RecognizerResource resource_;
std
::
shared_ptr
<
NnetProducer
>
nnet_producer_
;
std
::
shared_ptr
<
Decodable
>
decodable_
;
std
::
unique_ptr
<
CTCPrefixBeamSearch
>
decoder_
;
...
...
@@ -167,6 +168,7 @@ class U2Recognizer {
const
int
time_stamp_gap_
=
100
;
bool
input_finished_
;
std
::
thread
thread_
;
};
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/asr/recognizer/u2_recognizer_main.cc
浏览文件 @
8a225b17
...
...
@@ -49,6 +49,7 @@ int main(int argc, char* argv[]) {
ppspeech
::
U2Recognizer
recognizer
(
resource
);
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
recognizer
.
InitDecoder
();
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
...
...
@@ -79,7 +80,7 @@ int main(int argc, char* argv[]) {
recognizer
.
Accept
(
wav_chunk
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
recognizer
.
SetFinished
();
recognizer
.
Set
Input
Finished
();
}
recognizer
.
Decode
();
if
(
recognizer
.
DecodedSomething
())
{
...
...
@@ -100,7 +101,6 @@ int main(int argc, char* argv[]) {
std
::
string
result
=
recognizer
.
GetFinalResult
();
recognizer
.
Reset
();
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
...
...
speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc
浏览文件 @
8a225b17
...
...
@@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
void
decode_func
(
std
::
shared_ptr
<
ppspeech
::
U2Recognizer
>
recognizer
)
{
while
(
!
recognizer
->
IsFinished
())
{
recognizer
->
Decode
();
usleep
(
100
);
}
recognizer
->
Decode
();
recognizer
->
Rescoring
();
}
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
...
...
@@ -40,6 +31,7 @@ int main(int argc, char* argv[]) {
int32
num_done
=
0
,
num_err
=
0
;
double
tot_wav_duration
=
0.0
;
double
tot_attention_rescore_time
=
0.0
;
double
tot_decode_time
=
0.0
;
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
...
...
@@ -59,7 +51,7 @@ int main(int argc, char* argv[]) {
new
ppspeech
::
U2Recognizer
(
resource
));
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
std
::
thread
recognizer_thread
(
decode_func
,
recognizer_ptr
);
recognizer_ptr
->
InitDecoder
(
);
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
...
...
@@ -74,7 +66,6 @@ int main(int argc, char* argv[]) {
LOG
(
INFO
)
<<
"wav len (sample): "
<<
tot_samples
;
int
sample_offset
=
0
;
kaldi
::
Timer
timer
;
kaldi
::
Timer
local_timer
;
while
(
sample_offset
<
tot_samples
)
{
...
...
@@ -85,21 +76,23 @@ int main(int argc, char* argv[]) {
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
[
i
]
=
waveform
(
sample_offset
+
i
);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer_ptr
->
Accept
(
wav_chunk
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
recognizer_ptr
->
SetFinished
();
recognizer_ptr
->
Set
Input
Finished
();
}
// no overlap
sample_offset
+=
cur_chunk_size
;
}
CHECK
(
sample_offset
==
tot_samples
);
recognizer_ptr
->
WaitDecodeFinished
();
kaldi
::
Timer
timer
;
recognizer_ptr
->
AttentionRescoring
();
tot_attention_rescore_time
+=
timer
.
Elapsed
();
recognizer_thread
.
join
();
std
::
string
result
=
recognizer_ptr
->
GetFinalResult
();
recognizer_ptr
->
Reset
();
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
++
num_err
;
...
...
@@ -107,6 +100,7 @@ int main(int argc, char* argv[]) {
continue
;
}
tot_decode_time
+=
local_timer
.
Elapsed
();
LOG
(
INFO
)
<<
utt
<<
" "
<<
result
;
LOG
(
INFO
)
<<
" RTF: "
<<
local_timer
.
Elapsed
()
/
dur
<<
" dur: "
<<
dur
<<
" cost: "
<<
local_timer
.
Elapsed
();
...
...
@@ -115,9 +109,11 @@ int main(int argc, char* argv[]) {
++
num_done
;
}
recognizer_ptr
->
WaitFinished
();
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"total decode cost:"
<<
tot_decode_time
<<
" sec"
;
LOG
(
INFO
)
<<
"total rescore cost:"
<<
tot_attention_rescore_time
<<
" sec"
;
LOG
(
INFO
)
<<
"RTF is: "
<<
tot_decode_time
/
tot_wav_duration
;
}
speechx/speechx/common/frontend/compute_fbank_main.cc
浏览文件 @
8a225b17
...
...
@@ -73,8 +73,7 @@ int main(int argc, char* argv[]) {
new
ppspeech
::
CMVN
(
FLAGS_cmvn_file
,
std
::
move
(
fbank
)));
// the feature cache output feature chunk by chunk.
ppspeech
::
FeatureCacheOptions
feat_cache_opts
;
ppspeech
::
FeatureCache
feature_cache
(
feat_cache_opts
,
std
::
move
(
cmvn
));
ppspeech
::
FeatureCache
feature_cache
(
kint16max
,
std
::
move
(
cmvn
));
LOG
(
INFO
)
<<
"fbank: "
<<
true
;
LOG
(
INFO
)
<<
"feat dim: "
<<
feature_cache
.
Dim
();
...
...
speechx/speechx/common/frontend/feature_cache.cc
浏览文件 @
8a225b17
...
...
@@ -20,10 +20,9 @@ using kaldi::BaseFloat;
using
std
::
unique_ptr
;
using
std
::
vector
;
FeatureCache
::
FeatureCache
(
FeatureCacheOptions
opts
,
FeatureCache
::
FeatureCache
(
size_t
max_size
,
unique_ptr
<
FrontendInterface
>
base_extractor
)
{
max_size_
=
opts
.
max_size
;
timeout_
=
opts
.
timeout
;
// ms
max_size_
=
max_size
;
base_extractor_
=
std
::
move
(
base_extractor
);
dim_
=
base_extractor_
->
Dim
();
}
...
...
@@ -31,34 +30,25 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts,
void
FeatureCache
::
Accept
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
inputs
)
{
// read inputs
base_extractor_
->
Accept
(
inputs
);
// feed current data
bool
result
=
false
;
do
{
result
=
Compute
();
}
while
(
result
);
}
// pop feature chunk
bool
FeatureCache
::
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
feats
)
{
kaldi
::
Timer
timer
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
while
(
cache_
.
empty
()
&&
base_extractor_
->
IsFinished
()
==
false
)
{
// todo refactor: wait
// ready_read_condition_.wait(lock);
int32
elapsed
=
static_cast
<
int32
>
(
timer
.
Elapsed
()
*
1000
);
// ms
if
(
elapsed
>
timeout_
)
{
return
false
;
}
usleep
(
100
);
// sleep 0.1 ms
// feed current data
if
(
cache_
.
empty
())
{
bool
result
=
false
;
do
{
result
=
Compute
();
}
while
(
result
);
}
if
(
cache_
.
empty
())
return
false
;
// read from cache
*
feats
=
cache_
.
front
();
cache_
.
pop
();
ready_feed_condition_
.
notify_one
();
VLOG
(
1
)
<<
"FeatureCache::Read cost: "
<<
timer
.
Elapsed
()
<<
" sec."
;
return
true
;
}
...
...
@@ -73,23 +63,15 @@ bool FeatureCache::Compute() {
kaldi
::
Timer
timer
;
int32
num_chunk
=
feature
.
size
()
/
dim_
;
nframe_
+=
num_chunk
;
VLOG
(
3
)
<<
"nframe computed: "
<<
nframe_
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_chunk
;
++
chunk_idx
)
{
int32
start
=
chunk_idx
*
dim_
;
vector
<
BaseFloat
>
feature_chunk
(
feature
.
data
()
+
start
,
feature
.
data
()
+
start
+
dim_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
while
(
cache_
.
size
()
>=
max_size_
)
{
// cache full, wait
ready_feed_condition_
.
wait
(
lock
);
}
// feed cache
cache_
.
push
(
feature_chunk
);
ready_read_condition_
.
notify_one
()
;
++
nframe_
;
}
VLOG
(
1
)
<<
"FeatureCache::Compute cost: "
<<
timer
.
Elapsed
()
<<
" sec. "
...
...
@@ -97,4 +79,4 @@ bool FeatureCache::Compute() {
return
true
;
}
}
// namespace ppspeech
\ No newline at end of file
}
// namespace ppspeech
speechx/speechx/common/frontend/feature_cache.h
浏览文件 @
8a225b17
...
...
@@ -19,16 +19,10 @@
namespace
ppspeech
{
struct
FeatureCacheOptions
{
int32
max_size
;
int32
timeout
;
// ms
FeatureCacheOptions
()
:
max_size
(
kint16max
),
timeout
(
1
)
{}
};
class
FeatureCache
:
public
FrontendInterface
{
public:
explicit
FeatureCache
(
FeatureCacheOptions
opts
,
size_t
max_size
=
kint16max
,
std
::
unique_ptr
<
FrontendInterface
>
base_extractor
=
NULL
);
// Feed feats or waves
...
...
@@ -41,13 +35,11 @@ class FeatureCache : public FrontendInterface {
virtual
size_t
Dim
()
const
{
return
dim_
;
}
virtual
void
SetFinished
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
LOG
(
INFO
)
<<
"set finished"
;
// std::unique_lock<std::mutex> lock(mutex_);
base_extractor_
->
SetFinished
();
// read the last chunk data
Compute
();
// ready_feed_condition_.notify_one
();
base_extractor_
->
SetFinished
();
LOG
(
INFO
)
<<
"compute last feats done."
;
}
...
...
@@ -66,16 +58,10 @@ class FeatureCache : public FrontendInterface {
int32
dim_
;
size_t
max_size_
;
// cache capacity
int32
frame_chunk_size_
;
// window
int32
frame_chunk_stride_
;
// stride
std
::
unique_ptr
<
FrontendInterface
>
base_extractor_
;
kaldi
::
int32
timeout_
;
// ms
std
::
vector
<
kaldi
::
BaseFloat
>
remained_feature_
;
std
::
queue
<
std
::
vector
<
BaseFloat
>>
cache_
;
// feature cache
std
::
mutex
mutex_
;
std
::
condition_variable
ready_feed_condition_
;
std
::
condition_variable
ready_read_condition_
;
int32
nframe_
;
// num of feature computed
DISALLOW_COPY_AND_ASSIGN
(
FeatureCache
);
...
...
speechx/speechx/common/frontend/feature_pipeline.cc
浏览文件 @
8a225b17
...
...
@@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
new
ppspeech
::
CMVN
(
opts
.
cmvn_file
,
std
::
move
(
base_feature
)));
unique_ptr
<
FrontendInterface
>
cache
(
new
ppspeech
::
FeatureCache
(
opts
.
feature_cache_opts
,
std
::
move
(
cmvn
)));
new
ppspeech
::
FeatureCache
(
kint16max
,
std
::
move
(
cmvn
)));
base_extractor_
.
reset
(
new
ppspeech
::
Assembler
(
opts
.
assembler_opts
,
std
::
move
(
cache
)));
...
...
speechx/speechx/common/frontend/feature_pipeline.h
浏览文件 @
8a225b17
...
...
@@ -39,7 +39,6 @@ namespace ppspeech {
struct
FeaturePipelineOptions
{
std
::
string
cmvn_file
{};
knf
::
FbankOptions
fbank_opts
{};
FeatureCacheOptions
feature_cache_opts
{};
AssemblerOptions
assembler_opts
{};
static
FeaturePipelineOptions
InitFromFlags
()
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录