Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5042a168
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5042a168
编写于
2月 01, 2023
作者:
Y
YangZhou
提交者:
GitHub
2月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[speechx] add batch recognizer decode. (#2866)
* add recognizer_batch
上级
8a225b17
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
242 addition
and
19 deletion
+242
-19
speechx/speechx/asr/nnet/u2_nnet.cc
speechx/speechx/asr/nnet/u2_nnet.cc
+23
-12
speechx/speechx/asr/nnet/u2_nnet.h
speechx/speechx/asr/nnet/u2_nnet.h
+2
-2
speechx/speechx/asr/recognizer/CMakeLists.txt
speechx/speechx/asr/recognizer/CMakeLists.txt
+1
-0
speechx/speechx/asr/recognizer/u2_recognizer.cc
speechx/speechx/asr/recognizer/u2_recognizer.cc
+26
-4
speechx/speechx/asr/recognizer/u2_recognizer.h
speechx/speechx/asr/recognizer/u2_recognizer.h
+3
-1
speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc
speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc
+185
-0
speechx/speechx/common/base/common.h
speechx/speechx/common/base/common.h
+2
-0
未找到文件。
speechx/speechx/asr/nnet/u2_nnet.cc
浏览文件 @
5042a168
...
...
@@ -118,27 +118,38 @@ U2Nnet::U2Nnet(const ModelOptions& opts) : opts_(opts) {
// shallow copy
U2Nnet
::
U2Nnet
(
const
U2Nnet
&
other
)
{
// copy meta
right_context_
=
other
.
right_context_
;
subsampling_rate_
=
other
.
subsampling_rate_
;
sos_
=
other
.
sos_
;
eos_
=
other
.
eos_
;
is_bidecoder_
=
other
.
is_bidecoder_
;
chunk_size_
=
other
.
chunk_size_
;
num_left_chunks_
=
other
.
num_left_chunks_
;
forward_encoder_chunk_
=
other
.
forward_encoder_chunk_
;
forward_attention_decoder_
=
other
.
forward_attention_decoder_
;
ctc_activation_
=
other
.
ctc_activation_
;
offset_
=
other
.
offset_
;
// copy model ptr
model_
=
other
.
model_
;
model_
=
other
.
model_
->
Clone
();
ctc_activation_
=
model_
->
Function
(
"ctc_activation"
);
subsampling_rate_
=
model_
->
Attribute
<
int
>
(
"subsampling_rate"
);
right_context_
=
model_
->
Attribute
<
int
>
(
"right_context"
);
sos_
=
model_
->
Attribute
<
int
>
(
"sos_symbol"
);
eos_
=
model_
->
Attribute
<
int
>
(
"eos_symbol"
);
is_bidecoder_
=
model_
->
Attribute
<
int
>
(
"is_bidirectional_decoder"
);
forward_encoder_chunk_
=
model_
->
Function
(
"forward_encoder_chunk"
);
forward_attention_decoder_
=
model_
->
Function
(
"forward_attention_decoder"
);
ctc_activation_
=
model_
->
Function
(
"ctc_activation"
);
CHECK
(
forward_encoder_chunk_
.
IsValid
());
CHECK
(
forward_attention_decoder_
.
IsValid
());
CHECK
(
ctc_activation_
.
IsValid
());
LOG
(
INFO
)
<<
"Paddle Model Info: "
;
LOG
(
INFO
)
<<
"
\t
subsampling_rate "
<<
subsampling_rate_
;
LOG
(
INFO
)
<<
"
\t
right context "
<<
right_context_
;
LOG
(
INFO
)
<<
"
\t
sos "
<<
sos_
;
LOG
(
INFO
)
<<
"
\t
eos "
<<
eos_
;
LOG
(
INFO
)
<<
"
\t
is bidecoder "
<<
is_bidecoder_
<<
std
::
endl
;
// ignore inner states
}
std
::
shared_ptr
<
NnetBase
>
U2Nnet
::
C
opy
()
const
{
std
::
shared_ptr
<
NnetBase
>
U2Nnet
::
C
lone
()
const
{
auto
asr_model
=
std
::
make_shared
<
U2Nnet
>
(
*
this
);
// reset inner state for new decoding
asr_model
->
Reset
();
...
...
speechx/speechx/asr/nnet/u2_nnet.h
浏览文件 @
5042a168
...
...
@@ -42,7 +42,7 @@ class U2NnetBase : public NnetBase {
num_left_chunks_
=
num_left_chunks
;
}
virtual
std
::
shared_ptr
<
NnetBase
>
C
opy
()
const
=
0
;
virtual
std
::
shared_ptr
<
NnetBase
>
C
lone
()
const
=
0
;
protected:
virtual
void
ForwardEncoderChunkImpl
(
...
...
@@ -91,7 +91,7 @@ class U2Nnet : public U2NnetBase {
std
::
shared_ptr
<
paddle
::
jit
::
Layer
>
model
()
const
{
return
model_
;
}
std
::
shared_ptr
<
NnetBase
>
C
opy
()
const
override
;
std
::
shared_ptr
<
NnetBase
>
C
lone
()
const
override
;
void
ForwardEncoderChunkImpl
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
chunk_feats
,
...
...
speechx/speechx/asr/recognizer/CMakeLists.txt
浏览文件 @
5042a168
...
...
@@ -10,6 +10,7 @@ target_link_libraries(recognizer PUBLIC decoder)
set
(
TEST_BINS
u2_recognizer_main
u2_recognizer_thread_main
u2_recognizer_batch_main
)
foreach
(
bin_name IN LISTS TEST_BINS
)
...
...
speechx/speechx/asr/recognizer/u2_recognizer.cc
浏览文件 @
5042a168
...
...
@@ -43,12 +43,34 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
input_finished_
=
false
;
num_frames_
=
0
;
result_
.
clear
();
}
U2Recognizer
::
U2Recognizer
(
const
U2RecognizerResource
&
resource
,
std
::
shared_ptr
<
NnetBase
>
nnet
)
:
opts_
(
resource
)
{
BaseFloat
am_scale
=
resource
.
acoustic_scale
;
const
FeaturePipelineOptions
&
feature_opts
=
resource
.
feature_pipeline_opts
;
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline
=
std
::
make_shared
<
FeaturePipeline
>
(
feature_opts
);
nnet_producer_
.
reset
(
new
NnetProducer
(
nnet
,
feature_pipeline
));
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
CHECK_NE
(
resource
.
vocab_path
,
""
);
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
));
unit_table_
=
decoder_
->
VocabTable
();
symbol_table_
=
unit_table_
;
global_frame_offset_
=
0
;
input_finished_
=
false
;
num_frames_
=
0
;
result_
.
clear
();
}
U2Recognizer
::~
U2Recognizer
()
{
SetInputFinished
();
WaitDecodeFinished
();
SetInputFinished
();
WaitDecodeFinished
();
}
void
U2Recognizer
::
WaitDecodeFinished
()
{
...
...
@@ -97,8 +119,8 @@ void U2Recognizer::RunDecoderSearchInternal() {
void
U2Recognizer
::
Accept
(
const
vector
<
BaseFloat
>&
waves
)
{
kaldi
::
Timer
timer
;
nnet_producer_
->
Accept
(
waves
);
VLOG
(
1
)
<<
"feed waves cost: "
<<
timer
.
Elapsed
()
<<
" sec. "
<<
waves
.
size
()
<<
" samples."
;
VLOG
(
1
)
<<
"feed waves cost: "
<<
timer
.
Elapsed
()
<<
" sec. "
<<
waves
.
size
()
<<
" samples."
;
}
void
U2Recognizer
::
Decode
()
{
...
...
speechx/speechx/asr/recognizer/u2_recognizer.h
浏览文件 @
5042a168
...
...
@@ -112,6 +112,8 @@ struct U2RecognizerResource {
class
U2Recognizer
{
public:
explicit
U2Recognizer
(
const
U2RecognizerResource
&
resouce
);
explicit
U2Recognizer
(
const
U2RecognizerResource
&
resource
,
std
::
shared_ptr
<
NnetBase
>
nnet
);
~
U2Recognizer
();
void
InitDecoder
();
void
ResetContinuousDecoding
();
...
...
@@ -143,7 +145,7 @@ class U2Recognizer {
void
AttentionRescoring
();
private:
static
void
RunDecoderSearch
(
U2Recognizer
*
me
);
static
void
RunDecoderSearch
(
U2Recognizer
*
me
);
void
RunDecoderSearchInternal
();
void
UpdateResult
(
bool
finish
=
false
);
...
...
speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc
0 → 100644
浏览文件 @
5042a168
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/u2_recognizer.h"
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
DEFINE_int32
(
njob
,
3
,
"njob"
);
using
std
::
string
;
using
std
::
vector
;
void
SplitUtt
(
string
wavlist_file
,
vector
<
vector
<
string
>>*
uttlists
,
vector
<
vector
<
string
>>*
wavlists
,
int
njob
)
{
vector
<
string
>
wavlist
;
wavlists
->
resize
(
njob
);
uttlists
->
resize
(
njob
);
ppspeech
::
ReadFileToVector
(
wavlist_file
,
&
wavlist
);
for
(
size_t
idx
=
0
;
idx
<
wavlist
.
size
();
++
idx
)
{
string
utt_str
=
wavlist
[
idx
];
vector
<
string
>
utt_wav
=
ppspeech
::
StrSplit
(
utt_str
,
"
\t
"
);
LOG
(
INFO
)
<<
utt_wav
[
0
];
CHECK_EQ
(
utt_wav
.
size
(),
size_t
(
2
));
uttlists
->
at
(
idx
%
njob
).
push_back
(
utt_wav
[
0
]);
wavlists
->
at
(
idx
%
njob
).
push_back
(
utt_wav
[
1
]);
}
}
void
recognizer_func
(
const
ppspeech
::
U2RecognizerResource
&
resource
,
std
::
shared_ptr
<
ppspeech
::
NnetBase
>
nnet
,
std
::
vector
<
string
>
wavlist
,
std
::
vector
<
string
>
uttlist
,
std
::
vector
<
string
>*
results
)
{
int32
num_done
=
0
,
num_err
=
0
;
double
tot_wav_duration
=
0.0
;
double
tot_attention_rescore_time
=
0.0
;
double
tot_decode_time
=
0.0
;
int
chunk_sample_size
=
FLAGS_streaming_chunk
*
FLAGS_sample_rate
;
if
(
wavlist
.
empty
())
return
;
std
::
shared_ptr
<
ppspeech
::
U2Recognizer
>
recognizer_ptr
=
std
::
make_shared
<
ppspeech
::
U2Recognizer
>
(
resource
,
nnet
);
results
->
reserve
(
wavlist
.
size
());
for
(
size_t
idx
=
0
;
idx
<
wavlist
.
size
();
++
idx
)
{
std
::
string
utt
=
uttlist
[
idx
];
std
::
string
wav_file
=
wavlist
[
idx
];
std
::
ifstream
infile
;
infile
.
open
(
wav_file
,
std
::
ifstream
::
in
);
kaldi
::
WaveData
wave_data
;
wave_data
.
Read
(
infile
);
recognizer_ptr
->
InitDecoder
();
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"wav dur: "
<<
wave_data
.
Duration
()
<<
" sec."
;
double
dur
=
wave_data
.
Duration
();
tot_wav_duration
+=
dur
;
int32
this_channel
=
0
;
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
int
tot_samples
=
waveform
.
Dim
();
LOG
(
INFO
)
<<
"wav len (sample): "
<<
tot_samples
;
int
sample_offset
=
0
;
kaldi
::
Timer
local_timer
;
while
(
sample_offset
<
tot_samples
)
{
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
std
::
vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
[
i
]
=
waveform
(
sample_offset
+
i
);
}
recognizer_ptr
->
Accept
(
wav_chunk
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
recognizer_ptr
->
SetInputFinished
();
}
// no overlap
sample_offset
+=
cur_chunk_size
;
}
CHECK
(
sample_offset
==
tot_samples
);
recognizer_ptr
->
WaitDecodeFinished
();
kaldi
::
Timer
timer
;
recognizer_ptr
->
AttentionRescoring
();
tot_attention_rescore_time
+=
timer
.
Elapsed
();
std
::
string
result
=
recognizer_ptr
->
GetFinalResult
();
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
++
num_err
;
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is empty"
;
result
=
" "
;
}
tot_decode_time
+=
local_timer
.
Elapsed
();
LOG
(
INFO
)
<<
utt
<<
" "
<<
result
;
LOG
(
INFO
)
<<
" RTF: "
<<
local_timer
.
Elapsed
()
/
dur
<<
" dur: "
<<
dur
<<
" cost: "
<<
local_timer
.
Elapsed
();
results
->
push_back
(
result
);
++
num_done
;
}
recognizer_ptr
->
WaitFinished
();
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"total decode cost:"
<<
tot_decode_time
<<
" sec"
;
LOG
(
INFO
)
<<
"total rescore cost:"
<<
tot_attention_rescore_time
<<
" sec"
;
LOG
(
INFO
)
<<
"RTF is: "
<<
tot_decode_time
/
tot_wav_duration
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int
sample_rate
=
FLAGS_sample_rate
;
float
streaming_chunk
=
FLAGS_streaming_chunk
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
int
njob
=
FLAGS_njob
;
LOG
(
INFO
)
<<
"sr: "
<<
sample_rate
;
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
ppspeech
::
U2RecognizerResource
resource
=
ppspeech
::
U2RecognizerResource
::
InitFromFlags
();
ThreadPool
threadpool
(
njob
);
vector
<
vector
<
string
>>
wavlist
;
vector
<
vector
<
string
>>
uttlist
;
vector
<
vector
<
string
>>
resultlist
(
njob
);
vector
<
std
::
future
<
void
>>
futurelist
;
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet
(
new
ppspeech
::
U2Nnet
(
resource
.
model_opts
));
SplitUtt
(
FLAGS_wav_rspecifier
,
&
uttlist
,
&
wavlist
,
njob
);
for
(
size_t
i
=
0
;
i
<
njob
;
++
i
)
{
std
::
future
<
void
>
f
=
threadpool
.
enqueue
(
recognizer_func
,
resource
,
nnet
->
Clone
(),
wavlist
[
i
],
uttlist
[
i
],
&
resultlist
[
i
]);
futurelist
.
push_back
(
std
::
move
(
f
));
}
for
(
size_t
i
=
0
;
i
<
njob
;
++
i
)
{
futurelist
[
i
].
get
();
}
for
(
size_t
idx
=
0
;
idx
<
njob
;
++
idx
)
{
for
(
size_t
utt_idx
=
0
;
utt_idx
<
uttlist
[
idx
].
size
();
++
utt_idx
)
{
string
utt
=
uttlist
[
idx
][
utt_idx
];
string
result
=
resultlist
[
idx
][
utt_idx
];
result_writer
.
Write
(
utt
,
result
);
}
}
return
0
;
}
speechx/speechx/common/base/common.h
浏览文件 @
5042a168
...
...
@@ -42,6 +42,8 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include <future>
#include <functional>
#include "base/basic_types.h"
#include "base/flags.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录