Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b2bc6eb5
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b2bc6eb5
编写于
3月 16, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add layer for transformer
上级
9cf8c1a5
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
924 addition
and
30 deletion
+924
-30
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+3
-2
deepspeech/modules/activation.py
deepspeech/modules/activation.py
+172
-26
deepspeech/modules/conformer_convolution.py
deepspeech/modules/conformer_convolution.py
+149
-0
deepspeech/modules/conv.py
deepspeech/modules/conv.py
+26
-1
deepspeech/modules/encoder_layer.py
deepspeech/modules/encoder_layer.py
+277
-0
deepspeech/modules/positionwise_feed_forward.py
deepspeech/modules/positionwise_feed_forward.py
+59
-0
deepspeech/modules/subsampling.py
deepspeech/modules/subsampling.py
+235
-0
deepspeech/training/gradclip.py
deepspeech/training/gradclip.py
+3
-1
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
b2bc6eb5
...
...
@@ -28,7 +28,7 @@ from paddle import distributed as dist
from
paddle.io
import
DataLoader
from
deepspeech.training
import
Trainer
from
deepspeech.training.gradclip
import
MyClipGradByGlobalNorm
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils
import
layer_tools
...
...
@@ -125,7 +125,8 @@ class DeepSpeech2Trainer(Trainer):
layer_tools
.
print_params
(
model
,
self
.
logger
.
info
)
grad_clip
=
MyClipGradByGlobalNorm
(
config
.
training
.
global_grad_clip
)
grad_clip
=
ClipGradByGlobalNormWithLog
(
config
.
training
.
global_grad_clip
)
lr_scheduler
=
paddle
.
optimizer
.
lr
.
ExponentialDecay
(
learning_rate
=
config
.
training
.
lr
,
gamma
=
config
.
training
.
lr_decay
,
...
...
deepspeech/modules/activation.py
浏览文件 @
b2bc6eb5
...
...
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Union
import
logging
import
numpy
as
np
import
math
from
collections
import
OrderedDict
import
paddle
from
paddle
import
nn
...
...
@@ -23,7 +25,7 @@ from paddle.nn import initializer as I
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'brelu'
,
"
softplus"
,
"gelu_accurate"
,
"gelu"
,
'Swish'
]
__all__
=
[
'brelu'
,
"
glu"
]
def
brelu
(
x
,
t_min
=
0.0
,
t_max
=
24.0
,
name
=
None
):
...
...
@@ -33,36 +35,180 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return
x
.
maximum
(
t_min
).
minimum
(
t_max
)
def
softplus
(
x
):
"""Softplus function."""
if
hasattr
(
paddle
.
nn
.
functional
,
'softplus'
):
#return paddle.nn.functional.softplus(x.float()).type_as(x)
return
paddle
.
nn
.
functional
.
softplus
(
x
)
else
:
raise
NotImplementedError
#
def softplus(x):
#
"""Softplus function."""
#
if hasattr(paddle.nn.functional, 'softplus'):
#
#return paddle.nn.functional.softplus(x.float()).type_as(x)
#
return paddle.nn.functional.softplus(x)
#
else:
#
raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
def
gelu_accurate
(
x
):
"""Gaussian Error Linear Units (GELU) activation."""
# [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
if
not
hasattr
(
gelu_accurate
,
"_a"
):
gelu_accurate
.
_a
=
math
.
sqrt
(
2
/
math
.
pi
)
return
0.5
*
x
*
(
1
+
paddle
.
tanh
(
gelu_accurate
.
_a
*
(
x
+
0.044715
*
paddle
.
pow
(
x
,
3
)
)))
# def gelu
(x):
#
"""Gaussian Error Linear Units (GELU) activation."""
# if hasattr(nn.functional, 'gelu'):
# #return nn.functional.gelu(x.float()).type_as(x)
# return nn.functional.gelu(x
)
# else:
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0
)))
def
gelu
(
x
):
"""Gaussian Error Linear Units (GELU) activation."""
if
hasattr
(
torch
.
nn
.
functional
,
'gelu'
):
#return torch.nn.functional.gelu(x.float()).type_as(x)
return
torch
.
nn
.
functional
.
ge
lu
(
x
)
# TODO(Hui Zhang): remove this activation
def
glu
(
x
,
dim
=-
1
):
"""The gated linear unit (GLU) activation."""
if
hasattr
(
nn
.
functional
,
'glu'
):
return
nn
.
functional
.
g
lu
(
x
)
else
:
return
x
*
0.5
*
(
1.0
+
paddle
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
a
,
b
=
x
.
split
(
2
,
axis
=
dim
)
act_b
=
F
.
sigmoid
(
b
)
return
a
*
act_b
# TODO(Hui Zhang): remove this activation
if
not
hasattr
(
nn
.
functional
,
'glu'
):
setattr
(
nn
.
functional
,
'glu'
,
glu
)
# TODO(Hui Zhang): remove this activation
class
GLU
(
nn
.
Layer
):
"""Gated Linear Units (GLU) Layer"""
def
__init__
(
self
,
dim
:
int
=-
1
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
xs
):
return
glu
(
xs
,
dim
=
self
.
dim
)
class
LinearGLUBlock
(
nn
.
Layer
):
"""A linear Gated Linear Units (GLU) block."""
def
__init__
(
self
,
idim
:
int
):
""" GLU.
Args:
idim (int): input and output dimension
"""
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
idim
,
idim
*
2
)
def
forward
(
self
,
xs
):
return
glu
(
self
.
fc
(
xs
),
dim
=-
1
)
# TODO(Hui Zhang): remove this Layer
class
ConstantPad2d
(
nn
.
Layer
):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def
__init__
(
self
,
padding
:
Union
[
tuple
,
list
,
int
],
value
:
float
):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self
.
padding
=
padding
if
isinstance
(
padding
,
[
tuple
,
list
])
else
[
padding
]
*
4
self
.
value
=
value
def
forward
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
return
nn
.
functional
.
pad
(
xs
,
self
.
padding
,
mode
=
'constant'
,
value
=
self
.
value
,
data_format
=
'NCHW'
)
class
ConvGLUBlock
(
nn
.
Layer
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
bottlececk_dim
=
0
,
dropout
=
0.
):
"""A convolutional Gated Linear Units (GLU) block.
Args:
kernel_size (int): kernel size
in_ch (int): number of input channels
out_ch (int): number of output channels
bottlececk_dim (int): dimension of the bottleneck layers for computational efficiency. Defaults to 0.
dropout (float): dropout probability. Defaults to 0..
"""
super
().
__init__
()
self
.
conv_residual
=
None
if
in_ch
!=
out_ch
:
self
.
conv_residual
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
in_ch
,
out_channels
=
out_ch
,
kernel_size
=
(
1
,
1
)),
name
=
'weight'
,
dim
=
0
)
self
.
dropout_residual
=
nn
.
Dropout
(
p
=
dropout
)
self
.
pad_left
=
ConstantPad2d
((
0
,
0
,
kernel_size
-
1
,
0
),
0
)
layers
=
OrderedDict
()
if
bottlececk_dim
==
0
:
layers
[
'conv'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
in_ch
,
out_channels
=
out_ch
*
2
,
kernel_size
=
(
kernel_size
,
1
)),
name
=
'weight'
,
dim
=
0
)
# TODO(hirofumi0810): padding?
layers
[
'dropout'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'glu'
]
=
GLU
()
elif
bottlececk_dim
>
0
:
layers
[
'conv_in'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
in_ch
,
out_channels
=
bottlececk_dim
,
kernel_size
=
(
1
,
1
)),
name
=
'weight'
,
dim
=
0
)
layers
[
'dropout_in'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'conv_bottleneck'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
bottlececk_dim
,
out_channels
=
bottlececk_dim
,
kernel_size
=
(
kernel_size
,
1
)),
name
=
'weight'
,
dim
=
0
)
layers
[
'dropout'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'glu'
]
=
GLU
()
layers
[
'conv_out'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
bottlececk_dim
,
out_channels
=
out_ch
*
2
,
kernel_size
=
(
1
,
1
)),
name
=
'weight'
,
dim
=
0
)
layers
[
'dropout_out'
]
=
nn
.
Dropout
(
p
=
dropout
)
class
Swish
(
nn
.
Layer
):
"""Construct an Swish object."""
self
.
layers
=
nn
.
Sequential
(
layers
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Return Swish activation function."""
return
x
*
F
.
sigmoid
(
x
)
def
forward
(
self
,
xs
):
"""Forward pass.
Args:
xs (FloatTensor): `[B, in_ch, T, feat_dim]`
Returns:
out (FloatTensor): `[B, out_ch, T, feat_dim]`
"""
residual
=
xs
if
self
.
conv_residual
is
not
None
:
residual
=
self
.
dropout_residual
(
self
.
conv_residual
(
residual
))
xs
=
self
.
pad_left
(
xs
)
# `[B, embed_dim, T+kernel-1, 1]`
xs
=
self
.
layers
(
xs
)
# `[B, out_ch * 2, T ,1]`
xs
=
xs
+
residual
return
xs
deepspeech/modules/conformer_convolution.py
0 → 100644
浏览文件 @
b2bc6eb5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ConvolutionModule definition."""
from
typing
import
Optional
,
Tuple
from
typeguard
import
check_argument_types
import
logging
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
# init F.glu func
# TODO(Hui Zhang): remove this line
import
deepspeech.modules.activation
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'ConvolutionModule'
]
class
ConvolutionModule
(
nn
.
Layer
):
"""ConvolutionModule in Conformer model."""
def
__init__
(
self
,
channels
:
int
,
kernel_size
:
int
=
15
,
activation
:
nn
.
Layer
=
nn
.
ReLU
(),
norm
:
str
=
"batch_norm"
,
causal
:
bool
=
False
,
bias
:
bool
=
True
):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
activation (nn.Layer): Activation Layer.
norm (str): Normalization type, 'batch_norm' or 'layer_norm'
causal (bool): Whether use causal convolution or not
bias (bool): Whether Conv with bias or not
"""
assert
check_argument_types
()
super
().
__init__
()
self
.
pointwise_conv1
=
nn
.
Conv1D
(
channels
,
2
*
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
None
if
bias
else
False
,
# None for True as default
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# `self.lorder` frames on the left in forward (causal conv impl).
# else: it's a symmetrical convolution
if
causal
:
padding
=
0
self
.
lorder
=
kernel_size
-
1
else
:
# kernel_size should be an odd number for none causal convolution
assert
(
kernel_size
-
1
)
%
2
==
0
padding
=
(
kernel_size
-
1
)
//
2
self
.
lorder
=
0
self
.
depthwise_conv
=
nn
.
Conv1D
(
channels
,
channels
,
kernel_size
,
stride
=
1
,
padding
=
padding
,
groups
=
channels
,
bias
=
None
if
bias
else
False
,
# None for True as default
)
assert
norm
in
[
'batch_norm'
,
'layer_norm'
]
if
norm
==
"batch_norm"
:
self
.
use_layer_norm
=
False
self
.
norm
=
nn
.
BatchNorm1D
(
channels
)
else
:
self
.
use_layer_norm
=
True
self
.
norm
=
nn
.
LayerNorm
(
channels
)
self
.
pointwise_conv2
=
nn
.
Conv1D
(
channels
,
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
None
if
bias
else
False
,
# None for True as default
)
self
.
activation
=
activation
def
forward
(
self
,
x
:
paddle
.
Tensor
,
cache
:
Optional
[
paddle
.
Tensor
]
=
None
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).
cache (paddle.Tensor): left context cache, it is only
used in causal convolution. (#batch, channels, time)
Returns:
paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time)
"""
# exchange the temporal dimension and the feature dimension
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, C, T]
if
self
.
lorder
>
0
:
if
cache
is
None
:
x
=
nn
.
functional
.
pad
(
x
,
(
self
.
lorder
,
0
),
'constant'
,
0.0
,
data_format
=
'NCL'
)
else
:
assert
cache
.
shape
[
0
]
==
x
.
shape
[
0
]
# B
assert
cache
.
shape
[
1
]
==
x
.
shape
[
1
]
# C
x
=
paddle
.
concat
((
cache
,
x
),
axis
=
2
)
assert
(
x
.
shape
[
2
]
>
self
.
lorder
)
new_cache
=
x
[:,
:,
-
self
.
lorder
:]
#[B, C, T]
else
:
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
place
)
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, dim)
x
=
nn
.
functional
.
glu
(
x
,
dim
=
1
)
# (batch, channel, dim)
# 1D Depthwise Conv
x
=
self
.
depthwise_conv
(
x
)
if
self
.
use_layer_norm
:
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, T, C]
x
=
self
.
activation
(
self
.
norm
(
x
))
if
self
.
use_layer_norm
:
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, C, T]
x
=
self
.
pointwise_conv2
(
x
)
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, T, C]
return
x
,
new_cache
deepspeech/modules/conv.py
浏览文件 @
b2bc6eb5
...
...
@@ -24,7 +24,32 @@ from deepspeech.modules.activation import brelu
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'ConvStack'
]
__all__
=
[
'ConvStack'
,
"conv_output_size"
]
def
conv_output_size
(
I
,
F
,
P
,
S
):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
# When Pstart == Pend == 0
# O = (I - F - S) // S
# https://iq.opengenus.org/output-size-of-convolution/
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
return
(
I
-
F
+
2
*
P
-
S
)
//
S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class
ConvBn
(
nn
.
Layer
):
...
...
deepspeech/modules/encoder_layer.py
0 → 100644
浏览文件 @
b2bc6eb5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encoder self-attention layer definition."""
from
typing
import
Optional
,
Tuple
import
logging
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"TransformerEncoderLayer"
,
"ConformerEncoderLayer"
]
class
TransformerEncoderLayer
(
nn
.
Layer
):
"""Encoder layer module."""
def
__init__
(
self
,
size
:
int
,
self_attn
:
nn
.
Layer
,
feed_forward
:
nn
.
Layer
,
dropout_rate
:
float
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
super
().
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
norm1
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
norm2
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
# concat_linear may be not used in forward fuction,
# but will be saved in the *.pt
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
def
forward
(
self
,
x
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
pos_emb
:
paddle
.
Tensor
,
output_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
cnn_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute encoded features.
Args:
x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, time).
pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
"""
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
if
output_cache
is
None
:
x_q
=
x
else
:
assert
output_cache
.
shape
[
0
]
==
x
.
shape
[
0
]
assert
output_cache
.
shape
[
1
]
<
x
.
shape
[
1
]
assert
output_cache
.
shape
[
2
]
==
self
.
size
chunk
=
x
.
shape
[
1
]
-
output_cache
.
shape
[
1
]
x_q
=
x
[:,
-
chunk
:,
:]
residual
=
residual
[:,
-
chunk
:,
:]
mask
=
mask
[:,
-
chunk
:,
:]
if
self
.
concat_after
:
x_concat
=
paddle
.
concat
(
(
x
,
self
.
self_attn
(
x_q
,
x
,
x
,
mask
)),
axis
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
self_attn
(
x_q
,
x
,
x
,
mask
))
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
x
=
residual
+
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
if
output_cache
is
not
None
:
x
=
paddle
.
concat
([
output_cache
,
x
],
axis
=
1
)
fake_cnn_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
place
)
return
x
,
mask
,
fake_cnn_cache
class
ConformerEncoderLayer
(
nn
.
Layer
):
"""Encoder layer module."""
def
__init__
(
self
,
size
:
int
,
self_attn
:
int
,
feed_forward
:
Optional
[
nn
.
Layer
]
=
None
,
feed_forward_macaron
:
Optional
[
nn
.
Layer
]
=
None
,
conv_module
:
Optional
[
nn
.
Layer
]
=
None
,
dropout_rate
:
float
=
0.1
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (nn.Layer): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (nn.Layer): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
super
().
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
feed_forward_macaron
=
feed_forward_macaron
self
.
conv_module
=
conv_module
self
.
norm_ff
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the FNN module
self
.
norm_mha
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the MHA module
if
feed_forward_macaron
is
not
None
:
self
.
norm_ff_macaron
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
ff_scale
=
0.5
else
:
self
.
ff_scale
=
1.0
if
self
.
conv_module
is
not
None
:
self
.
norm_conv
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the CNN module
self
.
norm_final
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the final output of the block
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
def
forward
(
self
,
x
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
pos_emb
:
paddle
.
Tensor
,
output_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
cnn_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute encoded features.
Args:
x (paddle.Tensor): (#batch, time, size)
mask (paddle.Tensor): Mask tensor for the input (#batch, time,time).
pos_emb (paddle.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
"""
# whether to use macaron style FFN
if
self
.
feed_forward_macaron
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward_macaron
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
# multi-headed self-attention module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
if
output_cache
is
None
:
x_q
=
x
else
:
assert
output_cache
.
shape
[
0
]
==
x
.
shape
[
0
]
assert
output_cache
.
shape
[
1
]
<
x
.
shape
[
1
]
assert
output_cache
.
shape
[
2
]
==
self
.
size
chunk
=
x
.
shape
[
1
]
-
output_cache
.
shape
[
1
]
x_q
=
x
[:,
-
chunk
:,
:]
residual
=
residual
[:,
-
chunk
:,
:]
mask
=
mask
[:,
-
chunk
:,
:]
x_att
=
self
.
self_attn
(
x_q
,
x
,
x
,
pos_emb
,
mask
)
if
self
.
concat_after
:
x_concat
=
paddle
.
concat
((
x
,
x_att
),
axis
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
x_att
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
place
)
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
x
,
new_cnn_cache
=
self
.
conv_module
(
x
,
cnn_cache
)
x
=
residual
+
self
.
dropout
(
x
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
# feed forward module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
norm_final
(
x
)
if
output_cache
is
not
None
:
x
=
paddle
.
concat
([
output_cache
,
x
],
axis
=
1
)
return
x
,
mask
,
new_cnn_cache
deepspeech/modules/positionwise_feed_forward.py
0 → 100644
浏览文件 @
b2bc6eb5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import
logging
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"PositionwiseFeedForward"
]
class
PositionwiseFeedForward
(
nn
.
Layer
):
"""Positionwise feed forward layer."""
def
__init__
(
self
,
idim
:
int
,
hidden_units
:
int
,
dropout_rate
:
float
,
activation
:
nn
.
Layer
=
nn
.
ReLU
()):
"""Construct a PositionwiseFeedForward object.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (paddle.nn.Layer): Activation function
"""
super
().
__init__
()
self
.
w_1
=
nn
.
Linear
(
idim
,
hidden_units
)
self
.
activation
=
activation
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
w_2
=
nn
.
Linear
(
hidden_units
,
idim
)
def
forward
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Forward function.
Args:
xs: input tensor (B, Lmax, D)
Returns:
output tensor, (B, Lmax, D)
"""
return
self
.
w_2
(
self
.
dropout
(
self
.
activation
(
self
.
w_1
(
xs
))))
deepspeech/modules/subsampling.py
0 → 100644
浏览文件 @
b2bc6eb5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Subsampling layer definition."""
from
typing
import
Tuple
import
logging
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.embedding
import
PositionalEncoding
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"LinearNoSubsampling"
,
"Conv2dSubsampling4"
,
"Conv2dSubsampling6"
,
"Conv2dSubsampling8"
]
class
BaseSubsampling
(
nn
.
Layer
):
def
__init__
(
self
,
pos_enc_class
:
PositionalEncoding
):
super
().
__init__
()
self
.
pos_enc
=
pos_enc_class
self
.
right_context
=
0
self
.
subsampling_rate
=
1
def
position_encoding
(
self
,
offset
:
int
,
size
:
int
)
->
paddle
.
Tensor
:
return
self
.
pos_enc
.
position_encoding
(
offset
,
size
)
class
LinearNoSubsampling
(
BaseSubsampling
):
"""Linear transform the input without subsampling."""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
PositionalEncoding
):
"""Construct an linear object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc_class (PositionalEncoding): position encoding class
"""
super
().
__init__
(
pos_enc_class
)
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
idim
,
odim
),
nn
.
LayerNorm
(
odim
,
epsilon
=
1e-12
),
nn
.
Dropout
(
dropout_rate
),
)
self
.
right_context
=
0
self
.
subsampling_rate
=
1
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_mask
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Input x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
paddle.Tensor: positional encoding
paddle.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x
=
self
.
out
(
x
)
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
class
Conv2dSubsampling4
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/4 length)."""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
PositionalEncoding
):
"""Construct an Conv2dSubsampling4 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
super
().
__init__
(
pos_enc_class
)
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
1
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
3
,
2
),
nn
.
ReLU
(),
)
self
.
linear
=
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
1
)
//
2
),
odim
)
self
.
subsampling_rate
=
4
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 6 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2
self
.
right_context
=
6
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_mask
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c=1, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
paddle
.
shape
(
x
)
x
=
self
.
linear
(
x
.
transpose
([
0
,
1
,
2
,
3
]).
reshape
([
b
,
t
,
c
*
f
]))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
]
class
Conv2dSubsampling6
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/6 length)."""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
PositionalEncoding
):
"""Construct an Conv2dSubsampling6 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (PositionalEncoding): Custom position encoding layer.
"""
super
().
__init__
(
pos_enc_class
)
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
1
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
5
,
3
),
nn
.
ReLU
(),
)
# O = (I - F + Pstart + Pend) // S + 1
# when Padding == 0, O = (I - F - S) // S
self
.
linear
=
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
2
)
//
3
),
odim
)
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 14 = (3 - 1) / 2 * 2 * 1 + (5 - 1) / 2 * 3 * 2
self
.
subsampling_rate
=
6
self
.
right_context
=
14
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_mask
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
paddle
.
shape
(
x
)
x
=
self
.
linear
(
x
.
transpose
([
0
,
1
,
2
,
3
]).
reshape
([
b
,
t
,
c
*
f
]))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
4
:
3
]
class
Conv2dSubsampling8
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/8 length)."""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
PositionalEncoding
):
"""Construct an Conv2dSubsampling8 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
super
().
__init__
(
pos_enc_class
)
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
1
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
3
,
2
),
nn
.
ReLU
(),
)
self
.
linear
=
nn
.
Linear
(
odim
*
((((
idim
-
1
)
//
2
-
1
)
//
2
-
1
)
//
2
),
odim
)
self
.
subsampling_rate
=
8
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 14 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + (3 - 1) / 2 * 2 * 4
self
.
right_context
=
14
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_mask
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
x
=
self
.
linear
(
x
.
transpose
([
0
,
1
,
2
,
3
]).
reshape
([
b
,
t
,
c
*
f
]))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
]
deepspeech/training/gradclip.py
浏览文件 @
b2bc6eb5
...
...
@@ -21,8 +21,10 @@ from paddle.fluid import core
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"ClipGradByGlobalNormWithLog"
]
class
MyClipGradByGlobalNorm
(
paddle
.
nn
.
ClipGradByGlobalNorm
):
class
ClipGradByGlobalNormWithLog
(
paddle
.
nn
.
ClipGradByGlobalNorm
):
def
__init__
(
self
,
clip_norm
):
super
().
__init__
(
clip_norm
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录