Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5e7e582d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5e7e582d
编写于
4月 07, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
2fa6bbbe
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
116 addition
and
41 deletion
+116
-41
deepspeech/__init__.py
deepspeech/__init__.py
+62
-5
deepspeech/models/u2.py
deepspeech/models/u2.py
+5
-5
deepspeech/modules/attention.py
deepspeech/modules/attention.py
+2
-2
deepspeech/modules/embedding.py
deepspeech/modules/embedding.py
+6
-11
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+8
-6
deepspeech/modules/encoder_layer.py
deepspeech/modules/encoder_layer.py
+2
-2
deepspeech/utils/layer_tools.py
deepspeech/utils/layer_tools.py
+9
-4
deepspeech/utils/tensor_utils.py
deepspeech/utils/tensor_utils.py
+1
-1
tests/u2_model_test.py
tests/u2_model_test.py
+21
-5
未找到文件。
deepspeech/__init__.py
浏览文件 @
5e7e582d
...
...
@@ -22,7 +22,8 @@ import paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
#TODO(Hui Zhang): remove fluid import
from
paddle.fluid
import
core
logger
=
logging
.
getLogger
(
__name__
)
########### hcak logging #############
...
...
@@ -44,10 +45,51 @@ paddle.int = 'int32'
paddle
.
int64
=
'int64'
paddle
.
long
=
'int64'
paddle
.
uint8
=
'uint8'
paddle
.
uint16
=
'uint16'
paddle
.
complex64
=
'complex64'
paddle
.
complex128
=
'complex128'
paddle
.
cdouble
=
'complex128'
def
convert_dtype_to_string
(
tensor_dtype
):
"""
Convert the data type in numpy to the data type in Paddle
Args:
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
Returns:
core.VarDesc.VarType: the data type in Paddle.
"""
dtype
=
tensor_dtype
if
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
return
paddle
.
float32
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP64
:
return
paddle
.
float64
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
return
paddle
.
float16
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT32
:
return
paddle
.
int32
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT16
:
return
paddle
.
int16
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT64
:
return
paddle
.
int64
elif
dtype
==
core
.
VarDesc
.
VarType
.
BOOL
:
return
paddle
.
bool
elif
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return
paddle
.
uint16
elif
dtype
==
core
.
VarDesc
.
VarType
.
UINT8
:
return
paddle
.
uint8
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT8
:
return
paddle
.
int8
elif
dtype
==
core
.
VarDesc
.
VarType
.
COMPLEX64
:
return
paddle
.
complex64
elif
dtype
==
core
.
VarDesc
.
VarType
.
COMPLEX128
:
return
paddle
.
complex128
else
:
raise
ValueError
(
"Not supported tensor dtype %s"
%
dtype
)
if
not
hasattr
(
paddle
,
'softmax'
):
logger
.
warn
(
"register user softmax to paddle, remove this when fixed!"
)
setattr
(
paddle
,
'softmax'
,
paddle
.
nn
.
functional
.
softmax
)
...
...
@@ -126,7 +168,9 @@ if not hasattr(paddle.Tensor, 'new_full'):
def
eq
(
xs
:
paddle
.
Tensor
,
ys
:
Union
[
paddle
.
Tensor
,
float
])
->
paddle
.
Tensor
:
return
xs
.
equal
(
paddle
.
to_tensor
(
ys
,
dtype
=
xs
.
dtype
,
place
=
xs
.
place
))
return
xs
.
equal
(
paddle
.
to_tensor
(
ys
,
dtype
=
convert_dtype_to_string
(
xs
.
dtype
),
place
=
xs
.
place
))
if
not
hasattr
(
paddle
.
Tensor
,
'eq'
):
...
...
@@ -184,10 +228,21 @@ if not hasattr(paddle.Tensor, 'view_as'):
paddle
.
Tensor
.
view_as
=
view_as
def
is_broadcastable
(
shp1
,
shp2
):
for
a
,
b
in
zip
(
shp1
[::
-
1
],
shp2
[::
-
1
]):
if
a
==
1
or
b
==
1
or
a
==
b
:
pass
else
:
return
False
return
True
def
masked_fill
(
xs
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
value
:
Union
[
float
,
int
]):
assert
xs
.
shape
==
mask
.
shape
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
==
True
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
mask
=
mask
.
broadcast_to
(
bshape
)
trues
=
paddle
.
ones_like
(
xs
)
*
value
xs
=
paddle
.
where
(
mask
,
trues
,
xs
)
return
xs
...
...
@@ -202,7 +257,9 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
def
masked_fill_
(
xs
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
value
:
Union
[
float
,
int
]):
assert
xs
.
shape
==
mask
.
shape
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
==
True
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
mask
=
mask
.
broadcast_to
(
bshape
)
trues
=
paddle
.
ones_like
(
xs
)
*
value
ret
=
paddle
.
where
(
mask
,
trues
,
xs
)
paddle
.
assign
(
ret
,
output
=
xs
)
...
...
@@ -414,4 +471,4 @@ if not hasattr(paddle.nn, 'ConstantPad2d'):
if
not
hasattr
(
paddle
.
jit
,
'export'
):
logger
.
warn
(
"register user export to paddle.jit, remove this when fixed!"
)
setattr
(
paddle
.
jit
,
'export'
,
paddle
.
jit
.
to_static
)
\ No newline at end of file
setattr
(
paddle
.
jit
,
'export'
,
paddle
.
jit
.
to_static
)
deepspeech/models/u2.py
浏览文件 @
5e7e582d
...
...
@@ -39,7 +39,7 @@ from deepspeech.modules.encoder import ConformerEncoder
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.modules.ctc
import
CTCDecoder
from
deepspeech.modules.decoder
import
TransformerDecoder
from
deepspeech.modules.l
abel_smoothing_l
oss
import
LabelSmoothingLoss
from
deepspeech.modules.loss
import
LabelSmoothingLoss
from
deepspeech.frontend.utility
import
load_cmvn
...
...
@@ -633,7 +633,7 @@ class U2Model(nn.Module):
class
U2TransformerModel
(
U2Model
):
def
__init__
(
configs
:
dict
):
def
__init__
(
self
,
configs
:
dict
):
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'cmvn_file_type'
])
...
...
@@ -655,7 +655,7 @@ class U2TransformerModel(U2Model):
**
configs
[
'decoder_conf'
])
ctc
=
CTCDecoder
(
vocab_size
,
encoder
.
output_size
())
s
elf
.
__init__
(
s
uper
()
.
__init__
(
vocab_size
=
vocab_size
,
encoder
=
encoder
,
decoder
=
decoder
,
...
...
@@ -664,7 +664,7 @@ class U2TransformerModel(U2Model):
class
U2ConformerModel
(
U2Model
):
def
__init__
(
configs
:
dict
):
def
__init__
(
self
,
configs
:
dict
):
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'cmvn_file_type'
])
...
...
@@ -686,7 +686,7 @@ class U2ConformerModel(U2Model):
**
configs
[
'decoder_conf'
])
ctc
=
CTCDecoder
(
vocab_size
,
encoder
.
output_size
())
s
elf
.
__init__
(
s
uper
()
.
__init__
(
vocab_size
=
vocab_size
,
encoder
=
encoder
,
decoder
=
decoder
,
...
...
deepspeech/modules/attention.py
浏览文件 @
5e7e582d
...
...
@@ -217,11 +217,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
k
.
transpose
([
0
,
1
,
3
,
2
]))
matrix_ac
=
paddle
.
matmul
(
q_with_bias_u
,
k
.
transpose
([
0
,
1
,
3
,
2
]))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
([
0
,
1
,
3
,
2
]))
matrix_bd
=
paddle
.
matmul
(
q_with_bias_v
,
p
.
transpose
([
0
,
1
,
3
,
2
]))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
...
...
deepspeech/modules/embedding.py
浏览文件 @
5e7e582d
...
...
@@ -48,7 +48,7 @@ class PositionalEncoding(nn.Layer):
self
.
max_len
=
max_len
self
.
xscale
=
paddle
.
to_tensor
(
math
.
sqrt
(
self
.
d_model
))
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
pe
=
paddle
.
zeros
(
self
.
max_len
,
self
.
d_model
)
#[T,D]
self
.
pe
=
paddle
.
zeros
(
[
self
.
max_len
,
self
.
d_model
]
)
#[T,D]
position
=
paddle
.
arange
(
0
,
self
.
max_len
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
#[T, 1]
...
...
@@ -70,11 +70,9 @@ class PositionalEncoding(nn.Layer):
paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...)
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
"""
T
=
paddle
.
shape
(
x
)[
1
]
assert
offset
+
T
<
self
.
max_len
#assert offset + x.size(1) < self.max_len
#self.pe = self.pe.to(x.device)
#pos_emb = self.pe[:, offset:offset + x.size(1)]
T
=
x
.
shape
[
1
]
assert
offset
+
x
.
size
(
1
)
<
self
.
max_len
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
T
]
x
=
x
*
self
.
xscale
+
pos_emb
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
...
...
@@ -119,11 +117,8 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
#T = paddle.shape()[1]
#assert offset + T < self.max_len
assert
offset
+
x
.
size
(
1
)
<
self
.
max_len
#self.pe = self.pe.to(x.device)
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
x
.
size
(
1
)]
#pos_emb = self.pe[:, offset:offset + T
]
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
x
.
shape
[
1
]
]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
deepspeech/modules/encoder.py
浏览文件 @
5e7e582d
...
...
@@ -23,7 +23,7 @@ from paddle.nn import initializer as I
from
deepspeech.modules.attention
import
MultiHeadedAttention
from
deepspeech.modules.attention
import
RelPositionMultiHeadedAttention
from
deepspeech.modules.convolution
import
ConvolutionModule
from
deepspeech.modules.con
former_con
volution
import
ConvolutionModule
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.embedding
import
RelPositionalEncoding
from
deepspeech.modules.encoder_layer
import
TransformerEncoderLayer
...
...
@@ -33,7 +33,7 @@ from deepspeech.modules.subsampling import Conv2dSubsampling4
from
deepspeech.modules.subsampling
import
Conv2dSubsampling6
from
deepspeech.modules.subsampling
import
Conv2dSubsampling8
from
deepspeech.modules.subsampling
import
LinearNoSubsampling
from
deepspeech.modules.mask
import
make_pad_mask
from
deepspeech.modules.mask
import
make_
non_
pad_mask
from
deepspeech.modules.mask
import
add_optional_chunk_mask
from
deepspeech.modules.activation
import
get_activation
...
...
@@ -155,10 +155,12 @@ class BaseEncoder(nn.Layer):
encoder output tensor, lens and mask
"""
masks
=
make_non_pad_mask
(
xs_lens
).
unsqueeze
(
1
)
# (B, 1, L)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad
=
masks
.
logical_not
()
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
,
offset
=
0
)
mask_pad
=
~
masks
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
.
type_as
(
xs
),
offset
=
0
)
chunk_masks
=
add_optional_chunk_mask
(
xs
,
masks
,
self
.
use_dynamic_chunk
,
self
.
use_dynamic_left_chunk
,
decoding_chunk_size
,
self
.
static_chunk_size
,
...
...
@@ -380,7 +382,7 @@ class ConformerEncoder(BaseEncoder):
concat_after
:
bool
=
False
,
static_chunk_size
:
int
=
0
,
use_dynamic_chunk
:
bool
=
False
,
global_cmvn
:
torch
.
nn
.
Module
=
None
,
global_cmvn
:
nn
.
Layer
=
None
,
use_dynamic_left_chunk
:
bool
=
False
,
positionwise_conv_kernel_size
:
int
=
1
,
macaron_style
:
bool
=
True
,
...
...
@@ -431,7 +433,7 @@ class ConformerEncoder(BaseEncoder):
self
.
encoders
=
nn
.
ModuleList
([
ConformerEncoderLayer
(
size
=
output_size
,
eself_attn
=
ncoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
self_attn
=
e
ncoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
feed_forward
=
positionwise_layer
(
*
positionwise_layer_args
),
feed_forward_macaron
=
positionwise_layer
(
*
positionwise_layer_args
)
if
macaron_style
else
None
,
...
...
deepspeech/modules/encoder_layer.py
浏览文件 @
5e7e582d
...
...
@@ -127,7 +127,7 @@ class TransformerEncoderLayer(nn.Layer):
if
output_cache
is
not
None
:
x
=
paddle
.
concat
([
output_cache
,
x
],
axis
=
1
)
fake_cnn_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
plac
e
)
fake_cnn_cache
=
paddle
.
zeros
([
1
],
dtype
=
x
.
dtyp
e
)
return
x
,
mask
,
fake_cnn_cache
...
...
@@ -253,7 +253,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
plac
e
)
new_cnn_cache
=
paddle
.
zeros
([
1
],
dtype
=
x
.
dtyp
e
)
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
...
...
deepspeech/utils/layer_tools.py
浏览文件 @
5e7e582d
...
...
@@ -24,13 +24,18 @@ __all__ = [
def
summary
(
layer
:
nn
.
Layer
,
print_func
=
print
):
num_params
=
num_elements
=
0
print_func
(
"layer summary:"
)
if
print_func
:
print_func
(
f
"
{
layer
.
__class__
.
__name__
}
summary:"
)
for
name
,
param
in
layer
.
state_dict
().
items
():
print_func
(
"{}|{}|{}"
.
format
(
name
,
param
.
shape
,
np
.
prod
(
param
.
shape
)))
if
print_func
:
print_func
(
"{} | {} | {}"
.
format
(
name
,
param
.
shape
,
np
.
prod
(
param
.
shape
)))
num_elements
+=
np
.
prod
(
param
.
shape
)
num_params
+=
1
print_func
(
"layer has {} parameters, {} elements."
.
format
(
num_params
,
num_elements
))
if
print_func
:
print_func
(
f
"
{
layer
.
__class__
.
__name__
}
has
{
num_params
}
parameters,
{
num_elements
}
elements."
)
def
gradient_norm
(
layer
:
nn
.
Layer
):
...
...
deepspeech/utils/tensor_utils.py
浏览文件 @
5e7e582d
...
...
@@ -122,7 +122,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
ys
=
[
y
[
y
!=
ignore_id
]
for
y
in
ys_pad
]
# parse padded ys
ys_in
=
[
paddle
.
cat
([
_sos
,
y
],
dim
=
0
)
for
y
in
ys
]
ys_out
=
[
paddle
.
cat
([
y
,
_eos
],
dim
=
0
)
for
y
in
ys
]
return
pad_
list
(
ys_in
,
eos
),
pad_list
(
ys_out
,
ignore_id
)
return
pad_
sequence
(
ys_in
,
padding_value
=
eos
),
pad_sequence
(
ys_out
,
padding_value
=
ignore_id
)
def
th_accuracy
(
pad_outputs
:
paddle
.
Tensor
,
...
...
tests/u2_model_test.py
浏览文件 @
5e7e582d
...
...
@@ -20,6 +20,7 @@ from yacs.config import CfgNode as CN
from
deepspeech.models.u2
import
U2TransformerModel
from
deepspeech.models.u2
import
U2ConformerModel
from
deepspeech.utils.layer_tools
import
summary
class
TestU2Model
(
unittest
.
TestCase
):
...
...
@@ -27,8 +28,9 @@ class TestU2Model(unittest.TestCase):
paddle
.
set_device
(
'cpu'
)
self
.
batch_size
=
2
self
.
feat_dim
=
161
self
.
feat_dim
=
83
self
.
max_len
=
64
self
.
vocab_size
=
4239
#(B, T, D)
audio
=
np
.
random
.
randn
(
self
.
batch_size
,
self
.
max_len
,
self
.
feat_dim
)
...
...
@@ -77,8 +79,15 @@ class TestU2Model(unittest.TestCase):
length_normalized_loss: false
"""
cfg
=
CN
().
load_cfg
(
conf_str
)
print
(
cfg
)
model
=
U2TransformerModel
()
cfg
.
input_dim
=
self
.
feat_dim
cfg
.
output_dim
=
self
.
vocab_size
cfg
.
cmvn_file
=
None
cfg
.
cmvn_file_type
=
'npz'
cfg
.
freeze
()
model
=
U2TransformerModel
(
cfg
)
summary
(
model
,
None
)
output
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
print
(
output
)
def
test_conformer
(
self
):
conf_str
=
"""
...
...
@@ -119,8 +128,15 @@ class TestU2Model(unittest.TestCase):
length_normalized_loss: false
"""
cfg
=
CN
().
load_cfg
(
conf_str
)
print
(
cfg
)
model
=
U2ConformerModel
()
cfg
.
input_dim
=
self
.
feat_dim
cfg
.
output_dim
=
self
.
vocab_size
cfg
.
cmvn_file
=
None
cfg
.
cmvn_file_type
=
'npz'
cfg
.
freeze
()
model
=
U2ConformerModel
(
cfg
)
summary
(
model
,
None
)
output
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
print
(
output
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录