Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
1a1ce92c
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a1ce92c
编写于
9月 22, 2022
作者:
H
Hui Zhang
提交者:
GitHub
9月 22, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2415 from Zth9730/u2++_decoder
[s2t] support bitransformer decoder
上级
52af86fc
d3e59375
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
332 addition
and
37 deletion
+332
-37
paddlespeech/audio/utils/tensor_utils.py
paddlespeech/audio/utils/tensor_utils.py
+112
-9
paddlespeech/s2t/exps/u2/bin/test_wav.py
paddlespeech/s2t/exps/u2/bin/test_wav.py
+3
-2
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+4
-2
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+1
-1
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+86
-19
paddlespeech/s2t/modules/decoder.py
paddlespeech/s2t/modules/decoder.py
+124
-3
paddlespeech/server/engine/asr/online/python/asr_engine.py
paddlespeech/server/engine/asr/online/python/asr_engine.py
+2
-1
未找到文件。
paddlespeech/audio/utils/tensor_utils.py
浏览文件 @
1a1ce92c
...
...
@@ -31,7 +31,6 @@ def has_tensor(val):
return
True
elif
isinstance
(
val
,
dict
):
for
k
,
v
in
val
.
items
():
print
(
k
)
if
has_tensor
(
v
):
return
True
else
:
...
...
@@ -143,14 +142,15 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
# _sos = paddle.to_tensor(
# [sos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
# _eos = paddle.to_tensor(
# [eos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
# ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
# ys_in = [paddle.concat([_sos, y], axis=0) for y in ys]
# ys_out = [paddle.concat([y, _eos], axis=0) for y in ys]
# return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0])
B
=
ys_pad
.
shape
[
0
]
_sos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
sos
_eos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
eos
...
...
@@ -190,3 +190,106 @@ def th_accuracy(pad_outputs: paddle.Tensor,
# denominator = paddle.sum(mask)
denominator
=
paddle
.
sum
(
mask
.
type_as
(
pad_targets
))
return
float
(
numerator
)
/
float
(
denominator
)
def
reverse_pad_list
(
ys_pad
:
paddle
.
Tensor
,
ys_lens
:
paddle
.
Tensor
,
pad_value
:
float
=-
1.0
)
->
paddle
.
Tensor
:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad
=
pad_sequence
([(
paddle
.
flip
(
y
.
int
()[:
i
],
[
0
]))
for
y
,
i
in
zip
(
ys_pad
,
ys_lens
)],
True
,
pad_value
)
return
r_ys_pad
def
st_reverse_pad_list
(
ys_pad
:
paddle
.
Tensor
,
ys_lens
:
paddle
.
Tensor
,
sos
:
float
,
eos
:
float
)
->
paddle
.
Tensor
:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B
=
ys_pad
.
shape
[
0
]
_sos
=
paddle
.
full
([
B
,
1
],
sos
,
dtype
=
ys_pad
.
dtype
)
max_len
=
paddle
.
max
(
ys_lens
)
index_range
=
paddle
.
arange
(
0
,
max_len
,
1
)
seq_len_expand
=
ys_lens
.
unsqueeze
(
1
)
seq_mask
=
seq_len_expand
>
index_range
# (beam, max_len)
index
=
(
seq_len_expand
-
1
)
-
index_range
# (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index
=
index
*
seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
def
paddle_gather
(
x
,
dim
,
index
):
index_shape
=
index
.
shape
index_flatten
=
index
.
flatten
()
if
dim
<
0
:
dim
=
len
(
x
.
shape
)
+
dim
nd_index
=
[]
for
k
in
range
(
len
(
x
.
shape
)):
if
k
==
dim
:
nd_index
.
append
(
index_flatten
)
else
:
reshape_shape
=
[
1
]
*
len
(
x
.
shape
)
reshape_shape
[
k
]
=
x
.
shape
[
k
]
x_arange
=
paddle
.
arange
(
x
.
shape
[
k
],
dtype
=
index
.
dtype
)
x_arange
=
x_arange
.
reshape
(
reshape_shape
)
dim_index
=
paddle
.
expand
(
x_arange
,
index_shape
).
flatten
()
nd_index
.
append
(
dim_index
)
ind2
=
paddle
.
transpose
(
paddle
.
stack
(
nd_index
),
[
1
,
0
]).
astype
(
"int64"
)
paddle_out
=
paddle
.
gather_nd
(
x
,
ind2
).
reshape
(
index_shape
)
return
paddle_out
r_hyps
=
paddle_gather
(
ys_pad
,
1
,
index
)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
eos
=
paddle
.
full
([
1
],
eos
,
dtype
=
r_hyps
.
dtype
)
r_hyps
=
paddle
.
where
(
seq_mask
,
r_hyps
,
eos
)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps
=
paddle
.
cat
([
_sos
,
r_hyps
],
dim
=
1
)
# r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
return
r_hyps
paddlespeech/s2t/exps/u2/bin/test_wav.py
浏览文件 @
1a1ce92c
...
...
@@ -40,7 +40,7 @@ class U2Infer():
self
.
preprocess_conf
=
config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
reverse_weight
=
getattr
(
config
.
model_conf
,
'reverse_weight'
,
0.0
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
,
...
...
@@ -89,7 +89,8 @@ class U2Infer():
ctc_weight
=
decode_config
.
ctc_weight
,
decoding_chunk_size
=
decode_config
.
decoding_chunk_size
,
num_decoding_left_chunks
=
decode_config
.
num_decoding_left_chunks
,
simulate_streaming
=
decode_config
.
simulate_streaming
)
simulate_streaming
=
decode_config
.
simulate_streaming
,
reverse_weight
=
self
.
reverse_weight
)
rsl
=
result_transcripts
[
0
][
0
]
utt
=
Path
(
self
.
audio_file
).
name
logger
.
info
(
f
"hyp:
{
utt
}
{
result_transcripts
[
0
][
0
]
}
"
)
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
1a1ce92c
...
...
@@ -253,7 +253,6 @@ class U2Trainer(Trainer):
model_conf
.
output_dim
=
self
.
test_loader
.
vocab_size
model
=
U2Model
.
from_config
(
model_conf
)
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
...
...
@@ -317,6 +316,7 @@ class U2Tester(U2Trainer):
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
vocab_list
=
self
.
text_feature
.
vocab_list
self
.
reverse_weight
=
getattr
(
config
.
model_conf
,
'reverse_weight'
,
0.0
)
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
""" ord() id to chr() chr """
...
...
@@ -341,6 +341,7 @@ class U2Tester(U2Trainer):
start_time
=
time
.
time
()
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
self
.
text_feature
)
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
audio
,
audio_len
,
...
...
@@ -350,7 +351,8 @@ class U2Tester(U2Trainer):
ctc_weight
=
decode_config
.
ctc_weight
,
decoding_chunk_size
=
decode_config
.
decoding_chunk_size
,
num_decoding_left_chunks
=
decode_config
.
num_decoding_left_chunks
,
simulate_streaming
=
decode_config
.
simulate_streaming
)
simulate_streaming
=
decode_config
.
simulate_streaming
,
reverse_weight
=
self
.
reverse_weight
)
decode_time
=
time
.
time
()
-
start_time
for
utt
,
target
,
result
,
rec_tids
in
zip
(
...
...
paddlespeech/s2t/io/dataloader.py
浏览文件 @
1a1ce92c
...
...
@@ -361,7 +361,7 @@ class DataLoaderFactory():
elif
mode
==
'valid'
:
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'train_mode'
]
=
False
elif
mode
l
==
'test'
or
mode
==
'align'
:
elif
mode
==
'test'
or
mode
==
'align'
:
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'train_mode'
]
=
False
config
[
'dither'
]
=
0.0
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
1a1ce92c
...
...
@@ -31,6 +31,8 @@ from paddle import nn
from
paddlespeech.audio.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.audio.utils.tensor_utils
import
pad_sequence
from
paddlespeech.audio.utils.tensor_utils
import
reverse_pad_list
from
paddlespeech.audio.utils.tensor_utils
import
st_reverse_pad_list
from
paddlespeech.audio.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.decoders.scorers.ctc
import
CTCPrefixScorer
from
paddlespeech.s2t.frontend.utility
import
IGNORE_ID
...
...
@@ -38,6 +40,7 @@ from paddlespeech.s2t.frontend.utility import load_cmvn
from
paddlespeech.s2t.models.asr_interface
import
ASRInterface
from
paddlespeech.s2t.modules.cmvn
import
GlobalCMVN
from
paddlespeech.s2t.modules.ctc
import
CTCDecoderBase
from
paddlespeech.s2t.modules.decoder
import
BiTransformerDecoder
from
paddlespeech.s2t.modules.decoder
import
TransformerDecoder
from
paddlespeech.s2t.modules.encoder
import
ConformerEncoder
from
paddlespeech.s2t.modules.encoder
import
TransformerEncoder
...
...
@@ -69,6 +72,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc
:
CTCDecoderBase
,
ctc_weight
:
float
=
0.5
,
ignore_id
:
int
=
IGNORE_ID
,
reverse_weight
:
float
=
0.0
,
lsm_weight
:
float
=
0.0
,
length_normalized_loss
:
bool
=
False
,
**
kwargs
):
...
...
@@ -82,6 +86,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
self
.
vocab_size
=
vocab_size
self
.
ignore_id
=
ignore_id
self
.
ctc_weight
=
ctc_weight
self
.
reverse_weight
=
reverse_weight
self
.
encoder
=
encoder
self
.
decoder
=
decoder
...
...
@@ -171,12 +176,21 @@ class U2BaseModel(ASRInterface, nn.Layer):
self
.
ignore_id
)
ys_in_lens
=
ys_pad_lens
+
1
r_ys_pad
=
reverse_pad_list
(
ys_pad
,
ys_pad_lens
,
float
(
self
.
ignore_id
))
r_ys_in_pad
,
r_ys_out_pad
=
add_sos_eos
(
r_ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
# 1. Forward decoder
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
ys_in_pad
,
ys_in_lens
)
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
ys_in_pad
,
ys_in_lens
,
r_ys_in_pad
,
self
.
reverse_weight
)
# 2. Compute attention loss
loss_att
=
self
.
criterion_att
(
decoder_out
,
ys_out_pad
)
r_loss_att
=
paddle
.
to_tensor
(
0.0
)
if
self
.
reverse_weight
>
0.0
:
r_loss_att
=
self
.
criterion_att
(
r_decoder_out
,
r_ys_out_pad
)
loss_att
=
loss_att
*
(
1
-
self
.
reverse_weight
)
+
r_loss_att
*
self
.
reverse_weight
acc_att
=
th_accuracy
(
decoder_out
.
view
(
-
1
,
self
.
vocab_size
),
ys_out_pad
,
...
...
@@ -359,6 +373,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax)
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
...
...
@@ -500,7 +515,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
ctc_weight
:
float
=
0.0
,
simulate_streaming
:
bool
=
False
,
)
->
List
[
int
]:
simulate_streaming
:
bool
=
False
,
reverse_weight
:
float
=
0.0
,
)
->
List
[
int
]:
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
...
...
@@ -520,6 +536,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
if
reverse_weight
>
0.0
:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert
hasattr
(
self
.
decoder
,
'right_decoder'
)
device
=
speech
.
place
batch_size
=
speech
.
shape
[
0
]
# For attention rescoring we only support batch_size=1
...
...
@@ -541,22 +560,30 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyp_content
,
place
=
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
ignore_id
)
ori_hyps_pad
=
hyps_pad
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
r_hyps_pad
=
st_reverse_pad_list
(
ori_hyps_pad
,
hyps_lens
-
1
,
self
.
sos
,
self
.
eos
)
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
,
r_hyps_pad
,
reverse_weight
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
r_decoder_out
,
axis
=-
1
)
r_decoder_out
=
r_decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
...
...
@@ -567,6 +594,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
eos
]
if
reverse_weight
>
0
:
r_score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
r_score
+=
r_decoder_out
[
i
][
len
(
hyp
[
0
])
-
j
-
1
][
w
]
r_score
+=
r_decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
eos
]
score
=
score
*
(
1
-
reverse_weight
)
+
r_score
*
reverse_weight
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
ctc_weight
if
score
>
best_score
:
...
...
@@ -653,12 +686,24 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
return
self
.
ctc
.
log_softmax
(
xs
)
@
jit
.
to_static
# @jit.to_static
def
is_bidirectional_decoder
(
self
)
->
bool
:
"""
Returns:
paddle.Tensor: decoder output
"""
if
hasattr
(
self
.
decoder
,
'right_decoder'
):
return
True
else
:
return
False
# @jit.to_static
def
forward_attention_decoder
(
self
,
hyps
:
paddle
.
Tensor
,
hyps_lens
:
paddle
.
Tensor
,
encoder_out
:
paddle
.
Tensor
,
)
->
paddle
.
Tensor
:
encoder_out
:
paddle
.
Tensor
,
reverse_weight
:
float
=
0.0
,
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
...
...
@@ -676,11 +721,22 @@ class U2BaseModel(ASRInterface, nn.Layer):
# (B, 1, T)
encoder_mask
=
paddle
.
ones
(
[
num_hyps
,
1
,
encoder_out
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
# input for right to left decoder
# this hyps_lens has count <sos> token, we need minus it.
r_hyps_lens
=
hyps_lens
-
1
# this hyps has included <sos> token, so it should be
# convert the original hyps.
r_hyps
=
hyps
[:,
1
:]
# (num_hyps, max_hyps_len, vocab_size)
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_lens
)
r_hyps
=
st_reverse_pad_list
(
r_hyps
,
r_hyps_lens
,
self
.
sos
,
self
.
eos
)
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_lens
,
r_hyps
,
reverse_weight
)
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
return
decoder_out
r_decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
r_decoder_out
,
axis
=-
1
)
return
decoder_out
,
r_decoder_out
@
paddle
.
no_grad
()
def
decode
(
self
,
...
...
@@ -692,7 +748,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc_weight
:
float
=
0.0
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
):
simulate_streaming
:
bool
=
False
,
reverse_weight
:
float
=
0.0
):
"""u2 decoding.
Args:
...
...
@@ -764,7 +821,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
,
ctc_weight
=
ctc_weight
,
simulate_streaming
=
simulate_streaming
)
simulate_streaming
=
simulate_streaming
,
reverse_weight
=
reverse_weight
)
hyps
=
[
hyp
]
else
:
raise
ValueError
(
f
"Not support decoding method:
{
decoding_method
}
"
)
...
...
@@ -801,7 +859,6 @@ class U2Model(U2DecodeModel):
with
DefaultInitializerContext
(
init_type
):
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
configs
)
super
().
__init__
(
vocab_size
=
vocab_size
,
encoder
=
encoder
,
...
...
@@ -851,10 +908,20 @@ class U2Model(U2DecodeModel):
raise
ValueError
(
f
"not support encoder type:
{
encoder_type
}
"
)
# decoder
decoder_type
=
configs
.
get
(
'decoder'
,
'transformer'
)
logger
.
debug
(
f
"U2 Decoder type:
{
decoder_type
}
"
)
if
decoder_type
==
'transformer'
:
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
elif
decoder_type
==
'bitransformer'
:
assert
0.0
<
configs
[
'model_conf'
][
'reverse_weight'
]
<
1.0
assert
configs
[
'decoder_conf'
][
'r_num_blocks'
]
>
0
decoder
=
BiTransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
else
:
raise
ValueError
(
f
"not support decoder type:
{
decoder_type
}
"
)
# ctc decoder and ctc loss
model_conf
=
configs
.
get
(
'model_conf'
,
dict
())
dropout_rate
=
model_conf
.
get
(
'ctc_dropout_rate'
,
0.0
)
...
...
paddlespeech/s2t/modules/decoder.py
浏览文件 @
1a1ce92c
...
...
@@ -35,7 +35,6 @@ from paddlespeech.s2t.modules.mask import make_xs_mask
from
paddlespeech.s2t.modules.mask
import
subsequent_mask
from
paddlespeech.s2t.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"TransformerDecoder"
]
...
...
@@ -116,13 +115,19 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
memory
:
paddle
.
Tensor
,
memory_mask
:
paddle
.
Tensor
,
ys_in_pad
:
paddle
.
Tensor
,
ys_in_lens
:
paddle
.
Tensor
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
ys_in_lens
:
paddle
.
Tensor
,
r_ys_in_pad
:
paddle
.
Tensor
=
paddle
.
empty
([
0
]),
reverse_weight
:
float
=
0.0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: not used in transformer decoder, in order to unify api
with bidirectional decoder
reverse_weight: not used in transformer decoder, in order to unify
api with bidirectional decode
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, vocab_size)
...
...
@@ -151,7 +156,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
# TODO(Hui Zhang): reduce_sum not support bool type
# olens = tgt_mask.sum(1)
olens
=
tgt_mask
.
astype
(
paddle
.
int
).
sum
(
1
)
return
x
,
olens
return
x
,
paddle
.
to_tensor
(
0.0
),
olens
def
forward_one_step
(
self
,
...
...
@@ -251,3 +256,119 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
state_list
=
[[
states
[
i
][
b
]
for
i
in
range
(
n_layers
)]
for
b
in
range
(
n_batch
)]
return
logp
,
state_list
class
BiTransformerDecoder
(
BatchScorerInterface
,
nn
.
Layer
):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
r_num_blocks: the number of right to left decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after: whether to concat attention layer's input and output
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def
__init__
(
self
,
vocab_size
:
int
,
encoder_output_size
:
int
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
r_num_blocks
:
int
=
0
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
self_attention_dropout_rate
:
float
=
0.0
,
src_attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"embed"
,
use_output_layer
:
bool
=
True
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
max_len
:
int
=
5000
):
assert
check_argument_types
()
nn
.
Layer
.
__init__
(
self
)
self
.
left_decoder
=
TransformerDecoder
(
vocab_size
,
encoder_output_size
,
attention_heads
,
linear_units
,
num_blocks
,
dropout_rate
,
positional_dropout_rate
,
self_attention_dropout_rate
,
src_attention_dropout_rate
,
input_layer
,
use_output_layer
,
normalize_before
,
concat_after
,
max_len
)
self
.
right_decoder
=
TransformerDecoder
(
vocab_size
,
encoder_output_size
,
attention_heads
,
linear_units
,
r_num_blocks
,
dropout_rate
,
positional_dropout_rate
,
self_attention_dropout_rate
,
src_attention_dropout_rate
,
input_layer
,
use_output_layer
,
normalize_before
,
concat_after
,
max_len
)
def
forward
(
self
,
memory
:
paddle
.
Tensor
,
memory_mask
:
paddle
.
Tensor
,
ys_in_pad
:
paddle
.
Tensor
,
ys_in_lens
:
paddle
.
Tensor
,
r_ys_in_pad
:
paddle
.
Tensor
,
reverse_weight
:
float
=
0.0
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
used for right to left decoder
reverse_weight: used for right to left decoder
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
r_x: x: decoded token score (right to left decoder)
before softmax (batch, maxlen_out, vocab_size)
if use_output_layer is True,
olens: (batch, )
"""
l_x
,
_
,
olens
=
self
.
left_decoder
(
memory
,
memory_mask
,
ys_in_pad
,
ys_in_lens
)
r_x
=
paddle
.
to_tensor
(
0.0
)
if
reverse_weight
>
0.0
:
r_x
,
_
,
olens
=
self
.
right_decoder
(
memory
,
memory_mask
,
r_ys_in_pad
,
ys_in_lens
)
return
l_x
,
r_x
,
olens
def
forward_one_step
(
self
,
memory
:
paddle
.
Tensor
,
memory_mask
:
paddle
.
Tensor
,
tgt
:
paddle
.
Tensor
,
tgt_mask
:
paddle
.
Tensor
,
cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
List
[
paddle
.
Tensor
]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out, maxlen_out)
dtype=paddle.bool
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
return
self
.
left_decoder
.
forward_one_step
(
memory
,
memory_mask
,
tgt
,
tgt_mask
,
cache
)
paddlespeech/server/engine/asr/online/python/asr_engine.py
浏览文件 @
1a1ce92c
...
...
@@ -612,7 +612,8 @@ class PaddleASRConnectionHanddler:
encoder_out
=
self
.
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
model
.
decoder
(
decoder_out
,
_
,
_
=
self
.
model
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录