Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
1a1ce92c
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a1ce92c
编写于
9月 22, 2022
作者:
H
Hui Zhang
提交者:
GitHub
9月 22, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2415 from Zth9730/u2++_decoder
[s2t] support bitransformer decoder
上级
52af86fc
d3e59375
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
332 addition
and
37 deletion
+332
-37
paddlespeech/audio/utils/tensor_utils.py
paddlespeech/audio/utils/tensor_utils.py
+112
-9
paddlespeech/s2t/exps/u2/bin/test_wav.py
paddlespeech/s2t/exps/u2/bin/test_wav.py
+3
-2
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+4
-2
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+1
-1
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+86
-19
paddlespeech/s2t/modules/decoder.py
paddlespeech/s2t/modules/decoder.py
+124
-3
paddlespeech/server/engine/asr/online/python/asr_engine.py
paddlespeech/server/engine/asr/online/python/asr_engine.py
+2
-1
未找到文件。
paddlespeech/audio/utils/tensor_utils.py
浏览文件 @
1a1ce92c
...
@@ -31,7 +31,6 @@ def has_tensor(val):
...
@@ -31,7 +31,6 @@ def has_tensor(val):
return
True
return
True
elif
isinstance
(
val
,
dict
):
elif
isinstance
(
val
,
dict
):
for
k
,
v
in
val
.
items
():
for
k
,
v
in
val
.
items
():
print
(
k
)
if
has_tensor
(
v
):
if
has_tensor
(
v
):
return
True
return
True
else
:
else
:
...
@@ -143,14 +142,15 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
...
@@ -143,14 +142,15 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
[ 7, 8, 9, 11, -1, -1]])
[ 7, 8, 9, 11, -1, -1]])
"""
"""
# TODO(Hui Zhang): using comment code,
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# _sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
# [sos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# _eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
# [eos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
# ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
# ys_in = [paddle.concat([_sos, y], axis=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
# ys_out = [paddle.concat([y, _eos], axis=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
# return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0])
B
=
ys_pad
.
shape
[
0
]
B
=
ys_pad
.
shape
[
0
]
_sos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
sos
_sos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
sos
_eos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
eos
_eos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
eos
...
@@ -190,3 +190,106 @@ def th_accuracy(pad_outputs: paddle.Tensor,
...
@@ -190,3 +190,106 @@ def th_accuracy(pad_outputs: paddle.Tensor,
# denominator = paddle.sum(mask)
# denominator = paddle.sum(mask)
denominator
=
paddle
.
sum
(
mask
.
type_as
(
pad_targets
))
denominator
=
paddle
.
sum
(
mask
.
type_as
(
pad_targets
))
return
float
(
numerator
)
/
float
(
denominator
)
return
float
(
numerator
)
/
float
(
denominator
)
def
reverse_pad_list
(
ys_pad
:
paddle
.
Tensor
,
ys_lens
:
paddle
.
Tensor
,
pad_value
:
float
=-
1.0
)
->
paddle
.
Tensor
:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad
=
pad_sequence
([(
paddle
.
flip
(
y
.
int
()[:
i
],
[
0
]))
for
y
,
i
in
zip
(
ys_pad
,
ys_lens
)],
True
,
pad_value
)
return
r_ys_pad
def
st_reverse_pad_list
(
ys_pad
:
paddle
.
Tensor
,
ys_lens
:
paddle
.
Tensor
,
sos
:
float
,
eos
:
float
)
->
paddle
.
Tensor
:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B
=
ys_pad
.
shape
[
0
]
_sos
=
paddle
.
full
([
B
,
1
],
sos
,
dtype
=
ys_pad
.
dtype
)
max_len
=
paddle
.
max
(
ys_lens
)
index_range
=
paddle
.
arange
(
0
,
max_len
,
1
)
seq_len_expand
=
ys_lens
.
unsqueeze
(
1
)
seq_mask
=
seq_len_expand
>
index_range
# (beam, max_len)
index
=
(
seq_len_expand
-
1
)
-
index_range
# (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index
=
index
*
seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
def
paddle_gather
(
x
,
dim
,
index
):
index_shape
=
index
.
shape
index_flatten
=
index
.
flatten
()
if
dim
<
0
:
dim
=
len
(
x
.
shape
)
+
dim
nd_index
=
[]
for
k
in
range
(
len
(
x
.
shape
)):
if
k
==
dim
:
nd_index
.
append
(
index_flatten
)
else
:
reshape_shape
=
[
1
]
*
len
(
x
.
shape
)
reshape_shape
[
k
]
=
x
.
shape
[
k
]
x_arange
=
paddle
.
arange
(
x
.
shape
[
k
],
dtype
=
index
.
dtype
)
x_arange
=
x_arange
.
reshape
(
reshape_shape
)
dim_index
=
paddle
.
expand
(
x_arange
,
index_shape
).
flatten
()
nd_index
.
append
(
dim_index
)
ind2
=
paddle
.
transpose
(
paddle
.
stack
(
nd_index
),
[
1
,
0
]).
astype
(
"int64"
)
paddle_out
=
paddle
.
gather_nd
(
x
,
ind2
).
reshape
(
index_shape
)
return
paddle_out
r_hyps
=
paddle_gather
(
ys_pad
,
1
,
index
)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
eos
=
paddle
.
full
([
1
],
eos
,
dtype
=
r_hyps
.
dtype
)
r_hyps
=
paddle
.
where
(
seq_mask
,
r_hyps
,
eos
)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps
=
paddle
.
cat
([
_sos
,
r_hyps
],
dim
=
1
)
# r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
return
r_hyps
paddlespeech/s2t/exps/u2/bin/test_wav.py
浏览文件 @
1a1ce92c
...
@@ -40,7 +40,7 @@ class U2Infer():
...
@@ -40,7 +40,7 @@ class U2Infer():
self
.
preprocess_conf
=
config
.
preprocess_config
self
.
preprocess_conf
=
config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
reverse_weight
=
getattr
(
config
.
model_conf
,
'reverse_weight'
,
0.0
)
self
.
text_feature
=
TextFeaturizer
(
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
config
.
unit_type
,
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
,
vocab
=
config
.
vocab_filepath
,
...
@@ -89,7 +89,8 @@ class U2Infer():
...
@@ -89,7 +89,8 @@ class U2Infer():
ctc_weight
=
decode_config
.
ctc_weight
,
ctc_weight
=
decode_config
.
ctc_weight
,
decoding_chunk_size
=
decode_config
.
decoding_chunk_size
,
decoding_chunk_size
=
decode_config
.
decoding_chunk_size
,
num_decoding_left_chunks
=
decode_config
.
num_decoding_left_chunks
,
num_decoding_left_chunks
=
decode_config
.
num_decoding_left_chunks
,
simulate_streaming
=
decode_config
.
simulate_streaming
)
simulate_streaming
=
decode_config
.
simulate_streaming
,
reverse_weight
=
self
.
reverse_weight
)
rsl
=
result_transcripts
[
0
][
0
]
rsl
=
result_transcripts
[
0
][
0
]
utt
=
Path
(
self
.
audio_file
).
name
utt
=
Path
(
self
.
audio_file
).
name
logger
.
info
(
f
"hyp:
{
utt
}
{
result_transcripts
[
0
][
0
]
}
"
)
logger
.
info
(
f
"hyp:
{
utt
}
{
result_transcripts
[
0
][
0
]
}
"
)
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
1a1ce92c
...
@@ -253,7 +253,6 @@ class U2Trainer(Trainer):
...
@@ -253,7 +253,6 @@ class U2Trainer(Trainer):
model_conf
.
output_dim
=
self
.
test_loader
.
vocab_size
model_conf
.
output_dim
=
self
.
test_loader
.
vocab_size
model
=
U2Model
.
from_config
(
model_conf
)
model
=
U2Model
.
from_config
(
model_conf
)
if
self
.
parallel
:
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
...
@@ -317,6 +316,7 @@ class U2Tester(U2Trainer):
...
@@ -317,6 +316,7 @@ class U2Tester(U2Trainer):
vocab
=
self
.
config
.
vocab_filepath
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
vocab_list
=
self
.
text_feature
.
vocab_list
self
.
vocab_list
=
self
.
text_feature
.
vocab_list
self
.
reverse_weight
=
getattr
(
config
.
model_conf
,
'reverse_weight'
,
0.0
)
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
""" ord() id to chr() chr """
""" ord() id to chr() chr """
...
@@ -341,6 +341,7 @@ class U2Tester(U2Trainer):
...
@@ -341,6 +341,7 @@ class U2Tester(U2Trainer):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
self
.
text_feature
)
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
self
.
text_feature
)
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
audio
,
audio
,
audio_len
,
audio_len
,
...
@@ -350,7 +351,8 @@ class U2Tester(U2Trainer):
...
@@ -350,7 +351,8 @@ class U2Tester(U2Trainer):
ctc_weight
=
decode_config
.
ctc_weight
,
ctc_weight
=
decode_config
.
ctc_weight
,
decoding_chunk_size
=
decode_config
.
decoding_chunk_size
,
decoding_chunk_size
=
decode_config
.
decoding_chunk_size
,
num_decoding_left_chunks
=
decode_config
.
num_decoding_left_chunks
,
num_decoding_left_chunks
=
decode_config
.
num_decoding_left_chunks
,
simulate_streaming
=
decode_config
.
simulate_streaming
)
simulate_streaming
=
decode_config
.
simulate_streaming
,
reverse_weight
=
self
.
reverse_weight
)
decode_time
=
time
.
time
()
-
start_time
decode_time
=
time
.
time
()
-
start_time
for
utt
,
target
,
result
,
rec_tids
in
zip
(
for
utt
,
target
,
result
,
rec_tids
in
zip
(
...
...
paddlespeech/s2t/io/dataloader.py
浏览文件 @
1a1ce92c
...
@@ -361,7 +361,7 @@ class DataLoaderFactory():
...
@@ -361,7 +361,7 @@ class DataLoaderFactory():
elif
mode
==
'valid'
:
elif
mode
==
'valid'
:
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'train_mode'
]
=
False
config
[
'train_mode'
]
=
False
elif
mode
l
==
'test'
or
mode
==
'align'
:
elif
mode
==
'test'
or
mode
==
'align'
:
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'train_mode'
]
=
False
config
[
'train_mode'
]
=
False
config
[
'dither'
]
=
0.0
config
[
'dither'
]
=
0.0
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
1a1ce92c
...
@@ -31,6 +31,8 @@ from paddle import nn
...
@@ -31,6 +31,8 @@ from paddle import nn
from
paddlespeech.audio.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.audio.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.audio.utils.tensor_utils
import
pad_sequence
from
paddlespeech.audio.utils.tensor_utils
import
pad_sequence
from
paddlespeech.audio.utils.tensor_utils
import
reverse_pad_list
from
paddlespeech.audio.utils.tensor_utils
import
st_reverse_pad_list
from
paddlespeech.audio.utils.tensor_utils
import
th_accuracy
from
paddlespeech.audio.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.decoders.scorers.ctc
import
CTCPrefixScorer
from
paddlespeech.s2t.decoders.scorers.ctc
import
CTCPrefixScorer
from
paddlespeech.s2t.frontend.utility
import
IGNORE_ID
from
paddlespeech.s2t.frontend.utility
import
IGNORE_ID
...
@@ -38,6 +40,7 @@ from paddlespeech.s2t.frontend.utility import load_cmvn
...
@@ -38,6 +40,7 @@ from paddlespeech.s2t.frontend.utility import load_cmvn
from
paddlespeech.s2t.models.asr_interface
import
ASRInterface
from
paddlespeech.s2t.models.asr_interface
import
ASRInterface
from
paddlespeech.s2t.modules.cmvn
import
GlobalCMVN
from
paddlespeech.s2t.modules.cmvn
import
GlobalCMVN
from
paddlespeech.s2t.modules.ctc
import
CTCDecoderBase
from
paddlespeech.s2t.modules.ctc
import
CTCDecoderBase
from
paddlespeech.s2t.modules.decoder
import
BiTransformerDecoder
from
paddlespeech.s2t.modules.decoder
import
TransformerDecoder
from
paddlespeech.s2t.modules.decoder
import
TransformerDecoder
from
paddlespeech.s2t.modules.encoder
import
ConformerEncoder
from
paddlespeech.s2t.modules.encoder
import
ConformerEncoder
from
paddlespeech.s2t.modules.encoder
import
TransformerEncoder
from
paddlespeech.s2t.modules.encoder
import
TransformerEncoder
...
@@ -69,6 +72,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -69,6 +72,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc
:
CTCDecoderBase
,
ctc
:
CTCDecoderBase
,
ctc_weight
:
float
=
0.5
,
ctc_weight
:
float
=
0.5
,
ignore_id
:
int
=
IGNORE_ID
,
ignore_id
:
int
=
IGNORE_ID
,
reverse_weight
:
float
=
0.0
,
lsm_weight
:
float
=
0.0
,
lsm_weight
:
float
=
0.0
,
length_normalized_loss
:
bool
=
False
,
length_normalized_loss
:
bool
=
False
,
**
kwargs
):
**
kwargs
):
...
@@ -82,6 +86,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -82,6 +86,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
ignore_id
=
ignore_id
self
.
ignore_id
=
ignore_id
self
.
ctc_weight
=
ctc_weight
self
.
ctc_weight
=
ctc_weight
self
.
reverse_weight
=
reverse_weight
self
.
encoder
=
encoder
self
.
encoder
=
encoder
self
.
decoder
=
decoder
self
.
decoder
=
decoder
...
@@ -171,12 +176,21 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -171,12 +176,21 @@ class U2BaseModel(ASRInterface, nn.Layer):
self
.
ignore_id
)
self
.
ignore_id
)
ys_in_lens
=
ys_pad_lens
+
1
ys_in_lens
=
ys_pad_lens
+
1
r_ys_pad
=
reverse_pad_list
(
ys_pad
,
ys_pad_lens
,
float
(
self
.
ignore_id
))
r_ys_in_pad
,
r_ys_out_pad
=
add_sos_eos
(
r_ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
# 1. Forward decoder
# 1. Forward decoder
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
ys_in_pad
,
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
ys_in_lens
)
encoder_out
,
encoder_mask
,
ys_in_pad
,
ys_in_lens
,
r_ys_in_pad
,
self
.
reverse_weight
)
# 2. Compute attention loss
# 2. Compute attention loss
loss_att
=
self
.
criterion_att
(
decoder_out
,
ys_out_pad
)
loss_att
=
self
.
criterion_att
(
decoder_out
,
ys_out_pad
)
r_loss_att
=
paddle
.
to_tensor
(
0.0
)
if
self
.
reverse_weight
>
0.0
:
r_loss_att
=
self
.
criterion_att
(
r_decoder_out
,
r_ys_out_pad
)
loss_att
=
loss_att
*
(
1
-
self
.
reverse_weight
)
+
r_loss_att
*
self
.
reverse_weight
acc_att
=
th_accuracy
(
acc_att
=
th_accuracy
(
decoder_out
.
view
(
-
1
,
self
.
vocab_size
),
decoder_out
.
view
(
-
1
,
self
.
vocab_size
),
ys_out_pad
,
ys_out_pad
,
...
@@ -359,6 +373,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -359,6 +373,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# Let's assume B = batch_size
# Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim)
# encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax)
# encoder_mask: (B, 1, Tmax)
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
num_decoding_left_chunks
,
simulate_streaming
)
...
@@ -500,7 +515,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -500,7 +515,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size
:
int
=-
1
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
ctc_weight
:
float
=
0.0
,
ctc_weight
:
float
=
0.0
,
simulate_streaming
:
bool
=
False
,
)
->
List
[
int
]:
simulate_streaming
:
bool
=
False
,
reverse_weight
:
float
=
0.0
,
)
->
List
[
int
]:
""" Apply attention rescoring decoding, CTC prefix beam search
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
attention decoder with corresponding encoder out
...
@@ -520,6 +536,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -520,6 +536,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
assert
decoding_chunk_size
!=
0
if
reverse_weight
>
0.0
:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert
hasattr
(
self
.
decoder
,
'right_decoder'
)
device
=
speech
.
place
device
=
speech
.
place
batch_size
=
speech
.
shape
[
0
]
batch_size
=
speech
.
shape
[
0
]
# For attention rescoring we only support batch_size=1
# For attention rescoring we only support batch_size=1
...
@@ -541,22 +560,30 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -541,22 +560,30 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyp_content
,
place
=
device
,
dtype
=
paddle
.
long
)
hyp_content
,
place
=
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
ignore_id
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
ignore_id
)
ori_hyps_pad
=
hyps_pad
hyps_lens
=
paddle
.
to_tensor
(
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
device
,
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
r_hyps_pad
=
st_reverse_pad_list
(
ori_hyps_pad
,
hyps_lens
-
1
,
self
.
sos
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
self
.
eos
)
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
,
r_hyps_pad
,
reverse_weight
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
decoder_out
=
decoder_out
.
numpy
()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
r_decoder_out
,
axis
=-
1
)
r_decoder_out
=
r_decoder_out
.
numpy
()
# Only use decoder score for rescoring
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_score
=
-
float
(
'inf'
)
best_index
=
0
best_index
=
0
...
@@ -567,6 +594,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -567,6 +594,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
score
+=
decoder_out
[
i
][
j
][
w
]
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
eos
]
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
eos
]
if
reverse_weight
>
0
:
r_score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
r_score
+=
r_decoder_out
[
i
][
len
(
hyp
[
0
])
-
j
-
1
][
w
]
r_score
+=
r_decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
eos
]
score
=
score
*
(
1
-
reverse_weight
)
+
r_score
*
reverse_weight
# add ctc score (which in ln domain)
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
ctc_weight
score
+=
hyp
[
1
]
*
ctc_weight
if
score
>
best_score
:
if
score
>
best_score
:
...
@@ -653,12 +686,24 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -653,12 +686,24 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
"""
return
self
.
ctc
.
log_softmax
(
xs
)
return
self
.
ctc
.
log_softmax
(
xs
)
@
jit
.
to_static
# @jit.to_static
def
is_bidirectional_decoder
(
self
)
->
bool
:
"""
Returns:
paddle.Tensor: decoder output
"""
if
hasattr
(
self
.
decoder
,
'right_decoder'
):
return
True
else
:
return
False
# @jit.to_static
def
forward_attention_decoder
(
def
forward_attention_decoder
(
self
,
self
,
hyps
:
paddle
.
Tensor
,
hyps
:
paddle
.
Tensor
,
hyps_lens
:
paddle
.
Tensor
,
hyps_lens
:
paddle
.
Tensor
,
encoder_out
:
paddle
.
Tensor
,
)
->
paddle
.
Tensor
:
encoder_out
:
paddle
.
Tensor
,
reverse_weight
:
float
=
0.0
,
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, forward decoder with multiple
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
hypothesis from ctc prefix beam search and one encoder output
Args:
Args:
...
@@ -676,11 +721,22 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -676,11 +721,22 @@ class U2BaseModel(ASRInterface, nn.Layer):
# (B, 1, T)
# (B, 1, T)
encoder_mask
=
paddle
.
ones
(
encoder_mask
=
paddle
.
ones
(
[
num_hyps
,
1
,
encoder_out
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
[
num_hyps
,
1
,
encoder_out
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
# input for right to left decoder
# this hyps_lens has count <sos> token, we need minus it.
r_hyps_lens
=
hyps_lens
-
1
# this hyps has included <sos> token, so it should be
# convert the original hyps.
r_hyps
=
hyps
[:,
1
:]
# (num_hyps, max_hyps_len, vocab_size)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_lens
)
r_hyps
=
st_reverse_pad_list
(
r_hyps
,
r_hyps_lens
,
self
.
sos
,
self
.
eos
)
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_lens
,
r_hyps
,
reverse_weight
)
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
return
decoder_out
r_decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
r_decoder_out
,
axis
=-
1
)
return
decoder_out
,
r_decoder_out
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
decode
(
self
,
def
decode
(
self
,
...
@@ -692,7 +748,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -692,7 +748,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc_weight
:
float
=
0.0
,
ctc_weight
:
float
=
0.0
,
decoding_chunk_size
:
int
=-
1
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
):
simulate_streaming
:
bool
=
False
,
reverse_weight
:
float
=
0.0
):
"""u2 decoding.
"""u2 decoding.
Args:
Args:
...
@@ -764,7 +821,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -764,7 +821,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size
=
decoding_chunk_size
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
,
num_decoding_left_chunks
=
num_decoding_left_chunks
,
ctc_weight
=
ctc_weight
,
ctc_weight
=
ctc_weight
,
simulate_streaming
=
simulate_streaming
)
simulate_streaming
=
simulate_streaming
,
reverse_weight
=
reverse_weight
)
hyps
=
[
hyp
]
hyps
=
[
hyp
]
else
:
else
:
raise
ValueError
(
f
"Not support decoding method:
{
decoding_method
}
"
)
raise
ValueError
(
f
"Not support decoding method:
{
decoding_method
}
"
)
...
@@ -801,7 +859,6 @@ class U2Model(U2DecodeModel):
...
@@ -801,7 +859,6 @@ class U2Model(U2DecodeModel):
with
DefaultInitializerContext
(
init_type
):
with
DefaultInitializerContext
(
init_type
):
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
configs
)
configs
)
super
().
__init__
(
super
().
__init__
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
encoder
=
encoder
,
encoder
=
encoder
,
...
@@ -851,10 +908,20 @@ class U2Model(U2DecodeModel):
...
@@ -851,10 +908,20 @@ class U2Model(U2DecodeModel):
raise
ValueError
(
f
"not support encoder type:
{
encoder_type
}
"
)
raise
ValueError
(
f
"not support encoder type:
{
encoder_type
}
"
)
# decoder
# decoder
decoder
=
TransformerDecoder
(
vocab_size
,
decoder_type
=
configs
.
get
(
'decoder'
,
'transformer'
)
encoder
.
output_size
(),
logger
.
debug
(
f
"U2 Decoder type:
{
decoder_type
}
"
)
**
configs
[
'decoder_conf'
])
if
decoder_type
==
'transformer'
:
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
elif
decoder_type
==
'bitransformer'
:
assert
0.0
<
configs
[
'model_conf'
][
'reverse_weight'
]
<
1.0
assert
configs
[
'decoder_conf'
][
'r_num_blocks'
]
>
0
decoder
=
BiTransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
else
:
raise
ValueError
(
f
"not support decoder type:
{
decoder_type
}
"
)
# ctc decoder and ctc loss
# ctc decoder and ctc loss
model_conf
=
configs
.
get
(
'model_conf'
,
dict
())
model_conf
=
configs
.
get
(
'model_conf'
,
dict
())
dropout_rate
=
model_conf
.
get
(
'ctc_dropout_rate'
,
0.0
)
dropout_rate
=
model_conf
.
get
(
'ctc_dropout_rate'
,
0.0
)
...
...
paddlespeech/s2t/modules/decoder.py
浏览文件 @
1a1ce92c
...
@@ -35,7 +35,6 @@ from paddlespeech.s2t.modules.mask import make_xs_mask
...
@@ -35,7 +35,6 @@ from paddlespeech.s2t.modules.mask import make_xs_mask
from
paddlespeech.s2t.modules.mask
import
subsequent_mask
from
paddlespeech.s2t.modules.mask
import
subsequent_mask
from
paddlespeech.s2t.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
paddlespeech.s2t.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"TransformerDecoder"
]
__all__
=
[
"TransformerDecoder"
]
...
@@ -116,13 +115,19 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
...
@@ -116,13 +115,19 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
memory
:
paddle
.
Tensor
,
memory
:
paddle
.
Tensor
,
memory_mask
:
paddle
.
Tensor
,
memory_mask
:
paddle
.
Tensor
,
ys_in_pad
:
paddle
.
Tensor
,
ys_in_pad
:
paddle
.
Tensor
,
ys_in_lens
:
paddle
.
Tensor
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
ys_in_lens
:
paddle
.
Tensor
,
r_ys_in_pad
:
paddle
.
Tensor
=
paddle
.
empty
([
0
]),
reverse_weight
:
float
=
0.0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Forward decoder.
"""Forward decoder.
Args:
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: not used in transformer decoder, in order to unify api
with bidirectional decoder
reverse_weight: not used in transformer decoder, in order to unify
api with bidirectional decode
Returns:
Returns:
(tuple): tuple containing:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, vocab_size)
x: decoded token score before softmax (batch, maxlen_out, vocab_size)
...
@@ -151,7 +156,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
...
@@ -151,7 +156,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
# TODO(Hui Zhang): reduce_sum not support bool type
# TODO(Hui Zhang): reduce_sum not support bool type
# olens = tgt_mask.sum(1)
# olens = tgt_mask.sum(1)
olens
=
tgt_mask
.
astype
(
paddle
.
int
).
sum
(
1
)
olens
=
tgt_mask
.
astype
(
paddle
.
int
).
sum
(
1
)
return
x
,
olens
return
x
,
paddle
.
to_tensor
(
0.0
),
olens
def
forward_one_step
(
def
forward_one_step
(
self
,
self
,
...
@@ -251,3 +256,119 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
...
@@ -251,3 +256,119 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
state_list
=
[[
states
[
i
][
b
]
for
i
in
range
(
n_layers
)]
state_list
=
[[
states
[
i
][
b
]
for
i
in
range
(
n_layers
)]
for
b
in
range
(
n_batch
)]
for
b
in
range
(
n_batch
)]
return
logp
,
state_list
return
logp
,
state_list
class
BiTransformerDecoder
(
BatchScorerInterface
,
nn
.
Layer
):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
r_num_blocks: the number of right to left decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after: whether to concat attention layer's input and output
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def
__init__
(
self
,
vocab_size
:
int
,
encoder_output_size
:
int
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
r_num_blocks
:
int
=
0
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
self_attention_dropout_rate
:
float
=
0.0
,
src_attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"embed"
,
use_output_layer
:
bool
=
True
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
max_len
:
int
=
5000
):
assert
check_argument_types
()
nn
.
Layer
.
__init__
(
self
)
self
.
left_decoder
=
TransformerDecoder
(
vocab_size
,
encoder_output_size
,
attention_heads
,
linear_units
,
num_blocks
,
dropout_rate
,
positional_dropout_rate
,
self_attention_dropout_rate
,
src_attention_dropout_rate
,
input_layer
,
use_output_layer
,
normalize_before
,
concat_after
,
max_len
)
self
.
right_decoder
=
TransformerDecoder
(
vocab_size
,
encoder_output_size
,
attention_heads
,
linear_units
,
r_num_blocks
,
dropout_rate
,
positional_dropout_rate
,
self_attention_dropout_rate
,
src_attention_dropout_rate
,
input_layer
,
use_output_layer
,
normalize_before
,
concat_after
,
max_len
)
def
forward
(
self
,
memory
:
paddle
.
Tensor
,
memory_mask
:
paddle
.
Tensor
,
ys_in_pad
:
paddle
.
Tensor
,
ys_in_lens
:
paddle
.
Tensor
,
r_ys_in_pad
:
paddle
.
Tensor
,
reverse_weight
:
float
=
0.0
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
used for right to left decoder
reverse_weight: used for right to left decoder
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
r_x: x: decoded token score (right to left decoder)
before softmax (batch, maxlen_out, vocab_size)
if use_output_layer is True,
olens: (batch, )
"""
l_x
,
_
,
olens
=
self
.
left_decoder
(
memory
,
memory_mask
,
ys_in_pad
,
ys_in_lens
)
r_x
=
paddle
.
to_tensor
(
0.0
)
if
reverse_weight
>
0.0
:
r_x
,
_
,
olens
=
self
.
right_decoder
(
memory
,
memory_mask
,
r_ys_in_pad
,
ys_in_lens
)
return
l_x
,
r_x
,
olens
def
forward_one_step
(
self
,
memory
:
paddle
.
Tensor
,
memory_mask
:
paddle
.
Tensor
,
tgt
:
paddle
.
Tensor
,
tgt_mask
:
paddle
.
Tensor
,
cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
List
[
paddle
.
Tensor
]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out, maxlen_out)
dtype=paddle.bool
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
return
self
.
left_decoder
.
forward_one_step
(
memory
,
memory_mask
,
tgt
,
tgt_mask
,
cache
)
paddlespeech/server/engine/asr/online/python/asr_engine.py
浏览文件 @
1a1ce92c
...
@@ -612,7 +612,8 @@ class PaddleASRConnectionHanddler:
...
@@ -612,7 +612,8 @@ class PaddleASRConnectionHanddler:
encoder_out
=
self
.
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_out
=
self
.
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
model
.
decoder
(
decoder_out
,
_
,
_
=
self
.
model
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
# ctc score in ln domain
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录