Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
767f6dd4
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“0a80ddfeff31982df25c806c0b8a6be96ae7efd6”上不存在“mobile/src/operators/transpose_op.cpp”
未验证
提交
767f6dd4
编写于
3月 27, 2023
作者:
Y
YangZhou
提交者:
GitHub
3月 27, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[engine] add recognizer_controller && fix build bugs (#3086)
* fix asr compile
上级
2be7e572
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
246 addition
and
80 deletion
+246
-80
runtime/CMakeLists.txt
runtime/CMakeLists.txt
+7
-5
runtime/cmake/gflags.cmake
runtime/cmake/gflags.cmake
+2
-1
runtime/cmake/openfst.cmake
runtime/cmake/openfst.cmake
+1
-1
runtime/engine/CMakeLists.txt
runtime/engine/CMakeLists.txt
+1
-1
runtime/engine/asr/decoder/CMakeLists.txt
runtime/engine/asr/decoder/CMakeLists.txt
+2
-2
runtime/engine/asr/nnet/u2_nnet.cc
runtime/engine/asr/nnet/u2_nnet.cc
+2
-1
runtime/engine/asr/recognizer/CMakeLists.txt
runtime/engine/asr/recognizer/CMakeLists.txt
+4
-2
runtime/engine/asr/recognizer/recognizer_batch_main.cc
runtime/engine/asr/recognizer/recognizer_batch_main.cc
+99
-43
runtime/engine/asr/recognizer/recognizer_controller.cc
runtime/engine/asr/recognizer/recognizer_controller.cc
+71
-0
runtime/engine/asr/recognizer/recognizer_controller.h
runtime/engine/asr/recognizer/recognizer_controller.h
+39
-0
runtime/engine/asr/recognizer/u2_recognizer_main.cc
runtime/engine/asr/recognizer/u2_recognizer_main.cc
+15
-21
runtime/engine/common/frontend/CMakeLists.txt
runtime/engine/common/frontend/CMakeLists.txt
+2
-2
runtime/engine/kaldi/fstbin/CMakeLists.txt
runtime/engine/kaldi/fstbin/CMakeLists.txt
+1
-1
未找到文件。
runtime/CMakeLists.txt
浏览文件 @
767f6dd4
...
@@ -14,11 +14,6 @@ set(PPS_VERSION_MINOR 0)
...
@@ -14,11 +14,6 @@ set(PPS_VERSION_MINOR 0)
set
(
PPS_VERSION_PATCH 0
)
set
(
PPS_VERSION_PATCH 0
)
set
(
PPS_VERSION
"
${
PPS_VERSION_MAJOR
}
.
${
PPS_VERSION_MINOR
}
.
${
PPS_VERSION_PATCH
}
"
)
set
(
PPS_VERSION
"
${
PPS_VERSION_MAJOR
}
.
${
PPS_VERSION_MINOR
}
.
${
PPS_VERSION_PATCH
}
"
)
# fc_patch dir
set
(
FETCHCONTENT_QUIET off
)
get_filename_component
(
fc_patch
"fc_patch"
REALPATH BASE_DIR
"
${
CMAKE_SOURCE_DIR
}
"
)
set
(
FETCHCONTENT_BASE_DIR
${
fc_patch
}
)
# compiler option
# compiler option
# Keep the same with openfst, -fPIC or -fpic
# Keep the same with openfst, -fPIC or -fpic
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
--std=c++14 -pthread -fPIC -O0 -Wall -g -ldl"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
--std=c++14 -pthread -fPIC -O0 -Wall -g -ldl"
)
...
@@ -50,11 +45,18 @@ set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install)
...
@@ -50,11 +45,18 @@ set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install)
include
(
FetchContent
)
include
(
FetchContent
)
include
(
ExternalProject
)
include
(
ExternalProject
)
# fc_patch dir
set
(
FETCHCONTENT_QUIET off
)
get_filename_component
(
fc_patch
"fc_patch"
REALPATH BASE_DIR
"
${
CMAKE_SOURCE_DIR
}
"
)
set
(
FETCHCONTENT_BASE_DIR
${
fc_patch
}
)
###############################################################################
###############################################################################
# Option Configurations
# Option Configurations
###############################################################################
###############################################################################
# https://github.com/google/brotli/pull/655
# https://github.com/google/brotli/pull/655
option
(
BUILD_SHARED_LIBS
"Build shared libraries"
ON
)
option
(
BUILD_SHARED_LIBS
"Build shared libraries"
ON
)
option
(
NDEBUG
"debug option"
OFF
)
option
(
WITH_ASR
"build asr"
ON
)
option
(
WITH_ASR
"build asr"
ON
)
option
(
WITH_CLS
"build cls"
ON
)
option
(
WITH_CLS
"build cls"
ON
)
...
...
runtime/cmake/gflags.cmake
浏览文件 @
767f6dd4
...
@@ -9,5 +9,6 @@ FetchContent_MakeAvailable(gflags)
...
@@ -9,5 +9,6 @@ FetchContent_MakeAvailable(gflags)
# openfst need
# openfst need
include_directories
(
${
gflags_BINARY_DIR
}
/include
)
include_directories
(
${
gflags_BINARY_DIR
}
/include
)
link_directories
(
${
gflags_BINARY_DIR
}
)
install
(
FILES
${
gflags_BINARY_DIR
}
/libgflags_nothreads.a DESTINATION lib
)
#install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib)
\ No newline at end of file
runtime/cmake/openfst.cmake
浏览文件 @
767f6dd4
...
@@ -30,7 +30,7 @@ ExternalProject_Add(openfst
...
@@ -30,7 +30,7 @@ ExternalProject_Add(openfst
CONFIGURE_COMMAND
${
openfst_SOURCE_DIR
}
/configure --prefix=
${
openfst_PREFIX_DIR
}
CONFIGURE_COMMAND
${
openfst_SOURCE_DIR
}
/configure --prefix=
${
openfst_PREFIX_DIR
}
"CPPFLAGS=-I
${
gflags_BINARY_DIR
}
/include -I
${
glog_SOURCE_DIR
}
/src -I
${
glog_BINARY_DIR
}
"
"CPPFLAGS=-I
${
gflags_BINARY_DIR
}
/include -I
${
glog_SOURCE_DIR
}
/src -I
${
glog_BINARY_DIR
}
"
"LDFLAGS=-L
${
gflags_BINARY_DIR
}
-L
${
glog_BINARY_DIR
}
"
"LDFLAGS=-L
${
gflags_BINARY_DIR
}
-L
${
glog_BINARY_DIR
}
"
"LIBS=-lgflags_nothreads -lglog -lpthread"
"LIBS=-lgflags_nothreads -lglog -lpthread
-fPIC
"
COMMAND
${
CMAKE_COMMAND
}
-E copy_directory
${
PROJECT_SOURCE_DIR
}
/patch/openfst
${
openfst_SOURCE_DIR
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_directory
${
PROJECT_SOURCE_DIR
}
/patch/openfst
${
openfst_SOURCE_DIR
}
BUILD_COMMAND make -j 4
BUILD_COMMAND make -j 4
)
)
...
...
runtime/engine/CMakeLists.txt
浏览文件 @
767f6dd4
...
@@ -21,4 +21,4 @@ if(WITH_VAD)
...
@@ -21,4 +21,4 @@ if(WITH_VAD)
add_subdirectory
(
vad
)
add_subdirectory
(
vad
)
endif
()
endif
()
add_subdirectory
(
codelab
)
add_subdirectory
(
codelab
)
\ No newline at end of file
runtime/engine/asr/decoder/CMakeLists.txt
浏览文件 @
767f6dd4
...
@@ -16,9 +16,9 @@ set(TEST_BINS
...
@@ -16,9 +16,9 @@ set(TEST_BINS
foreach
(
bin_name IN LISTS TEST_BINS
)
foreach
(
bin_name IN LISTS TEST_BINS
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
nnet decoder fst utils
gflags
glog kaldi-base kaldi-matrix kaldi-util
)
target_link_libraries
(
${
bin_name
}
nnet decoder fst utils
libgflags_nothreads.so
glog kaldi-base kaldi-matrix kaldi-util
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
-ldl
)
endforeach
()
endforeach
()
runtime/engine/asr/nnet/u2_nnet.cc
浏览文件 @
767f6dd4
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.cc
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.cc
#include "nnet/u2_nnet.h"
#include "nnet/u2_nnet.h"
#include <type_traits>
#ifdef WITH_PROFILING
#ifdef WITH_PROFILING
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
@@ -214,7 +215,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
...
@@ -214,7 +215,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
// not cache feature in nnet
// not cache feature in nnet
CHECK_EQ
(
cached_feats_
.
size
(),
0
);
CHECK_EQ
(
cached_feats_
.
size
(),
0
);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value
, true);
CHECK_EQ
((
std
::
is_same
<
float
,
kaldi
::
BaseFloat
>::
value
)
,
true
);
std
::
memcpy
(
feats_ptr
,
std
::
memcpy
(
feats_ptr
,
chunk_feats
.
data
(),
chunk_feats
.
data
(),
chunk_feats
.
size
()
*
sizeof
(
kaldi
::
BaseFloat
));
chunk_feats
.
size
()
*
sizeof
(
kaldi
::
BaseFloat
));
...
...
runtime/engine/asr/recognizer/CMakeLists.txt
浏览文件 @
767f6dd4
...
@@ -2,6 +2,7 @@ set(srcs)
...
@@ -2,6 +2,7 @@ set(srcs)
list
(
APPEND srcs
list
(
APPEND srcs
u2_recognizer.cc
u2_recognizer.cc
recognizer_controller.cc
)
)
add_library
(
recognizer STATIC
${
srcs
}
)
add_library
(
recognizer STATIC
${
srcs
}
)
...
@@ -11,13 +12,14 @@ set(TEST_BINS
...
@@ -11,13 +12,14 @@ set(TEST_BINS
u2_recognizer_main
u2_recognizer_main
u2_recognizer_thread_main
u2_recognizer_thread_main
u2_recognizer_batch_main
u2_recognizer_batch_main
recognizer_batch_main
)
)
foreach
(
bin_name IN LISTS TEST_BINS
)
foreach
(
bin_name IN LISTS TEST_BINS
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
recognizer nnet decoder fst utils
gflags
glog kaldi-base kaldi-matrix kaldi-util
)
target_link_libraries
(
${
bin_name
}
recognizer nnet decoder fst utils
libgflags_nothreads.so
glog kaldi-base kaldi-matrix kaldi-util
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
-ldl
)
endforeach
()
endforeach
()
runtime/engine/asr/recognizer/
u2_recognizer_thread
_main.cc
→
runtime/engine/asr/recognizer/
recognizer_batch
_main.cc
浏览文件 @
767f6dd4
...
@@ -12,48 +12,66 @@
...
@@ -12,48 +12,66 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/u2_recognizer.h"
#include "recognizer/u2_recognizer.h"
#include "recognizer/recognizer_controller.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
DEFINE_int32
(
njob
,
3
,
"njob"
);
using
std
::
string
;
using
std
::
vector
;
void
SplitUtt
(
string
wavlist_file
,
vector
<
vector
<
string
>>*
uttlists
,
vector
<
vector
<
string
>>*
wavlists
,
int
njob
)
{
vector
<
string
>
wavlist
;
wavlists
->
resize
(
njob
);
uttlists
->
resize
(
njob
);
ppspeech
::
ReadFileToVector
(
wavlist_file
,
&
wavlist
);
for
(
size_t
idx
=
0
;
idx
<
wavlist
.
size
();
++
idx
)
{
string
utt_str
=
wavlist
[
idx
];
vector
<
string
>
utt_wav
=
ppspeech
::
StrSplit
(
utt_str
,
"
\t
"
);
LOG
(
INFO
)
<<
utt_wav
[
0
];
CHECK_EQ
(
utt_wav
.
size
(),
size_t
(
2
));
uttlists
->
at
(
idx
%
njob
).
push_back
(
utt_wav
[
0
]);
wavlists
->
at
(
idx
%
njob
).
push_back
(
utt_wav
[
1
]);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
void
recognizer_func
(
ppspeech
::
RecognizerController
*
recognizer_controller
,
gflags
::
SetUsageMessage
(
"Usage:"
);
std
::
vector
<
string
>
wavlist
,
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
std
::
vector
<
string
>
uttlist
,
google
::
InitGoogleLogging
(
argv
[
0
]);
std
::
vector
<
string
>*
results
)
{
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int32
num_done
=
0
,
num_err
=
0
;
int32
num_done
=
0
,
num_err
=
0
;
double
tot_wav_duration
=
0.0
;
double
tot_wav_duration
=
0.0
;
double
tot_attention_rescore_time
=
0.0
;
double
tot_attention_rescore_time
=
0.0
;
double
tot_decode_time
=
0.0
;
double
tot_decode_time
=
0.0
;
int
chunk_sample_size
=
FLAGS_streaming_chunk
*
FLAGS_sample_rate
;
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
if
(
wavlist
.
empty
())
return
;
FLAGS_wav_rspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
results
->
reserve
(
wavlist
.
size
());
for
(
size_t
idx
=
0
;
idx
<
wavlist
.
size
();
++
idx
)
{
int
sample_rate
=
FLAGS_sample_rate
;
std
::
string
utt
=
uttlist
[
idx
];
float
streaming_chunk
=
FLAGS_streaming_chunk
;
std
::
string
wav_file
=
wavlist
[
idx
];
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
std
::
ifstream
infile
;
LOG
(
INFO
)
<<
"sr: "
<<
sample_rate
;
infile
.
open
(
wav_file
,
std
::
ifstream
::
in
);
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
kaldi
::
WaveData
wave_data
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
wave_data
.
Read
(
infile
);
int32
recog_id
=
-
1
;
ppspeech
::
U2RecognizerResource
resource
=
while
(
recog_id
!=
-
1
)
{
ppspeech
::
U2RecognizerResource
::
InitFromFlags
();
recog_id
=
recognizer_controller
->
GetRecognizerInstanceId
();
std
::
shared_ptr
<
ppspeech
::
U2Recognizer
>
recognizer_ptr
(
}
new
ppspeech
::
U2Recognizer
(
resource
));
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
recognizer_ptr
->
InitDecoder
();
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"wav dur: "
<<
wave_data
.
Duration
()
<<
" sec."
;
LOG
(
INFO
)
<<
"wav dur: "
<<
wave_data
.
Duration
()
<<
" sec."
;
double
dur
=
wave_data
.
Duration
();
double
dur
=
wave_data
.
Duration
();
...
@@ -77,27 +95,21 @@ int main(int argc, char* argv[]) {
...
@@ -77,27 +95,21 @@ int main(int argc, char* argv[]) {
wav_chunk
[
i
]
=
waveform
(
sample_offset
+
i
);
wav_chunk
[
i
]
=
waveform
(
sample_offset
+
i
);
}
}
recognizer_
ptr
->
Accept
(
wav_chunk
);
recognizer_
controller
->
Accept
(
wav_chunk
,
recog_id
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
if
(
cur_chunk_size
<
chunk_sample_size
)
{
recognizer_ptr
->
SetInputFinished
(
);
recognizer_controller
->
SetInputFinished
(
recog_id
);
}
}
// no overlap
// no overlap
sample_offset
+=
cur_chunk_size
;
sample_offset
+=
cur_chunk_size
;
}
}
CHECK
(
sample_offset
==
tot_samples
);
CHECK
(
sample_offset
==
tot_samples
);
recognizer_ptr
->
WaitDecodeFinished
();
std
::
string
result
=
recognizer_controller
->
GetFinalResult
(
recog_id
);
kaldi
::
Timer
timer
;
recognizer_ptr
->
AttentionRescoring
();
tot_attention_rescore_time
+=
timer
.
Elapsed
();
std
::
string
result
=
recognizer_ptr
->
GetFinalResult
();
if
(
result
.
empty
())
{
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
// the TokenWriter can not write empty string.
++
num_err
;
++
num_err
;
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is empty"
;
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is empty"
;
continue
;
result
=
" "
;
}
}
tot_decode_time
+=
local_timer
.
Elapsed
();
tot_decode_time
+=
local_timer
.
Elapsed
();
...
@@ -105,15 +117,59 @@ int main(int argc, char* argv[]) {
...
@@ -105,15 +117,59 @@ int main(int argc, char* argv[]) {
LOG
(
INFO
)
<<
" RTF: "
<<
local_timer
.
Elapsed
()
/
dur
<<
" dur: "
<<
dur
LOG
(
INFO
)
<<
" RTF: "
<<
local_timer
.
Elapsed
()
/
dur
<<
" dur: "
<<
dur
<<
" cost: "
<<
local_timer
.
Elapsed
();
<<
" cost: "
<<
local_timer
.
Elapsed
();
result_writer
.
Write
(
utt
,
result
);
results
->
push_back
(
result
);
++
num_done
;
++
num_done
;
}
}
recognizer_ptr
->
WaitFinished
();
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"total decode cost:"
<<
tot_decode_time
<<
" sec"
;
LOG
(
INFO
)
<<
"total decode cost:"
<<
tot_decode_time
<<
" sec"
;
LOG
(
INFO
)
<<
"total rescore cost:"
<<
tot_attention_rescore_time
<<
" sec"
;
LOG
(
INFO
)
<<
"RTF is: "
<<
tot_decode_time
/
tot_wav_duration
;
LOG
(
INFO
)
<<
"RTF is: "
<<
tot_decode_time
/
tot_wav_duration
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int
sample_rate
=
FLAGS_sample_rate
;
float
streaming_chunk
=
FLAGS_streaming_chunk
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
int
njob
=
FLAGS_njob
;
LOG
(
INFO
)
<<
"sr: "
<<
sample_rate
;
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
ppspeech
::
U2RecognizerResource
resource
=
ppspeech
::
U2RecognizerResource
::
InitFromFlags
();
ppspeech
::
RecognizerController
recognizer_controller
(
njob
,
resource
);
ThreadPool
threadpool
(
njob
);
vector
<
vector
<
string
>>
wavlist
;
vector
<
vector
<
string
>>
uttlist
;
vector
<
vector
<
string
>>
resultlist
(
njob
);
vector
<
std
::
future
<
void
>>
futurelist
;
SplitUtt
(
FLAGS_wav_rspecifier
,
&
uttlist
,
&
wavlist
,
njob
);
for
(
size_t
i
=
0
;
i
<
njob
;
++
i
)
{
std
::
future
<
void
>
f
=
threadpool
.
enqueue
(
recognizer_func
,
&
recognizer_controller
,
wavlist
[
i
],
uttlist
[
i
],
&
resultlist
[
i
]);
futurelist
.
push_back
(
std
::
move
(
f
));
}
for
(
size_t
i
=
0
;
i
<
njob
;
++
i
)
{
futurelist
[
i
].
get
();
}
for
(
size_t
idx
=
0
;
idx
<
njob
;
++
idx
)
{
for
(
size_t
utt_idx
=
0
;
utt_idx
<
uttlist
[
idx
].
size
();
++
utt_idx
)
{
string
utt
=
uttlist
[
idx
][
utt_idx
];
string
result
=
resultlist
[
idx
][
utt_idx
];
result_writer
.
Write
(
utt
,
result
);
}
}
return
0
;
}
runtime/engine/asr/recognizer/recognizer_controller.cc
0 → 100644
浏览文件 @
767f6dd4
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer_controller.h"
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
namespace
ppspeech
{
RecognizerController
::
RecognizerController
(
int
num_worker
,
U2RecognizerResource
resource
)
{
nnet_
=
std
::
make_shared
<
ppspeech
::
U2Nnet
>
(
resource
.
model_opts
);
recognizer_workers
.
resize
(
num_worker
);
for
(
size_t
i
=
0
;
i
<
num_worker
;
++
i
)
{
recognizer_workers
[
i
].
reset
(
new
ppspeech
::
U2Recognizer
(
resource
,
nnet_
->
Clone
()));
recognizer_workers
[
i
]
->
InitDecoder
();
waiting_workers
.
push
(
i
);
}
}
int
RecognizerController
::
GetRecognizerInstanceId
()
{
if
(
waiting_workers
.
empty
())
{
return
-
1
;
}
int
idx
=
-
1
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
idx
=
waiting_workers
.
front
();
waiting_workers
.
pop
();
}
return
idx
;
}
RecognizerController
::~
RecognizerController
()
{
for
(
size_t
i
=
0
;
i
<
recognizer_workers
.
size
();
++
i
)
{
recognizer_workers
[
i
]
->
SetInputFinished
();
recognizer_workers
[
i
]
->
WaitDecodeFinished
();
}
}
std
::
string
RecognizerController
::
GetFinalResult
(
int
idx
)
{
recognizer_workers
[
idx
]
->
WaitDecodeFinished
();
recognizer_workers
[
idx
]
->
AttentionRescoring
();
std
::
string
result
=
recognizer_workers
[
idx
]
->
GetFinalResult
();
recognizer_workers
[
idx
]
->
InitDecoder
();
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
waiting_workers
.
push
(
idx
);
}
return
result
;
}
void
RecognizerController
::
Accept
(
std
::
vector
<
float
>
data
,
int
idx
)
{
recognizer_workers
[
idx
]
->
Accept
(
data
);
}
void
RecognizerController
::
SetInputFinished
(
int
idx
)
{
recognizer_workers
[
idx
]
->
SetInputFinished
();
}
}
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer_controller.h
0 → 100644
浏览文件 @
767f6dd4
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <queue>
#include <memory>
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
namespace
ppspeech
{
class
RecognizerController
{
public:
explicit
RecognizerController
(
int
num_worker
,
U2RecognizerResource
resource
);
~
RecognizerController
();
int
GetRecognizerInstanceId
();
void
Accept
(
std
::
vector
<
float
>
data
,
int
idx
);
void
SetInputFinished
(
int
idx
);
std
::
string
GetFinalResult
(
int
idx
);
private:
std
::
queue
<
int
>
waiting_workers
;
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet_
;
std
::
mutex
mutex_
;
std
::
vector
<
std
::
unique_ptr
<
ppspeech
::
U2Recognizer
>>
recognizer_workers
;
};
}
\ No newline at end of file
runtime/engine/asr/recognizer/u2_recognizer_main.cc
浏览文件 @
767f6dd4
...
@@ -31,6 +31,7 @@ int main(int argc, char* argv[]) {
...
@@ -31,6 +31,7 @@ int main(int argc, char* argv[]) {
int32
num_done
=
0
,
num_err
=
0
;
int32
num_done
=
0
,
num_err
=
0
;
double
tot_wav_duration
=
0.0
;
double
tot_wav_duration
=
0.0
;
double
tot_attention_rescore_time
=
0.0
;
double
tot_decode_time
=
0.0
;
double
tot_decode_time
=
0.0
;
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
...
@@ -46,10 +47,11 @@ int main(int argc, char* argv[]) {
...
@@ -46,10 +47,11 @@ int main(int argc, char* argv[]) {
ppspeech
::
U2RecognizerResource
resource
=
ppspeech
::
U2RecognizerResource
resource
=
ppspeech
::
U2RecognizerResource
::
InitFromFlags
();
ppspeech
::
U2RecognizerResource
::
InitFromFlags
();
ppspeech
::
U2Recognizer
recognizer
(
resource
);
std
::
shared_ptr
<
ppspeech
::
U2Recognizer
>
recognizer_ptr
(
new
ppspeech
::
U2Recognizer
(
resource
));
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
recognizer
.
InitDecoder
();
recognizer
_ptr
->
InitDecoder
();
std
::
string
utt
=
wav_reader
.
Key
();
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
...
@@ -64,8 +66,6 @@ int main(int argc, char* argv[]) {
...
@@ -64,8 +66,6 @@ int main(int argc, char* argv[]) {
LOG
(
INFO
)
<<
"wav len (sample): "
<<
tot_samples
;
LOG
(
INFO
)
<<
"wav len (sample): "
<<
tot_samples
;
int
sample_offset
=
0
;
int
sample_offset
=
0
;
int
cnt
=
0
;
kaldi
::
Timer
timer
;
kaldi
::
Timer
local_timer
;
kaldi
::
Timer
local_timer
;
while
(
sample_offset
<
tot_samples
)
{
while
(
sample_offset
<
tot_samples
)
{
...
@@ -76,32 +76,23 @@ int main(int argc, char* argv[]) {
...
@@ -76,32 +76,23 @@ int main(int argc, char* argv[]) {
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
[
i
]
=
waveform
(
sample_offset
+
i
);
wav_chunk
[
i
]
=
waveform
(
sample_offset
+
i
);
}
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer
.
Accept
(
wav_chunk
);
recognizer_ptr
->
Accept
(
wav_chunk
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
if
(
cur_chunk_size
==
(
tot_samples
-
sample_offset
))
{
recognizer
.
SetInputFinished
();
recognizer_ptr
->
SetInputFinished
();
}
recognizer
.
Decode
();
if
(
recognizer
.
DecodedSomething
())
{
LOG
(
INFO
)
<<
"Pratial result: "
<<
cnt
<<
" "
<<
recognizer
.
GetPartialResult
();
}
}
// no overlap
// no overlap
sample_offset
+=
cur_chunk_size
;
sample_offset
+=
cur_chunk_size
;
cnt
++
;
}
}
CHECK
(
sample_offset
==
tot_samples
);
CHECK
(
sample_offset
==
tot_samples
);
recognizer_ptr
->
WaitDecodeFinished
();
// second pass decoding
kaldi
::
Timer
timer
;
recognizer
.
Rescoring
();
recognizer_ptr
->
AttentionRescoring
();
tot_attention_rescore_time
+=
timer
.
Elapsed
();
tot_decode_time
+=
timer
.
Elapsed
();
std
::
string
result
=
recognizer
.
GetFinalResult
();
std
::
string
result
=
recognizer_ptr
->
GetFinalResult
();
if
(
result
.
empty
())
{
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
// the TokenWriter can not write empty string.
++
num_err
;
++
num_err
;
...
@@ -109,6 +100,7 @@ int main(int argc, char* argv[]) {
...
@@ -109,6 +100,7 @@ int main(int argc, char* argv[]) {
continue
;
continue
;
}
}
tot_decode_time
+=
local_timer
.
Elapsed
();
LOG
(
INFO
)
<<
utt
<<
" "
<<
result
;
LOG
(
INFO
)
<<
utt
<<
" "
<<
result
;
LOG
(
INFO
)
<<
" RTF: "
<<
local_timer
.
Elapsed
()
/
dur
<<
" dur: "
<<
dur
LOG
(
INFO
)
<<
" RTF: "
<<
local_timer
.
Elapsed
()
/
dur
<<
" dur: "
<<
dur
<<
" cost: "
<<
local_timer
.
Elapsed
();
<<
" cost: "
<<
local_timer
.
Elapsed
();
...
@@ -117,9 +109,11 @@ int main(int argc, char* argv[]) {
...
@@ -117,9 +109,11 @@ int main(int argc, char* argv[]) {
++
num_done
;
++
num_done
;
}
}
recognizer_ptr
->
WaitFinished
();
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"total decode cost:"
<<
tot_decode_time
<<
" sec"
;
LOG
(
INFO
)
<<
"total decode cost:"
<<
tot_decode_time
<<
" sec"
;
LOG
(
INFO
)
<<
"total rescore cost:"
<<
tot_attention_rescore_time
<<
" sec"
;
LOG
(
INFO
)
<<
"RTF is: "
<<
tot_decode_time
/
tot_wav_duration
;
LOG
(
INFO
)
<<
"RTF is: "
<<
tot_decode_time
/
tot_wav_duration
;
}
}
runtime/engine/common/frontend/CMakeLists.txt
浏览文件 @
767f6dd4
...
@@ -26,5 +26,5 @@ foreach(bin_name IN LISTS BINS)
...
@@ -26,5 +26,5 @@ foreach(bin_name IN LISTS BINS)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207
# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207
target_link_libraries
(
${
bin_name
}
PUBLIC frontend base utils kaldi-util gflags Threads::Threads extern_glog
)
target_link_libraries
(
${
bin_name
}
PUBLIC frontend base utils kaldi-util libgflags_nothreads.so Threads::Threads extern_glog
)
endforeach
()
endforeach
()
\ No newline at end of file
runtime/engine/kaldi/fstbin/CMakeLists.txt
浏览文件 @
767f6dd4
...
@@ -11,5 +11,5 @@ fsttablecompose
...
@@ -11,5 +11,5 @@ fsttablecompose
foreach
(
binary IN LISTS BINS
)
foreach
(
binary IN LISTS BINS
)
add_executable
(
${
binary
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
binary
}
.cc
)
add_executable
(
${
binary
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
binary
}
.cc
)
target_include_directories
(
${
binary
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_include_directories
(
${
binary
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
binary
}
PUBLIC kaldi-fstext glog
gflags
fst dl
)
target_link_libraries
(
${
binary
}
PUBLIC kaldi-fstext glog
libgflags_nothreads.so
fst dl
)
endforeach
()
endforeach
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录