Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
45663511
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
45663511
编写于
10月 24, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add transformer lm and encoder score api
上级
c5f66921
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
309 addition
and
6 deletion
+309
-6
deepspeech/models/lm/__init__.py
deepspeech/models/lm/__init__.py
+0
-0
deepspeech/models/lm/transformer.py
deepspeech/models/lm/transformer.py
+259
-0
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+41
-0
deepspeech/modules/encoder_layer.py
deepspeech/modules/encoder_layer.py
+3
-3
deepspeech/modules/subsampling.py
deepspeech/modules/subsampling.py
+6
-3
未找到文件。
deepspeech/models/lm/__init__.py
0 → 100644
浏览文件 @
45663511
deepspeech/models/lm/transformer.py
0 → 100644
浏览文件 @
45663511
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Any
from
typing
import
List
from
typing
import
Tuple
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.decoders.scorers.scorer_interface
import
BatchScorerInterface
#LMInterface
class
TransformerLM
(
nn
.
Layer
,
BatchScorerInterface
):
def
__init__
(
self
,
n_vocab
:
int
,
pos_enc
:
str
=
None
,
embed_unit
:
int
=
128
,
att_unit
:
int
=
256
,
head
:
int
=
2
,
unit
:
int
=
1024
,
layer
:
int
=
4
,
dropout_rate
:
float
=
0.5
,
emb_dropout_rate
:
float
=
0.0
,
att_dropout_rate
:
float
=
0.0
,
tie_weights
:
bool
=
False
,):
nn
.
Layer
.
__init__
(
self
)
if
pos_enc
==
"sinusoidal"
:
pos_enc_layer_type
=
"abs_pos"
elif
pos_enc
is
None
:
#TODO
pos_enc_layer_type
=
"None"
else
:
raise
ValueError
(
f
"unknown pos-enc option:
{
pos_enc
}
"
)
self
.
embed
=
nn
.
Embedding
(
n_vocab
,
embed_unit
)
if
emb_dropout_rate
==
0.0
:
self
.
embed_drop
=
None
else
:
self
.
embed_drop
=
nn
.
Dropout
(
emb_dropout_rate
)
self
.
encoder
=
TransformerEncoder
(
input_size
=
embed_unit
,
output_size
=
att_unit
,
attention_heads
=
head
,
linear_units
=
unit
,
num_blocks
=
layer
,
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
att_dropout_rate
,
input_layer
=
"linear"
,
pos_enc_layer_type
=
pos_enc_layer_type
,
concat_after
=
False
,
static_chunk_size
=
1
,
use_dynamic_chunk
=
False
,
use_dynamic_left_chunk
=
False
)
self
.
decoder
=
nn
.
Linear
(
att_unit
,
n_vocab
)
logging
.
info
(
"Tie weights set to {}"
.
format
(
tie_weights
))
logging
.
info
(
"Dropout set to {}"
.
format
(
dropout_rate
))
logging
.
info
(
"Emb Dropout set to {}"
.
format
(
emb_dropout_rate
))
logging
.
info
(
"Att Dropout set to {}"
.
format
(
att_dropout_rate
))
if
tie_weights
:
assert
(
att_unit
==
embed_unit
),
"Tie Weights: True need embedding and final dimensions to match"
self
.
decoder
.
weight
=
self
.
embed
.
weight
def
_target_mask
(
self
,
ys_in_pad
):
ys_mask
=
ys_in_pad
!=
0
m
=
subsequent_mask
(
ys_mask
.
size
(
-
1
)).
unsqueeze
(
0
)
return
ys_mask
.
unsqueeze
(
-
2
)
&
m
def
forward
(
self
,
x
:
paddle
.
Tensor
,
xlens
,
t
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute LM loss value from buffer sequences.
Args:
x (paddle.Tensor): Input ids. (batch, len)
t (paddle.Tensor): Target ids. (batch, len)
Returns:
tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
xm
=
x
!=
0
if
self
.
embed_drop
is
not
None
:
emb
=
self
.
embed_drop
(
self
.
embed
(
x
))
else
:
emb
=
self
.
embed
(
x
)
xlen
=
xm
.
sum
(
axis
=
1
)
h
,
_
=
self
.
encoder
(
emb
,
xlen
)
y
=
self
.
decoder
(
h
)
loss
=
F
.
cross_entropy
(
y
.
view
(
-
1
,
y
.
shape
[
-
1
]),
t
.
view
(
-
1
),
reduction
=
"none"
)
mask
=
xm
.
to
(
dtype
=
loss
.
dtype
)
logp
=
loss
*
mask
.
view
(
-
1
)
logp
=
logp
.
sum
()
count
=
mask
.
sum
()
return
logp
/
count
,
logp
,
count
# beam search API (see ScorerInterface)
def
score
(
self
,
y
:
paddle
.
Tensor
,
state
:
Any
,
x
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
Any
]:
"""Score new token.
Args:
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
state: Scorer state for prefix tokens
x (paddle.Tensor): encoder feature that generates ys.
Returns:
tuple[paddle.Tensor, Any]: Tuple of
paddle.float32 scores for next token (n_vocab)
and next state for ys
"""
y
=
y
.
unsqueeze
(
0
)
if
self
.
embed_drop
is
not
None
:
emb
=
self
.
embed_drop
(
self
.
embed
(
y
))
else
:
emb
=
self
.
embed
(
y
)
h
,
_
,
cache
=
self
.
encoder
.
forward_one_step
(
emb
,
self
.
_target_mask
(
y
),
cache
=
state
)
h
=
self
.
decoder
(
h
[:,
-
1
])
logp
=
h
.
log_softmax
(
axis
=-
1
).
squeeze
(
0
)
return
logp
,
cache
# batch beam search API (see BatchScorerInterface)
def
batch_score
(
self
,
ys
:
paddle
.
Tensor
,
states
:
List
[
Any
],
xs
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
List
[
Any
]]:
"""Score new token batch (required).
Args:
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (paddle.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[paddle.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch
=
len
(
ys
)
n_layers
=
len
(
self
.
encoder
.
encoders
)
if
states
[
0
]
is
None
:
batch_state
=
None
else
:
# transpose state of [batch, layer] into [layer, batch]
batch_state
=
[
paddle
.
stack
([
states
[
b
][
i
]
for
b
in
range
(
n_batch
)])
for
i
in
range
(
n_layers
)
]
if
self
.
embed_drop
is
not
None
:
emb
=
self
.
embed_drop
(
self
.
embed
(
ys
))
else
:
emb
=
self
.
embed
(
ys
)
# batch decoding
h
,
_
,
states
=
self
.
encoder
.
forward_one_step
(
emb
,
self
.
_target_mask
(
ys
),
cache
=
batch_state
)
h
=
self
.
decoder
(
h
[:,
-
1
])
logp
=
h
.
log_softmax
(
axi
=-
1
)
# transpose state of [layer, batch] into [batch, layer]
state_list
=
[[
states
[
i
][
b
]
for
i
in
range
(
n_layers
)]
for
b
in
range
(
n_batch
)]
return
logp
,
state_list
if
__name__
==
"__main__"
:
tlm
=
TransformerLM
(
n_vocab
=
5002
,
pos_enc
=
None
,
embed_unit
=
128
,
att_unit
=
512
,
head
=
8
,
unit
=
2048
,
layer
=
16
,
dropout_rate
=
0.5
,
)
# n_vocab: int,
# pos_enc: str=None,
# embed_unit: int=128,
# att_unit: int=256,
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
paddle
.
set_device
(
"cpu"
)
model_dict
=
paddle
.
load
(
"transformerLM.pdparams"
)
tlm
.
set_state_dict
(
model_dict
)
tlm
.
eval
()
#Test the score
input2
=
np
.
array
([
5
])
input2
=
paddle
.
to_tensor
(
input2
)
state
=
(
None
,
None
,
0
)
output
,
state
=
tlm
.
score
(
input2
,
state
,
None
)
input3
=
np
.
array
([
10
])
input3
=
paddle
.
to_tensor
(
input3
)
output
,
state
=
tlm
.
score
(
input3
,
state
,
None
)
input4
=
np
.
array
([
0
])
input4
=
paddle
.
to_tensor
(
input4
)
output
,
state
=
tlm
.
score
(
input4
,
state
,
None
)
print
(
"output"
,
output
)
"""
#Test the batch score
batch_size = 2
inp2 = np.array([[5], [10]])
inp2 = paddle.to_tensor(inp2)
output, states = tlm.batch_score(
inp2, [(None,None,0)] * batch_size)
inp3 = np.array([[100], [30]])
inp3 = paddle.to_tensor(inp3)
output, states = tlm.batch_score(
inp3, states)
print("output", output)
#print("cache", cache)
#np.save("output_pd.npy", output)
"""
\ No newline at end of file
deepspeech/modules/encoder.py
浏览文件 @
45663511
...
...
@@ -31,6 +31,7 @@ from deepspeech.modules.encoder_layer import TransformerEncoderLayer
from
deepspeech.modules.mask
import
add_optional_chunk_mask
from
deepspeech.modules.mask
import
make_non_pad_mask
from
deepspeech.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
deepspeech.modules.subsampling
import
Conv2dSubsampling
from
deepspeech.modules.subsampling
import
Conv2dSubsampling4
from
deepspeech.modules.subsampling
import
Conv2dSubsampling6
from
deepspeech.modules.subsampling
import
Conv2dSubsampling8
...
...
@@ -370,6 +371,46 @@ class TransformerEncoder(BaseEncoder):
concat_after
=
concat_after
)
for
_
in
range
(
num_blocks
)
])
def
forward_one_step
(
self
,
xs
:
paddle
.
Tensor
,
masks
:
paddle
.
Tensor
,
cache
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Encode input frame.
Args:
xs (paddle.Tensor): Input tensor. (B, T, D)
masks (paddle.Tensor): Mask tensor. (B, 1, T)
cache (List[paddle.Tensor]): List of cache tensors.
Returns:
paddle.Tensor: Output tensor.
paddle.Tensor: Mask tensor.
List[paddle.Tensor]: List of new cache tensors.
"""
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
if
isinstance
(
self
.
embed
,
Conv2dSubsampling
):
# xs, masks = self.embed(xs, masks)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
.
astype
(
xs
.
dtype
),
offset
=
0
)
else
:
xs
=
self
.
embed
(
xs
)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks
=
masks
.
astype
(
paddle
.
bool
)
if
cache
is
None
:
cache
=
[
None
for
_
in
range
(
len
(
self
.
encoders
))]
new_cache
=
[]
for
c
,
e
in
zip
(
cache
,
self
.
encoders
):
xs
,
masks
,
_
=
e
(
xs
,
masks
,
output_cache
=
c
)
new_cache
.
append
(
xs
)
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
return
xs
,
masks
,
new_cache
class
ConformerEncoder
(
BaseEncoder
):
"""Conformer encoder module."""
...
...
deepspeech/modules/encoder_layer.py
浏览文件 @
45663511
...
...
@@ -71,7 +71,7 @@ class TransformerEncoderLayer(nn.Layer):
self
,
x
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
pos_emb
:
paddle
.
Tensor
,
pos_emb
:
Optional
[
paddle
.
Tensor
]
=
None
,
mask_pad
:
Optional
[
paddle
.
Tensor
]
=
None
,
output_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
cnn_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
...
...
@@ -82,8 +82,8 @@ class TransformerEncoderLayer(nn.Layer):
mask (paddle.Tensor): Mask tensor for the input (#batch, time).
pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer
mask_pad (paddle.Tensor):
does not used in transformer layer,
just for unified api with conformer.
mask_pad (paddle.Tensor):
not used here, it's for interface
compatibility to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface
...
...
deepspeech/modules/subsampling.py
浏览文件 @
45663511
...
...
@@ -82,8 +82,11 @@ class LinearNoSubsampling(BaseSubsampling):
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
class
Conv2dSubsampling
(
BaseSubsampling
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
class
Conv2dSubsampling4
(
Base
Subsampling
):
class
Conv2dSubsampling4
(
Conv2d
Subsampling
):
"""Convolutional 2D subsampling (to 1/4 length)."""
def
__init__
(
self
,
...
...
@@ -134,7 +137,7 @@ class Conv2dSubsampling4(BaseSubsampling):
return
x
,
pos_emb
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
]
class
Conv2dSubsampling6
(
Base
Subsampling
):
class
Conv2dSubsampling6
(
Conv2d
Subsampling
):
"""Convolutional 2D subsampling (to 1/6 length)."""
def
__init__
(
self
,
...
...
@@ -187,7 +190,7 @@ class Conv2dSubsampling6(BaseSubsampling):
return
x
,
pos_emb
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
4
:
3
]
class
Conv2dSubsampling8
(
Base
Subsampling
):
class
Conv2dSubsampling8
(
Conv2d
Subsampling
):
"""Convolutional 2D subsampling (to 1/8 length)."""
def
__init__
(
self
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录