Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5c8725e8
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5c8725e8
编写于
10月 12, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
unify model opts; add attention rescore in decodable; rename ds2 ctc beam search
上级
6987751f
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
96 addition
and
88 deletion
+96
-88
speechx/examples/codelab/decoder/run.sh
speechx/examples/codelab/decoder/run.sh
+1
-1
speechx/examples/ds2_ol/aishell/run.sh
speechx/examples/ds2_ol/aishell/run.sh
+2
-2
speechx/examples/ds2_ol/aishell/run_fbank.sh
speechx/examples/ds2_ol/aishell/run_fbank.sh
+2
-2
speechx/speechx/decoder/CMakeLists.txt
speechx/speechx/decoder/CMakeLists.txt
+1
-1
speechx/speechx/decoder/ctc_beam_search_decoder.h
speechx/speechx/decoder/ctc_beam_search_decoder.h
+2
-0
speechx/speechx/decoder/ctc_beam_search_decoder_main.cc
speechx/speechx/decoder/ctc_beam_search_decoder_main.cc
+1
-1
speechx/speechx/decoder/ctc_prefix_beam_search.cc
speechx/speechx/decoder/ctc_prefix_beam_search.cc
+0
-0
speechx/speechx/decoder/param.h
speechx/speechx/decoder/param.h
+1
-0
speechx/speechx/nnet/decodable.cc
speechx/speechx/nnet/decodable.cc
+6
-0
speechx/speechx/nnet/decodable.h
speechx/speechx/nnet/decodable.h
+18
-13
speechx/speechx/nnet/ds2_nnet.h
speechx/speechx/nnet/ds2_nnet.h
+6
-45
speechx/speechx/nnet/nnet_itf.h
speechx/speechx/nnet/nnet_itf.h
+51
-0
speechx/speechx/nnet/u2_nnet.cc
speechx/speechx/nnet/u2_nnet.cc
+1
-1
speechx/speechx/nnet/u2_nnet.h
speechx/speechx/nnet/u2_nnet.h
+3
-21
speechx/speechx/nnet/u2_nnet_main.cc
speechx/speechx/nnet/u2_nnet_main.cc
+1
-1
未找到文件。
speechx/examples/codelab/decoder/run.sh
浏览文件 @
5c8725e8
...
@@ -69,7 +69,7 @@ compute_linear_spectrogram_main \
...
@@ -69,7 +69,7 @@ compute_linear_spectrogram_main \
echo
"compute linear spectrogram feature."
echo
"compute linear spectrogram feature."
# run ctc beam search decoder as streaming
# run ctc beam search decoder as streaming
ctc_
prefix_
beam_search_decoder_main
\
ctc_beam_search_decoder_main
\
--result_wspecifier
=
ark,t:
$exp_dir
/result.txt
\
--result_wspecifier
=
ark,t:
$exp_dir
/result.txt
\
--feature_rspecifier
=
ark:
$feat_wspecifier
\
--feature_rspecifier
=
ark:
$feat_wspecifier
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
...
...
speechx/examples/ds2_ol/aishell/run.sh
浏览文件 @
5c8725e8
...
@@ -84,7 +84,7 @@ fi
...
@@ -84,7 +84,7 @@ fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# recognizer
# recognizer
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recog.wolm.log
\
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recog.wolm.log
\
ctc_
prefix_
beam_search_decoder_main
\
ctc_beam_search_decoder_main
\
--feature_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/feat.scp
\
--feature_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/feat.scp
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
--param_path
=
$model_dir
/avg_1.jit.pdiparams
\
--param_path
=
$model_dir
/avg_1.jit.pdiparams
\
...
@@ -103,7 +103,7 @@ fi
...
@@ -103,7 +103,7 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# decode with lm
# decode with lm
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recog.lm.log
\
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recog.lm.log
\
ctc_
prefix_
beam_search_decoder_main
\
ctc_beam_search_decoder_main
\
--feature_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/feat.scp
\
--feature_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/feat.scp
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
--model_path
=
$model_dir
/avg_1.jit.pdmodel
\
--param_path
=
$model_dir
/avg_1.jit.pdiparams
\
--param_path
=
$model_dir
/avg_1.jit.pdiparams
\
...
...
speechx/examples/ds2_ol/aishell/run_fbank.sh
浏览文件 @
5c8725e8
...
@@ -84,7 +84,7 @@ fi
...
@@ -84,7 +84,7 @@ fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# recognizer
# recognizer
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recog.fbank.wolm.log
\
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recog.fbank.wolm.log
\
ctc_
prefix_
beam_search_decoder_main
\
ctc_beam_search_decoder_main
\
--feature_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/fbank_feat.scp
\
--feature_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/fbank_feat.scp
\
--model_path
=
$model_dir
/avg_5.jit.pdmodel
\
--model_path
=
$model_dir
/avg_5.jit.pdmodel
\
--param_path
=
$model_dir
/avg_5.jit.pdiparams
\
--param_path
=
$model_dir
/avg_5.jit.pdiparams
\
...
@@ -102,7 +102,7 @@ fi
...
@@ -102,7 +102,7 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# decode with lm
# decode with lm
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recog.fbank.lm.log
\
utils/run.pl
JOB
=
1:
$nj
$data
/split
${
nj
}
/JOB/recog.fbank.lm.log
\
ctc_
prefix_
beam_search_decoder_main
\
ctc_beam_search_decoder_main
\
--feature_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/fbank_feat.scp
\
--feature_rspecifier
=
scp:
$data
/split
${
nj
}
/JOB/fbank_feat.scp
\
--model_path
=
$model_dir
/avg_5.jit.pdmodel
\
--model_path
=
$model_dir
/avg_5.jit.pdmodel
\
--param_path
=
$model_dir
/avg_5.jit.pdiparams
\
--param_path
=
$model_dir
/avg_5.jit.pdiparams
\
...
...
speechx/speechx/decoder/CMakeLists.txt
浏览文件 @
5c8725e8
...
@@ -12,7 +12,7 @@ add_library(decoder STATIC
...
@@ -12,7 +12,7 @@ add_library(decoder STATIC
target_link_libraries
(
decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder
)
target_link_libraries
(
decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder
)
set
(
BINS
set
(
BINS
ctc_
prefix_
beam_search_decoder_main
ctc_beam_search_decoder_main
nnet_logprob_decoder_main
nnet_logprob_decoder_main
recognizer_main
recognizer_main
tlg_decoder_main
tlg_decoder_main
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.h
浏览文件 @
5c8725e8
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
// used by deepspeech2
#include "base/common.h"
#include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/scorer.h"
...
...
speechx/speechx/decoder/ctc_
prefix_
beam_search_decoder_main.cc
→
speechx/speechx/decoder/ctc_beam_search_decoder_main.cc
浏览文件 @
5c8725e8
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
//
todo refactor, repalce with gtest
//
used by deepspeech2
#include "base/flags.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/log.h"
...
...
speechx/speechx/decoder/ctc_prefix_beam_search.cc
0 → 100644
浏览文件 @
5c8725e8
speechx/speechx/decoder/param.h
浏览文件 @
5c8725e8
...
@@ -67,6 +67,7 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
...
@@ -67,6 +67,7 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
frame_opts
.
dither
=
0.0
;
frame_opts
.
dither
=
0.0
;
frame_opts
.
frame_shift_ms
=
10
;
frame_opts
.
frame_shift_ms
=
10
;
opts
.
use_fbank
=
FLAGS_use_fbank
;
opts
.
use_fbank
=
FLAGS_use_fbank
;
LOG
(
INFO
)
<<
"feature type: "
<<
opts
.
use_fbank
?
"fbank"
:
"linear"
;
if
(
opts
.
use_fbank
)
{
if
(
opts
.
use_fbank
)
{
opts
.
to_float32
=
false
;
opts
.
to_float32
=
false
;
frame_opts
.
window_type
=
"povey"
;
frame_opts
.
window_type
=
"povey"
;
...
...
speechx/speechx/nnet/decodable.cc
浏览文件 @
5c8725e8
...
@@ -157,4 +157,10 @@ void Decodable::Reset() {
...
@@ -157,4 +157,10 @@ void Decodable::Reset() {
nnet_out_cache_
.
Resize
(
0
,
0
);
nnet_out_cache_
.
Resize
(
0
,
0
);
}
}
void
Decodable
::
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
){
nnet_
->
AttentionRescoring
(
hyps
,
reverse_weight
,
rescoring_score
);
}
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/nnet/decodable.h
浏览文件 @
5c8725e8
...
@@ -30,23 +30,31 @@ class Decodable : public kaldi::DecodableInterface {
...
@@ -30,23 +30,31 @@ class Decodable : public kaldi::DecodableInterface {
// void Init(DecodableOpts config);
// void Init(DecodableOpts config);
// nnet logprob output
// nnet logprob output
, used by wfst
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
// nnet output
virtual
bool
FrameLikelihood
(
int32
frame
,
std
::
vector
<
kaldi
::
BaseFloat
>*
likelihood
);
// forward nnet with feats
bool
AdvanceChunk
();
// forward nnet with feats, and get nnet output
bool
AdvanceChunk
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
logprobs
,
int
*
vocab_dim
);
void
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
);
virtual
bool
IsLastFrame
(
int32
frame
);
virtual
bool
IsLastFrame
(
int32
frame
);
// nnet output dim, e.g. vocab size
// nnet output dim, e.g. vocab size
virtual
int32
NumIndices
()
const
;
virtual
int32
NumIndices
()
const
;
// nnet prob output
virtual
bool
FrameLikelihood
(
int32
frame
,
std
::
vector
<
kaldi
::
BaseFloat
>*
likelihood
);
virtual
int32
NumFramesReady
()
const
;
virtual
int32
NumFramesReady
()
const
;
// for offline test
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
void
Reset
();
void
Reset
();
bool
IsInputFinished
()
const
{
return
frontend_
->
IsFinished
();
}
bool
IsInputFinished
()
const
{
return
frontend_
->
IsFinished
();
}
...
@@ -57,11 +65,8 @@ class Decodable : public kaldi::DecodableInterface {
...
@@ -57,11 +65,8 @@ class Decodable : public kaldi::DecodableInterface {
std
::
shared_ptr
<
NnetInterface
>
Nnet
()
{
return
nnet_
;
}
std
::
shared_ptr
<
NnetInterface
>
Nnet
()
{
return
nnet_
;
}
// forward nnet with feats
// for offline test
bool
AdvanceChunk
();
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
// forward nnet with feats, and get nnet output
bool
AdvanceChunk
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
logprobs
,
int
*
vocab_dim
);
private:
private:
std
::
shared_ptr
<
FrontendInterface
>
frontend_
;
std
::
shared_ptr
<
FrontendInterface
>
frontend_
;
...
...
speechx/speechx/nnet/ds2_nnet.h
浏览文件 @
5c8725e8
...
@@ -15,56 +15,11 @@
...
@@ -15,56 +15,11 @@
#include <numeric>
#include <numeric>
#include "base/common.h"
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
#include "nnet/nnet_itf.h"
#include "nnet/nnet_itf.h"
#include "paddle_inference_api.h"
#include "paddle_inference_api.h"
namespace
ppspeech
{
namespace
ppspeech
{
struct
ModelOptions
{
std
::
string
model_path
;
std
::
string
param_path
;
int
thread_num
;
// predictor thread pool size
bool
use_gpu
;
bool
switch_ir_optim
;
std
::
string
input_names
;
std
::
string
output_names
;
std
::
string
cache_names
;
std
::
string
cache_shape
;
bool
enable_fc_padding
;
bool
enable_profile
;
ModelOptions
()
:
model_path
(
""
),
param_path
(
""
),
thread_num
(
2
),
use_gpu
(
false
),
input_names
(
""
),
output_names
(
""
),
cache_names
(
""
),
cache_shape
(
""
),
switch_ir_optim
(
false
),
enable_fc_padding
(
false
),
enable_profile
(
false
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"model-path"
,
&
model_path
,
"model file path"
);
opts
->
Register
(
"model-param"
,
&
param_path
,
"params model file path"
);
opts
->
Register
(
"thread-num"
,
&
thread_num
,
"thread num"
);
opts
->
Register
(
"use-gpu"
,
&
use_gpu
,
"if use gpu"
);
opts
->
Register
(
"input-names"
,
&
input_names
,
"paddle input names"
);
opts
->
Register
(
"output-names"
,
&
output_names
,
"paddle output names"
);
opts
->
Register
(
"cache-names"
,
&
cache_names
,
"cache names"
);
opts
->
Register
(
"cache-shape"
,
&
cache_shape
,
"cache shape"
);
opts
->
Register
(
"switch-ir-optiom"
,
&
switch_ir_optim
,
"paddle SwitchIrOptim option"
);
opts
->
Register
(
"enable-fc-padding"
,
&
enable_fc_padding
,
"paddle EnableFCPadding option"
);
opts
->
Register
(
"enable-profile"
,
&
enable_profile
,
"paddle EnableProfile option"
);
}
};
template
<
typename
T
>
template
<
typename
T
>
class
Tensor
{
class
Tensor
{
...
@@ -100,6 +55,12 @@ class PaddleNnet : public NnetInterface {
...
@@ -100,6 +55,12 @@ class PaddleNnet : public NnetInterface {
const
int32
&
feature_dim
,
const
int32
&
feature_dim
,
NnetOut
*
out
)
override
;
NnetOut
*
out
)
override
;
void
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
)
override
{
VLOG
(
2
)
<<
"deepspeech2 not has AttentionRescoring."
;
}
void
Dim
();
void
Dim
();
void
Reset
()
override
;
void
Reset
()
override
;
...
...
speechx/speechx/nnet/nnet_itf.h
浏览文件 @
5c8725e8
...
@@ -18,9 +18,56 @@
...
@@ -18,9 +18,56 @@
#include "base/basic_types.h"
#include "base/basic_types.h"
#include "kaldi/base/kaldi-types.h"
#include "kaldi/base/kaldi-types.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace
ppspeech
{
namespace
ppspeech
{
struct
ModelOptions
{
std
::
string
model_path
;
std
::
string
param_path
;
int
thread_num
;
// predictor thread pool size for ds2;
bool
use_gpu
;
bool
switch_ir_optim
;
std
::
string
input_names
;
std
::
string
output_names
;
std
::
string
cache_names
;
std
::
string
cache_shape
;
bool
enable_fc_padding
;
bool
enable_profile
;
ModelOptions
()
:
model_path
(
""
),
param_path
(
""
),
thread_num
(
1
),
use_gpu
(
false
),
input_names
(
""
),
output_names
(
""
),
cache_names
(
""
),
cache_shape
(
""
),
switch_ir_optim
(
false
),
enable_fc_padding
(
false
),
enable_profile
(
false
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"model-path"
,
&
model_path
,
"model file path"
);
opts
->
Register
(
"model-param"
,
&
param_path
,
"params model file path"
);
opts
->
Register
(
"thread-num"
,
&
thread_num
,
"thread num"
);
opts
->
Register
(
"use-gpu"
,
&
use_gpu
,
"if use gpu"
);
opts
->
Register
(
"input-names"
,
&
input_names
,
"paddle input names"
);
opts
->
Register
(
"output-names"
,
&
output_names
,
"paddle output names"
);
opts
->
Register
(
"cache-names"
,
&
cache_names
,
"cache names"
);
opts
->
Register
(
"cache-shape"
,
&
cache_shape
,
"cache shape"
);
opts
->
Register
(
"switch-ir-optiom"
,
&
switch_ir_optim
,
"paddle SwitchIrOptim option"
);
opts
->
Register
(
"enable-fc-padding"
,
&
enable_fc_padding
,
"paddle EnableFCPadding option"
);
opts
->
Register
(
"enable-profile"
,
&
enable_profile
,
"paddle EnableProfile option"
);
}
};
struct
NnetOut
{
struct
NnetOut
{
// nnet out. maybe logprob or prob. Almost time this is logprob.
// nnet out. maybe logprob or prob. Almost time this is logprob.
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
logprobs
;
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
logprobs
;
...
@@ -45,6 +92,10 @@ class NnetInterface {
...
@@ -45,6 +92,10 @@ class NnetInterface {
const
int32
&
feature_dim
,
const
int32
&
feature_dim
,
NnetOut
*
out
)
=
0
;
NnetOut
*
out
)
=
0
;
virtual
void
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
)
=
0
;
// reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_.
// reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_.
virtual
void
Reset
()
=
0
;
virtual
void
Reset
()
=
0
;
...
...
speechx/speechx/nnet/u2_nnet.cc
浏览文件 @
5c8725e8
...
@@ -166,7 +166,7 @@ void U2Nnet::Warmup() {
...
@@ -166,7 +166,7 @@ void U2Nnet::Warmup() {
Reset
();
Reset
();
}
}
U2Nnet
::
U2Nnet
(
const
U2
ModelOptions
&
opts
)
:
opts_
(
opts
)
{
U2Nnet
::
U2Nnet
(
const
ModelOptions
&
opts
)
:
opts_
(
opts
)
{
LoadModel
(
opts_
.
model_path
);
LoadModel
(
opts_
.
model_path
);
}
}
...
...
speechx/speechx/nnet/u2_nnet.h
浏览文件 @
5c8725e8
...
@@ -17,28 +17,14 @@
...
@@ -17,28 +17,14 @@
#include "base/common.h"
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
#include "nnet/nnet_itf.h"
#include "nnet/nnet_itf.h"
#include "paddle/extension.h"
#include "paddle/extension.h"
#include "paddle/jit/all.h"
#include "paddle/jit/all.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/api/all.h"
namespace
ppspeech
{
namespace
ppspeech
{
struct
U2ModelOptions
{
std
::
string
model_path
;
int
thread_num
;
bool
use_gpu
;
U2ModelOptions
()
:
model_path
(
""
),
thread_num
(
1
),
use_gpu
(
false
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"model-path"
,
&
model_path
,
"model file path"
);
opts
->
Register
(
"thread-num"
,
&
thread_num
,
"thread num"
);
opts
->
Register
(
"use-gpu"
,
&
use_gpu
,
"if use gpu"
);
}
};
class
U2NnetBase
:
public
NnetInterface
{
class
U2NnetBase
:
public
NnetInterface
{
public:
public:
...
@@ -65,10 +51,6 @@ class U2NnetBase : public NnetInterface {
...
@@ -65,10 +51,6 @@ class U2NnetBase : public NnetInterface {
std
::
vector
<
kaldi
::
BaseFloat
>*
ctc_probs
,
std
::
vector
<
kaldi
::
BaseFloat
>*
ctc_probs
,
int32
*
vocab_dim
);
int32
*
vocab_dim
);
virtual
void
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
)
=
0
;
protected:
protected:
virtual
void
ForwardEncoderChunkImpl
(
virtual
void
ForwardEncoderChunkImpl
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
chunk_feats
,
const
std
::
vector
<
kaldi
::
BaseFloat
>&
chunk_feats
,
...
@@ -102,7 +84,7 @@ class U2NnetBase : public NnetInterface {
...
@@ -102,7 +84,7 @@ class U2NnetBase : public NnetInterface {
class
U2Nnet
:
public
U2NnetBase
{
class
U2Nnet
:
public
U2NnetBase
{
public:
public:
U2Nnet
(
const
U2
ModelOptions
&
opts
);
U2Nnet
(
const
ModelOptions
&
opts
);
U2Nnet
(
const
U2Nnet
&
other
);
U2Nnet
(
const
U2Nnet
&
other
);
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
...
@@ -143,7 +125,7 @@ class U2Nnet : public U2NnetBase {
...
@@ -143,7 +125,7 @@ class U2Nnet : public U2NnetBase {
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>*
encoder_out
)
const
;
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>*
encoder_out
)
const
;
private:
private:
U2
ModelOptions
opts_
;
ModelOptions
opts_
;
phi
::
Place
dev_
;
phi
::
Place
dev_
;
std
::
shared_ptr
<
paddle
::
jit
::
Layer
>
model_
{
nullptr
};
std
::
shared_ptr
<
paddle
::
jit
::
Layer
>
model_
{
nullptr
};
...
...
speechx/speechx/nnet/u2_nnet_main.cc
浏览文件 @
5c8725e8
...
@@ -58,7 +58,7 @@ int main(int argc, char* argv[]) {
...
@@ -58,7 +58,7 @@ int main(int argc, char* argv[]) {
kaldi
::
BaseFloatMatrixWriter
nnet_out_writer
(
FLAGS_nnet_prob_wspecifier
);
kaldi
::
BaseFloatMatrixWriter
nnet_out_writer
(
FLAGS_nnet_prob_wspecifier
);
kaldi
::
BaseFloatMatrixWriter
nnet_encoder_outs_writer
(
FLAGS_nnet_encoder_outs_wspecifier
);
kaldi
::
BaseFloatMatrixWriter
nnet_encoder_outs_writer
(
FLAGS_nnet_encoder_outs_wspecifier
);
ppspeech
::
U2
ModelOptions
model_opts
;
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
FLAGS_model_path
;
model_opts
.
model_path
=
FLAGS_model_path
;
int32
chunk_size
=
int32
chunk_size
=
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录