Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
6987751f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6987751f
编写于
10月 12, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix LogLikelihood and add AdvanceChunk
上级
5cc874e1
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
87 addition
and
33 deletion
+87
-33
speechx/speechx/base/common.h
speechx/speechx/base/common.h
+1
-0
speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc
speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc
+4
-4
speechx/speechx/kaldi/decoder/decodable-itf.h
speechx/speechx/kaldi/decoder/decodable-itf.h
+7
-4
speechx/speechx/nnet/decodable.cc
speechx/speechx/nnet/decodable.cc
+50
-14
speechx/speechx/nnet/decodable.h
speechx/speechx/nnet/decodable.h
+5
-1
speechx/speechx/nnet/ds2_nnet.h
speechx/speechx/nnet/ds2_nnet.h
+2
-0
speechx/speechx/nnet/nnet_itf.h
speechx/speechx/nnet/nnet_itf.h
+5
-1
speechx/speechx/nnet/u2_nnet.h
speechx/speechx/nnet/u2_nnet.h
+2
-0
speechx/speechx/nnet/u2_nnet_main.cc
speechx/speechx/nnet/u2_nnet_main.cc
+11
-9
未找到文件。
speechx/speechx/base/common.h
浏览文件 @
6987751f
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <algorithm>
#include <cmath>
#include <condition_variable>
#include <cstring>
#include <deque>
...
...
speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc
浏览文件 @
6987751f
...
...
@@ -47,13 +47,13 @@ int main(int argc, char* argv[]) {
for
(
auto
obj
:
value
.
as_object
())
{
if
(
obj
.
key
()
==
"mean_stat"
)
{
LOG
(
INFO
)
<<
"mean_stat:"
<<
obj
.
value
();
VLOG
(
2
)
<<
"mean_stat:"
<<
obj
.
value
();
}
if
(
obj
.
key
()
==
"var_stat"
)
{
LOG
(
INFO
)
<<
"var_stat: "
<<
obj
.
value
();
VLOG
(
2
)
<<
"var_stat: "
<<
obj
.
value
();
}
if
(
obj
.
key
()
==
"frame_num"
)
{
LOG
(
INFO
)
<<
"frame_num: "
<<
obj
.
value
();
VLOG
(
2
)
<<
"frame_num: "
<<
obj
.
value
();
}
}
...
...
@@ -79,7 +79,7 @@ int main(int argc, char* argv[]) {
cmvn_stats
(
1
,
idx
)
=
var_stat_vec
[
idx
];
}
cmvn_stats
(
0
,
mean_size
)
=
frame_num
;
LOG
(
INFO
)
<<
cmvn_stats
;
VLOG
(
2
)
<<
cmvn_stats
;
kaldi
::
WriteKaldiObject
(
cmvn_stats
,
FLAGS_cmvn_write_path
,
FLAGS_binary
);
LOG
(
INFO
)
<<
"cmvn stats have write into: "
<<
FLAGS_cmvn_write_path
;
...
...
speechx/speechx/kaldi/decoder/decodable-itf.h
浏览文件 @
6987751f
...
...
@@ -101,7 +101,9 @@ namespace kaldi {
*/
class
DecodableInterface
{
public:
/// Returns the log likelihood, which will be negated in the decoder.
virtual
~
DecodableInterface
()
{}
/// Returns the log likelihood(logprob), which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this.
...
...
@@ -143,11 +145,12 @@ class DecodableInterface {
/// this is for compatibility with OpenFst).
virtual
int32
NumIndices
()
const
=
0
;
/// Returns the likelihood(prob), which will be postive in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this.
virtual
bool
FrameLikelihood
(
int32
frame
,
std
::
vector
<
kaldi
::
BaseFloat
>*
likelihood
)
=
0
;
virtual
~
DecodableInterface
()
{}
};
/// @}
}
// namespace Kaldi
...
...
speechx/speechx/nnet/decodable.cc
浏览文件 @
6987751f
...
...
@@ -55,18 +55,10 @@ int32 Decodable::NumIndices() const { return 0; }
// id.
int32
Decodable
::
TokenId2NnetId
(
int32
token_id
)
{
return
token_id
-
1
;
}
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
CHECK_LE
(
index
,
nnet_out_cache_
.
NumCols
());
CHECK_LE
(
frame
,
frames_ready_
);
int32
frame_idx
=
frame
-
frame_offset_
;
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
return
acoustic_scale_
*
std
::
log
(
nnet_out_cache_
(
frame_idx
,
TokenId2NnetId
(
index
))
+
std
::
numeric_limits
<
float
>::
min
());
}
bool
Decodable
::
EnsureFrameHaveComputed
(
int32
frame
)
{
// decoding frame
if
(
frame
>=
frames_ready_
)
{
return
AdvanceChunk
();
}
...
...
@@ -74,26 +66,48 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) {
}
bool
Decodable
::
AdvanceChunk
()
{
kaldi
::
Timer
timer
;
// read feats
Vector
<
BaseFloat
>
features
;
if
(
frontend_
==
NULL
||
frontend_
->
Read
(
&
features
)
==
false
)
{
// no feat or frontend_ not init.
return
false
;
}
VLOG
(
2
)
<<
"Forward with "
<<
features
.
Dim
()
<<
" frames."
;
// forward feats
NnetOut
out
;
nnet_
->
FeedForward
(
features
,
frontend_
->
Dim
(),
&
out
);
int32
&
vocab_dim
=
out
.
vocab_dim
;
Vector
<
BaseFloat
>&
probs
=
out
.
logprobs
;
Vector
<
BaseFloat
>&
log
probs
=
out
.
logprobs
;
// cache nnet outupts
nnet_out_cache_
.
Resize
(
probs
.
Dim
()
/
vocab_dim
,
vocab_dim
);
nnet_out_cache_
.
CopyRowsFromVec
(
probs
);
nnet_out_cache_
.
Resize
(
log
probs
.
Dim
()
/
vocab_dim
,
vocab_dim
);
nnet_out_cache_
.
CopyRowsFromVec
(
log
probs
);
// update state
// update state
, decoding frame.
frame_offset_
=
frames_ready_
;
frames_ready_
+=
nnet_out_cache_
.
NumRows
();
VLOG
(
2
)
<<
"Forward feat chunk cost: "
<<
timer
.
Elapsed
()
<<
" sec."
;
return
true
;
}
bool
Decodable
::
AdvanceChunk
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
logprobs
,
int
*
vocab_dim
)
{
if
(
AdvanceChunk
()
==
false
)
{
return
false
;
}
int
nrows
=
nnet_out_cache_
.
NumRows
();
CHECK
(
nrows
==
(
frames_ready_
-
frame_offset_
));
if
(
nrows
<=
0
){
LOG
(
WARNING
)
<<
"No new nnet out in cache."
;
return
false
;
}
logprobs
->
Resize
(
nnet_out_cache_
.
NumRows
()
*
nnet_out_cache_
.
NumCols
());
logprobs
->
CopyRowsFromMat
(
nnet_out_cache_
);
*
vocab_dim
=
nnet_out_cache_
.
NumCols
();
return
true
;
}
...
...
@@ -113,6 +127,28 @@ bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
return
true
;
}
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
if
(
EnsureFrameHaveComputed
(
frame
)
==
false
)
{
return
false
;
}
CHECK_LE
(
index
,
nnet_out_cache_
.
NumCols
());
CHECK_LE
(
frame
,
frames_ready_
);
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
BaseFloat
logprob
=
0.0
;
int32
frame_idx
=
frame
-
frame_offset_
;
BaseFloat
nnet_out
=
nnet_out_cache_
(
frame_idx
,
TokenId2NnetId
(
index
));
if
(
nnet_
->
IsLogProb
()){
logprob
=
nnet_out
;
}
else
{
logprob
=
std
::
log
(
nnet_out
+
std
::
numeric_limits
<
float
>::
epsilon
());
}
CHECK
(
!
std
::
isnan
(
logprob
)
&&
!
std
::
isinf
(
logprob
));
return
acoustic_scale_
*
logprob
;
}
void
Decodable
::
Reset
()
{
if
(
frontend_
!=
nullptr
)
frontend_
->
Reset
();
if
(
nnet_
!=
nullptr
)
nnet_
->
Reset
();
...
...
speechx/speechx/nnet/decodable.h
浏览文件 @
6987751f
...
...
@@ -57,9 +57,13 @@ class Decodable : public kaldi::DecodableInterface {
std
::
shared_ptr
<
NnetInterface
>
Nnet
()
{
return
nnet_
;
}
private:
// forward nnet with feats
bool
AdvanceChunk
();
// forward nnet with feats, and get nnet output
bool
AdvanceChunk
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
logprobs
,
int
*
vocab_dim
);
private:
std
::
shared_ptr
<
FrontendInterface
>
frontend_
;
std
::
shared_ptr
<
NnetInterface
>
nnet_
;
...
...
speechx/speechx/nnet/ds2_nnet.h
浏览文件 @
6987751f
...
...
@@ -104,6 +104,8 @@ class PaddleNnet : public NnetInterface {
void
Reset
()
override
;
bool
IsLogProb
()
override
{
return
false
;
}
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
const
std
::
string
&
name
);
...
...
speechx/speechx/nnet/nnet_itf.h
浏览文件 @
6987751f
...
...
@@ -39,7 +39,8 @@ class NnetInterface {
// forward feat with nnet.
// nnet do not cache feats, feats cached by frontend.
// nnet cache model outputs, i.e. logprobs/encoder_outs.
// nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache,
// frame_offset.
virtual
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
const
int32
&
feature_dim
,
NnetOut
*
out
)
=
0
;
...
...
@@ -47,6 +48,9 @@ class NnetInterface {
// reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_.
virtual
void
Reset
()
=
0
;
// true, nnet output is logprob; otherwise is prob,
virtual
bool
IsLogProb
()
=
0
;
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual
void
EncoderOuts
(
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>*
encoder_out
)
const
=
0
;
...
...
speechx/speechx/nnet/u2_nnet.h
浏览文件 @
6987751f
...
...
@@ -111,6 +111,8 @@ class U2Nnet : public U2NnetBase {
void
Reset
()
override
;
bool
IsLogProb
()
override
{
return
true
;
}
void
Dim
();
void
LoadModel
(
const
std
::
string
&
model_path_w_prefix
);
...
...
speechx/speechx/nnet/u2_nnet_main.cc
浏览文件 @
6987751f
...
...
@@ -98,6 +98,7 @@ int main(int argc, char* argv[]) {
// }
int32
frame_idx
=
0
;
int
vocab_dim
=
0
;
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>
prob_vec
;
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>
encoder_out_vec
;
int32
ori_feature_len
=
feature
.
NumRows
();
...
...
@@ -138,17 +139,17 @@ int main(int argc, char* argv[]) {
}
// get nnet outputs
vector
<
kaldi
::
BaseFloat
>
prob
;
while
(
decodable
->
FrameLikelihood
(
frame_idx
,
&
prob
))
{
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
vec_tmp
(
prob
.
size
());
std
::
memcpy
(
vec_tmp
.
Data
(),
prob
.
data
(),
sizeof
(
kaldi
::
BaseFloat
)
*
prob
.
size
());
kaldi
::
Timer
timer
;
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
logprobs
;
bool
isok
=
decodable
->
AdvanceChunk
(
&
logprobs
,
&
vocab_dim
);
CHECK
(
isok
==
true
);
for
(
int
row_idx
=
0
;
row_idx
<
logprobs
.
Dim
()
/
vocab_dim
;
row_idx
++
)
{
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
vec_tmp
(
vocab_dim
);
std
::
memcpy
(
vec_tmp
.
Data
(),
logprobs
.
Data
()
+
row_idx
*
vocab_dim
,
sizeof
(
kaldi
::
BaseFloat
)
*
vocab_dim
);
prob_vec
.
push_back
(
vec_tmp
);
frame_idx
++
;
}
VLOG
(
2
)
<<
"frame_idx: "
<<
frame_idx
<<
" elapsed: "
<<
timer
.
Elapsed
()
<<
" sec."
;
}
// get encoder out
...
...
@@ -196,8 +197,9 @@ int main(int argc, char* argv[]) {
++
num_done
;
}
double
elapsed
=
timer
.
Elapsed
();
LOG
(
INFO
)
<<
" cost:"
<<
elapsed
<<
" sec"
;
LOG
(
INFO
)
<<
"
Program
cost:"
<<
elapsed
<<
" sec"
;
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" utterances, "
<<
num_err
<<
" with errors."
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录