Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5046d8ee
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5046d8ee
编写于
12月 27, 2022
作者:
Y
YangZhou
提交者:
GitHub
12月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Speechx] add nnet prob cache && make 2 thread decode work (#2769)
* add nnet cache && make 2 thread work * do not compile websocket
上级
f8caaf46
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
415 addition
and
107 deletion
+415
-107
speechx/CMakeLists.txt
speechx/CMakeLists.txt
+1
-1
speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc
...peechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc
+8
-5
speechx/speechx/asr/nnet/CMakeLists.txt
speechx/speechx/asr/nnet/CMakeLists.txt
+11
-11
speechx/speechx/asr/nnet/decodable.cc
speechx/speechx/asr/nnet/decodable.cc
+23
-65
speechx/speechx/asr/nnet/decodable.h
speechx/speechx/asr/nnet/decodable.h
+5
-11
speechx/speechx/asr/nnet/nnet_producer.cc
speechx/speechx/asr/nnet/nnet_producer.cc
+84
-0
speechx/speechx/asr/nnet/nnet_producer.h
speechx/speechx/asr/nnet/nnet_producer.h
+73
-0
speechx/speechx/asr/recognizer/CMakeLists.txt
speechx/speechx/asr/recognizer/CMakeLists.txt
+1
-0
speechx/speechx/asr/recognizer/u2_recognizer.cc
speechx/speechx/asr/recognizer/u2_recognizer.cc
+8
-7
speechx/speechx/asr/recognizer/u2_recognizer.h
speechx/speechx/asr/recognizer/u2_recognizer.h
+5
-5
speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc
speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc
+123
-0
speechx/speechx/asr/server/CMakeLists.txt
speechx/speechx/asr/server/CMakeLists.txt
+1
-1
speechx/speechx/common/base/common.h
speechx/speechx/common/base/common.h
+1
-1
speechx/speechx/common/base/safe_queue.h
speechx/speechx/common/base/safe_queue.h
+71
-0
未找到文件。
speechx/CMakeLists.txt
浏览文件 @
5046d8ee
...
...
@@ -45,7 +45,7 @@ option(USE_PROFILING "enable c++ profling" OFF)
option
(
WITH_TESTING
"unit test"
ON
)
option
(
USING_U2
"compile u2 model."
ON
)
option
(
USING_DS2
"compile with ds2 model."
O
N
)
option
(
USING_DS2
"compile with ds2 model."
O
FF
)
option
(
USING_GPU
"u2 compute on GPU."
OFF
)
...
...
speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc
浏览文件 @
5046d8ee
...
...
@@ -18,6 +18,7 @@
#include "fst/symbol-table.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
#include "nnet/u2_nnet.h"
DEFINE_string
(
feature_rspecifier
,
""
,
"test feature rspecifier"
);
...
...
@@ -39,7 +40,7 @@ using kaldi::BaseFloat;
using
kaldi
::
Matrix
;
using
std
::
vector
;
// test
ds
2 online decoder by feeding speech feature
// test
u
2 online decoder by feeding speech feature
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
...
...
@@ -69,8 +70,10 @@ int main(int argc, char* argv[]) {
// decodeable
std
::
shared_ptr
<
ppspeech
::
DataCache
>
raw_data
=
std
::
make_shared
<
ppspeech
::
DataCache
>
();
std
::
shared_ptr
<
ppspeech
::
NnetProducer
>
nnet_producer
=
std
::
make_shared
<
ppspeech
::
NnetProducer
>
(
nnet
,
raw_data
);
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
=
std
::
make_shared
<
ppspeech
::
Decodable
>
(
nnet
,
raw_data
);
std
::
make_shared
<
ppspeech
::
Decodable
>
(
nnet
_producer
);
// decoder
ppspeech
::
CTCBeamSearchOptions
opts
;
...
...
@@ -114,9 +117,9 @@ int main(int argc, char* argv[]) {
ori_feature_len
-
chunk_idx
*
chunk_stride
,
chunk_size
);
}
if
(
this_chunk_size
<
receptive_field_length
)
{
LOG
(
WARNING
)
<<
"utt: "
<<
utt
<<
" skip last "
<<
this_chunk_size
<<
" frames, expect is "
<<
receptive_field_length
;
LOG
(
WARNING
)
<<
"utt: "
<<
utt
<<
" skip last "
<<
this_chunk_size
<<
" frames, expect is "
<<
receptive_field_length
;
break
;
}
...
...
speechx/speechx/asr/nnet/CMakeLists.txt
浏览文件 @
5046d8ee
set
(
srcs decodable.cc
)
set
(
srcs decodable.cc
nnet_producer.cc
)
if
(
USING_DS2
)
list
(
APPEND srcs ds2_nnet.cc
)
...
...
@@ -27,13 +27,13 @@ if(USING_DS2)
endif
()
# test bin
if
(
USING_U2
)
set
(
bin_name u2_nnet_main
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
utils kaldi-util kaldi-matrix gflags glog nnet
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
)
endif
()
#
if(USING_U2)
#
set(bin_name u2_nnet_main)
#
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
#
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
#
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
#
endif()
speechx/speechx/asr/nnet/decodable.cc
浏览文件 @
5046d8ee
...
...
@@ -21,19 +21,16 @@ using kaldi::Matrix;
using
kaldi
::
Vector
;
using
std
::
vector
;
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetBase
>&
nnet
,
const
std
::
shared_ptr
<
FrontendInterface
>&
frontend
,
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetProducer
>&
nnet_producer
,
kaldi
::
BaseFloat
acoustic_scale
)
:
frontend_
(
frontend
),
nnet_
(
nnet
),
:
nnet_producer_
(
nnet_producer
),
frame_offset_
(
0
),
frames_ready_
(
0
),
acoustic_scale_
(
acoustic_scale
)
{}
// for debug
void
Decodable
::
Acceptlikelihood
(
const
Matrix
<
BaseFloat
>&
likelihood
)
{
nnet_out_cache_
=
likelihood
;
frames_ready_
+=
likelihood
.
NumRows
();
nnet_producer_
->
Acceptlikelihood
(
likelihood
);
}
...
...
@@ -43,7 +40,7 @@ int32 Decodable::NumFramesReady() const { return frames_ready_; }
// frame idx is from 0 to frame_ready_ -1;
bool
Decodable
::
IsLastFrame
(
int32
frame
)
{
bool
flag
=
EnsureFrameHaveComputed
(
frame
);
EnsureFrameHaveComputed
(
frame
);
return
frame
>=
frames_ready_
;
}
...
...
@@ -64,32 +61,10 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) {
bool
Decodable
::
AdvanceChunk
()
{
kaldi
::
Timer
timer
;
// read feats
Vector
<
BaseFloat
>
features
;
if
(
frontend_
==
NULL
||
frontend_
->
Read
(
&
features
)
==
false
)
{
// no feat or frontend_ not init.
VLOG
(
3
)
<<
"decodable exit;"
;
return
false
;
}
CHECK_GE
(
frontend_
->
Dim
(),
0
);
VLOG
(
1
)
<<
"AdvanceChunk feat cost: "
<<
timer
.
Elapsed
()
<<
" sec."
;
VLOG
(
2
)
<<
"Forward in "
<<
features
.
Dim
()
/
frontend_
->
Dim
()
<<
" feats."
;
// forward feats
NnetOut
out
;
nnet_
->
FeedForward
(
features
,
frontend_
->
Dim
(),
&
out
);
int32
&
vocab_dim
=
out
.
vocab_dim
;
Vector
<
BaseFloat
>&
logprobs
=
out
.
logprobs
;
VLOG
(
2
)
<<
"Forward out "
<<
logprobs
.
Dim
()
/
vocab_dim
<<
" decoder frames."
;
// cache nnet outupts
nnet_out_cache_
.
Resize
(
logprobs
.
Dim
()
/
vocab_dim
,
vocab_dim
);
nnet_out_cache_
.
CopyRowsFromVec
(
logprobs
);
// update state, decoding frame.
bool
flag
=
nnet_producer_
->
Read
(
&
framelikelihood_
);
if
(
flag
==
false
)
return
false
;
frame_offset_
=
frames_ready_
;
frames_ready_
+=
nnet_out_cache_
.
NumRows
()
;
frames_ready_
+=
1
;
VLOG
(
1
)
<<
"AdvanceChunk feat + forward cost: "
<<
timer
.
Elapsed
()
<<
" sec."
;
return
true
;
...
...
@@ -101,17 +76,17 @@ bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
return
false
;
}
int
nrows
=
nnet_out_cache_
.
NumRows
();
CHECK
(
nrows
==
(
frames_ready_
-
frame_offset_
));
if
(
nrows
<=
0
)
{
if
(
framelikelihood_
.
empty
())
{
LOG
(
WARNING
)
<<
"No new nnet out in cache."
;
return
false
;
}
logprobs
->
Resize
(
nnet_out_cache_
.
NumRows
()
*
nnet_out_cache_
.
NumCols
());
logprobs
->
CopyRowsFromMat
(
nnet_out_cache_
);
*
vocab_dim
=
nnet_out_cache_
.
NumCols
();
size_t
dim
=
framelikelihood_
.
size
();
logprobs
->
Resize
(
framelikelihood_
.
size
());
std
::
memcpy
(
logprobs
->
Data
(),
framelikelihood_
.
data
(),
dim
*
sizeof
(
kaldi
::
BaseFloat
));
*
vocab_dim
=
framelikelihood_
.
size
();
return
true
;
}
...
...
@@ -122,19 +97,8 @@ bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
return
false
;
}
int
nrows
=
nnet_out_cache_
.
NumRows
();
CHECK
(
nrows
==
(
frames_ready_
-
frame_offset_
));
int
vocab_size
=
nnet_out_cache_
.
NumCols
();
likelihood
->
resize
(
vocab_size
);
for
(
int32
idx
=
0
;
idx
<
vocab_size
;
++
idx
)
{
(
*
likelihood
)[
idx
]
=
nnet_out_cache_
(
frame
-
frame_offset_
,
idx
)
*
acoustic_scale_
;
VLOG
(
4
)
<<
"nnet out: "
<<
frame
<<
" offset:"
<<
frame_offset_
<<
" "
<<
nnet_out_cache_
.
NumRows
()
<<
" logprob: "
<<
nnet_out_cache_
(
frame
-
frame_offset_
,
idx
);
}
CHECK_EQ
(
1
,
(
frames_ready_
-
frame_offset_
));
*
likelihood
=
framelikelihood_
;
return
true
;
}
...
...
@@ -143,37 +107,31 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return
false
;
}
CHECK_LE
(
index
,
nnet_out_cache_
.
NumCols
());
CHECK_LE
(
index
,
framelikelihood_
.
size
());
CHECK_LE
(
frame
,
frames_ready_
);
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
BaseFloat
logprob
=
0.0
;
int32
frame_idx
=
frame
-
frame_offset_
;
BaseFloat
nnet_out
=
nnet_out_cache_
(
frame_idx
,
TokenId2NnetId
(
index
));
if
(
nnet_
->
IsLogProb
())
{
logprob
=
nnet_out
;
}
else
{
logprob
=
std
::
log
(
nnet_out
+
std
::
numeric_limits
<
float
>::
epsilon
());
}
CHECK
(
!
std
::
isnan
(
logprob
)
&&
!
std
::
isinf
(
logprob
));
CHECK_EQ
(
frame_idx
,
0
);
logprob
=
framelikelihood_
[
TokenId2NnetId
(
index
)];
return
acoustic_scale_
*
logprob
;
}
void
Decodable
::
Reset
()
{
if
(
frontend_
!=
nullptr
)
frontend_
->
Reset
();
if
(
nnet_
!=
nullptr
)
nnet_
->
Reset
();
if
(
nnet_producer_
!=
nullptr
)
nnet_producer_
->
Reset
();
frame_offset_
=
0
;
frames_ready_
=
0
;
nnet_out_cache_
.
Resize
(
0
,
0
);
framelikelihood_
.
clear
(
);
}
void
Decodable
::
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
)
{
kaldi
::
Timer
timer
;
nnet_
->
AttentionRescoring
(
hyps
,
reverse_weight
,
rescoring_score
);
nnet_
producer_
->
AttentionRescoring
(
hyps
,
reverse_weight
,
rescoring_score
);
VLOG
(
1
)
<<
"Attention Rescoring cost: "
<<
timer
.
Elapsed
()
<<
" sec."
;
}
}
// namespace ppspeech
\ No newline at end of file
}
// namespace ppspeech
speechx/speechx/asr/nnet/decodable.h
浏览文件 @
5046d8ee
...
...
@@ -13,10 +13,10 @@
// limitations under the License.
#include "base/common.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/decoder/decodable-itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "nnet/nnet_producer.h"
namespace
ppspeech
{
...
...
@@ -24,8 +24,7 @@ struct DecodableOpts;
class
Decodable
:
public
kaldi
::
DecodableInterface
{
public:
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetBase
>&
nnet
,
const
std
::
shared_ptr
<
FrontendInterface
>&
frontend
,
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetProducer
>&
nnet_producer
,
kaldi
::
BaseFloat
acoustic_scale
=
1.0
);
// void Init(DecodableOpts config);
...
...
@@ -57,23 +56,17 @@ class Decodable : public kaldi::DecodableInterface {
void
Reset
();
bool
IsInputFinished
()
const
{
return
frontend
_
->
IsFinished
();
}
bool
IsInputFinished
()
const
{
return
nnet_producer
_
->
IsFinished
();
}
bool
EnsureFrameHaveComputed
(
int32
frame
);
int32
TokenId2NnetId
(
int32
token_id
);
std
::
shared_ptr
<
NnetBase
>
Nnet
()
{
return
nnet_
;
}
// for offline test
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
private:
std
::
shared_ptr
<
FrontendInterface
>
frontend_
;
std
::
shared_ptr
<
NnetBase
>
nnet_
;
// nnet outputs' cache
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
nnet_out_cache_
;
std
::
shared_ptr
<
NnetProducer
>
nnet_producer_
;
// the frame is nnet prob frame rather than audio feature frame
// nnet frame subsample the feature frame
...
...
@@ -85,6 +78,7 @@ class Decodable : public kaldi::DecodableInterface {
// so use subsampled_frame
int32
current_log_post_subsampled_offset_
;
int32
num_chunk_computed_
;
std
::
vector
<
kaldi
::
BaseFloat
>
framelikelihood_
;
kaldi
::
BaseFloat
acoustic_scale_
;
};
...
...
speechx/speechx/asr/nnet/nnet_producer.cc
0 → 100644
浏览文件 @
5046d8ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet_producer.h"
namespace
ppspeech
{
using
kaldi
::
Vector
;
using
kaldi
::
BaseFloat
;
NnetProducer
::
NnetProducer
(
std
::
shared_ptr
<
NnetBase
>
nnet
,
std
::
shared_ptr
<
FrontendInterface
>
frontend
)
:
nnet_
(
nnet
),
frontend_
(
frontend
)
{}
void
NnetProducer
::
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
inputs
)
{
frontend_
->
Accept
(
inputs
);
bool
result
=
false
;
do
{
result
=
Compute
();
}
while
(
result
);
}
void
NnetProducer
::
Acceptlikelihood
(
const
kaldi
::
Matrix
<
BaseFloat
>&
likelihood
)
{
std
::
vector
<
BaseFloat
>
prob
;
prob
.
resize
(
likelihood
.
NumCols
());
for
(
size_t
idx
=
0
;
idx
<
likelihood
.
NumRows
();
++
idx
)
{
for
(
size_t
col
=
0
;
col
<
likelihood
.
NumCols
();
++
col
)
{
prob
[
col
]
=
likelihood
(
idx
,
col
);
cache_
.
push_back
(
prob
);
}
}
}
bool
NnetProducer
::
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
)
{
bool
flag
=
cache_
.
pop
(
nnet_prob
);
return
flag
;
}
bool
NnetProducer
::
Compute
()
{
Vector
<
BaseFloat
>
features
;
if
(
frontend_
==
NULL
||
frontend_
->
Read
(
&
features
)
==
false
)
{
// no feat or frontend_ not init.
VLOG
(
3
)
<<
"no feat avalible"
;
return
false
;
}
CHECK_GE
(
frontend_
->
Dim
(),
0
);
VLOG
(
2
)
<<
"Forward in "
<<
features
.
Dim
()
/
frontend_
->
Dim
()
<<
" feats."
;
NnetOut
out
;
nnet_
->
FeedForward
(
features
,
frontend_
->
Dim
(),
&
out
);
int32
&
vocab_dim
=
out
.
vocab_dim
;
Vector
<
BaseFloat
>&
logprobs
=
out
.
logprobs
;
size_t
nframes
=
logprobs
.
Dim
()
/
vocab_dim
;
VLOG
(
2
)
<<
"Forward out "
<<
nframes
<<
" decoder frames."
;
std
::
vector
<
BaseFloat
>
logprob
(
vocab_dim
);
// remove later.
for
(
size_t
idx
=
0
;
idx
<
nframes
;
++
idx
)
{
for
(
size_t
prob_idx
=
0
;
prob_idx
<
vocab_dim
;
++
prob_idx
)
{
logprob
[
prob_idx
]
=
logprobs
(
idx
*
vocab_dim
+
prob_idx
);
}
cache_
.
push_back
(
logprob
);
}
return
true
;
}
void
NnetProducer
::
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
)
{
nnet_
->
AttentionRescoring
(
hyps
,
reverse_weight
,
rescoring_score
);
}
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/asr/nnet/nnet_producer.h
0 → 100644
浏览文件 @
5046d8ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "base/safe_queue.h"
#include "frontend/audio/frontend_itf.h"
#include "nnet/nnet_itf.h"
namespace
ppspeech
{
class
NnetProducer
{
public:
explicit
NnetProducer
(
std
::
shared_ptr
<
NnetBase
>
nnet
,
std
::
shared_ptr
<
FrontendInterface
>
frontend
=
NULL
);
// Feed feats or waves
void
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
inputs
);
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
BaseFloat
>&
likelihood
);
// nnet
bool
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
);
bool
Empty
()
const
{
return
cache_
.
empty
();
}
void
SetFinished
()
{
LOG
(
INFO
)
<<
"set finished"
;
// std::unique_lock<std::mutex> lock(mutex_);
frontend_
->
SetFinished
();
// read the last chunk data
Compute
();
// ready_feed_condition_.notify_one();
LOG
(
INFO
)
<<
"compute last feats done."
;
}
bool
IsFinished
()
const
{
return
frontend_
->
IsFinished
();
}
void
Reset
()
{
frontend_
->
Reset
();
nnet_
->
Reset
();
VLOG
(
3
)
<<
"feature cache reset: cache size: "
<<
cache_
.
size
();
cache_
.
clear
();
}
void
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
);
private:
bool
Compute
();
std
::
shared_ptr
<
FrontendInterface
>
frontend_
;
std
::
shared_ptr
<
NnetBase
>
nnet_
;
SafeQueue
<
std
::
vector
<
kaldi
::
BaseFloat
>>
cache_
;
DISALLOW_COPY_AND_ASSIGN
(
NnetProducer
);
};
}
// namespace ppspeech
speechx/speechx/asr/recognizer/CMakeLists.txt
浏览文件 @
5046d8ee
...
...
@@ -30,6 +30,7 @@ endif()
if
(
USING_U2
)
set
(
TEST_BINS
u2_recognizer_main
u2_recognizer_thread_main
)
foreach
(
bin_name IN LISTS TEST_BINS
)
...
...
speechx/speechx/asr/recognizer/u2_recognizer.cc
浏览文件 @
5046d8ee
...
...
@@ -27,13 +27,13 @@ using std::vector;
U2Recognizer
::
U2Recognizer
(
const
U2RecognizerResource
&
resource
)
:
opts_
(
resource
)
{
BaseFloat
am_scale
=
resource
.
acoustic_scale
;
const
FeaturePipelineOptions
&
feature_opts
=
resource
.
feature_pipeline_opts
;
feature_pipeline_
.
reset
(
new
FeaturePipeline
(
feature_opts
));
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline
(
new
FeaturePipeline
(
feature_opts
));
std
::
shared_ptr
<
NnetBase
>
nnet
(
new
U2Nnet
(
resource
.
model_opts
));
BaseFloat
am_scale
=
resource
.
acoustic_scale
;
decodable_
.
reset
(
new
Decodable
(
nnet
,
feature_pipeline_
,
am_scale
));
nnet_producer_
.
reset
(
new
NnetProducer
(
nnet
,
feature_pipeline
));
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
CHECK_NE
(
resource
.
vocab_path
,
""
);
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
...
...
@@ -49,6 +49,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
void
U2Recognizer
::
Reset
()
{
global_frame_offset_
=
0
;
input_finished_
=
false
;
num_frames_
=
0
;
result_
.
clear
();
...
...
@@ -68,7 +69,7 @@ void U2Recognizer::ResetContinuousDecoding() {
void
U2Recognizer
::
Accept
(
const
VectorBase
<
BaseFloat
>&
waves
)
{
kaldi
::
Timer
timer
;
feature_pipeline
_
->
Accept
(
waves
);
nnet_producer
_
->
Accept
(
waves
);
VLOG
(
1
)
<<
"feed waves cost: "
<<
timer
.
Elapsed
()
<<
" sec. "
<<
waves
.
Dim
()
<<
" samples."
;
}
...
...
@@ -210,7 +211,7 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std
::
string
U2Recognizer
::
GetPartialResult
()
{
return
result_
[
0
].
sentence
;
}
void
U2Recognizer
::
SetFinished
()
{
feature_pipeline
_
->
SetFinished
();
nnet_producer
_
->
SetFinished
();
input_finished_
=
true
;
}
...
...
speechx/speechx/asr/recognizer/u2_recognizer.h
浏览文件 @
5046d8ee
...
...
@@ -130,11 +130,11 @@ class U2Recognizer {
return
!
result_
.
empty
()
&&
!
result_
[
0
].
sentence
.
empty
();
}
int
FrameShiftInMs
()
const
{
// one decoder frame length in ms
return
decodable_
->
Nnet
()
->
SubsamplingRate
()
*
feature_pipeline_
->
FrameShift
();
// one decoder frame length in ms, todo
return
1
;
// return decodable_->Nnet()->SubsamplingRate() *
// feature_pipeline_->FrameShift();
}
...
...
@@ -149,7 +149,7 @@ class U2Recognizer {
// std::shared_ptr<U2RecognizerResource> resource_;
// U2RecognizerResource resource_;
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline
_
;
std
::
shared_ptr
<
NnetProducer
>
nnet_producer
_
;
std
::
shared_ptr
<
Decodable
>
decodable_
;
std
::
unique_ptr
<
CTCPrefixBeamSearch
>
decoder_
;
...
...
speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc
0 → 100644
浏览文件 @
5046d8ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/u2_recognizer.h"
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
void
decode_func
(
std
::
shared_ptr
<
ppspeech
::
U2Recognizer
>
recognizer
)
{
while
(
!
recognizer
->
IsFinished
())
{
recognizer
->
Decode
();
usleep
(
100
);
}
recognizer
->
Decode
();
recognizer
->
Rescoring
();
}
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int32
num_done
=
0
,
num_err
=
0
;
double
tot_wav_duration
=
0.0
;
double
tot_decode_time
=
0.0
;
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
FLAGS_wav_rspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
int
sample_rate
=
FLAGS_sample_rate
;
float
streaming_chunk
=
FLAGS_streaming_chunk
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
LOG
(
INFO
)
<<
"sr: "
<<
sample_rate
;
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
ppspeech
::
U2RecognizerResource
resource
=
ppspeech
::
U2RecognizerResource
::
InitFromFlags
();
std
::
shared_ptr
<
ppspeech
::
U2Recognizer
>
recognizer_ptr
(
new
ppspeech
::
U2Recognizer
(
resource
));
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
std
::
thread
recognizer_thread
(
decode_func
,
recognizer_ptr
);
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"wav dur: "
<<
wave_data
.
Duration
()
<<
" sec."
;
double
dur
=
wave_data
.
Duration
();
tot_wav_duration
+=
dur
;
int32
this_channel
=
0
;
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
int
tot_samples
=
waveform
.
Dim
();
LOG
(
INFO
)
<<
"wav len (sample): "
<<
tot_samples
;
int
sample_offset
=
0
;
kaldi
::
Timer
timer
;
kaldi
::
Timer
local_timer
;
while
(
sample_offset
<
tot_samples
)
{
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
(
i
)
=
waveform
(
sample_offset
+
i
);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer_ptr
->
Accept
(
wav_chunk
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
recognizer_ptr
->
SetFinished
();
}
// no overlap
sample_offset
+=
cur_chunk_size
;
}
CHECK
(
sample_offset
==
tot_samples
);
recognizer_thread
.
join
();
std
::
string
result
=
recognizer_ptr
->
GetFinalResult
();
recognizer_ptr
->
Reset
();
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
++
num_err
;
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is empty"
;
continue
;
}
LOG
(
INFO
)
<<
utt
<<
" "
<<
result
;
LOG
(
INFO
)
<<
" RTF: "
<<
local_timer
.
Elapsed
()
/
dur
<<
" dur: "
<<
dur
<<
" cost: "
<<
local_timer
.
Elapsed
();
result_writer
.
Write
(
utt
,
result
);
++
num_done
;
}
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"total decode cost:"
<<
tot_decode_time
<<
" sec"
;
LOG
(
INFO
)
<<
"RTF is: "
<<
tot_decode_time
/
tot_wav_duration
;
}
speechx/speechx/asr/server/CMakeLists.txt
浏览文件 @
5046d8ee
add_subdirectory
(
websocket
)
#
add_subdirectory(websocket)
speechx/speechx/common/base/common.h
浏览文件 @
5046d8ee
...
...
@@ -48,4 +48,4 @@
#include "base/log.h"
#include "base/macros.h"
#include "utils/file_utils.h"
#include "utils/math.h"
\ No newline at end of file
#include "utils/math.h"
speechx/speechx/common/base/safe_queue.h
0 → 100644
浏览文件 @
5046d8ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
namespace
ppspeech
{
template
<
typename
T
>
class
SafeQueue
{
public:
explicit
SafeQueue
(
size_t
capacity
=
0
);
void
push_back
(
const
T
&
in
);
bool
pop
(
T
*
out
);
bool
empty
()
const
{
return
buffer_
.
empty
();
}
size_t
size
()
const
{
return
buffer_
.
size
();
}
void
clear
();
private:
std
::
mutex
mutex_
;
std
::
condition_variable
condition_
;
std
::
deque
<
T
>
buffer_
;
size_t
capacity_
;
};
template
<
typename
T
>
SafeQueue
<
T
>::
SafeQueue
(
size_t
capacity
)
:
capacity_
(
capacity
)
{}
template
<
typename
T
>
void
SafeQueue
<
T
>::
push_back
(
const
T
&
in
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
capacity_
>
0
&&
buffer_
.
size
()
==
capacity_
)
{
condition_
.
wait
(
lock
,
[
this
]
{
return
capacity_
>=
buffer_
.
size
();
});
}
buffer_
.
push_back
(
in
);
condition_
.
notify_one
();
}
template
<
typename
T
>
bool
SafeQueue
<
T
>::
pop
(
T
*
out
)
{
if
(
buffer_
.
empty
())
{
return
false
;
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
condition_
.
wait
(
lock
,
[
this
]
{
return
buffer_
.
size
()
>
0
;
});
*
out
=
std
::
move
(
buffer_
.
front
());
buffer_
.
pop_front
();
condition_
.
notify_one
();
return
true
;
}
template
<
typename
T
>
void
SafeQueue
<
T
>::
clear
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
buffer_
.
clear
();
condition_
.
notify_one
();
}
}
// namespace ppspeech
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录