Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f35a87ab
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f35a87ab
编写于
4月 06, 2023
作者:
Y
YangZhou
提交者:
GitHub
4月 06, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Engine] recognizer controller refactor (#3139)
* refactor recognizer_controller * clean frontend file
上级
591b957b
变更
36
隐藏空白更改
内联
并排
Showing
36 changed file
with
449 addition
and
956 deletion
+449
-956
runtime/CMakeLists.txt
runtime/CMakeLists.txt
+1
-0
runtime/engine/asr/decoder/CMakeLists.txt
runtime/engine/asr/decoder/CMakeLists.txt
+1
-1
runtime/engine/asr/nnet/CMakeLists.txt
runtime/engine/asr/nnet/CMakeLists.txt
+0
-12
runtime/engine/asr/nnet/decodable.h
runtime/engine/asr/nnet/decodable.h
+2
-0
runtime/engine/asr/nnet/nnet_producer.cc
runtime/engine/asr/nnet/nnet_producer.cc
+5
-41
runtime/engine/asr/nnet/nnet_producer.h
runtime/engine/asr/nnet/nnet_producer.h
+5
-24
runtime/engine/asr/recognizer/CMakeLists.txt
runtime/engine/asr/recognizer/CMakeLists.txt
+3
-4
runtime/engine/asr/recognizer/recognizer.cc
runtime/engine/asr/recognizer/recognizer.cc
+13
-0
runtime/engine/asr/recognizer/recognizer.h
runtime/engine/asr/recognizer/recognizer.h
+13
-0
runtime/engine/asr/recognizer/recognizer_batch_main.cc
runtime/engine/asr/recognizer/recognizer_batch_main.cc
+5
-8
runtime/engine/asr/recognizer/recognizer_controller.cc
runtime/engine/asr/recognizer/recognizer_controller.cc
+9
-9
runtime/engine/asr/recognizer/recognizer_controller.h
runtime/engine/asr/recognizer/recognizer_controller.h
+8
-3
runtime/engine/asr/recognizer/recognizer_controller_impl.cc
runtime/engine/asr/recognizer/recognizer_controller_impl.cc
+116
-119
runtime/engine/asr/recognizer/recognizer_controller_impl.h
runtime/engine/asr/recognizer/recognizer_controller_impl.h
+91
-0
runtime/engine/asr/recognizer/recognizer_impl.cc
runtime/engine/asr/recognizer/recognizer_impl.cc
+13
-0
runtime/engine/asr/recognizer/recognizer_impl.h
runtime/engine/asr/recognizer/recognizer_impl.h
+13
-0
runtime/engine/asr/recognizer/recognizer_main.cc
runtime/engine/asr/recognizer/recognizer_main.cc
+6
-6
runtime/engine/asr/recognizer/recognizer_resource.h
runtime/engine/asr/recognizer/recognizer_resource.h
+4
-88
runtime/engine/asr/recognizer/u2_recognizer_batch_main.cc
runtime/engine/asr/recognizer/u2_recognizer_batch_main.cc
+0
-185
runtime/engine/common/frontend/assembler.cc
runtime/engine/common/frontend/assembler.cc
+4
-7
runtime/engine/common/frontend/audio_cache.cc
runtime/engine/common/frontend/audio_cache.cc
+7
-21
runtime/engine/common/frontend/audio_cache.h
runtime/engine/common/frontend/audio_cache.h
+1
-1
runtime/engine/common/frontend/fbank.cc
runtime/engine/common/frontend/fbank.cc
+0
-62
runtime/engine/common/frontend/feature_cache.cc
runtime/engine/common/frontend/feature_cache.cc
+3
-2
runtime/engine/common/frontend/feature_cache.h
runtime/engine/common/frontend/feature_cache.h
+4
-6
runtime/engine/common/frontend/mfcc.cc
runtime/engine/common/frontend/mfcc.cc
+0
-109
runtime/engine/common/frontend/mfcc.h
runtime/engine/common/frontend/mfcc.h
+0
-75
runtime/engine/common/utils/blank_process.cc
runtime/engine/common/utils/blank_process.cc
+0
-26
runtime/engine/common/utils/blank_process.h
runtime/engine/common/utils/blank_process.h
+0
-9
runtime/engine/common/utils/strings.cc
runtime/engine/common/utils/strings.cc
+70
-1
runtime/engine/common/utils/strings.h
runtime/engine/common/utils/strings.h
+7
-1
runtime/engine/common/utils/strings_test.cc
runtime/engine/common/utils/strings_test.cc
+44
-1
runtime/engine/common/utils/text_process.cc
runtime/engine/common/utils/text_process.cc
+0
-74
runtime/engine/common/utils/text_process.h
runtime/engine/common/utils/text_process.h
+0
-13
runtime/engine/common/utils/text_process_test.cc
runtime/engine/common/utils/text_process_test.cc
+0
-47
runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh
runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh
+1
-1
未找到文件。
runtime/CMakeLists.txt
浏览文件 @
f35a87ab
...
...
@@ -60,6 +60,7 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch})
###############################################################################
# https://github.com/google/brotli/pull/655
option
(
BUILD_SHARED_LIBS
"Build shared libraries"
ON
)
option
(
WITH_PPS_DEBUG
"debug option"
OFF
)
if
(
WITH_PPS_DEBUG
)
add_definitions
(
"-DPPS_DEBUG"
)
...
...
runtime/engine/asr/decoder/CMakeLists.txt
浏览文件 @
f35a87ab
...
...
@@ -16,7 +16,7 @@ set(TEST_BINS
foreach
(
bin_name IN LISTS TEST_BINS
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
nnet decoder fst utils
libgflags_nothreads.so
glog kaldi-base kaldi-matrix kaldi-util
)
target_link_libraries
(
${
bin_name
}
nnet decoder fst utils
gflags
glog kaldi-base kaldi-matrix kaldi-util
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
-ldl
)
...
...
runtime/engine/asr/nnet/CMakeLists.txt
浏览文件 @
f35a87ab
...
...
@@ -16,18 +16,6 @@ target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURC
# test bin
#set(bin_name u2_nnet_main)
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
set
(
bin_name u2_nnet_thread_main
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
utils kaldi-util kaldi-matrix gflags glog nnet frontend
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
)
runtime/engine/asr/nnet/decodable.h
浏览文件 @
f35a87ab
...
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "kaldi/decoder/decodable-itf.h"
#include "matrix/kaldi-matrix.h"
...
...
runtime/engine/asr/nnet/nnet_producer.cc
浏览文件 @
f35a87ab
...
...
@@ -24,42 +24,11 @@ using std::vector;
NnetProducer
::
NnetProducer
(
std
::
shared_ptr
<
NnetBase
>
nnet
,
std
::
shared_ptr
<
FrontendInterface
>
frontend
)
:
nnet_
(
nnet
),
frontend_
(
frontend
)
{
abort_
=
false
;
Reset
();
if
(
nnet_
!=
nullptr
)
thread_
=
std
::
thread
(
RunNnetEvaluation
,
this
);
}
void
NnetProducer
::
Accept
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
inputs
)
{
frontend_
->
Accept
(
inputs
);
condition_variable_
.
notify_one
();
}
void
NnetProducer
::
WaitProduce
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
read_mutex_
);
while
(
frontend_
->
IsFinished
()
==
false
&&
cache_
.
empty
())
{
condition_read_ready_
.
wait
(
lock
);
}
return
;
}
void
NnetProducer
::
RunNnetEvaluation
(
NnetProducer
*
me
)
{
me
->
RunNnetEvaluationInteral
();
}
void
NnetProducer
::
RunNnetEvaluationInteral
()
{
bool
result
=
false
;
LOG
(
INFO
)
<<
"NnetEvaluationInteral begin"
;
while
(
!
abort_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
condition_variable_
.
wait
(
lock
);
do
{
result
=
Compute
();
}
while
(
result
);
if
(
frontend_
->
IsFinished
()
==
true
)
{
if
(
cache_
.
empty
())
finished_
=
true
;
}
}
LOG
(
INFO
)
<<
"NnetEvaluationInteral exit"
;
}
void
NnetProducer
::
Acceptlikelihood
(
...
...
@@ -76,14 +45,7 @@ void NnetProducer::Acceptlikelihood(
bool
NnetProducer
::
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
)
{
bool
flag
=
cache_
.
pop
(
nnet_prob
);
condition_variable_
.
notify_one
();
return
flag
;
}
bool
NnetProducer
::
ReadandCompute
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
)
{
Compute
();
if
(
frontend_
->
IsFinished
()
&&
cache_
.
empty
())
finished_
=
true
;
bool
flag
=
cache_
.
pop
(
nnet_prob
);
LOG
(
INFO
)
<<
"nnet cache_ size: "
<<
cache_
.
size
();
return
flag
;
}
...
...
@@ -91,7 +53,10 @@ bool NnetProducer::Compute() {
vector
<
BaseFloat
>
features
;
if
(
frontend_
==
NULL
||
frontend_
->
Read
(
&
features
)
==
false
)
{
// no feat or frontend_ not init.
VLOG
(
2
)
<<
"no feat avalible"
;
LOG
(
INFO
)
<<
"no feat avalible"
;
if
(
frontend_
->
IsFinished
()
==
true
)
{
finished_
=
true
;
}
return
false
;
}
CHECK_GE
(
frontend_
->
Dim
(),
0
);
...
...
@@ -107,7 +72,6 @@ bool NnetProducer::Compute() {
out
.
logprobs
.
data
()
+
idx
*
vocab_dim
,
out
.
logprobs
.
data
()
+
(
idx
+
1
)
*
vocab_dim
);
cache_
.
push_back
(
logprob
);
condition_read_ready_
.
notify_one
();
}
return
true
;
}
...
...
runtime/engine/asr/nnet/nnet_producer.h
浏览文件 @
f35a87ab
...
...
@@ -25,7 +25,6 @@ class NnetProducer {
public:
explicit
NnetProducer
(
std
::
shared_ptr
<
NnetBase
>
nnet
,
std
::
shared_ptr
<
FrontendInterface
>
frontend
=
NULL
);
// Feed feats or waves
void
Accept
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
inputs
);
...
...
@@ -33,36 +32,24 @@ class NnetProducer {
// nnet
bool
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
);
bool
ReadandCompute
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
);
static
void
RunNnetEvaluation
(
NnetProducer
*
me
);
void
RunNnetEvaluationInteral
();
void
WaitProduce
();
void
Wait
()
{
abort_
=
true
;
condition_variable_
.
notify_one
();
if
(
thread_
.
joinable
())
thread_
.
join
();
}
bool
Empty
()
const
{
return
cache_
.
empty
();
}
void
SetInputFinished
()
{
LOG
(
INFO
)
<<
"set finished"
;
frontend_
->
SetFinished
();
condition_variable_
.
notify_one
();
}
// the compute thread exit
bool
IsFinished
()
const
{
return
finished_
;
}
~
NnetProducer
()
{
if
(
thread_
.
joinable
())
thread_
.
join
();
bool
IsFinished
()
const
{
return
(
frontend_
->
IsFinished
()
&&
finished_
);
}
~
NnetProducer
()
{}
void
Reset
()
{
if
(
frontend_
!=
NULL
)
frontend_
->
Reset
();
if
(
nnet_
!=
NULL
)
nnet_
->
Reset
();
VLOG
(
3
)
<<
"feature cache reset: cache size: "
<<
cache_
.
size
();
cache_
.
clear
();
finished_
=
false
;
}
...
...
@@ -71,19 +58,13 @@ class NnetProducer {
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
);
private:
bool
Compute
();
private:
std
::
shared_ptr
<
FrontendInterface
>
frontend_
;
std
::
shared_ptr
<
NnetBase
>
nnet_
;
SafeQueue
<
std
::
vector
<
kaldi
::
BaseFloat
>>
cache_
;
std
::
mutex
mutex_
;
std
::
mutex
read_mutex_
;
std
::
condition_variable
condition_variable_
;
std
::
condition_variable
condition_read_ready_
;
std
::
thread
thread_
;
bool
finished_
;
bool
abort_
;
DISALLOW_COPY_AND_ASSIGN
(
NnetProducer
);
};
...
...
runtime/engine/asr/recognizer/CMakeLists.txt
浏览文件 @
f35a87ab
set
(
srcs
)
list
(
APPEND srcs
u2_recognizer.cc
recognizer_controller.cc
recognizer_controller_impl.cc
)
add_library
(
recognizer STATIC
${
srcs
}
)
target_link_libraries
(
recognizer PUBLIC decoder
)
set
(
TEST_BINS
u2_recognizer_main
u2_recognizer_batch_main
recognizer_batch_main
recognizer_main
)
foreach
(
bin_name IN LISTS TEST_BINS
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
recognizer nnet decoder fst utils
libgflags_nothreads.so
glog kaldi-base kaldi-matrix kaldi-util
)
target_link_libraries
(
${
bin_name
}
recognizer nnet decoder fst utils
gflags
glog kaldi-base kaldi-matrix kaldi-util
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
-ldl
)
...
...
runtime/engine/asr/recognizer/recognizer.cc
0 → 100644
浏览文件 @
f35a87ab
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer.h
0 → 100644
浏览文件 @
f35a87ab
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer_batch_main.cc
浏览文件 @
f35a87ab
...
...
@@ -19,7 +19,6 @@
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/u2_recognizer.h"
#include "recognizer/recognizer_controller.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
...
...
@@ -69,9 +68,10 @@ void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
kaldi
::
WaveData
wave_data
;
wave_data
.
Read
(
infile
);
int32
recog_id
=
-
1
;
while
(
recog_id
!
=
-
1
)
{
while
(
recog_id
=
=
-
1
)
{
recog_id
=
recognizer_controller
->
GetRecognizerInstanceId
();
}
recognizer_controller
->
InitDecoder
(
recog_id
);
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"wav dur: "
<<
wave_data
.
Duration
()
<<
" sec."
;
double
dur
=
wave_data
.
Duration
();
...
...
@@ -96,13 +96,10 @@ void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
}
recognizer_controller
->
Accept
(
wav_chunk
,
recog_id
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
recognizer_controller
->
SetInputFinished
(
recog_id
);
}
// no overlap
sample_offset
+=
cur_chunk_size
;
}
recognizer_controller
->
SetInputFinished
(
recog_id
);
CHECK
(
sample_offset
==
tot_samples
);
std
::
string
result
=
recognizer_controller
->
GetFinalResult
(
recog_id
);
if
(
result
.
empty
())
{
...
...
@@ -142,8 +139,8 @@ int main(int argc, char* argv[]) {
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
ppspeech
::
U2
RecognizerResource
resource
=
ppspeech
::
U2
RecognizerResource
::
InitFromFlags
();
ppspeech
::
RecognizerResource
resource
=
ppspeech
::
RecognizerResource
::
InitFromFlags
();
ppspeech
::
RecognizerController
recognizer_controller
(
njob
,
resource
);
ThreadPool
threadpool
(
njob
);
vector
<
vector
<
string
>>
wavlist
;
...
...
runtime/engine/asr/recognizer/recognizer_controller.cc
浏览文件 @
f35a87ab
...
...
@@ -13,17 +13,15 @@
// limitations under the License.
#include "recognizer/recognizer_controller.h"
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
namespace
ppspeech
{
RecognizerController
::
RecognizerController
(
int
num_worker
,
U2
RecognizerResource
resource
)
{
RecognizerController
::
RecognizerController
(
int
num_worker
,
RecognizerResource
resource
)
{
nnet_
=
std
::
make_shared
<
ppspeech
::
U2Nnet
>
(
resource
.
model_opts
);
recognizer_workers
.
resize
(
num_worker
);
for
(
size_t
i
=
0
;
i
<
num_worker
;
++
i
)
{
recognizer_workers
[
i
].
reset
(
new
ppspeech
::
U2Recognizer
(
resource
,
nnet_
->
Clone
()));
recognizer_workers
[
i
]
->
InitDecoder
();
recognizer_workers
[
i
].
reset
(
new
ppspeech
::
RecognizerControllerImpl
(
resource
,
nnet_
->
Clone
()));
waiting_workers
.
push
(
i
);
}
}
...
...
@@ -43,16 +41,18 @@ int RecognizerController::GetRecognizerInstanceId() {
RecognizerController
::~
RecognizerController
()
{
for
(
size_t
i
=
0
;
i
<
recognizer_workers
.
size
();
++
i
)
{
recognizer_workers
[
i
]
->
SetInputFinished
();
recognizer_workers
[
i
]
->
WaitDecodeFinished
();
recognizer_workers
[
i
]
->
WaitFinished
();
}
}
void
RecognizerController
::
InitDecoder
(
int
idx
)
{
recognizer_workers
[
idx
]
->
InitDecoder
();
}
std
::
string
RecognizerController
::
GetFinalResult
(
int
idx
)
{
recognizer_workers
[
idx
]
->
WaitDecodeFinished
();
recognizer_workers
[
idx
]
->
WaitDecode
r
Finished
();
recognizer_workers
[
idx
]
->
AttentionRescoring
();
std
::
string
result
=
recognizer_workers
[
idx
]
->
GetFinalResult
();
recognizer_workers
[
idx
]
->
InitDecoder
();
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
waiting_workers
.
push
(
idx
);
...
...
@@ -68,4 +68,4 @@ void RecognizerController::SetInputFinished(int idx) {
recognizer_workers
[
idx
]
->
SetInputFinished
();
}
}
\ No newline at end of file
}
runtime/engine/asr/recognizer/recognizer_controller.h
浏览文件 @
f35a87ab
...
...
@@ -12,19 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include <memory>
#include "recognizer/
u2_recognizer
.h"
#include "recognizer/
recognizer_controller_impl
.h"
#include "nnet/u2_nnet.h"
namespace
ppspeech
{
class
RecognizerController
{
public:
explicit
RecognizerController
(
int
num_worker
,
U2
RecognizerResource
resource
);
explicit
RecognizerController
(
int
num_worker
,
RecognizerResource
resource
);
~
RecognizerController
();
int
GetRecognizerInstanceId
();
void
InitDecoder
(
int
idx
);
void
Accept
(
std
::
vector
<
float
>
data
,
int
idx
);
void
SetInputFinished
(
int
idx
);
std
::
string
GetFinalResult
(
int
idx
);
...
...
@@ -33,7 +36,9 @@ class RecognizerController {
std
::
queue
<
int
>
waiting_workers
;
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet_
;
std
::
mutex
mutex_
;
std
::
vector
<
std
::
unique_ptr
<
ppspeech
::
U2Recognizer
>>
recognizer_workers
;
std
::
vector
<
std
::
unique_ptr
<
ppspeech
::
RecognizerControllerImpl
>>
recognizer_workers
;
DISALLOW_COPY_AND_ASSIGN
(
RecognizerController
);
};
}
\ No newline at end of file
runtime/engine/asr/recognizer/
u2_recognizer
.cc
→
runtime/engine/asr/recognizer/
recognizer_controller_impl
.cc
浏览文件 @
f35a87ab
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
3
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,21 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
#ifdef USE_ONNX
#include "nnet/u2_onnx_nnet.h"
#endif
#include "recognizer/recognizer_controller_impl.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "common/utils/strings.h"
namespace
ppspeech
{
using
kaldi
::
BaseFloat
;
using
std
::
unique_ptr
;
using
std
::
vector
;
U2Recognizer
::
U2Recognizer
(
const
U2RecognizerResource
&
resource
)
:
opts_
(
resource
)
{
RecognizerControllerImpl
::
RecognizerControllerImpl
(
const
RecognizerResource
&
resource
)
:
opts_
(
resource
)
{
BaseFloat
am_scale
=
resource
.
acoustic_scale
;
const
FeaturePipelineOptions
&
feature_opts
=
resource
.
feature_pipeline_opts
;
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline
(
...
...
@@ -42,8 +35,9 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
}
#endif
nnet_producer_
.
reset
(
new
NnetProducer
(
nnet
,
feature_pipeline
));
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
)
);
nnet_thread_
=
std
::
thread
(
RunNnetEvaluation
,
this
);
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
CHECK_NE
(
resource
.
vocab_path
,
""
);
if
(
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
.
empty
())
{
LOG
(
INFO
)
<<
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
;
...
...
@@ -55,21 +49,21 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
}
symbol_table_
=
decoder_
->
WordSymbolTable
();
global_frame_offset_
=
0
;
input_finished_
=
false
;
num_frames_
=
0
;
result_
.
clear
();
result_
.
clear
();
}
U2Recognizer
::
U2Recognizer
(
const
U2
RecognizerResource
&
resource
,
std
::
shared_ptr
<
NnetBase
>
nnet
)
:
opts_
(
resource
)
{
RecognizerControllerImpl
::
RecognizerControllerImpl
(
const
RecognizerResource
&
resource
,
std
::
shared_ptr
<
NnetBase
>
nnet
)
:
opts_
(
resource
)
{
BaseFloat
am_scale
=
resource
.
acoustic_scale
;
const
FeaturePipelineOptions
&
feature_opts
=
resource
.
feature_pipeline_opts
;
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline
=
std
::
make_shared
<
FeaturePipeline
>
(
feature_opts
);
nnet_producer_
.
reset
(
new
NnetProducer
(
nnet
,
feature_pipeline
));
nnet_producer_
=
std
::
make_shared
<
NnetProducer
>
(
nnet
,
feature_pipeline
);
nnet_thread_
=
std
::
thread
(
RunNnetEvaluation
,
this
);
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
CHECK_NE
(
resource
.
vocab_path
,
""
);
...
...
@@ -88,21 +82,72 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource,
result_
.
clear
();
}
U2Recognizer
::~
U2Recognizer
()
{
SetInputFinished
();
WaitDecodeFinished
();
RecognizerControllerImpl
::~
RecognizerControllerImpl
()
{
WaitFinished
();
}
void
RecognizerControllerImpl
::
Reset
()
{
nnet_producer_
->
Reset
();
}
void
RecognizerControllerImpl
::
RunDecoder
(
RecognizerControllerImpl
*
me
)
{
me
->
RunDecoderInternal
();
}
void
RecognizerControllerImpl
::
RunDecoderInternal
()
{
LOG
(
INFO
)
<<
"DecoderInternal begin"
;
while
(
!
nnet_producer_
->
IsFinished
())
{
nnet_condition_
.
notify_one
();
decoder_
->
AdvanceDecode
(
decodable_
);
}
decoder_
->
AdvanceDecode
(
decodable_
);
UpdateResult
(
false
);
LOG
(
INFO
)
<<
"DecoderInternal exit"
;
}
void
RecognizerControllerImpl
::
WaitDecoderFinished
()
{
if
(
decoder_thread_
.
joinable
())
decoder_thread_
.
join
();
}
void
RecognizerControllerImpl
::
RunNnetEvaluation
(
RecognizerControllerImpl
*
me
)
{
me
->
RunNnetEvaluationInternal
();
}
void
RecognizerControllerImpl
::
SetInputFinished
()
{
nnet_producer_
->
SetInputFinished
();
nnet_condition_
.
notify_one
();
LOG
(
INFO
)
<<
"Set Input Finished"
;
}
void
RecognizerControllerImpl
::
WaitFinished
()
{
abort_
=
true
;
LOG
(
INFO
)
<<
"nnet wait finished"
;
nnet_condition_
.
notify_one
();
if
(
nnet_thread_
.
joinable
())
{
nnet_thread_
.
join
();
}
}
void
U2Recognizer
::
WaitDecodeFinished
()
{
if
(
thread_
.
joinable
())
thread_
.
join
();
void
RecognizerControllerImpl
::
RunNnetEvaluationInternal
()
{
bool
result
=
false
;
LOG
(
INFO
)
<<
"NnetEvaluationInteral begin"
;
while
(
!
abort_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
nnet_mutex_
);
nnet_condition_
.
wait
(
lock
);
do
{
result
=
nnet_producer_
->
Compute
();
decoder_condition_
.
notify_one
();
}
while
(
result
);
}
LOG
(
INFO
)
<<
"NnetEvaluationInteral exit"
;
}
void
U2Recognizer
::
WaitFinished
(
)
{
if
(
thread_
.
joinable
())
thread_
.
join
(
);
nnet_
producer_
->
Wait
();
void
RecognizerControllerImpl
::
Accept
(
std
::
vector
<
float
>
data
)
{
nnet_producer_
->
Accept
(
data
);
nnet_
condition_
.
notify_one
();
}
void
U2Recognizer
::
InitDecoder
()
{
void
RecognizerControllerImpl
::
InitDecoder
()
{
global_frame_offset_
=
0
;
input_finished_
=
false
;
num_frames_
=
0
;
...
...
@@ -110,51 +155,56 @@ void U2Recognizer::InitDecoder() {
decodable_
->
Reset
();
decoder_
->
Reset
();
thread_
=
std
::
thread
(
RunDecoderSearch
,
this
);
decoder_thread_
=
std
::
thread
(
RunDecoder
,
this
);
}
void
U2Recognizer
::
ResetContinuousDecoding
()
{
global_frame_offset_
=
num_frames_
;
num_frames_
=
0
;
result_
.
clear
();
void
RecognizerControllerImpl
::
AttentionRescoring
()
{
decoder_
->
FinalizeSearch
();
UpdateResult
(
false
);
decodable_
->
Reset
();
decoder_
->
Reset
();
}
// No need to do rescoring
if
(
0.0
==
opts_
.
decoder_opts
.
rescoring_weight
)
{
LOG_EVERY_N
(
WARNING
,
3
)
<<
"Not do AttentionRescoring!"
;
return
;
}
LOG_EVERY_N
(
WARNING
,
3
)
<<
"Do AttentionRescoring!"
;
void
U2Recognizer
::
RunDecoderSearch
(
U2Recognizer
*
me
)
{
me
->
RunDecoderSearchInternal
();
}
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const
auto
&
hypotheses
=
decoder_
->
Inputs
();
int
num_hyps
=
hypotheses
.
size
();
if
(
num_hyps
<=
0
)
{
return
;
}
void
U2Recognizer
::
RunDecoderSearchInternal
()
{
LOG
(
INFO
)
<<
"DecoderSearchInteral begin"
;
while
(
!
nnet_producer_
->
IsFinished
())
{
nnet_producer_
->
WaitProduce
();
decoder_
->
AdvanceDecode
(
decodable_
);
std
::
vector
<
float
>
rescoring_score
;
decodable_
->
AttentionRescoring
(
hypotheses
,
opts_
.
decoder_opts
.
reverse_weight
,
&
rescoring_score
);
// combine ctc score and rescoring score
for
(
size_t
i
=
0
;
i
<
num_hyps
;
i
++
)
{
VLOG
(
3
)
<<
"hyp "
<<
i
<<
" rescoring_score: "
<<
rescoring_score
[
i
]
<<
" ctc_score: "
<<
result_
[
i
].
score
<<
" rescoring_weight: "
<<
opts_
.
decoder_opts
.
rescoring_weight
<<
" ctc_weight: "
<<
opts_
.
decoder_opts
.
ctc_weight
;
result_
[
i
].
score
=
opts_
.
decoder_opts
.
rescoring_weight
*
rescoring_score
[
i
]
+
opts_
.
decoder_opts
.
ctc_weight
*
result_
[
i
].
score
;
VLOG
(
3
)
<<
"hyp: "
<<
result_
[
0
].
sentence
<<
" score: "
<<
result_
[
0
].
score
;
}
decoder_
->
AdvanceDecode
(
decodable_
);
UpdateResult
(
false
);
LOG
(
INFO
)
<<
"DecoderSearchInteral exit"
;
}
void
U2Recognizer
::
Accept
(
const
vector
<
BaseFloat
>&
waves
)
{
kaldi
::
Timer
timer
;
nnet_producer_
->
Accept
(
waves
);
VLOG
(
1
)
<<
"feed waves cost: "
<<
timer
.
Elapsed
()
<<
" sec. "
<<
waves
.
size
()
<<
" samples."
;
std
::
sort
(
result_
.
begin
(),
result_
.
end
(),
DecodeResult
::
CompareFunc
);
VLOG
(
3
)
<<
"result: "
<<
result_
[
0
].
sentence
<<
" score: "
<<
result_
[
0
].
score
;
}
void
U2Recognizer
::
Decode
()
{
decoder_
->
AdvanceDecode
(
decodable_
);
UpdateResult
(
false
);
}
std
::
string
RecognizerControllerImpl
::
GetFinalResult
()
{
return
result_
[
0
].
sentence
;
}
void
U2Recognizer
::
Rescoring
()
{
// Do attention Rescoring
AttentionRescoring
();
}
std
::
string
RecognizerControllerImpl
::
GetPartialResult
()
{
return
result_
[
0
].
sentence
;
}
void
U2Recognizer
::
UpdateResult
(
bool
finish
)
{
void
RecognizerControllerImpl
::
UpdateResult
(
bool
finish
)
{
const
auto
&
hypotheses
=
decoder_
->
Outputs
();
const
auto
&
inputs
=
decoder_
->
Inputs
();
const
auto
&
likelihood
=
decoder_
->
Likelihood
();
...
...
@@ -169,10 +219,9 @@ void U2Recognizer::UpdateResult(bool finish) {
path
.
score
=
likelihood
[
i
];
for
(
size_t
j
=
0
;
j
<
hypothesis
.
size
();
j
++
)
{
std
::
string
word
=
symbol_table_
->
Find
(
hypothesis
[
j
]);
// path.sentence += (" " + word); // todo SmileGoat: add blank
// processor
path
.
sentence
+=
word
;
// todo SmileGoat: add blank processor
path
.
sentence
+=
(
" "
+
word
);
}
path
.
sentence
=
DelBlank
(
path
.
sentence
);
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
...
...
@@ -229,56 +278,4 @@ void U2Recognizer::UpdateResult(bool finish) {
}
}
void
U2Recognizer
::
AttentionRescoring
()
{
decoder_
->
FinalizeSearch
();
UpdateResult
(
false
);
// No need to do rescoring
if
(
0.0
==
opts_
.
decoder_opts
.
rescoring_weight
)
{
LOG_EVERY_N
(
WARNING
,
3
)
<<
"Not do AttentionRescoring!"
;
return
;
}
LOG_EVERY_N
(
WARNING
,
3
)
<<
"Do AttentionRescoring!"
;
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const
auto
&
hypotheses
=
decoder_
->
Inputs
();
int
num_hyps
=
hypotheses
.
size
();
if
(
num_hyps
<=
0
)
{
return
;
}
std
::
vector
<
float
>
rescoring_score
;
decodable_
->
AttentionRescoring
(
hypotheses
,
opts_
.
decoder_opts
.
reverse_weight
,
&
rescoring_score
);
// combine ctc score and rescoring score
for
(
size_t
i
=
0
;
i
<
num_hyps
;
i
++
)
{
VLOG
(
3
)
<<
"hyp "
<<
i
<<
" rescoring_score: "
<<
rescoring_score
[
i
]
<<
" ctc_score: "
<<
result_
[
i
].
score
<<
" rescoring_weight: "
<<
opts_
.
decoder_opts
.
rescoring_weight
<<
" ctc_weight: "
<<
opts_
.
decoder_opts
.
ctc_weight
;
result_
[
i
].
score
=
opts_
.
decoder_opts
.
rescoring_weight
*
rescoring_score
[
i
]
+
opts_
.
decoder_opts
.
ctc_weight
*
result_
[
i
].
score
;
VLOG
(
3
)
<<
"hyp: "
<<
result_
[
0
].
sentence
<<
" score: "
<<
result_
[
0
].
score
;
}
std
::
sort
(
result_
.
begin
(),
result_
.
end
(),
DecodeResult
::
CompareFunc
);
VLOG
(
3
)
<<
"result: "
<<
result_
[
0
].
sentence
<<
" score: "
<<
result_
[
0
].
score
;
}
std
::
string
U2Recognizer
::
GetFinalResult
()
{
return
result_
[
0
].
sentence
;
}
std
::
string
U2Recognizer
::
GetPartialResult
()
{
return
result_
[
0
].
sentence
;
}
void
U2Recognizer
::
SetInputFinished
()
{
nnet_producer_
->
SetInputFinished
();
input_finished_
=
true
;
}
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer_controller_impl.h
0 → 100644
浏览文件 @
f35a87ab
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/common.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "nnet/u2_nnet.h"
#include "nnet/nnet_producer.h"
#ifdef USE_ONNX
#include "nnet/u2_onnx_nnet.h"
#endif
#include "nnet/decodable.h"
#include "recognizer/recognizer_resource.h"
#include <memory>
namespace
ppspeech
{
class
RecognizerControllerImpl
{
public:
explicit
RecognizerControllerImpl
(
const
RecognizerResource
&
resource
);
explicit
RecognizerControllerImpl
(
const
RecognizerResource
&
resource
,
std
::
shared_ptr
<
NnetBase
>
nnet
);
~
RecognizerControllerImpl
();
void
Accept
(
std
::
vector
<
float
>
data
);
void
InitDecoder
();
void
SetInputFinished
();
std
::
string
GetFinalResult
();
std
::
string
GetPartialResult
();
void
Rescoring
();
void
Reset
();
void
WaitDecoderFinished
();
void
WaitFinished
();
void
AttentionRescoring
();
bool
DecodedSomething
()
const
{
return
!
result_
.
empty
()
&&
!
result_
[
0
].
sentence
.
empty
();
}
int
FrameShiftInMs
()
const
{
return
1
;
//todo
}
private:
static
void
RunNnetEvaluation
(
RecognizerControllerImpl
*
me
);
void
RunNnetEvaluationInternal
();
static
void
RunDecoder
(
RecognizerControllerImpl
*
me
);
void
RunDecoderInternal
();
void
UpdateResult
(
bool
finish
=
false
);
std
::
shared_ptr
<
Decodable
>
decodable_
;
std
::
unique_ptr
<
DecoderBase
>
decoder_
;
std
::
shared_ptr
<
NnetProducer
>
nnet_producer_
;
// e2e unit symbol table
std
::
shared_ptr
<
fst
::
SymbolTable
>
symbol_table_
=
nullptr
;
std
::
vector
<
DecodeResult
>
result_
;
RecognizerResource
opts_
;
bool
abort_
=
false
;
// global decoded frame offset
int
global_frame_offset_
;
// cur decoded frame num
int
num_frames_
;
// timestamp gap between words in a sentence
const
int
time_stamp_gap_
=
100
;
bool
input_finished_
;
std
::
mutex
nnet_mutex_
;
std
::
mutex
decoder_mutex_
;
std
::
condition_variable
nnet_condition_
;
std
::
condition_variable
decoder_condition_
;
std
::
thread
nnet_thread_
;
std
::
thread
decoder_thread_
;
DISALLOW_COPY_AND_ASSIGN
(
RecognizerControllerImpl
);
};
}
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer_impl.cc
0 → 100644
浏览文件 @
f35a87ab
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer_impl.h
0 → 100644
浏览文件 @
f35a87ab
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
\ No newline at end of file
runtime/engine/asr/recognizer/
u2_
recognizer_main.cc
→
runtime/engine/asr/recognizer/recognizer_main.cc
浏览文件 @
f35a87ab
...
...
@@ -15,7 +15,7 @@
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "recognizer/
u2_recogniz
er.h"
#include "recognizer/
recognizer_controll
er.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
...
...
@@ -45,10 +45,10 @@ int main(int argc, char* argv[]) {
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
ppspeech
::
U2
RecognizerResource
resource
=
ppspeech
::
U2
RecognizerResource
::
InitFromFlags
();
std
::
shared_ptr
<
ppspeech
::
U2Recognizer
>
recognizer_ptr
(
new
ppspeech
::
U2Recognizer
(
resource
));
ppspeech
::
RecognizerResource
resource
=
ppspeech
::
RecognizerResource
::
InitFromFlags
();
std
::
shared_ptr
<
ppspeech
::
RecognizerControllerImpl
>
recognizer_ptr
(
new
ppspeech
::
RecognizerControllerImpl
(
resource
));
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
recognizer_ptr
->
InitDecoder
();
...
...
@@ -84,7 +84,7 @@ int main(int argc, char* argv[]) {
}
CHECK
(
sample_offset
==
tot_samples
);
recognizer_ptr
->
SetInputFinished
();
recognizer_ptr
->
WaitDecodeFinished
();
recognizer_ptr
->
WaitDecode
r
Finished
();
kaldi
::
Timer
timer
;
recognizer_ptr
->
AttentionRescoring
();
...
...
runtime/engine/asr/recognizer/
u2_recognizer
.h
→
runtime/engine/asr/recognizer/
recognizer_resource
.h
浏览文件 @
f35a87ab
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/decoder_itf.h"
#include "frontend/feature_pipeline.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "nnet/decodable.h"
DECLARE_int32
(
nnet_decoder_chunk
);
DECLARE_int32
(
num_left_chunks
);
...
...
@@ -87,7 +67,7 @@ struct DecodeOptions {
}
};
struct
U2
RecognizerResource
{
struct
RecognizerResource
{
kaldi
::
BaseFloat
acoustic_scale
{
1.0
};
std
::
string
vocab_path
{};
...
...
@@ -95,8 +75,8 @@ struct U2RecognizerResource {
ModelOptions
model_opts
{};
DecodeOptions
decoder_opts
{};
static
U2
RecognizerResource
InitFromFlags
()
{
U2
RecognizerResource
resource
;
static
RecognizerResource
InitFromFlags
()
{
RecognizerResource
resource
;
resource
.
vocab_path
=
FLAGS_vocab_path
;
resource
.
acoustic_scale
=
FLAGS_acoustic_scale
;
LOG
(
INFO
)
<<
"vocab path: "
<<
resource
.
vocab_path
;
...
...
@@ -113,68 +93,4 @@ struct U2RecognizerResource {
}
};
class
U2Recognizer
{
public:
explicit
U2Recognizer
(
const
U2RecognizerResource
&
resouce
);
explicit
U2Recognizer
(
const
U2RecognizerResource
&
resource
,
std
::
shared_ptr
<
NnetBase
>
nnet
);
~
U2Recognizer
();
void
InitDecoder
();
void
ResetContinuousDecoding
();
void
Accept
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
waves
);
void
Decode
();
void
Rescoring
();
std
::
string
GetFinalResult
();
std
::
string
GetPartialResult
();
void
SetInputFinished
();
bool
IsFinished
()
{
return
input_finished_
;
}
void
WaitDecodeFinished
();
void
WaitFinished
();
bool
DecodedSomething
()
const
{
return
!
result_
.
empty
()
&&
!
result_
[
0
].
sentence
.
empty
();
}
int
FrameShiftInMs
()
const
{
// one decoder frame length in ms, todo
return
1
;
// return decodable_->Nnet()->SubsamplingRate() *
// feature_pipeline_->FrameShift();
}
const
std
::
vector
<
DecodeResult
>&
Result
()
const
{
return
result_
;
}
void
AttentionRescoring
();
private:
static
void
RunDecoderSearch
(
U2Recognizer
*
me
);
void
RunDecoderSearchInternal
();
void
UpdateResult
(
bool
finish
=
false
);
private:
U2RecognizerResource
opts_
;
std
::
shared_ptr
<
NnetProducer
>
nnet_producer_
;
std
::
shared_ptr
<
Decodable
>
decodable_
;
std
::
unique_ptr
<
DecoderBase
>
decoder_
;
// e2e unit symbol table
std
::
shared_ptr
<
fst
::
SymbolTable
>
symbol_table_
=
nullptr
;
std
::
vector
<
DecodeResult
>
result_
;
// global decoded frame offset
int
global_frame_offset_
;
// cur decoded frame num
int
num_frames_
;
// timestamp gap between words in a sentence
const
int
time_stamp_gap_
=
100
;
bool
input_finished_
;
std
::
thread
thread_
;
};
}
// namespace ppspeech
}
//namespace ppspeech
\ No newline at end of file
runtime/engine/asr/recognizer/u2_recognizer_batch_main.cc
已删除
100644 → 0
浏览文件 @
591b957b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/u2_recognizer.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
DEFINE_int32
(
njob
,
3
,
"njob"
);
using
std
::
string
;
using
std
::
vector
;
void
SplitUtt
(
string
wavlist_file
,
vector
<
vector
<
string
>>*
uttlists
,
vector
<
vector
<
string
>>*
wavlists
,
int
njob
)
{
vector
<
string
>
wavlist
;
wavlists
->
resize
(
njob
);
uttlists
->
resize
(
njob
);
ppspeech
::
ReadFileToVector
(
wavlist_file
,
&
wavlist
);
for
(
size_t
idx
=
0
;
idx
<
wavlist
.
size
();
++
idx
)
{
string
utt_str
=
wavlist
[
idx
];
vector
<
string
>
utt_wav
=
ppspeech
::
StrSplit
(
utt_str
,
"
\t
"
);
LOG
(
INFO
)
<<
utt_wav
[
0
];
CHECK_EQ
(
utt_wav
.
size
(),
size_t
(
2
));
uttlists
->
at
(
idx
%
njob
).
push_back
(
utt_wav
[
0
]);
wavlists
->
at
(
idx
%
njob
).
push_back
(
utt_wav
[
1
]);
}
}
void
recognizer_func
(
const
ppspeech
::
U2RecognizerResource
&
resource
,
std
::
shared_ptr
<
ppspeech
::
NnetBase
>
nnet
,
std
::
vector
<
string
>
wavlist
,
std
::
vector
<
string
>
uttlist
,
std
::
vector
<
string
>*
results
)
{
int32
num_done
=
0
,
num_err
=
0
;
double
tot_wav_duration
=
0.0
;
double
tot_attention_rescore_time
=
0.0
;
double
tot_decode_time
=
0.0
;
int
chunk_sample_size
=
FLAGS_streaming_chunk
*
FLAGS_sample_rate
;
if
(
wavlist
.
empty
())
return
;
std
::
shared_ptr
<
ppspeech
::
U2Recognizer
>
recognizer_ptr
=
std
::
make_shared
<
ppspeech
::
U2Recognizer
>
(
resource
,
nnet
);
results
->
reserve
(
wavlist
.
size
());
for
(
size_t
idx
=
0
;
idx
<
wavlist
.
size
();
++
idx
)
{
std
::
string
utt
=
uttlist
[
idx
];
std
::
string
wav_file
=
wavlist
[
idx
];
std
::
ifstream
infile
;
infile
.
open
(
wav_file
,
std
::
ifstream
::
in
);
kaldi
::
WaveData
wave_data
;
wave_data
.
Read
(
infile
);
recognizer_ptr
->
InitDecoder
();
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"wav dur: "
<<
wave_data
.
Duration
()
<<
" sec."
;
double
dur
=
wave_data
.
Duration
();
tot_wav_duration
+=
dur
;
int32
this_channel
=
0
;
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
int
tot_samples
=
waveform
.
Dim
();
LOG
(
INFO
)
<<
"wav len (sample): "
<<
tot_samples
;
int
sample_offset
=
0
;
kaldi
::
Timer
local_timer
;
while
(
sample_offset
<
tot_samples
)
{
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
std
::
vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
[
i
]
=
waveform
(
sample_offset
+
i
);
}
recognizer_ptr
->
Accept
(
wav_chunk
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
recognizer_ptr
->
SetInputFinished
();
}
// no overlap
sample_offset
+=
cur_chunk_size
;
}
CHECK
(
sample_offset
==
tot_samples
);
recognizer_ptr
->
WaitDecodeFinished
();
kaldi
::
Timer
timer
;
recognizer_ptr
->
AttentionRescoring
();
tot_attention_rescore_time
+=
timer
.
Elapsed
();
std
::
string
result
=
recognizer_ptr
->
GetFinalResult
();
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
++
num_err
;
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is empty"
;
result
=
" "
;
}
tot_decode_time
+=
local_timer
.
Elapsed
();
LOG
(
INFO
)
<<
utt
<<
" "
<<
result
;
LOG
(
INFO
)
<<
" RTF: "
<<
local_timer
.
Elapsed
()
/
dur
<<
" dur: "
<<
dur
<<
" cost: "
<<
local_timer
.
Elapsed
();
results
->
push_back
(
result
);
++
num_done
;
}
recognizer_ptr
->
WaitFinished
();
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"total decode cost:"
<<
tot_decode_time
<<
" sec"
;
LOG
(
INFO
)
<<
"total rescore cost:"
<<
tot_attention_rescore_time
<<
" sec"
;
LOG
(
INFO
)
<<
"RTF is: "
<<
tot_decode_time
/
tot_wav_duration
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int
sample_rate
=
FLAGS_sample_rate
;
float
streaming_chunk
=
FLAGS_streaming_chunk
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
int
njob
=
FLAGS_njob
;
LOG
(
INFO
)
<<
"sr: "
<<
sample_rate
;
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
ppspeech
::
U2RecognizerResource
resource
=
ppspeech
::
U2RecognizerResource
::
InitFromFlags
();
ThreadPool
threadpool
(
njob
);
vector
<
vector
<
string
>>
wavlist
;
vector
<
vector
<
string
>>
uttlist
;
vector
<
vector
<
string
>>
resultlist
(
njob
);
vector
<
std
::
future
<
void
>>
futurelist
;
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet
(
new
ppspeech
::
U2Nnet
(
resource
.
model_opts
));
SplitUtt
(
FLAGS_wav_rspecifier
,
&
uttlist
,
&
wavlist
,
njob
);
for
(
size_t
i
=
0
;
i
<
njob
;
++
i
)
{
std
::
future
<
void
>
f
=
threadpool
.
enqueue
(
recognizer_func
,
resource
,
nnet
->
Clone
(),
wavlist
[
i
],
uttlist
[
i
],
&
resultlist
[
i
]);
futurelist
.
push_back
(
std
::
move
(
f
));
}
for
(
size_t
i
=
0
;
i
<
njob
;
++
i
)
{
futurelist
[
i
].
get
();
}
for
(
size_t
idx
=
0
;
idx
<
njob
;
++
idx
)
{
for
(
size_t
utt_idx
=
0
;
utt_idx
<
uttlist
[
idx
].
size
();
++
utt_idx
)
{
string
utt
=
uttlist
[
idx
][
utt_idx
];
string
result
=
resultlist
[
idx
][
utt_idx
];
result_writer
.
Write
(
utt
,
result
);
}
}
return
0
;
}
runtime/engine/common/frontend/assembler.cc
浏览文件 @
f35a87ab
...
...
@@ -52,23 +52,20 @@ bool Assembler::Compute(vector<BaseFloat>* feats) {
vector
<
BaseFloat
>
feature
;
bool
result
=
base_extractor_
->
Read
(
&
feature
);
if
(
result
==
false
||
feature
.
size
()
==
0
)
{
VLOG
(
3
)
<<
"result: "
<<
result
VLOG
(
1
)
<<
"result: "
<<
result
<<
" feature dim: "
<<
feature
.
size
();
if
(
IsFinished
()
==
false
)
{
VLOG
(
3
)
<<
"finished reading feature. cache size: "
VLOG
(
1
)
<<
"finished reading feature. cache size: "
<<
feature_cache_
.
size
();
return
false
;
}
else
{
VLOG
(
3
)
<<
"break"
;
VLOG
(
1
)
<<
"break"
;
break
;
}
}
CHECK
(
feature
.
size
()
==
dim_
);
feature_cache_
.
push
(
feature
);
nframes_
+=
1
;
VLOG
(
3
)
<<
"nframes: "
<<
nframes_
;
VLOG
(
1
)
<<
"nframes: "
<<
nframes_
;
}
if
(
feature_cache_
.
size
()
<
receptive_filed_length_
)
{
...
...
runtime/engine/common/frontend/audio_cache.cc
浏览文件 @
f35a87ab
...
...
@@ -56,28 +56,14 @@ bool AudioCache::Read(vector<BaseFloat>* waves) {
kaldi
::
Timer
timer
;
size_t
chunk_size
=
waves
->
size
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
while
(
chunk_size
>
size_
)
{
// when audio is empty and no more data feed
// ready_read_condition will block in dead lock,
// so replace with timeout_
// ready_read_condition_.wait(lock);
int32
elapsed
=
static_cast
<
int32
>
(
timer
.
Elapsed
()
*
1000
);
if
(
elapsed
>
timeout_
)
{
if
(
finished_
==
true
)
{
// read last chunk data
break
;
}
if
(
chunk_size
>
size_
)
{
return
false
;
}
}
usleep
(
100
);
// sleep 0.1 ms
}
// read last chunk data
if
(
chunk_size
>
size_
)
{
chunk_size
=
size_
;
waves
->
resize
(
chunk_size
);
if
(
finished_
==
false
)
{
return
false
;
}
else
{
// read last chunk data
chunk_size
=
size_
;
waves
->
resize
(
chunk_size
);
}
}
for
(
size_t
idx
=
0
;
idx
<
chunk_size
;
++
idx
)
{
...
...
runtime/engine/common/frontend/audio_cache.h
浏览文件 @
f35a87ab
...
...
@@ -39,7 +39,7 @@ class AudioCache : public FrontendInterface {
finished_
=
true
;
}
virtual
bool
IsFinished
()
const
{
return
finished_
;
}
virtual
bool
IsFinished
()
const
{
return
finished_
&&
(
size_
==
0
)
;
}
void
Reset
()
override
{
offset_
=
0
;
...
...
runtime/engine/common/frontend/fbank.cc
已删除
100644 → 0
浏览文件 @
591b957b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/fbank.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h"
#include "kaldi/matrix/matrix-functions.h"
namespace
ppspeech
{
using
kaldi
::
BaseFloat
;
using
kaldi
::
int32
;
using
kaldi
::
Matrix
;
using
kaldi
::
SubVector
;
using
kaldi
::
Vector
;
using
kaldi
::
VectorBase
;
using
std
::
vector
;
FbankComputer
::
FbankComputer
(
const
Options
&
opts
)
:
opts_
(
opts
),
computer_
(
opts
)
{}
int32
FbankComputer
::
Dim
()
const
{
return
opts_
.
mel_opts
.
num_bins
+
(
opts_
.
use_energy
?
1
:
0
);
}
bool
FbankComputer
::
NeedRawLogEnergy
()
{
return
opts_
.
use_energy
&&
opts_
.
raw_energy
;
}
// Compute feat
bool
FbankComputer
::
Compute
(
Vector
<
BaseFloat
>*
window
,
Vector
<
BaseFloat
>*
feat
)
{
RealFft
(
window
,
true
);
kaldi
::
ComputePowerSpectrum
(
window
);
const
kaldi
::
MelBanks
&
mel_bank
=
*
(
computer_
.
GetMelBanks
(
1.0
));
SubVector
<
BaseFloat
>
power_spectrum
(
*
window
,
0
,
window
->
Dim
()
/
2
+
1
);
if
(
!
opts_
.
use_power
)
{
power_spectrum
.
ApplyPow
(
0.5
);
}
int32
mel_offset
=
((
opts_
.
use_energy
&&
!
opts_
.
htk_compat
)
?
1
:
0
);
SubVector
<
BaseFloat
>
mel_energies
(
*
feat
,
mel_offset
,
opts_
.
mel_opts
.
num_bins
);
mel_bank
.
Compute
(
power_spectrum
,
&
mel_energies
);
mel_energies
.
ApplyFloor
(
1e-07
);
mel_energies
.
ApplyLog
();
return
true
;
}
}
// namespace ppspeech
runtime/engine/common/frontend/feature_cache.cc
浏览文件 @
f35a87ab
...
...
@@ -49,7 +49,8 @@ bool FeatureCache::Read(std::vector<kaldi::BaseFloat>* feats) {
// read from cache
*
feats
=
cache_
.
front
();
cache_
.
pop
();
VLOG
(
1
)
<<
"FeatureCache::Read cost: "
<<
timer
.
Elapsed
()
<<
" sec."
;
VLOG
(
2
)
<<
"FeatureCache::Read cost: "
<<
timer
.
Elapsed
()
<<
" sec."
;
VLOG
(
1
)
<<
"FeatureCache::size : "
<<
cache_
.
size
();
return
true
;
}
...
...
@@ -74,7 +75,7 @@ bool FeatureCache::Compute() {
++
nframe_
;
}
VLOG
(
1
)
<<
"FeatureCache::Compute cost: "
<<
timer
.
Elapsed
()
<<
" sec. "
VLOG
(
2
)
<<
"FeatureCache::Compute cost: "
<<
timer
.
Elapsed
()
<<
" sec. "
<<
num_chunk
<<
" feats."
;
return
true
;
}
...
...
runtime/engine/common/frontend/feature_cache.h
浏览文件 @
f35a87ab
...
...
@@ -36,21 +36,19 @@ class FeatureCache : public FrontendInterface {
virtual
void
SetFinished
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
LOG
(
INFO
)
<<
"set finished"
;
// read the last chunk data
Compute
();
base_extractor_
->
SetFinished
();
LOG
(
INFO
)
<<
"compute last feats done."
;
}
virtual
bool
IsFinished
()
const
{
return
base_extractor_
->
IsFinished
();
}
virtual
bool
IsFinished
()
const
{
return
base_extractor_
->
IsFinished
()
&&
cache_
.
empty
();
}
void
Reset
()
override
{
std
::
queue
<
std
::
vector
<
BaseFloat
>>
empty
;
VLOG
(
1
)
<<
"feature cache size: "
<<
cache_
.
size
();
std
::
swap
(
cache_
,
empty
);
nframe_
=
0
;
base_extractor_
->
Reset
();
VLOG
(
3
)
<<
"feature cache reset: cache size: "
<<
cache_
.
size
();
}
private:
...
...
runtime/engine/common/frontend/mfcc.cc
已删除
100644 → 0
浏览文件 @
591b957b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/mfcc.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h"
#include "kaldi/matrix/matrix-functions.h"
namespace
ppspeech
{
using
kaldi
::
BaseFloat
;
using
kaldi
::
int32
;
using
kaldi
::
Matrix
;
using
kaldi
::
SubVector
;
using
kaldi
::
Vector
;
using
kaldi
::
VectorBase
;
using
std
::
vector
;
Mfcc
::
Mfcc
(
const
MfccOptions
&
opts
,
std
::
unique_ptr
<
FrontendInterface
>
base_extractor
)
:
opts_
(
opts
),
computer_
(
opts
.
mfcc_opts
),
window_function_
(
computer_
.
GetFrameOptions
())
{
base_extractor_
=
std
::
move
(
base_extractor
);
chunk_sample_size_
=
static_cast
<
int32
>
(
opts
.
streaming_chunk
*
opts
.
frame_opts
.
samp_freq
);
}
void
Mfcc
::
Accept
(
const
VectorBase
<
BaseFloat
>&
inputs
)
{
base_extractor_
->
Accept
(
inputs
);
}
bool
Mfcc
::
Read
(
Vector
<
BaseFloat
>*
feats
)
{
Vector
<
BaseFloat
>
wav
(
chunk_sample_size_
);
bool
flag
=
base_extractor_
->
Read
(
&
wav
);
if
(
flag
==
false
||
wav
.
Dim
()
==
0
)
return
false
;
// append remaned waves
int32
wav_len
=
wav
.
Dim
();
int32
left_len
=
remained_wav_
.
Dim
();
Vector
<
BaseFloat
>
waves
(
left_len
+
wav_len
);
waves
.
Range
(
0
,
left_len
).
CopyFromVec
(
remained_wav_
);
waves
.
Range
(
left_len
,
wav_len
).
CopyFromVec
(
wav
);
// compute speech feature
Compute
(
waves
,
feats
);
// cache remaned waves
kaldi
::
FrameExtractionOptions
frame_opts
=
computer_
.
GetFrameOptions
();
int32
num_frames
=
kaldi
::
NumFrames
(
waves
.
Dim
(),
frame_opts
);
int32
frame_shift
=
frame_opts
.
WindowShift
();
int32
left_samples
=
waves
.
Dim
()
-
frame_shift
*
num_frames
;
remained_wav_
.
Resize
(
left_samples
);
remained_wav_
.
CopyFromVec
(
waves
.
Range
(
frame_shift
*
num_frames
,
left_samples
));
return
true
;
}
// Compute spectrogram feat
bool
Mfcc
::
Compute
(
const
Vector
<
BaseFloat
>&
waves
,
Vector
<
BaseFloat
>*
feats
)
{
const
FrameExtractionOptions
&
frame_opts
=
computer_
.
GetFrameOptions
();
int32
num_samples
=
waves
.
Dim
();
int32
frame_length
=
frame_opts
.
WindowSize
();
int32
sample_rate
=
frame_opts
.
samp_freq
;
if
(
num_samples
<
frame_length
)
{
return
true
;
}
int32
num_frames
=
kaldi
::
NumFrames
(
num_samples
,
frame_opts
);
feats
->
Rsize
(
num_frames
*
Dim
());
Vector
<
BaseFloat
>
window
;
bool
need_raw_log_energy
=
computer_
.
NeedRawLogEnergy
();
for
(
int32
frame
=
0
;
frame
<
num_frames
;
frame
++
)
{
BaseFloat
raw_log_energy
=
0.0
;
kaldi
::
ExtractWindow
(
0
,
waves
,
frame
,
frame_opts
,
window_function_
,
&
window
,
need_raw_log_energy
?
&
raw_log_energy
:
NULL
);
Vector
<
BaseFloat
>
this_feature
(
computer_
.
Dim
(),
kUndefined
);
// note: this online feature-extraction code does not support VTLN.
BaseFloat
vtln_warp
=
1.0
;
computer_
.
Compute
(
raw_log_energy
,
vtln_warp
,
&
window
,
&
this_feature
);
SubVector
<
BaseFloat
>
output_row
(
feats
->
Data
()
+
frame
*
Dim
(),
Dim
());
output_row
.
CopyFromVec
(
this_feature
);
}
return
true
;
}
}
// namespace ppspeech
\ No newline at end of file
runtime/engine/common/frontend/mfcc.h
已删除
100644 → 0
浏览文件 @
591b957b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "kaldi/feat/feature-mfcc.h"
#include "kaldi/matrix/kaldi-vector.h"
namespace
ppspeech
{
struct
MfccOptions
{
kaldi
::
MfccOptions
mfcc_opts
;
kaldi
::
BaseFloat
streaming_chunk
;
// second
MfccOptions
()
:
streaming_chunk
(
0.1
),
mfcc_opts
()
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"streaming-chunk"
,
&
streaming_chunk
,
"streaming chunk size, default: 0.1 sec"
);
mfcc_opts
.
Register
(
opts
);
}
};
class
Mfcc
:
public
FrontendInterface
{
public:
explicit
Mfcc
(
const
MfccOptions
&
opts
,
unique_ptr
<
FrontendInterface
>
base_extractor
);
virtual
void
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
inputs
);
virtual
bool
Read
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
);
// the dim_ is the dim of single frame feature
virtual
size_t
Dim
()
const
{
return
computer_
.
Dim
();
}
virtual
void
SetFinished
()
{
base_extractor_
->
SetFinished
();
}
virtual
bool
IsFinished
()
const
{
return
base_extractor_
->
IsFinished
();
}
virtual
void
Reset
()
{
base_extractor_
->
Reset
();
remained_wav_
.
Resize
(
0
);
}
private:
bool
Compute
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
waves
,
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
);
MfccOptions
opts_
;
std
::
unique_ptr
<
FrontendInterface
>
base_extractor_
;
FeatureWindowFunction
window_function_
;
kaldi
::
MfccComputer
computer_
;
// features_ is the Mfcc or Plp or Fbank features that we have already
// computed.
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
features_
;
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
remained_wav_
;
DISALLOW_COPY_AND_ASSIGN
(
Fbank
);
};
}
// namespace ppspeech
\ No newline at end of file
runtime/engine/common/utils/blank_process.cc
已删除
100644 → 0
浏览文件 @
591b957b
#include "utils/blank_process.h"
namespace
ppspeech
{
std
::
string
BlankProcess
(
const
std
::
string
&
str
)
{
std
::
string
out
=
""
;
int
p
=
0
;
int
end
=
str
.
size
();
int
q
=
-
1
;
// last char of the output string
while
(
p
!=
end
)
{
while
(
p
!=
end
&&
str
[
p
]
==
' '
)
{
p
+=
1
;
}
if
(
p
==
end
)
return
out
;
if
(
q
!=
-
1
&&
isalpha
(
str
[
p
])
&&
isalpha
(
str
[
q
])
&&
str
[
p
-
1
]
==
' '
)
// add a space when the last and current chars are in English and there have space(s) between them
out
+=
' '
;
out
+=
str
[
p
];
q
=
p
;
p
+=
1
;
}
return
out
;
}
}
// namespace ppspeech
\ No newline at end of file
runtime/engine/common/utils/blank_process.h
已删除
100644 → 0
浏览文件 @
591b957b
#include <string>
#include <vector>
#include <cctype>
namespace
ppspeech
{
std
::
string
BlankProcess
(
const
std
::
string
&
str
);
}
// namespace ppspeech
\ No newline at end of file
runtime/engine/common/utils/strings.cc
浏览文件 @
f35a87ab
...
...
@@ -49,6 +49,75 @@ std::string StrJoin(const std::vector<std::string>& strs, const char* delim) {
return
ss
.
str
();
}
std
::
string
DelBlank
(
const
std
::
string
&
str
)
{
std
::
string
out
=
""
;
int
ptr_in
=
0
;
// the pointer of input string (for traversal)
int
end
=
str
.
size
();
int
ptr_out
=
-
1
;
// the pointer of output string (last char)
while
(
ptr_in
!=
end
)
{
while
(
ptr_in
!=
end
&&
str
[
ptr_in
]
==
' '
)
{
ptr_in
+=
1
;
}
if
(
ptr_in
==
end
)
return
out
;
if
(
ptr_out
!=
-
1
&&
isalpha
(
str
[
ptr_in
])
&&
isalpha
(
str
[
ptr_out
])
&&
str
[
ptr_in
-
1
]
==
' '
)
// add a space when the last and current chars are in English and there have space(s) between them
out
+=
' '
;
out
+=
str
[
ptr_in
];
ptr_out
=
ptr_in
;
ptr_in
+=
1
;
}
return
out
;
}
std
::
string
AddBlank
(
const
std
::
string
&
str
)
{
std
::
string
out
=
""
;
int
ptr
=
0
;
// the pointer of the input string
int
end
=
str
.
size
();
while
(
ptr
!=
end
)
{
if
(
isalpha
(
str
[
ptr
]))
{
if
(
ptr
==
0
or
str
[
ptr
-
1
]
!=
' '
)
out
+=
" "
;
// add pre-space for an English word
while
(
isalpha
(
str
[
ptr
]))
{
out
+=
str
[
ptr
];
ptr
+=
1
;
}
out
+=
" "
;
// add post-space for an English word
}
else
{
out
+=
str
[
ptr
];
ptr
+=
1
;
}
}
return
out
;
}
std
::
string
ReverseFraction
(
const
std
::
string
&
str
)
{
std
::
string
out
=
""
;
int
ptr
=
0
;
// the pointer of the input string
int
end
=
str
.
size
();
int
left
,
right
,
frac
;
// the start index of the left tag, right tag and '/'.
left
=
right
=
frac
=
0
;
int
len_tag
=
5
;
// length of "<tag>"
while
(
ptr
!=
end
)
{
// find the position of left tag, right tag and '/'. (xxx<tag>num1/num2</tag>)
left
=
str
.
find
(
"<tag>"
,
ptr
);
if
(
left
==
-
1
)
break
;
out
+=
str
.
substr
(
ptr
,
left
-
ptr
);
// content before left tag (xxx)
frac
=
str
.
find
(
"/"
,
left
);
right
=
str
.
find
(
"<tag>"
,
frac
);
out
+=
str
.
substr
(
frac
+
1
,
right
-
frac
-
1
)
+
'/'
+
str
.
substr
(
left
+
len_tag
,
frac
-
left
-
len_tag
);
// num2/num1
ptr
=
right
+
len_tag
;
}
if
(
ptr
!=
end
)
{
out
+=
str
.
substr
(
ptr
,
end
-
ptr
);
}
return
out
;
}
#ifdef _MSC_VER
std
::
wstring
ToWString
(
const
std
::
string
&
str
)
{
unsigned
len
=
str
.
size
()
*
2
;
...
...
@@ -61,4 +130,4 @@ std::wstring ToWString(const std::string& str) {
}
#endif
}
// namespace ppspeech
\ No newline at end of file
}
// namespace ppspeech
runtime/engine/common/utils/strings.h
浏览文件 @
f35a87ab
...
...
@@ -25,8 +25,14 @@ std::vector<std::string> StrSplit(const std::string& str,
std
::
string
StrJoin
(
const
std
::
vector
<
std
::
string
>&
strs
,
const
char
*
delim
);
std
::
string
DelBlank
(
const
std
::
string
&
str
);
std
::
string
AddBlank
(
const
std
::
string
&
str
);
std
::
string
ReverseFraction
(
const
std
::
string
&
str
);
#ifdef _MSC_VER
std
::
wstring
ToWString
(
const
std
::
string
&
str
);
#endif
}
// namespace ppspeech
\ No newline at end of file
}
// namespace ppspeech
runtime/engine/common/utils/strings_test.cc
浏览文件 @
f35a87ab
...
...
@@ -32,4 +32,47 @@ TEST(StringTest, StrJoinTest) {
std
::
vector
<
std
::
string
>
ins
{
"hello"
,
"world"
};
std
::
string
out
=
ppspeech
::
StrJoin
(
ins
,
" "
);
EXPECT_THAT
(
out
,
"hello world"
);
}
\ No newline at end of file
}
TEST
(
StringText
,
DelBlankTest
)
{
std
::
string
test_str
=
"我 今天 去 了 超市 花了 120 元。"
;
std
::
string
out_str
=
ppspeech
::
DelBlank
(
test_str
);
int
ret
=
out_str
.
compare
(
"我今天去了超市花了120元。"
);
EXPECT_EQ
(
ret
,
0
);
test_str
=
"how are you today"
;
out_str
=
ppspeech
::
DelBlank
(
test_str
);
ret
=
out_str
.
compare
(
"how are you today"
);
EXPECT_EQ
(
ret
,
0
);
test_str
=
"我 的 paper 在 哪里?"
;
out_str
=
ppspeech
::
DelBlank
(
test_str
);
ret
=
out_str
.
compare
(
"我的paper在哪里?"
);
EXPECT_EQ
(
ret
,
0
);
}
TEST
(
StringTest
,
AddBlankTest
)
{
std
::
string
test_str
=
"how are you"
;
std
::
string
out_str
=
ppspeech
::
AddBlank
(
test_str
);
int
ret
=
out_str
.
compare
(
" how are you "
);
EXPECT_EQ
(
ret
,
0
);
test_str
=
"欢迎来到China。"
;
out_str
=
ppspeech
::
AddBlank
(
test_str
);
ret
=
out_str
.
compare
(
"欢迎来到 China 。"
);
EXPECT_EQ
(
ret
,
0
);
}
TEST
(
StringTest
,
ReverseFractionTest
)
{
std
::
string
test_str
=
"<tag>3/1<tag>"
;
std
::
string
out_str
=
ppspeech
::
ReverseFraction
(
test_str
);
int
ret
=
out_str
.
compare
(
"1/3"
);
std
::
cout
<<
out_str
<<
std
::
endl
;
EXPECT_EQ
(
ret
,
0
);
test_str
=
"<tag>3/1<tag> <tag>100/10000<tag>"
;
out_str
=
ppspeech
::
ReverseFraction
(
test_str
);
ret
=
out_str
.
compare
(
"1/3 10000/100"
);
std
::
cout
<<
out_str
<<
std
::
endl
;
EXPECT_EQ
(
ret
,
0
);
}
runtime/engine/common/utils/text_process.cc
已删除
100644 → 0
浏览文件 @
591b957b
#include "utils/text_process.h"
namespace
ppspeech
{
std
::
string
DelBlank
(
const
std
::
string
&
str
)
{
std
::
string
out
=
""
;
int
ptr_in
=
0
;
// the pointer of input string (for traversal)
int
end
=
str
.
size
();
int
ptr_out
=
-
1
;
// the pointer of output string (last char)
while
(
ptr_in
!=
end
)
{
while
(
ptr_in
!=
end
&&
str
[
ptr_in
]
==
' '
)
{
ptr_in
+=
1
;
}
if
(
ptr_in
==
end
)
return
out
;
if
(
ptr_out
!=
-
1
&&
isalpha
(
str
[
ptr_in
])
&&
isalpha
(
str
[
ptr_out
])
&&
str
[
ptr_in
-
1
]
==
' '
)
// add a space when the last and current chars are in English and there have space(s) between them
out
+=
' '
;
out
+=
str
[
ptr_in
];
ptr_out
=
ptr_in
;
ptr_in
+=
1
;
}
return
out
;
}
std
::
string
AddBlank
(
const
std
::
string
&
str
)
{
std
::
string
out
=
""
;
int
ptr
=
0
;
// the pointer of the input string
int
end
=
str
.
size
();
while
(
ptr
!=
end
)
{
if
(
isalpha
(
str
[
ptr
]))
{
if
(
ptr
==
0
or
str
[
ptr
-
1
]
!=
' '
)
out
+=
" "
;
// add pre-space for an English word
while
(
isalpha
(
str
[
ptr
]))
{
out
+=
str
[
ptr
];
ptr
+=
1
;
}
out
+=
" "
;
// add post-space for an English word
}
else
{
out
+=
str
[
ptr
];
ptr
+=
1
;
}
}
return
out
;
}
std
::
string
ReverseFraction
(
const
std
::
string
&
str
)
{
std
::
string
out
=
""
;
int
ptr
=
0
;
// the pointer of the input string
int
end
=
str
.
size
();
int
left
,
right
,
frac
;
// the start index of the left tag, right tag and '/'.
left
=
right
=
frac
=
0
;
int
len_tag
=
5
;
// length of "<tag>"
while
(
ptr
!=
end
)
{
// find the position of left tag, right tag and '/'. (xxx<tag>num1/num2</tag>)
left
=
str
.
find
(
"<tag>"
,
ptr
);
if
(
left
==
-
1
)
break
;
out
+=
str
.
substr
(
ptr
,
left
-
ptr
);
// content before left tag (xxx)
frac
=
str
.
find
(
"/"
,
left
);
right
=
str
.
find
(
"<tag>"
,
frac
);
out
+=
str
.
substr
(
frac
+
1
,
right
-
frac
-
1
)
+
'/'
+
str
.
substr
(
left
+
len_tag
,
frac
-
left
-
len_tag
);
// num2/num1
ptr
=
right
+
len_tag
;
}
if
(
ptr
!=
end
)
{
out
+=
str
.
substr
(
ptr
,
end
-
ptr
);
}
return
out
;
}
}
// namespace ppspeech
\ No newline at end of file
runtime/engine/common/utils/text_process.h
已删除
100644 → 0
浏览文件 @
591b957b
#include <string>
#include <vector>
#include <cctype>
namespace
ppspeech
{
std
::
string
DelBlank
(
const
std
::
string
&
str
);
std
::
string
AddBlank
(
const
std
::
string
&
str
);
std
::
string
ReverseFraction
(
const
std
::
string
&
str
);
}
// namespace ppspeech
\ No newline at end of file
runtime/engine/common/utils/text_process_test.cc
已删除
100644 → 0
浏览文件 @
591b957b
#include "utils/text_process.h"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
TEST
(
TextProcess
,
DelBlankTest
)
{
std
::
string
test_str
=
"我 今天 去 了 超市 花了 120 元。"
;
std
::
string
out_str
=
ppspeech
::
DelBlank
(
test_str
);
int
ret
=
out_str
.
compare
(
"我今天去了超市花了120元。"
);
EXPECT_EQ
(
ret
,
0
);
test_str
=
"how are you today"
;
out_str
=
ppspeech
::
DelBlank
(
test_str
);
ret
=
out_str
.
compare
(
"how are you today"
);
EXPECT_EQ
(
ret
,
0
);
test_str
=
"我 的 paper 在 哪里?"
;
out_str
=
ppspeech
::
DelBlank
(
test_str
);
ret
=
out_str
.
compare
(
"我的paper在哪里?"
);
EXPECT_EQ
(
ret
,
0
);
}
TEST
(
TextProcess
,
AddBlankTest
)
{
std
::
string
test_str
=
"how are you"
;
std
::
string
out_str
=
ppspeech
::
AddBlank
(
test_str
);
int
ret
=
out_str
.
compare
(
" how are you "
);
EXPECT_EQ
(
ret
,
0
);
test_str
=
"欢迎来到China。"
;
out_str
=
ppspeech
::
AddBlank
(
test_str
);
ret
=
out_str
.
compare
(
"欢迎来到 China 。"
);
EXPECT_EQ
(
ret
,
0
);
}
TEST
(
TextProcess
,
ReverseFractionTest
)
{
std
::
string
test_str
=
"<tag>3/1<tag>"
;
std
::
string
out_str
=
ppspeech
::
ReverseFraction
(
test_str
);
int
ret
=
out_str
.
compare
(
"1/3"
);
std
::
cout
<<
out_str
<<
std
::
endl
;
EXPECT_EQ
(
ret
,
0
);
test_str
=
"<tag>3/1<tag> <tag>100/10000<tag>"
;
out_str
=
ppspeech
::
ReverseFraction
(
test_str
);
ret
=
out_str
.
compare
(
"1/3 10000/100"
);
std
::
cout
<<
out_str
<<
std
::
endl
;
EXPECT_EQ
(
ret
,
0
);
}
\ No newline at end of file
runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh
浏览文件 @
f35a87ab
...
...
@@ -16,7 +16,7 @@ text=$data/test/text
./local/split_data.sh
$data
$data
/
$aishell_wav_scp
$aishell_wav_scp
$nj
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recognizer.log
\
u2_
recognizer_main
\
recognizer_main
\
--use_fbank
=
true
\
--num_bins
=
80
\
--cmvn_file
=
$model_dir
/mean_std.json
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录