Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8690a00b
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8690a00b
编写于
9月 13, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add feature pipeline layer(cmvn, fbank), but to_static and jit.layer output is not equal
上级
67709155
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
206 addition
and
36 deletion
+206
-36
paddlespeech/audio/compliance/kaldi.py
paddlespeech/audio/compliance/kaldi.py
+11
-11
paddlespeech/s2t/exps/u2/bin/test_wav.py
paddlespeech/s2t/exps/u2/bin/test_wav.py
+3
-0
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+51
-24
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+58
-0
paddlespeech/s2t/modules/cmvn.py
paddlespeech/s2t/modules/cmvn.py
+9
-1
paddlespeech/s2t/modules/fbank.py
paddlespeech/s2t/modules/fbank.py
+74
-0
未找到文件。
paddlespeech/audio/compliance/kaldi.py
浏览文件 @
8690a00b
...
...
@@ -74,16 +74,16 @@ def _feature_window_function(
window_size
:
int
,
blackman_coeff
:
float
,
dtype
:
int
,
)
->
Tensor
:
if
window_type
==
HANNING
:
if
window_type
==
"hann"
:
return
get_window
(
'hann'
,
window_size
,
fftbins
=
False
,
dtype
=
dtype
)
elif
window_type
==
HAMMING
:
elif
window_type
==
"hamming"
:
return
get_window
(
'hamming'
,
window_size
,
fftbins
=
False
,
dtype
=
dtype
)
elif
window_type
==
POVEY
:
elif
window_type
==
"povey"
:
return
get_window
(
'hann'
,
window_size
,
fftbins
=
False
,
dtype
=
dtype
).
pow
(
0.85
)
elif
window_type
==
RECTANGULAR
:
elif
window_type
==
"rect"
:
return
paddle
.
ones
([
window_size
],
dtype
=
dtype
)
elif
window_type
==
BLACKMAN
:
elif
window_type
==
"blackman"
:
a
=
2
*
math
.
pi
/
(
window_size
-
1
)
window_function
=
paddle
.
arange
(
window_size
,
dtype
=
dtype
)
return
(
blackman_coeff
-
0.5
*
paddle
.
cos
(
a
*
window_function
)
+
...
...
@@ -216,7 +216,7 @@ def spectrogram(waveform: Tensor,
sr
:
int
=
16000
,
snip_edges
:
bool
=
True
,
subtract_mean
:
bool
=
False
,
window_type
:
str
=
POVEY
)
->
Tensor
:
window_type
:
str
=
"povey"
)
->
Tensor
:
"""Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
Args:
...
...
@@ -236,7 +236,7 @@ def spectrogram(waveform: Tensor,
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
window_type (str, optional): Choose type of window for FFT computation. Defaults to
POVEY
.
window_type (str, optional): Choose type of window for FFT computation. Defaults to
"povey"
.
Returns:
Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames
...
...
@@ -418,11 +418,11 @@ def fbank(waveform: Tensor,
vtln_high
:
float
=-
500.0
,
vtln_low
:
float
=
100.0
,
vtln_warp
:
float
=
1.0
,
window_type
:
str
=
POVEY
)
->
Tensor
:
window_type
:
str
=
"povey"
)
->
Tensor
:
"""Compute and return filter banks from a waveform. The output is identical to Kaldi's.
Args:
waveform (Tensor): A waveform tensor with shape `(C, T)`.
waveform (Tensor): A waveform tensor with shape `(C, T)`.
`C` is in the range [0,1].
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
channel (int, optional): Select the channel of waveform. Defaults to -1.
dither (float, optional): Dithering constant . Defaults to 0.0.
...
...
@@ -448,7 +448,7 @@ def fbank(waveform: Tensor,
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
window_type (str, optional): Choose type of window for FFT computation. Defaults to
POVEY
.
window_type (str, optional): Choose type of window for FFT computation. Defaults to
"povey"
.
Returns:
Tensor: A filter banks tensor with shape `(m, n_mels)`.
...
...
@@ -537,7 +537,7 @@ def mfcc(waveform: Tensor,
vtln_high
:
float
=-
500.0
,
vtln_low
:
float
=
100.0
,
vtln_warp
:
float
=
1.0
,
window_type
:
str
=
POVEY
)
->
Tensor
:
window_type
:
str
=
"povey"
)
->
Tensor
:
"""Compute and return mel frequency cepstral coefficients from a waveform. The output is
identical to Kaldi's.
...
...
paddlespeech/s2t/exps/u2/bin/test_wav.py
浏览文件 @
8690a00b
...
...
@@ -18,6 +18,7 @@ from pathlib import Path
import
paddle
import
soundfile
import
numpy
as
np
from
yacs.config
import
CfgNode
from
paddlespeech.audio.transform.transformation
import
Transformation
...
...
@@ -77,6 +78,8 @@ class U2Infer():
feat
=
self
.
preprocessing
(
audio
,
**
self
.
preprocess_args
)
logger
.
info
(
f
"feat shape:
{
feat
.
shape
}
"
)
np
.
savetxt
(
"feat.transform.txt"
,
feat
)
ilen
=
paddle
.
to_tensor
(
feat
.
shape
[
0
])
xs
=
paddle
.
to_tensor
(
feat
,
dtype
=
'float32'
).
unsqueeze
(
axis
=
0
)
decode_config
=
self
.
config
.
decode
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
8690a00b
...
...
@@ -474,13 +474,20 @@ class U2Tester(U2Trainer):
def
export
(
self
):
infer_model
,
input_spec
=
self
.
load_inferspec
()
infer_model
.
eval
()
paddle
.
set_device
(
'cpu'
)
assert
isinstance
(
input_spec
,
list
),
type
(
input_spec
)
assert
isinstance
(
input_spec
,
(
list
,
tuple
)
),
type
(
input_spec
)
batch_size
,
feat_dim
,
model_size
,
num_left_chunks
=
input_spec
######################### infer_model.forward_encoder_chunk zero tensor online ############
# TODO: 80(feature dim) be configable
######################## infer_model.forward_encoder_chunk ############
input_spec
=
[
# (T,), int16
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int16'
),
]
infer_model
.
forward_feature
=
paddle
.
jit
.
to_static
(
infer_model
.
forward_feature
,
input_spec
=
input_spec
)
######################### infer_model.forward_encoder_chunk ############
input_spec
=
[
# xs, (B, T, D)
paddle
.
static
.
InputSpec
(
shape
=
[
batch_size
,
None
,
feat_dim
],
dtype
=
'float32'
),
...
...
@@ -499,8 +506,16 @@ class U2Tester(U2Trainer):
infer_model
.
forward_encoder_chunk
=
paddle
.
jit
.
to_static
(
infer_model
.
forward_encoder_chunk
,
input_spec
=
input_spec
)
######################### infer_model.ctc_activation ########################
input_spec
=
[
# encoder_out, (B,T,D)
paddle
.
static
.
InputSpec
(
shape
=
[
batch_size
,
None
,
model_size
],
dtype
=
'float32'
)
]
infer_model
.
ctc_activation
=
paddle
.
jit
.
to_static
(
infer_model
.
ctc_activation
,
input_spec
=
input_spec
)
######################### infer_model.forward_attention_decoder ########################
# TODO: 512(encoder_output) be configable. 1 for BatchSize
input_spec
=
[
# hyps, (B, U)
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
],
dtype
=
'int64'
),
...
...
@@ -512,17 +527,11 @@ class U2Tester(U2Trainer):
infer_model
.
forward_attention_decoder
=
paddle
.
jit
.
to_static
(
infer_model
.
forward_attention_decoder
,
input_spec
=
input_spec
)
######################### infer_model.ctc_activation ########################
input_spec
=
[
# encoder_out, (B,T,D)
paddle
.
static
.
InputSpec
(
shape
=
[
batch_size
,
None
,
model_size
],
dtype
=
'float32'
)
]
infer_model
.
ctc_activation
=
paddle
.
jit
.
to_static
(
infer_model
.
ctc_activation
,
input_spec
=
input_spec
)
# jit save
logger
.
info
(
f
"export save:
{
self
.
args
.
export_path
}
"
)
paddle
.
jit
.
save
(
infer_model
,
self
.
args
.
export_path
,
combine_params
=
True
,
skip_forward
=
True
)
# test dy2static
def
flatten
(
out
):
if
isinstance
(
out
,
paddle
.
Tensor
):
...
...
@@ -536,26 +545,44 @@ class U2Tester(U2Trainer):
flatten_out
.
append
(
var
)
return
flatten_out
xs1
=
paddle
.
rand
(
shape
=
[
1
,
67
,
80
],
dtype
=
'float32'
)
# forward_encoder_chunk dygraph
xs1
=
paddle
.
full
([
1
,
67
,
80
],
0.1
,
dtype
=
'float32'
)
offset
=
paddle
.
to_tensor
([
0
],
dtype
=
'int32'
)
required_cache_size
=
num_left_chunks
att_cache
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
cnn_cache
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
xs
,
att_cache
,
cnn_cache
=
infer_model
.
forward_encoder_chunk
(
xs1
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
xs2
=
paddle
.
rand
(
shape
=
[
1
,
67
,
80
],
dtype
=
'float32'
)
offset
=
paddle
.
to_tensor
([
16
],
dtype
=
'int32'
)
out1
=
infer_model
.
forward_encoder_chunk
(
xs2
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
print
(
'py encoder'
,
out1
)
xs_d
,
att_cache_d
,
cnn_cache_d
=
infer_model
.
forward_encoder_chunk
(
xs1
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
import
soundfile
audio
,
sample_rate
=
soundfile
.
read
(
'./zh.wav'
,
dtype
=
"int16"
,
always_2d
=
True
)
audio
=
audio
[:,
0
]
logger
.
info
(
f
"audio shape:
{
audio
.
shape
}
"
)
audio
=
paddle
.
to_tensor
(
audio
,
paddle
.
int16
)
feat_d
=
infer_model
.
forward_feature
(
audio
)
logger
.
info
(
f
"
{
feat_d
}
"
)
np
.
savetxt
(
"feat.tostatic.txt"
,
feat_d
)
# load static model
from
paddle.jit.layer
import
Layer
layer
=
Layer
()
layer
.
load
(
self
.
args
.
export_path
,
paddle
.
CPUPlace
())
xs1
=
paddle
.
full
([
1
,
7
,
80
],
0.1
,
dtype
=
'float32'
)
# forward_encoder_chunk static
xs1
=
paddle
.
full
([
1
,
67
,
80
],
0.1
,
dtype
=
'float32'
)
offset
=
paddle
.
to_tensor
([
0
],
dtype
=
'int32'
)
att_cache
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
cnn_cache
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
cnn_cache
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
func
=
getattr
(
layer
,
'forward_encoder_chunk'
)
xs
,
att_cache
,
cnn_cache
=
func
(
xs1
,
offset
,
att_cache
,
cnn_cache
)
print
(
'py static encoder'
,
xs
)
xs_s
,
att_cache_s
,
cnn_cache_s
=
func
(
xs1
,
offset
,
att_cache
,
cnn_cache
)
np
.
testing
.
assert_allclose
(
xs_d
,
xs_s
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
att_cache_d
,
att_cache_s
,
atol
=
1e-4
)
np
.
testing
.
assert_allclose
(
cnn_cache_d
,
cnn_cache_s
,
atol
=
1e-4
)
# logger.info(f"forward_encoder_chunk output: {xs_s}")
# forward_feature static
func
=
getattr
(
layer
,
'forward_feature'
)
feat_s
=
func
(
audio
)[
0
]
logger
.
info
(
f
"
{
feat_s
}
"
)
np
.
testing
.
assert_allclose
(
feat_d
,
feat_s
,
atol
=
1e-5
)
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
8690a00b
...
...
@@ -916,6 +916,50 @@ class U2InferModel(U2Model):
def
__init__
(
self
,
configs
:
dict
):
super
().
__init__
(
configs
)
from
paddlespeech.s2t.modules.fbank
import
KaldiFbank
import
yaml
import
json
import
numpy
as
np
input_dim
=
configs
[
'input_dim'
]
process
=
configs
[
'preprocess_config'
]
with
open
(
process
,
encoding
=
"utf-8"
)
as
f
:
conf
=
yaml
.
safe_load
(
f
)
assert
isinstance
(
conf
,
dict
),
type
(
self
.
conf
)
for
idx
,
process
in
enumerate
(
conf
[
'process'
]):
assert
isinstance
(
process
,
dict
),
type
(
process
)
opts
=
dict
(
process
)
process_type
=
opts
.
pop
(
"type"
)
if
process_type
==
'fbank_kaldi'
:
opts
.
update
({
'n_mels'
:
input_dim
})
opts
[
'dither'
]
=
0.0
self
.
fbank
=
KaldiFbank
(
**
opts
)
logger
.
info
(
f
"
{
self
.
__class__
.
__name__
}
export:
{
self
.
fbank
}
"
)
if
process_type
==
'cmvn_json'
:
# align with paddlespeech.audio.transform.cmvn:GlobalCMVN
std_floor
=
1.0e-20
cmvn
=
opts
[
'cmvn_path'
]
if
isinstance
(
cmvn
,
dict
):
cmvn_stats
=
cmvn
else
:
with
open
(
cmvn
)
as
f
:
cmvn_stats
=
json
.
load
(
f
)
count
=
cmvn_stats
[
'frame_num'
]
mean
=
np
.
array
(
cmvn_stats
[
'mean_stat'
])
/
count
square_sums
=
np
.
array
(
cmvn_stats
[
'var_stat'
])
var
=
square_sums
/
count
-
mean
**
2
std
=
np
.
maximum
(
np
.
sqrt
(
var
),
std_floor
)
istd
=
1.0
/
std
self
.
global_cmvn
=
GlobalCMVN
(
paddle
.
to_tensor
(
mean
,
dtype
=
paddle
.
float
),
paddle
.
to_tensor
(
istd
,
dtype
=
paddle
.
float
))
logger
.
info
(
f
"
{
self
.
__class__
.
__name__
}
export:
{
self
.
global_cmvn
}
"
)
def
forward
(
self
,
feats
,
feats_lengths
,
...
...
@@ -939,3 +983,17 @@ class U2InferModel(U2Model):
# num_decoding_left_chunks=num_decoding_left_chunks,
# simulate_streaming=simulate_streaming)
return
feats
,
feats_lengths
def
forward_feature
(
self
,
x
):
"""feature pipeline.
Args:
x (paddle.Tensor): waveform (T,).
Return:
feat (paddle.Tensor): feature (T, D)
"""
x
=
paddle
.
cast
(
x
,
paddle
.
float32
)
feat
=
self
.
fbank
(
x
)
feat
=
self
.
global_cmvn
(
feat
)
return
feat
\ No newline at end of file
paddlespeech/s2t/modules/cmvn.py
浏览文件 @
8690a00b
...
...
@@ -40,6 +40,14 @@ class GlobalCMVN(nn.Layer):
self
.
register_buffer
(
"mean"
,
mean
)
self
.
register_buffer
(
"istd"
,
istd
)
def
__repr__
(
self
):
return
(
"{name}(mean={mean}, istd={istd}, norm_var={norm_var})"
.
format
(
name
=
self
.
__class__
.
__name__
,
mean
=
self
.
mean
,
istd
=
self
.
istd
,
norm_var
=
self
.
norm_var
))
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""
Args:
...
...
@@ -50,4 +58,4 @@ class GlobalCMVN(nn.Layer):
x
=
x
-
self
.
mean
if
self
.
norm_var
:
x
=
x
*
self
.
istd
return
x
return
x
\ No newline at end of file
paddlespeech/s2t/modules/fbank.py
0 → 100644
浏览文件 @
8690a00b
import
paddle
from
paddle
import
nn
from
paddlespeech.audio.compliance
import
kaldi
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
'KaldiFbank'
]
class
KaldiFbank
(
nn
.
Layer
):
def
__init__
(
self
,
fs
=
16000
,
n_mels
=
80
,
n_shift
=
160
,
# unit:sample, 10ms
win_length
=
400
,
# unit:sample, 25ms
energy_floor
=
0.0
,
dither
=
0.0
):
"""
Args:
fs (int): sample rate of the audio
n_mels (int): number of mel filter banks
n_shift (int): number of points in a frame shift
win_length (int): number of points in a frame windows
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
dither (float): Dithering constant. Default 0.0
"""
super
().
__init__
()
self
.
fs
=
fs
self
.
n_mels
=
n_mels
num_point_ms
=
fs
/
1000
self
.
n_frame_length
=
win_length
/
num_point_ms
self
.
n_frame_shift
=
n_shift
/
num_point_ms
self
.
energy_floor
=
energy_floor
self
.
dither
=
dither
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, "
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
"dither={dither}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_frame_shift
=
self
.
n_frame_shift
,
n_frame_length
=
self
.
n_frame_length
,
dither
=
self
.
dither
,
))
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""
Args:
x (paddle.Tensor): shape (Ti).
Not support: [Time, Channel] and Batch mode.
Returns:
paddle.Tensor: (T, D)
"""
assert
x
.
ndim
==
1
feat
=
kaldi
.
fbank
(
x
.
unsqueeze
(
0
),
# append channel dim, (C, Ti)
n_mels
=
self
.
n_mels
,
frame_length
=
self
.
n_frame_length
,
frame_shift
=
self
.
n_frame_shift
,
dither
=
self
.
dither
,
energy_floor
=
self
.
energy_floor
,
sr
=
self
.
fs
)
assert
feat
.
ndim
==
2
# (T,D)
return
feat
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录