Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b05ead51
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b05ead51
编写于
4月 18, 2023
作者:
Y
YangZhou
提交者:
GitHub
4月 18, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[engine]add recognizer api && clean params && make a shared decoder resource (#3165)
上级
11ce08b2
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
453 addition
and
150 deletion
+453
-150
runtime/engine/asr/decoder/ctc_beam_search_opt.h
runtime/engine/asr/decoder/ctc_beam_search_opt.h
+5
-34
runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc
runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc
+2
-3
runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h
runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h
+1
-2
runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc
...engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc
+5
-4
runtime/engine/asr/decoder/ctc_tlg_decoder.cc
runtime/engine/asr/decoder/ctc_tlg_decoder.cc
+3
-1
runtime/engine/asr/decoder/ctc_tlg_decoder.h
runtime/engine/asr/decoder/ctc_tlg_decoder.h
+8
-1
runtime/engine/asr/decoder/param.h
runtime/engine/asr/decoder/param.h
+1
-15
runtime/engine/asr/nnet/nnet_itf.h
runtime/engine/asr/nnet/nnet_itf.h
+1
-26
runtime/engine/asr/nnet/nnet_producer.cc
runtime/engine/asr/nnet/nnet_producer.cc
+1
-2
runtime/engine/asr/recognizer/CMakeLists.txt
runtime/engine/asr/recognizer/CMakeLists.txt
+3
-0
runtime/engine/asr/recognizer/recognizer.cc
runtime/engine/asr/recognizer/recognizer.cc
+34
-1
runtime/engine/asr/recognizer/recognizer.h
runtime/engine/asr/recognizer/recognizer.h
+16
-1
runtime/engine/asr/recognizer/recognizer_batch_main2.cc
runtime/engine/asr/recognizer/recognizer_batch_main2.cc
+168
-0
runtime/engine/asr/recognizer/recognizer_controller.cc
runtime/engine/asr/recognizer/recognizer_controller.cc
+1
-2
runtime/engine/asr/recognizer/recognizer_controller.h
runtime/engine/asr/recognizer/recognizer_controller.h
+0
-2
runtime/engine/asr/recognizer/recognizer_controller_impl.cc
runtime/engine/asr/recognizer/recognizer_controller_impl.cc
+5
-32
runtime/engine/asr/recognizer/recognizer_controller_impl.h
runtime/engine/asr/recognizer/recognizer_controller_impl.h
+1
-3
runtime/engine/asr/recognizer/recognizer_impl.h
runtime/engine/asr/recognizer/recognizer_impl.h
+0
-13
runtime/engine/asr/recognizer/recognizer_instance.cc
runtime/engine/asr/recognizer/recognizer_instance.cc
+66
-0
runtime/engine/asr/recognizer/recognizer_instance.h
runtime/engine/asr/recognizer/recognizer_instance.h
+42
-0
runtime/engine/asr/recognizer/recognizer_resource.h
runtime/engine/asr/recognizer/recognizer_resource.h
+13
-4
runtime/engine/common/utils/file_utils.cc
runtime/engine/common/utils/file_utils.cc
+29
-0
runtime/engine/common/utils/file_utils.h
runtime/engine/common/utils/file_utils.h
+3
-0
runtime/examples/u2pp_ol/wenetspeech/local/decode.sh
runtime/examples/u2pp_ol/wenetspeech/local/decode.sh
+2
-2
runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh
runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh
+1
-1
runtime/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh
...me/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh
+1
-1
runtime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh
...ime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh
+41
-0
未找到文件。
runtime/engine/asr/decoder/ctc_beam_search_opt.h
浏览文件 @
b05ead51
...
@@ -22,51 +22,22 @@ namespace ppspeech {
...
@@ -22,51 +22,22 @@ namespace ppspeech {
struct
CTCBeamSearchOptions
{
struct
CTCBeamSearchOptions
{
// common
// common
int
blank
;
int
blank
;
std
::
string
word_symbol_table
;
// ds2
std
::
string
dict_file
;
std
::
string
lm_path
;
int
beam_size
;
BaseFloat
alpha
;
BaseFloat
beta
;
BaseFloat
cutoff_prob
;
int
cutoff_top_n
;
int
num_proc_bsearch
;
// u2
// u2
int
first_beam_size
;
int
first_beam_size
;
int
second_beam_size
;
int
second_beam_size
;
CTCBeamSearchOptions
()
CTCBeamSearchOptions
()
:
blank
(
0
),
:
blank
(
0
),
dict_file
(
"vocab.txt"
),
word_symbol_table
(
"vocab.txt"
),
lm_path
(
""
),
beam_size
(
300
),
alpha
(
1.9
f
),
beta
(
5.0
),
cutoff_prob
(
0.99
f
),
cutoff_top_n
(
40
),
num_proc_bsearch
(
10
),
first_beam_size
(
10
),
first_beam_size
(
10
),
second_beam_size
(
10
)
{}
second_beam_size
(
10
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
std
::
string
module
=
"Ds2BeamSearchConfig: "
;
std
::
string
module
=
"CTCBeamSearchOptions: "
;
opts
->
Register
(
"dict"
,
&
dict_file
,
module
+
"vocab file path."
);
opts
->
Register
(
"word_symbol_table"
,
&
word_symbol_table
,
module
+
"vocab file path."
);
opts
->
Register
(
"lm-path"
,
&
lm_path
,
module
+
"ngram language model path."
);
opts
->
Register
(
"alpha"
,
&
alpha
,
module
+
"alpha"
);
opts
->
Register
(
"beta"
,
&
beta
,
module
+
"beta"
);
opts
->
Register
(
"beam-size"
,
&
beam_size
,
module
+
"beam size for beam search method"
);
opts
->
Register
(
"cutoff-prob"
,
&
cutoff_prob
,
module
+
"cutoff probs"
);
opts
->
Register
(
"cutoff-top-n"
,
&
cutoff_top_n
,
module
+
"cutoff top n"
);
opts
->
Register
(
"num-proc-bsearch"
,
&
num_proc_bsearch
,
module
+
"num proc bsearch"
);
opts
->
Register
(
"blank"
,
&
blank
,
"blank id, default is 0."
);
opts
->
Register
(
"blank"
,
&
blank
,
"blank id, default is 0."
);
module
=
"U2BeamSearchConfig: "
;
opts
->
Register
(
opts
->
Register
(
"first-beam-size"
,
&
first_beam_size
,
module
+
"first beam size."
);
"first-beam-size"
,
&
first_beam_size
,
module
+
"first beam size."
);
opts
->
Register
(
"second-beam-size"
,
opts
->
Register
(
"second-beam-size"
,
...
...
runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc
浏览文件 @
b05ead51
...
@@ -30,11 +30,10 @@ using paddle::platform::TracerEventType;
...
@@ -30,11 +30,10 @@ using paddle::platform::TracerEventType;
namespace
ppspeech
{
namespace
ppspeech
{
CTCPrefixBeamSearch
::
CTCPrefixBeamSearch
(
const
std
::
string
&
vocab_path
,
CTCPrefixBeamSearch
::
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
)
const
CTCBeamSearchOptions
&
opts
)
:
opts_
(
opts
)
{
:
opts_
(
opts
)
{
unit_table_
=
std
::
shared_ptr
<
fst
::
SymbolTable
>
(
unit_table_
=
std
::
shared_ptr
<
fst
::
SymbolTable
>
(
fst
::
SymbolTable
::
ReadText
(
vocab_path
));
fst
::
SymbolTable
::
ReadText
(
opts
.
word_symbol_table
));
CHECK
(
unit_table_
!=
nullptr
);
CHECK
(
unit_table_
!=
nullptr
);
Reset
();
Reset
();
...
...
runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h
浏览文件 @
b05ead51
...
@@ -27,8 +27,7 @@ namespace ppspeech {
...
@@ -27,8 +27,7 @@ namespace ppspeech {
class
ContextGraph
;
class
ContextGraph
;
class
CTCPrefixBeamSearch
:
public
DecoderBase
{
class
CTCPrefixBeamSearch
:
public
DecoderBase
{
public:
public:
CTCPrefixBeamSearch
(
const
std
::
string
&
vocab_path
,
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
);
const
CTCBeamSearchOptions
&
opts
);
~
CTCPrefixBeamSearch
()
{}
~
CTCPrefixBeamSearch
()
{}
SearchType
Type
()
const
{
return
SearchType
::
kPrefixBeamSearch
;
}
SearchType
Type
()
const
{
return
SearchType
::
kPrefixBeamSearch
;
}
...
...
runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc
浏览文件 @
b05ead51
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
DEFINE_string
(
feature_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
feature_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_string
(
vocab_path
,
""
,
"vocab path"
);
DEFINE_string
(
word_symbol_table
,
""
,
"vocab path"
);
DEFINE_string
(
model_path
,
""
,
"paddle nnet model"
);
DEFINE_string
(
model_path
,
""
,
"paddle nnet model"
);
...
@@ -52,10 +52,10 @@ int main(int argc, char* argv[]) {
...
@@ -52,10 +52,10 @@ int main(int argc, char* argv[]) {
CHECK_NE
(
FLAGS_result_wspecifier
,
""
);
CHECK_NE
(
FLAGS_result_wspecifier
,
""
);
CHECK_NE
(
FLAGS_feature_rspecifier
,
""
);
CHECK_NE
(
FLAGS_feature_rspecifier
,
""
);
CHECK_NE
(
FLAGS_
vocab_path
,
""
);
CHECK_NE
(
FLAGS_
word_symbol_table
,
""
);
CHECK_NE
(
FLAGS_model_path
,
""
);
CHECK_NE
(
FLAGS_model_path
,
""
);
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
LOG
(
INFO
)
<<
"Reading vocab table "
<<
FLAGS_
vocab_path
;
LOG
(
INFO
)
<<
"Reading vocab table "
<<
FLAGS_
word_symbol_table
;
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_rspecifier
);
FLAGS_feature_rspecifier
);
...
@@ -80,7 +80,8 @@ int main(int argc, char* argv[]) {
...
@@ -80,7 +80,8 @@ int main(int argc, char* argv[]) {
opts
.
blank
=
0
;
opts
.
blank
=
0
;
opts
.
first_beam_size
=
10
;
opts
.
first_beam_size
=
10
;
opts
.
second_beam_size
=
10
;
opts
.
second_beam_size
=
10
;
ppspeech
::
CTCPrefixBeamSearch
decoder
(
FLAGS_vocab_path
,
opts
);
opts
.
word_symbol_table
=
FLAGS_word_symbol_table
;
ppspeech
::
CTCPrefixBeamSearch
decoder
(
opts
);
int32
chunk_size
=
FLAGS_receptive_field_length
+
int32
chunk_size
=
FLAGS_receptive_field_length
+
...
...
runtime/engine/asr/decoder/ctc_tlg_decoder.cc
浏览文件 @
b05ead51
...
@@ -13,12 +13,14 @@
...
@@ -13,12 +13,14 @@
// limitations under the License.
// limitations under the License.
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
namespace
ppspeech
{
namespace
ppspeech
{
TLGDecoder
::
TLGDecoder
(
TLGDecoderOptions
opts
)
:
opts_
(
opts
)
{
TLGDecoder
::
TLGDecoder
(
TLGDecoderOptions
opts
)
:
opts_
(
opts
)
{
fst_
.
reset
(
fst
::
Fst
<
fst
::
StdArc
>::
Read
(
opts
.
fst_path
))
;
fst_
=
opts
.
fst_ptr
;
CHECK
(
fst_
!=
nullptr
);
CHECK
(
fst_
!=
nullptr
);
CHECK
(
!
opts
.
word_symbol_table
.
empty
());
word_symbol_table_
.
reset
(
word_symbol_table_
.
reset
(
fst
::
SymbolTable
::
ReadText
(
opts
.
word_symbol_table
));
fst
::
SymbolTable
::
ReadText
(
opts
.
word_symbol_table
));
...
...
runtime/engine/asr/decoder/ctc_tlg_decoder.h
浏览文件 @
b05ead51
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "decoder/decoder_itf.h"
#include "decoder/decoder_itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
#include "util/parse-options.h"
#include "utils/file_utils.h"
DECLARE_string
(
word_symbol_table
);
DECLARE_string
(
word_symbol_table
);
DECLARE_string
(
graph_path
);
DECLARE_string
(
graph_path
);
...
@@ -33,9 +34,10 @@ struct TLGDecoderOptions {
...
@@ -33,9 +34,10 @@ struct TLGDecoderOptions {
// todo remove later, add into decode resource
// todo remove later, add into decode resource
std
::
string
word_symbol_table
;
std
::
string
word_symbol_table
;
std
::
string
fst_path
;
std
::
string
fst_path
;
std
::
shared_ptr
<
fst
::
Fst
<
fst
::
StdArc
>>
fst_ptr
;
int
nbest
;
int
nbest
;
TLGDecoderOptions
()
:
word_symbol_table
(
""
),
fst_path
(
""
),
nbest
(
10
)
{}
TLGDecoderOptions
()
:
word_symbol_table
(
""
),
fst_path
(
""
),
fst_ptr
(
nullptr
),
nbest
(
10
)
{}
static
TLGDecoderOptions
InitFromFlags
()
{
static
TLGDecoderOptions
InitFromFlags
()
{
TLGDecoderOptions
decoder_opts
;
TLGDecoderOptions
decoder_opts
;
...
@@ -44,6 +46,11 @@ struct TLGDecoderOptions {
...
@@ -44,6 +46,11 @@ struct TLGDecoderOptions {
LOG
(
INFO
)
<<
"fst path: "
<<
decoder_opts
.
fst_path
;
LOG
(
INFO
)
<<
"fst path: "
<<
decoder_opts
.
fst_path
;
LOG
(
INFO
)
<<
"fst symbole table: "
<<
decoder_opts
.
word_symbol_table
;
LOG
(
INFO
)
<<
"fst symbole table: "
<<
decoder_opts
.
word_symbol_table
;
if
(
!
decoder_opts
.
fst_path
.
empty
())
{
CHECK
(
FileExists
(
decoder_opts
.
fst_path
));
decoder_opts
.
fst_ptr
.
reset
(
fst
::
Fst
<
fst
::
StdArc
>::
Read
(
FLAGS_graph_path
));
}
decoder_opts
.
opts
.
max_active
=
FLAGS_max_active
;
decoder_opts
.
opts
.
max_active
=
FLAGS_max_active
;
decoder_opts
.
opts
.
beam
=
FLAGS_beam
;
decoder_opts
.
opts
.
beam
=
FLAGS_beam
;
decoder_opts
.
opts
.
lattice_beam
=
FLAGS_lattice_beam
;
decoder_opts
.
opts
.
lattice_beam
=
FLAGS_lattice_beam
;
...
...
runtime/engine/asr/decoder/param.h
浏览文件 @
b05ead51
...
@@ -37,28 +37,14 @@ DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
...
@@ -37,28 +37,14 @@ DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet
// nnet
DEFINE_string
(
vocab_path
,
""
,
"nnet vocab path."
);
DEFINE_string
(
model_path
,
"avg_1.jit.pdmodel"
,
"paddle nnet model"
);
DEFINE_string
(
model_path
,
"avg_1.jit.pdmodel"
,
"paddle nnet model"
);
#ifdef USE_ONNX
#ifdef USE_ONNX
DEFINE_bool
(
with_onnx_model
,
false
,
"True mean the model path is onnx model path"
);
DEFINE_bool
(
with_onnx_model
,
false
,
"True mean the model path is onnx model path"
);
#endif
#endif
DEFINE_string
(
param_path
,
"avg_1.jit.pdiparams"
,
"paddle nnet model param"
);
//DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string
(
model_input_names
,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"
,
"model input names"
);
DEFINE_string
(
model_output_names
,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0"
,
"model output names"
);
DEFINE_string
(
model_cache_names
,
"chunk_state_h_box,chunk_state_c_box"
,
"model cache names"
);
DEFINE_string
(
model_cache_shapes
,
"5-1-1024,5-1-1024"
,
"model cache shapes"
);
// decoder
// decoder
DEFINE_double
(
acoustic_scale
,
1
.
0
,
"acoustic scale"
);
DEFINE_double
(
acoustic_scale
,
1
.
0
,
"acoustic scale"
);
DEFINE_string
(
graph_path
,
""
,
"decoder graph"
);
DEFINE_string
(
graph_path
,
""
,
"decoder graph"
);
DEFINE_string
(
word_symbol_table
,
""
,
"word symbol table"
);
DEFINE_string
(
word_symbol_table
,
""
,
"word symbol table"
);
DEFINE_int32
(
max_active
,
7500
,
"max active"
);
DEFINE_int32
(
max_active
,
7500
,
"max active"
);
...
...
runtime/engine/asr/nnet/nnet_itf.h
浏览文件 @
b05ead51
...
@@ -33,24 +33,12 @@ namespace ppspeech {
...
@@ -33,24 +33,12 @@ namespace ppspeech {
struct
ModelOptions
{
struct
ModelOptions
{
// common
// common
int
subsample_rate
{
1
};
int
subsample_rate
{
1
};
int
thread_num
{
1
};
// predictor thread pool size for ds2;
bool
use_gpu
{
false
};
bool
use_gpu
{
false
};
std
::
string
model_path
;
std
::
string
model_path
;
#ifdef USE_ONNX
#ifdef USE_ONNX
bool
with_onnx_model
{
false
};
bool
with_onnx_model
{
false
};
#endif
#endif
std
::
string
param_path
;
// ds2 for inference
std
::
string
input_names
{};
std
::
string
output_names
{};
std
::
string
cache_names
{};
std
::
string
cache_shape
{};
bool
switch_ir_optim
{
false
};
bool
enable_fc_padding
{
false
};
bool
enable_profile
{
false
};
static
ModelOptions
InitFromFlags
()
{
static
ModelOptions
InitFromFlags
()
{
ModelOptions
opts
;
ModelOptions
opts
;
opts
.
subsample_rate
=
FLAGS_subsampling_rate
;
opts
.
subsample_rate
=
FLAGS_subsampling_rate
;
...
@@ -61,19 +49,6 @@ struct ModelOptions {
...
@@ -61,19 +49,6 @@ struct ModelOptions {
opts
.
with_onnx_model
=
FLAGS_with_onnx_model
;
opts
.
with_onnx_model
=
FLAGS_with_onnx_model
;
LOG
(
INFO
)
<<
"with onnx model: "
<<
opts
.
with_onnx_model
;
LOG
(
INFO
)
<<
"with onnx model: "
<<
opts
.
with_onnx_model
;
#endif
#endif
opts
.
param_path
=
FLAGS_param_path
;
LOG
(
INFO
)
<<
"param path: "
<<
opts
.
param_path
;
LOG
(
INFO
)
<<
"DS2 param: "
;
opts
.
cache_names
=
FLAGS_model_cache_names
;
LOG
(
INFO
)
<<
" cache names: "
<<
opts
.
cache_names
;
opts
.
cache_shape
=
FLAGS_model_cache_shapes
;
LOG
(
INFO
)
<<
" cache shape: "
<<
opts
.
cache_shape
;
opts
.
input_names
=
FLAGS_model_input_names
;
LOG
(
INFO
)
<<
" input names: "
<<
opts
.
input_names
;
opts
.
output_names
=
FLAGS_model_output_names
;
LOG
(
INFO
)
<<
" output names: "
<<
opts
.
output_names
;
return
opts
;
return
opts
;
}
}
};
};
...
@@ -121,7 +96,7 @@ class NnetInterface {
...
@@ -121,7 +96,7 @@ class NnetInterface {
class
NnetBase
:
public
NnetInterface
{
class
NnetBase
:
public
NnetInterface
{
public:
public:
int
SubsamplingRate
()
const
{
return
subsampling_rate_
;
}
int
SubsamplingRate
()
const
{
return
subsampling_rate_
;
}
virtual
std
::
shared_ptr
<
NnetBase
>
Clone
()
const
=
0
;
protected:
protected:
int
subsampling_rate_
{
1
};
int
subsampling_rate_
{
1
};
};
};
...
...
runtime/engine/asr/nnet/nnet_producer.cc
浏览文件 @
b05ead51
...
@@ -45,7 +45,7 @@ void NnetProducer::Acceptlikelihood(
...
@@ -45,7 +45,7 @@ void NnetProducer::Acceptlikelihood(
bool
NnetProducer
::
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
)
{
bool
NnetProducer
::
Read
(
std
::
vector
<
kaldi
::
BaseFloat
>*
nnet_prob
)
{
bool
flag
=
cache_
.
pop
(
nnet_prob
);
bool
flag
=
cache_
.
pop
(
nnet_prob
);
LOG
(
INFO
)
<<
"nnet cache_ size: "
<<
cache_
.
size
();
VLOG
(
1
)
<<
"nnet cache_ size: "
<<
cache_
.
size
();
return
flag
;
return
flag
;
}
}
...
@@ -53,7 +53,6 @@ bool NnetProducer::Compute() {
...
@@ -53,7 +53,6 @@ bool NnetProducer::Compute() {
vector
<
BaseFloat
>
features
;
vector
<
BaseFloat
>
features
;
if
(
frontend_
==
NULL
||
frontend_
->
Read
(
&
features
)
==
false
)
{
if
(
frontend_
==
NULL
||
frontend_
->
Read
(
&
features
)
==
false
)
{
// no feat or frontend_ not init.
// no feat or frontend_ not init.
LOG
(
INFO
)
<<
"no feat avalible"
;
if
(
frontend_
->
IsFinished
()
==
true
)
{
if
(
frontend_
->
IsFinished
()
==
true
)
{
finished_
=
true
;
finished_
=
true
;
}
}
...
...
runtime/engine/asr/recognizer/CMakeLists.txt
浏览文件 @
b05ead51
...
@@ -3,6 +3,8 @@ set(srcs)
...
@@ -3,6 +3,8 @@ set(srcs)
list
(
APPEND srcs
list
(
APPEND srcs
recognizer_controller.cc
recognizer_controller.cc
recognizer_controller_impl.cc
recognizer_controller_impl.cc
recognizer_instance.cc
recognizer.cc
)
)
add_library
(
recognizer STATIC
${
srcs
}
)
add_library
(
recognizer STATIC
${
srcs
}
)
...
@@ -10,6 +12,7 @@ target_link_libraries(recognizer PUBLIC decoder)
...
@@ -10,6 +12,7 @@ target_link_libraries(recognizer PUBLIC decoder)
set
(
TEST_BINS
set
(
TEST_BINS
recognizer_batch_main
recognizer_batch_main
recognizer_batch_main2
recognizer_main
recognizer_main
)
)
...
...
runtime/engine/asr/recognizer/recognizer.cc
浏览文件 @
b05ead51
...
@@ -10,4 +10,37 @@
...
@@ -10,4 +10,37 @@
// distributed under the License is distributed on an "AS IS" BASIS,
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
\ No newline at end of file
#include "recognizer/recognizer.h"
#include "recognizer/recognizer_instance.h"
bool
InitRecognizer
(
const
std
::
string
&
model_file
,
const
std
::
string
&
word_symbol_table_file
,
const
std
::
string
&
fst_file
,
int
num_instance
)
{
return
ppspeech
::
RecognizerInstance
::
GetInstance
().
Init
(
model_file
,
word_symbol_table_file
,
fst_file
,
num_instance
);
}
int
GetRecognizerInstanceId
()
{
return
ppspeech
::
RecognizerInstance
::
GetInstance
().
GetRecognizerInstanceId
();
}
void
InitDecoder
(
int
instance_id
)
{
return
ppspeech
::
RecognizerInstance
::
GetInstance
().
InitDecoder
(
instance_id
);
}
void
AcceptData
(
const
std
::
vector
<
float
>&
waves
,
int
instance_id
)
{
return
ppspeech
::
RecognizerInstance
::
GetInstance
().
Accept
(
waves
,
instance_id
);
}
void
SetInputFinished
(
int
instance_id
)
{
return
ppspeech
::
RecognizerInstance
::
GetInstance
().
SetInputFinished
(
instance_id
);
}
std
::
string
GetFinalResult
(
int
instance_id
)
{
return
ppspeech
::
RecognizerInstance
::
GetInstance
().
GetResult
(
instance_id
);
}
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer.h
浏览文件 @
b05ead51
...
@@ -10,4 +10,19 @@
...
@@ -10,4 +10,19 @@
// distributed under the License is distributed on an "AS IS" BASIS,
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
\ No newline at end of file
#pragma once
#include <string>
#include <vector>
bool
InitRecognizer
(
const
std
::
string
&
model_file
,
const
std
::
string
&
word_symbol_table_file
,
const
std
::
string
&
fst_file
,
int
num_instance
);
int
GetRecognizerInstanceId
();
void
InitDecoder
(
int
instance_id
);
void
AcceptData
(
const
std
::
vector
<
float
>&
waves
,
int
instance_id
);
void
SetInputFinished
(
int
instance_id
);
std
::
string
GetFinalResult
(
int
instance_id
);
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer_batch_main2.cc
0 → 100644
浏览文件 @
b05ead51
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
DEFINE_int32
(
njob
,
3
,
"njob"
);
using
std
::
string
;
using
std
::
vector
;
void
SplitUtt
(
string
wavlist_file
,
vector
<
vector
<
string
>>*
uttlists
,
vector
<
vector
<
string
>>*
wavlists
,
int
njob
)
{
vector
<
string
>
wavlist
;
wavlists
->
resize
(
njob
);
uttlists
->
resize
(
njob
);
ppspeech
::
ReadFileToVector
(
wavlist_file
,
&
wavlist
);
for
(
size_t
idx
=
0
;
idx
<
wavlist
.
size
();
++
idx
)
{
string
utt_str
=
wavlist
[
idx
];
vector
<
string
>
utt_wav
=
ppspeech
::
StrSplit
(
utt_str
,
"
\t
"
);
LOG
(
INFO
)
<<
utt_wav
[
0
];
CHECK_EQ
(
utt_wav
.
size
(),
size_t
(
2
));
uttlists
->
at
(
idx
%
njob
).
push_back
(
utt_wav
[
0
]);
wavlists
->
at
(
idx
%
njob
).
push_back
(
utt_wav
[
1
]);
}
}
void
recognizer_func
(
std
::
vector
<
string
>
wavlist
,
std
::
vector
<
string
>
uttlist
,
std
::
vector
<
string
>*
results
)
{
int32
num_done
=
0
,
num_err
=
0
;
double
tot_wav_duration
=
0.0
;
double
tot_attention_rescore_time
=
0.0
;
double
tot_decode_time
=
0.0
;
int
chunk_sample_size
=
FLAGS_streaming_chunk
*
FLAGS_sample_rate
;
if
(
wavlist
.
empty
())
return
;
results
->
reserve
(
wavlist
.
size
());
for
(
size_t
idx
=
0
;
idx
<
wavlist
.
size
();
++
idx
)
{
std
::
string
utt
=
uttlist
[
idx
];
std
::
string
wav_file
=
wavlist
[
idx
];
std
::
ifstream
infile
;
infile
.
open
(
wav_file
,
std
::
ifstream
::
in
);
kaldi
::
WaveData
wave_data
;
wave_data
.
Read
(
infile
);
int32
recog_id
=
-
1
;
while
(
recog_id
==
-
1
)
{
recog_id
=
GetRecognizerInstanceId
();
}
InitDecoder
(
recog_id
);
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"wav dur: "
<<
wave_data
.
Duration
()
<<
" sec."
;
double
dur
=
wave_data
.
Duration
();
tot_wav_duration
+=
dur
;
int32
this_channel
=
0
;
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
int
tot_samples
=
waveform
.
Dim
();
LOG
(
INFO
)
<<
"wav len (sample): "
<<
tot_samples
;
int
sample_offset
=
0
;
kaldi
::
Timer
local_timer
;
while
(
sample_offset
<
tot_samples
)
{
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
std
::
vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
[
i
]
=
waveform
(
sample_offset
+
i
);
}
AcceptData
(
wav_chunk
,
recog_id
);
// no overlap
sample_offset
+=
cur_chunk_size
;
}
SetInputFinished
(
recog_id
);
CHECK
(
sample_offset
==
tot_samples
);
std
::
string
result
=
GetFinalResult
(
recog_id
);
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
++
num_err
;
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is empty"
;
result
=
" "
;
}
tot_decode_time
+=
local_timer
.
Elapsed
();
LOG
(
INFO
)
<<
utt
<<
" "
<<
result
;
LOG
(
INFO
)
<<
" RTF: "
<<
local_timer
.
Elapsed
()
/
dur
<<
" dur: "
<<
dur
<<
" cost: "
<<
local_timer
.
Elapsed
();
results
->
push_back
(
result
);
++
num_done
;
}
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"total decode cost:"
<<
tot_decode_time
<<
" sec"
;
LOG
(
INFO
)
<<
"RTF is: "
<<
tot_decode_time
/
tot_wav_duration
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int
sample_rate
=
FLAGS_sample_rate
;
float
streaming_chunk
=
FLAGS_streaming_chunk
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
int
njob
=
FLAGS_njob
;
LOG
(
INFO
)
<<
"sr: "
<<
sample_rate
;
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
InitRecognizer
(
FLAGS_model_path
,
FLAGS_word_symbol_table
,
FLAGS_graph_path
,
njob
);
ThreadPool
threadpool
(
njob
);
vector
<
vector
<
string
>>
wavlist
;
vector
<
vector
<
string
>>
uttlist
;
vector
<
vector
<
string
>>
resultlist
(
njob
);
vector
<
std
::
future
<
void
>>
futurelist
;
SplitUtt
(
FLAGS_wav_rspecifier
,
&
uttlist
,
&
wavlist
,
njob
);
for
(
size_t
i
=
0
;
i
<
njob
;
++
i
)
{
std
::
future
<
void
>
f
=
threadpool
.
enqueue
(
recognizer_func
,
wavlist
[
i
],
uttlist
[
i
],
&
resultlist
[
i
]);
futurelist
.
push_back
(
std
::
move
(
f
));
}
for
(
size_t
i
=
0
;
i
<
njob
;
++
i
)
{
futurelist
[
i
].
get
();
}
for
(
size_t
idx
=
0
;
idx
<
njob
;
++
idx
)
{
for
(
size_t
utt_idx
=
0
;
utt_idx
<
uttlist
[
idx
].
size
();
++
utt_idx
)
{
string
utt
=
uttlist
[
idx
][
utt_idx
];
string
result
=
resultlist
[
idx
][
utt_idx
];
result_writer
.
Write
(
utt
,
result
);
}
}
return
0
;
}
runtime/engine/asr/recognizer/recognizer_controller.cc
浏览文件 @
b05ead51
...
@@ -18,10 +18,9 @@
...
@@ -18,10 +18,9 @@
namespace
ppspeech
{
namespace
ppspeech
{
RecognizerController
::
RecognizerController
(
int
num_worker
,
RecognizerResource
resource
)
{
RecognizerController
::
RecognizerController
(
int
num_worker
,
RecognizerResource
resource
)
{
nnet_
=
std
::
make_shared
<
ppspeech
::
U2Nnet
>
(
resource
.
model_opts
);
recognizer_workers
.
resize
(
num_worker
);
recognizer_workers
.
resize
(
num_worker
);
for
(
size_t
i
=
0
;
i
<
num_worker
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_worker
;
++
i
)
{
recognizer_workers
[
i
].
reset
(
new
ppspeech
::
RecognizerControllerImpl
(
resource
,
nnet_
->
Clone
()
));
recognizer_workers
[
i
].
reset
(
new
ppspeech
::
RecognizerControllerImpl
(
resource
));
waiting_workers
.
push
(
i
);
waiting_workers
.
push
(
i
);
}
}
}
}
...
...
runtime/engine/asr/recognizer/recognizer_controller.h
浏览文件 @
b05ead51
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
#include <memory>
#include <memory>
#include "recognizer/recognizer_controller_impl.h"
#include "recognizer/recognizer_controller_impl.h"
#include "nnet/u2_nnet.h"
namespace
ppspeech
{
namespace
ppspeech
{
...
@@ -34,7 +33,6 @@ class RecognizerController {
...
@@ -34,7 +33,6 @@ class RecognizerController {
private:
private:
std
::
queue
<
int
>
waiting_workers
;
std
::
queue
<
int
>
waiting_workers
;
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet_
;
std
::
mutex
mutex_
;
std
::
mutex
mutex_
;
std
::
vector
<
std
::
unique_ptr
<
ppspeech
::
RecognizerControllerImpl
>>
recognizer_workers
;
std
::
vector
<
std
::
unique_ptr
<
ppspeech
::
RecognizerControllerImpl
>>
recognizer_workers
;
...
...
runtime/engine/asr/recognizer/recognizer_controller_impl.cc
浏览文件 @
b05ead51
...
@@ -26,24 +26,24 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res
...
@@ -26,24 +26,24 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res
new
FeaturePipeline
(
feature_opts
));
new
FeaturePipeline
(
feature_opts
));
std
::
shared_ptr
<
NnetBase
>
nnet
;
std
::
shared_ptr
<
NnetBase
>
nnet
;
#ifndef USE_ONNX
#ifndef USE_ONNX
nnet
.
reset
(
new
U2Nnet
(
resource
.
model_opts
)
);
nnet
=
resource
.
nnet
->
Clone
(
);
#else
#else
if
(
resource
.
model_opts
.
with_onnx_model
){
if
(
resource
.
model_opts
.
with_onnx_model
){
nnet
.
reset
(
new
U2OnnxNnet
(
resource
.
model_opts
));
nnet
.
reset
(
new
U2OnnxNnet
(
resource
.
model_opts
));
}
else
{
}
else
{
nnet
.
reset
(
new
U2Nnet
(
resource
.
model_opts
)
);
nnet
=
resource
.
nnet
->
Clone
(
);
}
}
#endif
#endif
nnet_producer_
.
reset
(
new
NnetProducer
(
nnet
,
feature_pipeline
));
nnet_producer_
.
reset
(
new
NnetProducer
(
nnet
,
feature_pipeline
));
nnet_thread_
=
std
::
thread
(
RunNnetEvaluation
,
this
);
nnet_thread_
=
std
::
thread
(
RunNnetEvaluation
,
this
);
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
CHECK_NE
(
resource
.
vocab_path
,
""
);
if
(
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
.
empty
())
{
if
(
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
.
empty
())
{
LOG
(
INFO
)
<<
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
;
LOG
(
INFO
)
<<
"Init PrefixBeamSearch Decoder"
;
decoder_
=
std
::
make_unique
<
CTCPrefixBeamSearch
>
(
decoder_
=
std
::
make_unique
<
CTCPrefixBeamSearch
>
(
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
);
resource
.
decoder_opts
.
ctc_prefix_search_opts
);
}
else
{
}
else
{
LOG
(
INFO
)
<<
"Init TLGDecoder"
;
decoder_
=
std
::
make_unique
<
TLGDecoder
>
(
decoder_
=
std
::
make_unique
<
TLGDecoder
>
(
resource
.
decoder_opts
.
tlg_decoder_opts
);
resource
.
decoder_opts
.
tlg_decoder_opts
);
}
}
...
@@ -55,33 +55,6 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res
...
@@ -55,33 +55,6 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res
result_
.
clear
();
result_
.
clear
();
}
}
RecognizerControllerImpl
::
RecognizerControllerImpl
(
const
RecognizerResource
&
resource
,
std
::
shared_ptr
<
NnetBase
>
nnet
)
:
opts_
(
resource
)
{
BaseFloat
am_scale
=
resource
.
acoustic_scale
;
const
FeaturePipelineOptions
&
feature_opts
=
resource
.
feature_pipeline_opts
;
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline
=
std
::
make_shared
<
FeaturePipeline
>
(
feature_opts
);
nnet_producer_
=
std
::
make_shared
<
NnetProducer
>
(
nnet
,
feature_pipeline
);
nnet_thread_
=
std
::
thread
(
RunNnetEvaluation
,
this
);
decodable_
.
reset
(
new
Decodable
(
nnet_producer_
,
am_scale
));
CHECK_NE
(
resource
.
vocab_path
,
""
);
if
(
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
==
""
)
{
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
));
}
else
{
decoder_
.
reset
(
new
TLGDecoder
(
resource
.
decoder_opts
.
tlg_decoder_opts
));
}
symbol_table_
=
decoder_
->
WordSymbolTable
();
global_frame_offset_
=
0
;
input_finished_
=
false
;
num_frames_
=
0
;
result_
.
clear
();
}
RecognizerControllerImpl
::~
RecognizerControllerImpl
()
{
RecognizerControllerImpl
::~
RecognizerControllerImpl
()
{
WaitFinished
();
WaitFinished
();
}
}
...
...
runtime/engine/asr/recognizer/recognizer_controller_impl.h
浏览文件 @
b05ead51
...
@@ -32,8 +32,6 @@ namespace ppspeech {
...
@@ -32,8 +32,6 @@ namespace ppspeech {
class
RecognizerControllerImpl
{
class
RecognizerControllerImpl
{
public:
public:
explicit
RecognizerControllerImpl
(
const
RecognizerResource
&
resource
);
explicit
RecognizerControllerImpl
(
const
RecognizerResource
&
resource
);
explicit
RecognizerControllerImpl
(
const
RecognizerResource
&
resource
,
std
::
shared_ptr
<
NnetBase
>
nnet
);
~
RecognizerControllerImpl
();
~
RecognizerControllerImpl
();
void
Accept
(
std
::
vector
<
float
>
data
);
void
Accept
(
std
::
vector
<
float
>
data
);
void
InitDecoder
();
void
InitDecoder
();
...
@@ -88,4 +86,4 @@ class RecognizerControllerImpl {
...
@@ -88,4 +86,4 @@ class RecognizerControllerImpl {
DISALLOW_COPY_AND_ASSIGN
(
RecognizerControllerImpl
);
DISALLOW_COPY_AND_ASSIGN
(
RecognizerControllerImpl
);
};
};
}
}
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer_impl.h
已删除
100644 → 0
浏览文件 @
11ce08b2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer_instance.cc
0 → 100644
浏览文件 @
b05ead51
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer_instance.h"
namespace
ppspeech
{
RecognizerInstance
&
RecognizerInstance
::
GetInstance
()
{
static
RecognizerInstance
instance
;
return
instance
;
}
bool
RecognizerInstance
::
Init
(
const
std
::
string
&
model_file
,
const
std
::
string
&
word_symbol_table_file
,
const
std
::
string
&
fst_file
,
int
num_instance
)
{
RecognizerResource
resource
=
RecognizerResource
::
InitFromFlags
();
resource
.
model_opts
.
model_path
=
model_file
;
//resource.vocab_path = word_symbol_table_file;
if
(
!
fst_file
.
empty
())
{
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
=
fst_file
;
resource
.
decoder_opts
.
tlg_decoder_opts
.
fst_path
=
word_symbol_table_file
;
}
else
{
resource
.
decoder_opts
.
ctc_prefix_search_opts
.
word_symbol_table
=
word_symbol_table_file
;
}
recognizer_controller_
=
std
::
make_unique
<
RecognizerController
>
(
num_instance
,
resource
);
return
true
;
}
void
RecognizerInstance
::
InitDecoder
(
int
idx
)
{
recognizer_controller_
->
InitDecoder
(
idx
);
return
;
}
int
RecognizerInstance
::
GetRecognizerInstanceId
()
{
return
recognizer_controller_
->
GetRecognizerInstanceId
();
}
void
RecognizerInstance
::
Accept
(
const
std
::
vector
<
float
>&
waves
,
int
idx
)
const
{
recognizer_controller_
->
Accept
(
waves
,
idx
);
return
;
}
void
RecognizerInstance
::
SetInputFinished
(
int
idx
)
const
{
recognizer_controller_
->
SetInputFinished
(
idx
);
return
;
}
std
::
string
RecognizerInstance
::
GetResult
(
int
idx
)
const
{
return
recognizer_controller_
->
GetFinalResult
(
idx
);
}
}
\ No newline at end of file
runtime/engine/asr/recognizer/recognizer_i
mpl.cc
→
runtime/engine/asr/recognizer/recognizer_i
nstance.h
浏览文件 @
b05ead51
...
@@ -10,4 +10,33 @@
...
@@ -10,4 +10,33 @@
// distributed under the License is distributed on an "AS IS" BASIS,
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
\ No newline at end of file
#pragma once
#include "base/common.h"
#include "recognizer/recognizer_controller.h"
namespace
ppspeech
{
class
RecognizerInstance
{
public:
static
RecognizerInstance
&
GetInstance
();
RecognizerInstance
()
{}
~
RecognizerInstance
()
{}
bool
Init
(
const
std
::
string
&
model_file
,
const
std
::
string
&
word_symbol_table_file
,
const
std
::
string
&
fst_file
,
int
num_instance
);
int
GetRecognizerInstanceId
();
void
InitDecoder
(
int
idx
);
void
Accept
(
const
std
::
vector
<
float
>&
waves
,
int
idx
)
const
;
void
SetInputFinished
(
int
idx
)
const
;
std
::
string
GetResult
(
int
idx
)
const
;
private:
std
::
unique_ptr
<
RecognizerController
>
recognizer_controller_
;
};
}
// namespace ppspeech
runtime/engine/asr/recognizer/recognizer_resource.h
浏览文件 @
b05ead51
...
@@ -12,7 +12,6 @@ DECLARE_double(reverse_weight);
...
@@ -12,7 +12,6 @@ DECLARE_double(reverse_weight);
DECLARE_int32
(
nbest
);
DECLARE_int32
(
nbest
);
DECLARE_int32
(
blank
);
DECLARE_int32
(
blank
);
DECLARE_double
(
acoustic_scale
);
DECLARE_double
(
acoustic_scale
);
DECLARE_string
(
vocab_path
);
DECLARE_string
(
word_symbol_table
);
DECLARE_string
(
word_symbol_table
);
namespace
ppspeech
{
namespace
ppspeech
{
...
@@ -52,6 +51,8 @@ struct DecodeOptions {
...
@@ -52,6 +51,8 @@ struct DecodeOptions {
decoder_opts
.
ctc_prefix_search_opts
.
blank
=
FLAGS_blank
;
decoder_opts
.
ctc_prefix_search_opts
.
blank
=
FLAGS_blank
;
decoder_opts
.
ctc_prefix_search_opts
.
first_beam_size
=
FLAGS_nbest
;
decoder_opts
.
ctc_prefix_search_opts
.
first_beam_size
=
FLAGS_nbest
;
decoder_opts
.
ctc_prefix_search_opts
.
second_beam_size
=
FLAGS_nbest
;
decoder_opts
.
ctc_prefix_search_opts
.
second_beam_size
=
FLAGS_nbest
;
decoder_opts
.
ctc_prefix_search_opts
.
word_symbol_table
=
FLAGS_word_symbol_table
;
decoder_opts
.
tlg_decoder_opts
=
decoder_opts
.
tlg_decoder_opts
=
ppspeech
::
TLGDecoderOptions
::
InitFromFlags
();
ppspeech
::
TLGDecoderOptions
::
InitFromFlags
();
...
@@ -68,18 +69,17 @@ struct DecodeOptions {
...
@@ -68,18 +69,17 @@ struct DecodeOptions {
};
};
struct
RecognizerResource
{
struct
RecognizerResource
{
// decodable opt
kaldi
::
BaseFloat
acoustic_scale
{
1.0
};
kaldi
::
BaseFloat
acoustic_scale
{
1.0
};
std
::
string
vocab_path
{};
FeaturePipelineOptions
feature_pipeline_opts
{};
FeaturePipelineOptions
feature_pipeline_opts
{};
ModelOptions
model_opts
{};
ModelOptions
model_opts
{};
DecodeOptions
decoder_opts
{};
DecodeOptions
decoder_opts
{};
std
::
shared_ptr
<
NnetBase
>
nnet
;
static
RecognizerResource
InitFromFlags
()
{
static
RecognizerResource
InitFromFlags
()
{
RecognizerResource
resource
;
RecognizerResource
resource
;
resource
.
vocab_path
=
FLAGS_vocab_path
;
resource
.
acoustic_scale
=
FLAGS_acoustic_scale
;
resource
.
acoustic_scale
=
FLAGS_acoustic_scale
;
LOG
(
INFO
)
<<
"vocab path: "
<<
resource
.
vocab_path
;
LOG
(
INFO
)
<<
"acoustic_scale: "
<<
resource
.
acoustic_scale
;
LOG
(
INFO
)
<<
"acoustic_scale: "
<<
resource
.
acoustic_scale
;
resource
.
feature_pipeline_opts
=
resource
.
feature_pipeline_opts
=
...
@@ -89,6 +89,15 @@ struct RecognizerResource {
...
@@ -89,6 +89,15 @@ struct RecognizerResource {
<<
resource
.
feature_pipeline_opts
.
assembler_opts
.
fill_zero
;
<<
resource
.
feature_pipeline_opts
.
assembler_opts
.
fill_zero
;
resource
.
model_opts
=
ppspeech
::
ModelOptions
::
InitFromFlags
();
resource
.
model_opts
=
ppspeech
::
ModelOptions
::
InitFromFlags
();
resource
.
decoder_opts
=
ppspeech
::
DecodeOptions
::
InitFromFlags
();
resource
.
decoder_opts
=
ppspeech
::
DecodeOptions
::
InitFromFlags
();
#ifndef USE_ONNX
resource
.
nnet
.
reset
(
new
U2Nnet
(
resource
.
model_opts
));
#else
if
(
resource
.
model_opts
.
with_onnx_model
){
resource
.
nnet
.
reset
(
new
U2OnnxNnet
(
resource
.
model_opts
));
}
else
{
resource
.
nnet
.
reset
(
new
U2Nnet
(
resource
.
model_opts
));
}
#endif
return
resource
;
return
resource
;
}
}
};
};
...
...
runtime/engine/common/utils/file_utils.cc
浏览文件 @
b05ead51
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
#include "utils/file_utils.h"
#include "utils/file_utils.h"
#include <sys/stat.h>
namespace
ppspeech
{
namespace
ppspeech
{
bool
ReadFileToVector
(
const
std
::
string
&
filename
,
bool
ReadFileToVector
(
const
std
::
string
&
filename
,
...
@@ -40,4 +42,31 @@ std::string ReadFile2String(const std::string& path) {
...
@@ -40,4 +42,31 @@ std::string ReadFile2String(const std::string& path) {
return
std
::
string
((
std
::
istreambuf_iterator
<
char
>
(
input_file
)),
return
std
::
string
((
std
::
istreambuf_iterator
<
char
>
(
input_file
)),
std
::
istreambuf_iterator
<
char
>
());
std
::
istreambuf_iterator
<
char
>
());
}
}
bool
FileExists
(
const
std
::
string
&
strFilename
)
{
// this funciton if from:
// https://github.com/kaldi-asr/kaldi/blob/master/src/fstext/deterministic-fst-test.cc
struct
stat
stFileInfo
;
bool
blnReturn
;
int
intStat
;
// Attempt to get the file attributes
intStat
=
stat
(
strFilename
.
c_str
(),
&
stFileInfo
);
if
(
intStat
==
0
)
{
// We were able to get the file attributes
// so the file obviously exists.
blnReturn
=
true
;
}
else
{
// We were not able to get the file attributes.
// This may mean that we don't have permission to
// access the folder which contains this file. If you
// need to do that level of checking, lookup the
// return values of stat which will give you
// more details on why stat failed.
blnReturn
=
false
;
}
return
blnReturn
;
}
}
// namespace ppspeech
}
// namespace ppspeech
runtime/engine/common/utils/file_utils.h
浏览文件 @
b05ead51
...
@@ -20,4 +20,7 @@ bool ReadFileToVector(const std::string& filename,
...
@@ -20,4 +20,7 @@ bool ReadFileToVector(const std::string& filename,
std
::
vector
<
std
::
string
>*
data
);
std
::
vector
<
std
::
string
>*
data
);
std
::
string
ReadFile2String
(
const
std
::
string
&
path
);
std
::
string
ReadFile2String
(
const
std
::
string
&
path
);
bool
FileExists
(
const
std
::
string
&
filename
);
}
// namespace ppspeech
}
// namespace ppspeech
runtime/examples/u2pp_ol/wenetspeech/local/decode.sh
浏览文件 @
b05ead51
...
@@ -14,7 +14,7 @@ text=$data/test/text
...
@@ -14,7 +14,7 @@ text=$data/test/text
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/decoder.log
\
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/decoder.log
\
ctc_prefix_beam_search_decoder_main
\
ctc_prefix_beam_search_decoder_main
\
--model_path
=
$model_dir
/export.jit
\
--model_path
=
$model_dir
/export.jit
\
--
vocab_path
=
$model_dir
/unit.txt
\
--
word_symbol_table
=
$model_dir
/unit.txt
\
--nnet_decoder_chunk
=
16
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--receptive_field_length
=
7
\
--subsampling_rate
=
4
\
--subsampling_rate
=
4
\
...
@@ -23,4 +23,4 @@ ctc_prefix_beam_search_decoder_main \
...
@@ -23,4 +23,4 @@ ctc_prefix_beam_search_decoder_main \
cat
$data
/split
${
nj
}
/
*
/result_decode.ark
>
$exp
/aishell.decode.rsl
cat
$data
/split
${
nj
}
/
*
/result_decode.ark
>
$exp
/aishell.decode.rsl
utils/compute-wer.py
--char
=
1
--v
=
1
$text
$exp
/aishell.decode.rsl
>
$exp
/aishell.decode.err
utils/compute-wer.py
--char
=
1
--v
=
1
$text
$exp
/aishell.decode.rsl
>
$exp
/aishell.decode.err
tail
-n
7
$exp
/aishell.decode.err
tail
-n
7
$exp
/aishell.decode.err
\ No newline at end of file
runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh
浏览文件 @
b05ead51
...
@@ -21,7 +21,7 @@ recognizer_main \
...
@@ -21,7 +21,7 @@ recognizer_main \
--num_bins
=
80
\
--num_bins
=
80
\
--cmvn_file
=
$model_dir
/mean_std.json
\
--cmvn_file
=
$model_dir
/mean_std.json
\
--model_path
=
$model_dir
/export.jit
\
--model_path
=
$model_dir
/export.jit
\
--
vocab_path
=
$model_dir
/unit.txt
\
--
word_symbol_table
=
$model_dir
/unit.txt
\
--nnet_decoder_chunk
=
16
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--receptive_field_length
=
7
\
--subsampling_rate
=
4
\
--subsampling_rate
=
4
\
...
...
runtime/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh
浏览文件 @
b05ead51
...
@@ -21,7 +21,7 @@ u2_recognizer_main \
...
@@ -21,7 +21,7 @@ u2_recognizer_main \
--num_bins
=
80
\
--num_bins
=
80
\
--cmvn_file
=
$model_dir
/mean_std.json
\
--cmvn_file
=
$model_dir
/mean_std.json
\
--model_path
=
$model_dir
/export
\
--model_path
=
$model_dir
/export
\
--
vocab_path
=
$model_dir
/unit.txt
\
--
word_symbol_table
=
$model_dir
/unit.txt
\
--nnet_decoder_chunk
=
16
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--receptive_field_length
=
7
\
--subsampling_rate
=
4
\
--subsampling_rate
=
4
\
...
...
runtime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh
0 → 100755
浏览文件 @
b05ead51
#!/bin/bash
set
-e
data
=
data
exp
=
exp
nj
=
40
.
utils/parse_options.sh
mkdir
-p
$exp
ckpt_dir
=
./data/model
model_dir
=
$ckpt_dir
/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
aishell_wav_scp
=
aishell_test.scp
text
=
$data
/test/text
./local/split_data.sh
$data
$data
/
$aishell_wav_scp
$aishell_wav_scp
$nj
lang_dir
=
./data/lang_test/
graph
=
$lang_dir
/TLG.fst
word_table
=
$lang_dir
/words.txt
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recognizer_wfst.log
\
recognizer_main
\
--use_fbank
=
true
\
--num_bins
=
80
\
--cmvn_file
=
$model_dir
/mean_std.json
\
--model_path
=
$model_dir
/export.jit
\
--graph_path
=
$lang_dir
/TLG.fst
\
--word_symbol_table
=
$word_table
\
--nnet_decoder_chunk
=
16
\
--receptive_field_length
=
7
\
--subsampling_rate
=
4
\
--wav_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/
${
aishell_wav_scp
}
\
--result_wspecifier
=
ark,t:
$data
/split
${
nj
}
/JOB/result_recognizer_wfst.ark
cat
$data
/split
${
nj
}
/
*
/result_recognizer_wfst.ark
>
$exp
/aishell_recognizer_wfst
utils/compute-wer.py
--char
=
1
--v
=
1
$text
$exp
/aishell_recognizer_wfst
>
$exp
/aishell.recognizer_wfst.err
echo
"recognizer test have finished!!!"
echo
"please checkout in
$exp
/aishell.recognizer_wfst.err"
tail
-n
7
$exp
/aishell.recognizer_wfst.err
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录