Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5e7e582d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5e7e582d
编写于
4月 07, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
2fa6bbbe
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
116 addition
and
41 deletion
+116
-41
deepspeech/__init__.py
deepspeech/__init__.py
+62
-5
deepspeech/models/u2.py
deepspeech/models/u2.py
+5
-5
deepspeech/modules/attention.py
deepspeech/modules/attention.py
+2
-2
deepspeech/modules/embedding.py
deepspeech/modules/embedding.py
+6
-11
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+8
-6
deepspeech/modules/encoder_layer.py
deepspeech/modules/encoder_layer.py
+2
-2
deepspeech/utils/layer_tools.py
deepspeech/utils/layer_tools.py
+9
-4
deepspeech/utils/tensor_utils.py
deepspeech/utils/tensor_utils.py
+1
-1
tests/u2_model_test.py
tests/u2_model_test.py
+21
-5
未找到文件。
deepspeech/__init__.py
浏览文件 @
5e7e582d
...
@@ -22,7 +22,8 @@ import paddle
...
@@ -22,7 +22,8 @@ import paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
paddle.nn
import
initializer
as
I
#TODO(Hui Zhang): remove fluid import
from
paddle.fluid
import
core
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
########### hcak logging #############
########### hcak logging #############
...
@@ -44,10 +45,51 @@ paddle.int = 'int32'
...
@@ -44,10 +45,51 @@ paddle.int = 'int32'
paddle
.
int64
=
'int64'
paddle
.
int64
=
'int64'
paddle
.
long
=
'int64'
paddle
.
long
=
'int64'
paddle
.
uint8
=
'uint8'
paddle
.
uint8
=
'uint8'
paddle
.
uint16
=
'uint16'
paddle
.
complex64
=
'complex64'
paddle
.
complex64
=
'complex64'
paddle
.
complex128
=
'complex128'
paddle
.
complex128
=
'complex128'
paddle
.
cdouble
=
'complex128'
paddle
.
cdouble
=
'complex128'
def
convert_dtype_to_string
(
tensor_dtype
):
"""
Convert the data type in numpy to the data type in Paddle
Args:
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
Returns:
core.VarDesc.VarType: the data type in Paddle.
"""
dtype
=
tensor_dtype
if
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
return
paddle
.
float32
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP64
:
return
paddle
.
float64
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
return
paddle
.
float16
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT32
:
return
paddle
.
int32
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT16
:
return
paddle
.
int16
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT64
:
return
paddle
.
int64
elif
dtype
==
core
.
VarDesc
.
VarType
.
BOOL
:
return
paddle
.
bool
elif
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return
paddle
.
uint16
elif
dtype
==
core
.
VarDesc
.
VarType
.
UINT8
:
return
paddle
.
uint8
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT8
:
return
paddle
.
int8
elif
dtype
==
core
.
VarDesc
.
VarType
.
COMPLEX64
:
return
paddle
.
complex64
elif
dtype
==
core
.
VarDesc
.
VarType
.
COMPLEX128
:
return
paddle
.
complex128
else
:
raise
ValueError
(
"Not supported tensor dtype %s"
%
dtype
)
if
not
hasattr
(
paddle
,
'softmax'
):
if
not
hasattr
(
paddle
,
'softmax'
):
logger
.
warn
(
"register user softmax to paddle, remove this when fixed!"
)
logger
.
warn
(
"register user softmax to paddle, remove this when fixed!"
)
setattr
(
paddle
,
'softmax'
,
paddle
.
nn
.
functional
.
softmax
)
setattr
(
paddle
,
'softmax'
,
paddle
.
nn
.
functional
.
softmax
)
...
@@ -126,7 +168,9 @@ if not hasattr(paddle.Tensor, 'new_full'):
...
@@ -126,7 +168,9 @@ if not hasattr(paddle.Tensor, 'new_full'):
def
eq
(
xs
:
paddle
.
Tensor
,
ys
:
Union
[
paddle
.
Tensor
,
float
])
->
paddle
.
Tensor
:
def
eq
(
xs
:
paddle
.
Tensor
,
ys
:
Union
[
paddle
.
Tensor
,
float
])
->
paddle
.
Tensor
:
return
xs
.
equal
(
paddle
.
to_tensor
(
ys
,
dtype
=
xs
.
dtype
,
place
=
xs
.
place
))
return
xs
.
equal
(
paddle
.
to_tensor
(
ys
,
dtype
=
convert_dtype_to_string
(
xs
.
dtype
),
place
=
xs
.
place
))
if
not
hasattr
(
paddle
.
Tensor
,
'eq'
):
if
not
hasattr
(
paddle
.
Tensor
,
'eq'
):
...
@@ -184,10 +228,21 @@ if not hasattr(paddle.Tensor, 'view_as'):
...
@@ -184,10 +228,21 @@ if not hasattr(paddle.Tensor, 'view_as'):
paddle
.
Tensor
.
view_as
=
view_as
paddle
.
Tensor
.
view_as
=
view_as
def
is_broadcastable
(
shp1
,
shp2
):
for
a
,
b
in
zip
(
shp1
[::
-
1
],
shp2
[::
-
1
]):
if
a
==
1
or
b
==
1
or
a
==
b
:
pass
else
:
return
False
return
True
def
masked_fill
(
xs
:
paddle
.
Tensor
,
def
masked_fill
(
xs
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
value
:
Union
[
float
,
int
]):
value
:
Union
[
float
,
int
]):
assert
xs
.
shape
==
mask
.
shape
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
==
True
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
mask
=
mask
.
broadcast_to
(
bshape
)
trues
=
paddle
.
ones_like
(
xs
)
*
value
trues
=
paddle
.
ones_like
(
xs
)
*
value
xs
=
paddle
.
where
(
mask
,
trues
,
xs
)
xs
=
paddle
.
where
(
mask
,
trues
,
xs
)
return
xs
return
xs
...
@@ -202,7 +257,9 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
...
@@ -202,7 +257,9 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
def
masked_fill_
(
xs
:
paddle
.
Tensor
,
def
masked_fill_
(
xs
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
value
:
Union
[
float
,
int
]):
value
:
Union
[
float
,
int
]):
assert
xs
.
shape
==
mask
.
shape
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
==
True
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
mask
=
mask
.
broadcast_to
(
bshape
)
trues
=
paddle
.
ones_like
(
xs
)
*
value
trues
=
paddle
.
ones_like
(
xs
)
*
value
ret
=
paddle
.
where
(
mask
,
trues
,
xs
)
ret
=
paddle
.
where
(
mask
,
trues
,
xs
)
paddle
.
assign
(
ret
,
output
=
xs
)
paddle
.
assign
(
ret
,
output
=
xs
)
...
@@ -414,4 +471,4 @@ if not hasattr(paddle.nn, 'ConstantPad2d'):
...
@@ -414,4 +471,4 @@ if not hasattr(paddle.nn, 'ConstantPad2d'):
if
not
hasattr
(
paddle
.
jit
,
'export'
):
if
not
hasattr
(
paddle
.
jit
,
'export'
):
logger
.
warn
(
"register user export to paddle.jit, remove this when fixed!"
)
logger
.
warn
(
"register user export to paddle.jit, remove this when fixed!"
)
setattr
(
paddle
.
jit
,
'export'
,
paddle
.
jit
.
to_static
)
setattr
(
paddle
.
jit
,
'export'
,
paddle
.
jit
.
to_static
)
\ No newline at end of file
deepspeech/models/u2.py
浏览文件 @
5e7e582d
...
@@ -39,7 +39,7 @@ from deepspeech.modules.encoder import ConformerEncoder
...
@@ -39,7 +39,7 @@ from deepspeech.modules.encoder import ConformerEncoder
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.modules.ctc
import
CTCDecoder
from
deepspeech.modules.ctc
import
CTCDecoder
from
deepspeech.modules.decoder
import
TransformerDecoder
from
deepspeech.modules.decoder
import
TransformerDecoder
from
deepspeech.modules.l
abel_smoothing_l
oss
import
LabelSmoothingLoss
from
deepspeech.modules.loss
import
LabelSmoothingLoss
from
deepspeech.frontend.utility
import
load_cmvn
from
deepspeech.frontend.utility
import
load_cmvn
...
@@ -633,7 +633,7 @@ class U2Model(nn.Module):
...
@@ -633,7 +633,7 @@ class U2Model(nn.Module):
class
U2TransformerModel
(
U2Model
):
class
U2TransformerModel
(
U2Model
):
def
__init__
(
configs
:
dict
):
def
__init__
(
self
,
configs
:
dict
):
if
configs
[
'cmvn_file'
]
is
not
None
:
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'cmvn_file_type'
])
configs
[
'cmvn_file_type'
])
...
@@ -655,7 +655,7 @@ class U2TransformerModel(U2Model):
...
@@ -655,7 +655,7 @@ class U2TransformerModel(U2Model):
**
configs
[
'decoder_conf'
])
**
configs
[
'decoder_conf'
])
ctc
=
CTCDecoder
(
vocab_size
,
encoder
.
output_size
())
ctc
=
CTCDecoder
(
vocab_size
,
encoder
.
output_size
())
s
elf
.
__init__
(
s
uper
()
.
__init__
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
encoder
=
encoder
,
encoder
=
encoder
,
decoder
=
decoder
,
decoder
=
decoder
,
...
@@ -664,7 +664,7 @@ class U2TransformerModel(U2Model):
...
@@ -664,7 +664,7 @@ class U2TransformerModel(U2Model):
class
U2ConformerModel
(
U2Model
):
class
U2ConformerModel
(
U2Model
):
def
__init__
(
configs
:
dict
):
def
__init__
(
self
,
configs
:
dict
):
if
configs
[
'cmvn_file'
]
is
not
None
:
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'cmvn_file_type'
])
configs
[
'cmvn_file_type'
])
...
@@ -686,7 +686,7 @@ class U2ConformerModel(U2Model):
...
@@ -686,7 +686,7 @@ class U2ConformerModel(U2Model):
**
configs
[
'decoder_conf'
])
**
configs
[
'decoder_conf'
])
ctc
=
CTCDecoder
(
vocab_size
,
encoder
.
output_size
())
ctc
=
CTCDecoder
(
vocab_size
,
encoder
.
output_size
())
s
elf
.
__init__
(
s
uper
()
.
__init__
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
encoder
=
encoder
,
encoder
=
encoder
,
decoder
=
decoder
,
decoder
=
decoder
,
...
...
deepspeech/modules/attention.py
浏览文件 @
5e7e582d
...
@@ -217,11 +217,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
...
@@ -217,11 +217,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
# first compute matrix a and matrix c
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
# (batch, head, time1, time2)
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
k
.
transpose
([
0
,
1
,
3
,
2
]))
matrix_ac
=
paddle
.
matmul
(
q_with_bias_u
,
k
.
transpose
([
0
,
1
,
3
,
2
]))
# compute matrix b and matrix d
# compute matrix b and matrix d
# (batch, head, time1, time2)
# (batch, head, time1, time2)
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
([
0
,
1
,
3
,
2
]))
matrix_bd
=
paddle
.
matmul
(
q_with_bias_v
,
p
.
transpose
([
0
,
1
,
3
,
2
]))
# Remove rel_shift since it is useless in speech recognition,
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
# matrix_bd = self.rel_shift(matrix_bd)
...
...
deepspeech/modules/embedding.py
浏览文件 @
5e7e582d
...
@@ -48,7 +48,7 @@ class PositionalEncoding(nn.Layer):
...
@@ -48,7 +48,7 @@ class PositionalEncoding(nn.Layer):
self
.
max_len
=
max_len
self
.
max_len
=
max_len
self
.
xscale
=
paddle
.
to_tensor
(
math
.
sqrt
(
self
.
d_model
))
self
.
xscale
=
paddle
.
to_tensor
(
math
.
sqrt
(
self
.
d_model
))
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
pe
=
paddle
.
zeros
(
self
.
max_len
,
self
.
d_model
)
#[T,D]
self
.
pe
=
paddle
.
zeros
(
[
self
.
max_len
,
self
.
d_model
]
)
#[T,D]
position
=
paddle
.
arange
(
position
=
paddle
.
arange
(
0
,
self
.
max_len
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
#[T, 1]
0
,
self
.
max_len
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
#[T, 1]
...
@@ -70,11 +70,9 @@ class PositionalEncoding(nn.Layer):
...
@@ -70,11 +70,9 @@ class PositionalEncoding(nn.Layer):
paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...)
paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...)
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
"""
"""
T
=
paddle
.
shape
(
x
)[
1
]
T
=
x
.
shape
[
1
]
assert
offset
+
T
<
self
.
max_len
assert
offset
+
x
.
size
(
1
)
<
self
.
max_len
#assert offset + x.size(1) < self.max_len
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
#self.pe = self.pe.to(x.device)
#pos_emb = self.pe[:, offset:offset + x.size(1)]
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
T
]
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
T
]
x
=
x
*
self
.
xscale
+
pos_emb
x
=
x
*
self
.
xscale
+
pos_emb
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
...
@@ -119,11 +117,8 @@ class RelPositionalEncoding(PositionalEncoding):
...
@@ -119,11 +117,8 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
"""
#T = paddle.shape()[1]
#assert offset + T < self.max_len
assert
offset
+
x
.
size
(
1
)
<
self
.
max_len
assert
offset
+
x
.
size
(
1
)
<
self
.
max_len
#self.pe = self.pe.to(x.device)
x
=
x
*
self
.
xscale
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
x
.
size
(
1
)]
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
#pos_emb = self.pe[:, offset:offset + T
]
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
x
.
shape
[
1
]
]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
deepspeech/modules/encoder.py
浏览文件 @
5e7e582d
...
@@ -23,7 +23,7 @@ from paddle.nn import initializer as I
...
@@ -23,7 +23,7 @@ from paddle.nn import initializer as I
from
deepspeech.modules.attention
import
MultiHeadedAttention
from
deepspeech.modules.attention
import
MultiHeadedAttention
from
deepspeech.modules.attention
import
RelPositionMultiHeadedAttention
from
deepspeech.modules.attention
import
RelPositionMultiHeadedAttention
from
deepspeech.modules.convolution
import
ConvolutionModule
from
deepspeech.modules.con
former_con
volution
import
ConvolutionModule
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.embedding
import
RelPositionalEncoding
from
deepspeech.modules.embedding
import
RelPositionalEncoding
from
deepspeech.modules.encoder_layer
import
TransformerEncoderLayer
from
deepspeech.modules.encoder_layer
import
TransformerEncoderLayer
...
@@ -33,7 +33,7 @@ from deepspeech.modules.subsampling import Conv2dSubsampling4
...
@@ -33,7 +33,7 @@ from deepspeech.modules.subsampling import Conv2dSubsampling4
from
deepspeech.modules.subsampling
import
Conv2dSubsampling6
from
deepspeech.modules.subsampling
import
Conv2dSubsampling6
from
deepspeech.modules.subsampling
import
Conv2dSubsampling8
from
deepspeech.modules.subsampling
import
Conv2dSubsampling8
from
deepspeech.modules.subsampling
import
LinearNoSubsampling
from
deepspeech.modules.subsampling
import
LinearNoSubsampling
from
deepspeech.modules.mask
import
make_pad_mask
from
deepspeech.modules.mask
import
make_
non_
pad_mask
from
deepspeech.modules.mask
import
add_optional_chunk_mask
from
deepspeech.modules.mask
import
add_optional_chunk_mask
from
deepspeech.modules.activation
import
get_activation
from
deepspeech.modules.activation
import
get_activation
...
@@ -155,10 +155,12 @@ class BaseEncoder(nn.Layer):
...
@@ -155,10 +155,12 @@ class BaseEncoder(nn.Layer):
encoder output tensor, lens and mask
encoder output tensor, lens and mask
"""
"""
masks
=
make_non_pad_mask
(
xs_lens
).
unsqueeze
(
1
)
# (B, 1, L)
masks
=
make_non_pad_mask
(
xs_lens
).
unsqueeze
(
1
)
# (B, 1, L)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad
=
masks
.
logical_not
()
if
self
.
global_cmvn
is
not
None
:
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
xs
=
self
.
global_cmvn
(
xs
)
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
,
offset
=
0
)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
mask_pad
=
~
masks
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
.
type_as
(
xs
),
offset
=
0
)
chunk_masks
=
add_optional_chunk_mask
(
chunk_masks
=
add_optional_chunk_mask
(
xs
,
masks
,
self
.
use_dynamic_chunk
,
self
.
use_dynamic_left_chunk
,
xs
,
masks
,
self
.
use_dynamic_chunk
,
self
.
use_dynamic_left_chunk
,
decoding_chunk_size
,
self
.
static_chunk_size
,
decoding_chunk_size
,
self
.
static_chunk_size
,
...
@@ -380,7 +382,7 @@ class ConformerEncoder(BaseEncoder):
...
@@ -380,7 +382,7 @@ class ConformerEncoder(BaseEncoder):
concat_after
:
bool
=
False
,
concat_after
:
bool
=
False
,
static_chunk_size
:
int
=
0
,
static_chunk_size
:
int
=
0
,
use_dynamic_chunk
:
bool
=
False
,
use_dynamic_chunk
:
bool
=
False
,
global_cmvn
:
torch
.
nn
.
Module
=
None
,
global_cmvn
:
nn
.
Layer
=
None
,
use_dynamic_left_chunk
:
bool
=
False
,
use_dynamic_left_chunk
:
bool
=
False
,
positionwise_conv_kernel_size
:
int
=
1
,
positionwise_conv_kernel_size
:
int
=
1
,
macaron_style
:
bool
=
True
,
macaron_style
:
bool
=
True
,
...
@@ -431,7 +433,7 @@ class ConformerEncoder(BaseEncoder):
...
@@ -431,7 +433,7 @@ class ConformerEncoder(BaseEncoder):
self
.
encoders
=
nn
.
ModuleList
([
self
.
encoders
=
nn
.
ModuleList
([
ConformerEncoderLayer
(
ConformerEncoderLayer
(
size
=
output_size
,
size
=
output_size
,
eself_attn
=
ncoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
self_attn
=
e
ncoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
feed_forward
=
positionwise_layer
(
*
positionwise_layer_args
),
feed_forward
=
positionwise_layer
(
*
positionwise_layer_args
),
feed_forward_macaron
=
positionwise_layer
(
feed_forward_macaron
=
positionwise_layer
(
*
positionwise_layer_args
)
if
macaron_style
else
None
,
*
positionwise_layer_args
)
if
macaron_style
else
None
,
...
...
deepspeech/modules/encoder_layer.py
浏览文件 @
5e7e582d
...
@@ -127,7 +127,7 @@ class TransformerEncoderLayer(nn.Layer):
...
@@ -127,7 +127,7 @@ class TransformerEncoderLayer(nn.Layer):
if
output_cache
is
not
None
:
if
output_cache
is
not
None
:
x
=
paddle
.
concat
([
output_cache
,
x
],
axis
=
1
)
x
=
paddle
.
concat
([
output_cache
,
x
],
axis
=
1
)
fake_cnn_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
plac
e
)
fake_cnn_cache
=
paddle
.
zeros
([
1
],
dtype
=
x
.
dtyp
e
)
return
x
,
mask
,
fake_cnn_cache
return
x
,
mask
,
fake_cnn_cache
...
@@ -253,7 +253,7 @@ class ConformerEncoderLayer(nn.Layer):
...
@@ -253,7 +253,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module
# convolution module
# Fake new cnn cache here, and then change it in conv_module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
plac
e
)
new_cnn_cache
=
paddle
.
zeros
([
1
],
dtype
=
x
.
dtyp
e
)
if
self
.
conv_module
is
not
None
:
if
self
.
conv_module
is
not
None
:
residual
=
x
residual
=
x
if
self
.
normalize_before
:
if
self
.
normalize_before
:
...
...
deepspeech/utils/layer_tools.py
浏览文件 @
5e7e582d
...
@@ -24,13 +24,18 @@ __all__ = [
...
@@ -24,13 +24,18 @@ __all__ = [
def
summary
(
layer
:
nn
.
Layer
,
print_func
=
print
):
def
summary
(
layer
:
nn
.
Layer
,
print_func
=
print
):
num_params
=
num_elements
=
0
num_params
=
num_elements
=
0
print_func
(
"layer summary:"
)
if
print_func
:
print_func
(
f
"
{
layer
.
__class__
.
__name__
}
summary:"
)
for
name
,
param
in
layer
.
state_dict
().
items
():
for
name
,
param
in
layer
.
state_dict
().
items
():
print_func
(
"{}|{}|{}"
.
format
(
name
,
param
.
shape
,
np
.
prod
(
param
.
shape
)))
if
print_func
:
print_func
(
"{} | {} | {}"
.
format
(
name
,
param
.
shape
,
np
.
prod
(
param
.
shape
)))
num_elements
+=
np
.
prod
(
param
.
shape
)
num_elements
+=
np
.
prod
(
param
.
shape
)
num_params
+=
1
num_params
+=
1
print_func
(
"layer has {} parameters, {} elements."
.
format
(
num_params
,
if
print_func
:
num_elements
))
print_func
(
f
"
{
layer
.
__class__
.
__name__
}
has
{
num_params
}
parameters,
{
num_elements
}
elements."
)
def
gradient_norm
(
layer
:
nn
.
Layer
):
def
gradient_norm
(
layer
:
nn
.
Layer
):
...
...
deepspeech/utils/tensor_utils.py
浏览文件 @
5e7e582d
...
@@ -122,7 +122,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
...
@@ -122,7 +122,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
ys
=
[
y
[
y
!=
ignore_id
]
for
y
in
ys_pad
]
# parse padded ys
ys
=
[
y
[
y
!=
ignore_id
]
for
y
in
ys_pad
]
# parse padded ys
ys_in
=
[
paddle
.
cat
([
_sos
,
y
],
dim
=
0
)
for
y
in
ys
]
ys_in
=
[
paddle
.
cat
([
_sos
,
y
],
dim
=
0
)
for
y
in
ys
]
ys_out
=
[
paddle
.
cat
([
y
,
_eos
],
dim
=
0
)
for
y
in
ys
]
ys_out
=
[
paddle
.
cat
([
y
,
_eos
],
dim
=
0
)
for
y
in
ys
]
return
pad_
list
(
ys_in
,
eos
),
pad_list
(
ys_out
,
ignore_id
)
return
pad_
sequence
(
ys_in
,
padding_value
=
eos
),
pad_sequence
(
ys_out
,
padding_value
=
ignore_id
)
def
th_accuracy
(
pad_outputs
:
paddle
.
Tensor
,
def
th_accuracy
(
pad_outputs
:
paddle
.
Tensor
,
...
...
tests/u2_model_test.py
浏览文件 @
5e7e582d
...
@@ -20,6 +20,7 @@ from yacs.config import CfgNode as CN
...
@@ -20,6 +20,7 @@ from yacs.config import CfgNode as CN
from
deepspeech.models.u2
import
U2TransformerModel
from
deepspeech.models.u2
import
U2TransformerModel
from
deepspeech.models.u2
import
U2ConformerModel
from
deepspeech.models.u2
import
U2ConformerModel
from
deepspeech.utils.layer_tools
import
summary
class
TestU2Model
(
unittest
.
TestCase
):
class
TestU2Model
(
unittest
.
TestCase
):
...
@@ -27,8 +28,9 @@ class TestU2Model(unittest.TestCase):
...
@@ -27,8 +28,9 @@ class TestU2Model(unittest.TestCase):
paddle
.
set_device
(
'cpu'
)
paddle
.
set_device
(
'cpu'
)
self
.
batch_size
=
2
self
.
batch_size
=
2
self
.
feat_dim
=
161
self
.
feat_dim
=
83
self
.
max_len
=
64
self
.
max_len
=
64
self
.
vocab_size
=
4239
#(B, T, D)
#(B, T, D)
audio
=
np
.
random
.
randn
(
self
.
batch_size
,
self
.
max_len
,
self
.
feat_dim
)
audio
=
np
.
random
.
randn
(
self
.
batch_size
,
self
.
max_len
,
self
.
feat_dim
)
...
@@ -77,8 +79,15 @@ class TestU2Model(unittest.TestCase):
...
@@ -77,8 +79,15 @@ class TestU2Model(unittest.TestCase):
length_normalized_loss: false
length_normalized_loss: false
"""
"""
cfg
=
CN
().
load_cfg
(
conf_str
)
cfg
=
CN
().
load_cfg
(
conf_str
)
print
(
cfg
)
cfg
.
input_dim
=
self
.
feat_dim
model
=
U2TransformerModel
()
cfg
.
output_dim
=
self
.
vocab_size
cfg
.
cmvn_file
=
None
cfg
.
cmvn_file_type
=
'npz'
cfg
.
freeze
()
model
=
U2TransformerModel
(
cfg
)
summary
(
model
,
None
)
output
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
print
(
output
)
def
test_conformer
(
self
):
def
test_conformer
(
self
):
conf_str
=
"""
conf_str
=
"""
...
@@ -119,8 +128,15 @@ class TestU2Model(unittest.TestCase):
...
@@ -119,8 +128,15 @@ class TestU2Model(unittest.TestCase):
length_normalized_loss: false
length_normalized_loss: false
"""
"""
cfg
=
CN
().
load_cfg
(
conf_str
)
cfg
=
CN
().
load_cfg
(
conf_str
)
print
(
cfg
)
cfg
.
input_dim
=
self
.
feat_dim
model
=
U2ConformerModel
()
cfg
.
output_dim
=
self
.
vocab_size
cfg
.
cmvn_file
=
None
cfg
.
cmvn_file_type
=
'npz'
cfg
.
freeze
()
model
=
U2ConformerModel
(
cfg
)
summary
(
model
,
None
)
output
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
print
(
output
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录