Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b2bc6eb5
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b2bc6eb5
编写于
3月 16, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add layer for transformer
上级
9cf8c1a5
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
924 addition
and
30 deletion
+924
-30
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+3
-2
deepspeech/modules/activation.py
deepspeech/modules/activation.py
+172
-26
deepspeech/modules/conformer_convolution.py
deepspeech/modules/conformer_convolution.py
+149
-0
deepspeech/modules/conv.py
deepspeech/modules/conv.py
+26
-1
deepspeech/modules/encoder_layer.py
deepspeech/modules/encoder_layer.py
+277
-0
deepspeech/modules/positionwise_feed_forward.py
deepspeech/modules/positionwise_feed_forward.py
+59
-0
deepspeech/modules/subsampling.py
deepspeech/modules/subsampling.py
+235
-0
deepspeech/training/gradclip.py
deepspeech/training/gradclip.py
+3
-1
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
b2bc6eb5
...
@@ -28,7 +28,7 @@ from paddle import distributed as dist
...
@@ -28,7 +28,7 @@ from paddle import distributed as dist
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
from
deepspeech.training
import
Trainer
from
deepspeech.training
import
Trainer
from
deepspeech.training.gradclip
import
MyClipGradByGlobalNorm
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils
import
layer_tools
from
deepspeech.utils
import
layer_tools
...
@@ -125,7 +125,8 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -125,7 +125,8 @@ class DeepSpeech2Trainer(Trainer):
layer_tools
.
print_params
(
model
,
self
.
logger
.
info
)
layer_tools
.
print_params
(
model
,
self
.
logger
.
info
)
grad_clip
=
MyClipGradByGlobalNorm
(
config
.
training
.
global_grad_clip
)
grad_clip
=
ClipGradByGlobalNormWithLog
(
config
.
training
.
global_grad_clip
)
lr_scheduler
=
paddle
.
optimizer
.
lr
.
ExponentialDecay
(
lr_scheduler
=
paddle
.
optimizer
.
lr
.
ExponentialDecay
(
learning_rate
=
config
.
training
.
lr
,
learning_rate
=
config
.
training
.
lr
,
gamma
=
config
.
training
.
lr_decay
,
gamma
=
config
.
training
.
lr_decay
,
...
...
deepspeech/modules/activation.py
浏览文件 @
b2bc6eb5
...
@@ -12,9 +12,11 @@
...
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
Union
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
import
math
import
math
from
collections
import
OrderedDict
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
...
@@ -23,7 +25,7 @@ from paddle.nn import initializer as I
...
@@ -23,7 +25,7 @@ from paddle.nn import initializer as I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'brelu'
,
"
softplus"
,
"gelu_accurate"
,
"gelu"
,
'Swish'
]
__all__
=
[
'brelu'
,
"
glu"
]
def
brelu
(
x
,
t_min
=
0.0
,
t_max
=
24.0
,
name
=
None
):
def
brelu
(
x
,
t_min
=
0.0
,
t_max
=
24.0
,
name
=
None
):
...
@@ -33,36 +35,180 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
...
@@ -33,36 +35,180 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return
x
.
maximum
(
t_min
).
minimum
(
t_max
)
return
x
.
maximum
(
t_min
).
minimum
(
t_max
)
def
softplus
(
x
):
#
def softplus(x):
"""Softplus function."""
#
"""Softplus function."""
if
hasattr
(
paddle
.
nn
.
functional
,
'softplus'
):
#
if hasattr(paddle.nn.functional, 'softplus'):
#return paddle.nn.functional.softplus(x.float()).type_as(x)
#
#return paddle.nn.functional.softplus(x.float()).type_as(x)
return
paddle
.
nn
.
functional
.
softplus
(
x
)
#
return paddle.nn.functional.softplus(x)
else
:
#
else:
raise
NotImplementedError
#
raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
def
gelu_accurate
(
x
):
# def gelu
(x):
"""Gaussian Error Linear Units (GELU) activation."""
#
"""Gaussian Error Linear Units (GELU) activation."""
# [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if hasattr(nn.functional, 'gelu'):
if
not
hasattr
(
gelu_accurate
,
"_a"
):
# #return nn.functional.gelu(x.float()).type_as(x)
gelu_accurate
.
_a
=
math
.
sqrt
(
2
/
math
.
pi
)
# return nn.functional.gelu(x
)
return
0.5
*
x
*
(
1
+
paddle
.
tanh
(
gelu_accurate
.
_a
*
# else:
(
x
+
0.044715
*
paddle
.
pow
(
x
,
3
)
)))
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0
)))
def
gelu
(
x
):
# TODO(Hui Zhang): remove this activation
"""Gaussian Error Linear Units (GELU) activation."""
def
glu
(
x
,
dim
=-
1
):
if
hasattr
(
torch
.
nn
.
functional
,
'gelu'
):
"""The gated linear unit (GLU) activation."""
#return torch.nn.functional.gelu(x.float()).type_as(x)
if
hasattr
(
nn
.
functional
,
'glu'
):
return
torch
.
nn
.
functional
.
ge
lu
(
x
)
return
nn
.
functional
.
g
lu
(
x
)
else
:
else
:
return
x
*
0.5
*
(
1.0
+
paddle
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
a
,
b
=
x
.
split
(
2
,
axis
=
dim
)
act_b
=
F
.
sigmoid
(
b
)
return
a
*
act_b
# TODO(Hui Zhang): remove this activation
if
not
hasattr
(
nn
.
functional
,
'glu'
):
setattr
(
nn
.
functional
,
'glu'
,
glu
)
# TODO(Hui Zhang): remove this activation
class
GLU
(
nn
.
Layer
):
"""Gated Linear Units (GLU) Layer"""
def
__init__
(
self
,
dim
:
int
=-
1
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
xs
):
return
glu
(
xs
,
dim
=
self
.
dim
)
class
LinearGLUBlock
(
nn
.
Layer
):
"""A linear Gated Linear Units (GLU) block."""
def
__init__
(
self
,
idim
:
int
):
""" GLU.
Args:
idim (int): input and output dimension
"""
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
idim
,
idim
*
2
)
def
forward
(
self
,
xs
):
return
glu
(
self
.
fc
(
xs
),
dim
=-
1
)
# TODO(Hui Zhang): remove this Layer
class
ConstantPad2d
(
nn
.
Layer
):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def
__init__
(
self
,
padding
:
Union
[
tuple
,
list
,
int
],
value
:
float
):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self
.
padding
=
padding
if
isinstance
(
padding
,
[
tuple
,
list
])
else
[
padding
]
*
4
self
.
value
=
value
def
forward
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
return
nn
.
functional
.
pad
(
xs
,
self
.
padding
,
mode
=
'constant'
,
value
=
self
.
value
,
data_format
=
'NCHW'
)
class
ConvGLUBlock
(
nn
.
Layer
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
bottlececk_dim
=
0
,
dropout
=
0.
):
"""A convolutional Gated Linear Units (GLU) block.
Args:
kernel_size (int): kernel size
in_ch (int): number of input channels
out_ch (int): number of output channels
bottlececk_dim (int): dimension of the bottleneck layers for computational efficiency. Defaults to 0.
dropout (float): dropout probability. Defaults to 0..
"""
super
().
__init__
()
self
.
conv_residual
=
None
if
in_ch
!=
out_ch
:
self
.
conv_residual
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
in_ch
,
out_channels
=
out_ch
,
kernel_size
=
(
1
,
1
)),
name
=
'weight'
,
dim
=
0
)
self
.
dropout_residual
=
nn
.
Dropout
(
p
=
dropout
)
self
.
pad_left
=
ConstantPad2d
((
0
,
0
,
kernel_size
-
1
,
0
),
0
)
layers
=
OrderedDict
()
if
bottlececk_dim
==
0
:
layers
[
'conv'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
in_ch
,
out_channels
=
out_ch
*
2
,
kernel_size
=
(
kernel_size
,
1
)),
name
=
'weight'
,
dim
=
0
)
# TODO(hirofumi0810): padding?
layers
[
'dropout'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'glu'
]
=
GLU
()
elif
bottlececk_dim
>
0
:
layers
[
'conv_in'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
in_ch
,
out_channels
=
bottlececk_dim
,
kernel_size
=
(
1
,
1
)),
name
=
'weight'
,
dim
=
0
)
layers
[
'dropout_in'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'conv_bottleneck'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
bottlececk_dim
,
out_channels
=
bottlececk_dim
,
kernel_size
=
(
kernel_size
,
1
)),
name
=
'weight'
,
dim
=
0
)
layers
[
'dropout'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'glu'
]
=
GLU
()
layers
[
'conv_out'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
in_channels
=
bottlececk_dim
,
out_channels
=
out_ch
*
2
,
kernel_size
=
(
1
,
1
)),
name
=
'weight'
,
dim
=
0
)
layers
[
'dropout_out'
]
=
nn
.
Dropout
(
p
=
dropout
)
class
Swish
(
nn
.
Layer
):
self
.
layers
=
nn
.
Sequential
(
layers
)
"""Construct an Swish object."""
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
forward
(
self
,
xs
):
"""Return Swish activation function."""
"""Forward pass.
return
x
*
F
.
sigmoid
(
x
)
Args:
xs (FloatTensor): `[B, in_ch, T, feat_dim]`
Returns:
out (FloatTensor): `[B, out_ch, T, feat_dim]`
"""
residual
=
xs
if
self
.
conv_residual
is
not
None
:
residual
=
self
.
dropout_residual
(
self
.
conv_residual
(
residual
))
xs
=
self
.
pad_left
(
xs
)
# `[B, embed_dim, T+kernel-1, 1]`
xs
=
self
.
layers
(
xs
)
# `[B, out_ch * 2, T ,1]`
xs
=
xs
+
residual
return
xs
deepspeech/modules/conformer_convolution.py
0 → 100644
浏览文件 @
b2bc6eb5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ConvolutionModule definition."""
from
typing
import
Optional
,
Tuple
from
typeguard
import
check_argument_types
import
logging
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
# init F.glu func
# TODO(Hui Zhang): remove this line
import
deepspeech.modules.activation
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'ConvolutionModule'
]
class
ConvolutionModule
(
nn
.
Layer
):
"""ConvolutionModule in Conformer model."""
def
__init__
(
self
,
channels
:
int
,
kernel_size
:
int
=
15
,
activation
:
nn
.
Layer
=
nn
.
ReLU
(),
norm
:
str
=
"batch_norm"
,
causal
:
bool
=
False
,
bias
:
bool
=
True
):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
activation (nn.Layer): Activation Layer.
norm (str): Normalization type, 'batch_norm' or 'layer_norm'
causal (bool): Whether use causal convolution or not
bias (bool): Whether Conv with bias or not
"""
assert
check_argument_types
()
super
().
__init__
()
self
.
pointwise_conv1
=
nn
.
Conv1D
(
channels
,
2
*
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
None
if
bias
else
False
,
# None for True as default
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# `self.lorder` frames on the left in forward (causal conv impl).
# else: it's a symmetrical convolution
if
causal
:
padding
=
0
self
.
lorder
=
kernel_size
-
1
else
:
# kernel_size should be an odd number for none causal convolution
assert
(
kernel_size
-
1
)
%
2
==
0
padding
=
(
kernel_size
-
1
)
//
2
self
.
lorder
=
0
self
.
depthwise_conv
=
nn
.
Conv1D
(
channels
,
channels
,
kernel_size
,
stride
=
1
,
padding
=
padding
,
groups
=
channels
,
bias
=
None
if
bias
else
False
,
# None for True as default
)
assert
norm
in
[
'batch_norm'
,
'layer_norm'
]
if
norm
==
"batch_norm"
:
self
.
use_layer_norm
=
False
self
.
norm
=
nn
.
BatchNorm1D
(
channels
)
else
:
self
.
use_layer_norm
=
True
self
.
norm
=
nn
.
LayerNorm
(
channels
)
self
.
pointwise_conv2
=
nn
.
Conv1D
(
channels
,
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
None
if
bias
else
False
,
# None for True as default
)
self
.
activation
=
activation
def
forward
(
self
,
x
:
paddle
.
Tensor
,
cache
:
Optional
[
paddle
.
Tensor
]
=
None
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).
cache (paddle.Tensor): left context cache, it is only
used in causal convolution. (#batch, channels, time)
Returns:
paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time)
"""
# exchange the temporal dimension and the feature dimension
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, C, T]
if
self
.
lorder
>
0
:
if
cache
is
None
:
x
=
nn
.
functional
.
pad
(
x
,
(
self
.
lorder
,
0
),
'constant'
,
0.0
,
data_format
=
'NCL'
)
else
:
assert
cache
.
shape
[
0
]
==
x
.
shape
[
0
]
# B
assert
cache
.
shape
[
1
]
==
x
.
shape
[
1
]
# C
x
=
paddle
.
concat
((
cache
,
x
),
axis
=
2
)
assert
(
x
.
shape
[
2
]
>
self
.
lorder
)
new_cache
=
x
[:,
:,
-
self
.
lorder
:]
#[B, C, T]
else
:
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
place
)
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, dim)
x
=
nn
.
functional
.
glu
(
x
,
dim
=
1
)
# (batch, channel, dim)
# 1D Depthwise Conv
x
=
self
.
depthwise_conv
(
x
)
if
self
.
use_layer_norm
:
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, T, C]
x
=
self
.
activation
(
self
.
norm
(
x
))
if
self
.
use_layer_norm
:
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, C, T]
x
=
self
.
pointwise_conv2
(
x
)
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, T, C]
return
x
,
new_cache
deepspeech/modules/conv.py
浏览文件 @
b2bc6eb5
...
@@ -24,7 +24,32 @@ from deepspeech.modules.activation import brelu
...
@@ -24,7 +24,32 @@ from deepspeech.modules.activation import brelu
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'ConvStack'
]
__all__
=
[
'ConvStack'
,
"conv_output_size"
]
def
conv_output_size
(
I
,
F
,
P
,
S
):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
# When Pstart == Pend == 0
# O = (I - F - S) // S
# https://iq.opengenus.org/output-size-of-convolution/
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
return
(
I
-
F
+
2
*
P
-
S
)
//
S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class
ConvBn
(
nn
.
Layer
):
class
ConvBn
(
nn
.
Layer
):
...
...
deepspeech/modules/encoder_layer.py
0 → 100644
浏览文件 @
b2bc6eb5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encoder self-attention layer definition."""
from
typing
import
Optional
,
Tuple
import
logging
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"TransformerEncoderLayer"
,
"ConformerEncoderLayer"
]
class
TransformerEncoderLayer
(
nn
.
Layer
):
"""Encoder layer module."""
def
__init__
(
self
,
size
:
int
,
self_attn
:
nn
.
Layer
,
feed_forward
:
nn
.
Layer
,
dropout_rate
:
float
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
super
().
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
norm1
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
norm2
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
# concat_linear may be not used in forward fuction,
# but will be saved in the *.pt
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
def
forward
(
self
,
x
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
pos_emb
:
paddle
.
Tensor
,
output_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
cnn_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute encoded features.
Args:
x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, time).
pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
"""
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
if
output_cache
is
None
:
x_q
=
x
else
:
assert
output_cache
.
shape
[
0
]
==
x
.
shape
[
0
]
assert
output_cache
.
shape
[
1
]
<
x
.
shape
[
1
]
assert
output_cache
.
shape
[
2
]
==
self
.
size
chunk
=
x
.
shape
[
1
]
-
output_cache
.
shape
[
1
]
x_q
=
x
[:,
-
chunk
:,
:]
residual
=
residual
[:,
-
chunk
:,
:]
mask
=
mask
[:,
-
chunk
:,
:]
if
self
.
concat_after
:
x_concat
=
paddle
.
concat
(
(
x
,
self
.
self_attn
(
x_q
,
x
,
x
,
mask
)),
axis
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
self_attn
(
x_q
,
x
,
x
,
mask
))
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
x
=
residual
+
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
if
output_cache
is
not
None
:
x
=
paddle
.
concat
([
output_cache
,
x
],
axis
=
1
)
fake_cnn_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
place
)
return
x
,
mask
,
fake_cnn_cache
class
ConformerEncoderLayer
(
nn
.
Layer
):
"""Encoder layer module."""
def
__init__
(
self
,
size
:
int
,
self_attn
:
int
,
feed_forward
:
Optional
[
nn
.
Layer
]
=
None
,
feed_forward_macaron
:
Optional
[
nn
.
Layer
]
=
None
,
conv_module
:
Optional
[
nn
.
Layer
]
=
None
,
dropout_rate
:
float
=
0.1
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (nn.Layer): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (nn.Layer): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
super
().
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
feed_forward_macaron
=
feed_forward_macaron
self
.
conv_module
=
conv_module
self
.
norm_ff
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the FNN module
self
.
norm_mha
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the MHA module
if
feed_forward_macaron
is
not
None
:
self
.
norm_ff_macaron
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
ff_scale
=
0.5
else
:
self
.
ff_scale
=
1.0
if
self
.
conv_module
is
not
None
:
self
.
norm_conv
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the CNN module
self
.
norm_final
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the final output of the block
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
def
forward
(
self
,
x
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
pos_emb
:
paddle
.
Tensor
,
output_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
cnn_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute encoded features.
Args:
x (paddle.Tensor): (#batch, time, size)
mask (paddle.Tensor): Mask tensor for the input (#batch, time,time).
pos_emb (paddle.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
"""
# whether to use macaron style FFN
if
self
.
feed_forward_macaron
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward_macaron
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
# multi-headed self-attention module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
if
output_cache
is
None
:
x_q
=
x
else
:
assert
output_cache
.
shape
[
0
]
==
x
.
shape
[
0
]
assert
output_cache
.
shape
[
1
]
<
x
.
shape
[
1
]
assert
output_cache
.
shape
[
2
]
==
self
.
size
chunk
=
x
.
shape
[
1
]
-
output_cache
.
shape
[
1
]
x_q
=
x
[:,
-
chunk
:,
:]
residual
=
residual
[:,
-
chunk
:,
:]
mask
=
mask
[:,
-
chunk
:,
:]
x_att
=
self
.
self_attn
(
x_q
,
x
,
x
,
pos_emb
,
mask
)
if
self
.
concat_after
:
x_concat
=
paddle
.
concat
((
x
,
x_att
),
axis
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
x_att
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
place
)
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
x
,
new_cnn_cache
=
self
.
conv_module
(
x
,
cnn_cache
)
x
=
residual
+
self
.
dropout
(
x
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
# feed forward module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
norm_final
(
x
)
if
output_cache
is
not
None
:
x
=
paddle
.
concat
([
output_cache
,
x
],
axis
=
1
)
return
x
,
mask
,
new_cnn_cache
deepspeech/modules/positionwise_feed_forward.py
0 → 100644
浏览文件 @
b2bc6eb5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import
logging
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"PositionwiseFeedForward"
]
class
PositionwiseFeedForward
(
nn
.
Layer
):
"""Positionwise feed forward layer."""
def
__init__
(
self
,
idim
:
int
,
hidden_units
:
int
,
dropout_rate
:
float
,
activation
:
nn
.
Layer
=
nn
.
ReLU
()):
"""Construct a PositionwiseFeedForward object.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (paddle.nn.Layer): Activation function
"""
super
().
__init__
()
self
.
w_1
=
nn
.
Linear
(
idim
,
hidden_units
)
self
.
activation
=
activation
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
w_2
=
nn
.
Linear
(
hidden_units
,
idim
)
def
forward
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Forward function.
Args:
xs: input tensor (B, Lmax, D)
Returns:
output tensor, (B, Lmax, D)
"""
return
self
.
w_2
(
self
.
dropout
(
self
.
activation
(
self
.
w_1
(
xs
))))
deepspeech/modules/subsampling.py
0 → 100644
浏览文件 @
b2bc6eb5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Subsampling layer definition."""
from
typing
import
Tuple
import
logging
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.embedding
import
PositionalEncoding
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"LinearNoSubsampling"
,
"Conv2dSubsampling4"
,
"Conv2dSubsampling6"
,
"Conv2dSubsampling8"
]
class
BaseSubsampling
(
nn
.
Layer
):
def
__init__
(
self
,
pos_enc_class
:
PositionalEncoding
):
super
().
__init__
()
self
.
pos_enc
=
pos_enc_class
self
.
right_context
=
0
self
.
subsampling_rate
=
1
def
position_encoding
(
self
,
offset
:
int
,
size
:
int
)
->
paddle
.
Tensor
:
return
self
.
pos_enc
.
position_encoding
(
offset
,
size
)
class
LinearNoSubsampling
(
BaseSubsampling
):
"""Linear transform the input without subsampling."""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
PositionalEncoding
):
"""Construct an linear object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc_class (PositionalEncoding): position encoding class
"""
super
().
__init__
(
pos_enc_class
)
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
idim
,
odim
),
nn
.
LayerNorm
(
odim
,
epsilon
=
1e-12
),
nn
.
Dropout
(
dropout_rate
),
)
self
.
right_context
=
0
self
.
subsampling_rate
=
1
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_mask
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Input x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
paddle.Tensor: positional encoding
paddle.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x
=
self
.
out
(
x
)
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
class
Conv2dSubsampling4
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/4 length)."""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
PositionalEncoding
):
"""Construct an Conv2dSubsampling4 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
super
().
__init__
(
pos_enc_class
)
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
1
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
3
,
2
),
nn
.
ReLU
(),
)
self
.
linear
=
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
1
)
//
2
),
odim
)
self
.
subsampling_rate
=
4
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 6 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2
self
.
right_context
=
6
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_mask
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c=1, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
paddle
.
shape
(
x
)
x
=
self
.
linear
(
x
.
transpose
([
0
,
1
,
2
,
3
]).
reshape
([
b
,
t
,
c
*
f
]))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
]
class
Conv2dSubsampling6
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/6 length)."""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
PositionalEncoding
):
"""Construct an Conv2dSubsampling6 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (PositionalEncoding): Custom position encoding layer.
"""
super
().
__init__
(
pos_enc_class
)
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
1
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
5
,
3
),
nn
.
ReLU
(),
)
# O = (I - F + Pstart + Pend) // S + 1
# when Padding == 0, O = (I - F - S) // S
self
.
linear
=
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
2
)
//
3
),
odim
)
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 14 = (3 - 1) / 2 * 2 * 1 + (5 - 1) / 2 * 3 * 2
self
.
subsampling_rate
=
6
self
.
right_context
=
14
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_mask
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
paddle
.
shape
(
x
)
x
=
self
.
linear
(
x
.
transpose
([
0
,
1
,
2
,
3
]).
reshape
([
b
,
t
,
c
*
f
]))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
4
:
3
]
class
Conv2dSubsampling8
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/8 length)."""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
PositionalEncoding
):
"""Construct an Conv2dSubsampling8 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
super
().
__init__
(
pos_enc_class
)
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
1
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
3
,
2
),
nn
.
ReLU
(),
)
self
.
linear
=
nn
.
Linear
(
odim
*
((((
idim
-
1
)
//
2
-
1
)
//
2
-
1
)
//
2
),
odim
)
self
.
subsampling_rate
=
8
# The right context for every conv layer is computed by:
# (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
# 14 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + (3 - 1) / 2 * 2 * 4
self
.
right_context
=
14
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_mask
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Subsample x.
Args:
x (paddle.Tensor): Input tensor (#batch, time, idim).
x_mask (paddle.Tensor): Input mask (#batch, 1, time).
Returns:
paddle.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
paddle.Tensor: positional encoding
paddle.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
x
=
self
.
linear
(
x
.
transpose
([
0
,
1
,
2
,
3
]).
reshape
([
b
,
t
,
c
*
f
]))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
]
deepspeech/training/gradclip.py
浏览文件 @
b2bc6eb5
...
@@ -21,8 +21,10 @@ from paddle.fluid import core
...
@@ -21,8 +21,10 @@ from paddle.fluid import core
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"ClipGradByGlobalNormWithLog"
]
class
MyClipGradByGlobalNorm
(
paddle
.
nn
.
ClipGradByGlobalNorm
):
class
ClipGradByGlobalNormWithLog
(
paddle
.
nn
.
ClipGradByGlobalNorm
):
def
__init__
(
self
,
clip_norm
):
def
__init__
(
self
,
clip_norm
):
super
().
__init__
(
clip_norm
)
super
().
__init__
(
clip_norm
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录