Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
ac0ae57e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ac0ae57e
编写于
8月 04, 2021
作者:
J
Junkun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add collactor and evaluation code for ST
上级
03231519
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
1484 addition
and
15 deletion
+1484
-15
deepspeech/exps/u2_st/model.py
deepspeech/exps/u2_st/model.py
+17
-12
deepspeech/io/collator_st.py
deepspeech/io/collator_st.py
+666
-0
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+14
-3
deepspeech/models/u2_st.py
deepspeech/models/u2_st.py
+734
-0
deepspeech/utils/bleu_score.py
deepspeech/utils/bleu_score.py
+53
-0
未找到文件。
deepspeech/exps/u2_st/model.py
浏览文件 @
ac0ae57e
...
@@ -24,7 +24,6 @@ from typing import Tuple
...
@@ -24,7 +24,6 @@ from typing import Tuple
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
sacrebleu
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
...
@@ -32,6 +31,7 @@ from yacs.config import CfgNode
...
@@ -32,6 +31,7 @@ from yacs.config import CfgNode
from
deepspeech.io.collator_st
import
KaldiPrePorocessedCollator
from
deepspeech.io.collator_st
import
KaldiPrePorocessedCollator
from
deepspeech.io.collator_st
import
SpeechCollator
from
deepspeech.io.collator_st
import
SpeechCollator
from
deepspeech.io.collator_st
import
TripletKaldiPrePorocessedCollator
from
deepspeech.io.collator_st
import
TripletKaldiPrePorocessedCollator
from
deepspeech.io.collator_st
import
TripletSpeechCollator
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.io.dataset
import
TripletManifestDataset
from
deepspeech.io.dataset
import
TripletManifestDataset
from
deepspeech.io.sampler
import
SortagradBatchSampler
from
deepspeech.io.sampler
import
SortagradBatchSampler
...
@@ -40,6 +40,7 @@ from deepspeech.models.u2_st import U2STModel
...
@@ -40,6 +40,7 @@ from deepspeech.models.u2_st import U2STModel
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.training.scheduler
import
WarmupLR
from
deepspeech.training.scheduler
import
WarmupLR
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.utils
import
bleu_score
from
deepspeech.utils
import
ctc_utils
from
deepspeech.utils
import
ctc_utils
from
deepspeech.utils
import
error_rate
from
deepspeech.utils
import
error_rate
from
deepspeech.utils
import
layer_tools
from
deepspeech.utils
import
layer_tools
...
@@ -248,6 +249,10 @@ class U2STTrainer(Trainer):
...
@@ -248,6 +249,10 @@ class U2STTrainer(Trainer):
dev_dataset
=
Dataset
.
from_config
(
config
)
dev_dataset
=
Dataset
.
from_config
(
config
)
if
config
.
collator
.
raw_wav
:
if
config
.
collator
.
raw_wav
:
if
config
.
model
.
model_conf
.
asr_weight
>
0.
:
Collator
=
TripletSpeechCollator
TestCollator
=
SpeechCollator
else
:
TestCollator
=
Collator
=
SpeechCollator
TestCollator
=
Collator
=
SpeechCollator
# Not yet implement the mtl loader for raw_wav.
# Not yet implement the mtl loader for raw_wav.
else
:
else
:
...
@@ -393,7 +398,7 @@ class U2STTester(U2STTrainer):
...
@@ -393,7 +398,7 @@ class U2STTester(U2STTrainer):
lang_model_path
=
'models/lm/common_crawl_00.prune01111.trie.klm'
,
# Filepath for language model.
lang_model_path
=
'models/lm/common_crawl_00.prune01111.trie.klm'
,
# Filepath for language model.
decoding_method
=
'attention'
,
# Decoding method. Options: 'attention', 'ctc_greedy_search',
decoding_method
=
'attention'
,
# Decoding method. Options: 'attention', 'ctc_greedy_search',
# 'ctc_prefix_beam_search', 'attention_rescoring'
# 'ctc_prefix_beam_search', 'attention_rescoring'
error_rate_type
=
'
wer'
,
# Error rate type for evaluation. Options `wer`, 'cer
'
error_rate_type
=
'
bleu'
,
# Error rate type for evaluation. Options `bleu`, 'char_bleu
'
num_proc_bsearch
=
8
,
# # of CPUs for beam search.
num_proc_bsearch
=
8
,
# # of CPUs for beam search.
beam_size
=
10
,
# Beam search width.
beam_size
=
10
,
# Beam search width.
batch_size
=
16
,
# decoding batch size
batch_size
=
16
,
# decoding batch size
...
@@ -428,10 +433,10 @@ class U2STTester(U2STTrainer):
...
@@ -428,10 +433,10 @@ class U2STTester(U2STTrainer):
audio_len
,
audio_len
,
texts
,
texts
,
texts_len
,
texts_len
,
bleu_func
,
fout
=
None
):
fout
=
None
):
cfg
=
self
.
config
.
decoding
cfg
=
self
.
config
.
decoding
len_refs
,
num_ins
=
0
,
0
len_refs
,
num_ins
=
0
,
0
bleu_func
=
sacrebleu
.
corpus_bleu
start_time
=
time
.
time
()
start_time
=
time
.
time
()
text_feature
=
self
.
test_loader
.
collate_fn
.
text_feature
text_feature
=
self
.
test_loader
.
collate_fn
.
text_feature
...
@@ -487,6 +492,9 @@ class U2STTester(U2STTrainer):
...
@@ -487,6 +492,9 @@ class U2STTester(U2STTrainer):
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
cfg
=
self
.
config
.
decoding
bleu_func
=
bleu_score
.
char_bleu
if
cfg
.
error_rate_type
==
'char-bleu'
else
bleu_score
.
bleu
stride_ms
=
self
.
test_loader
.
collate_fn
.
stride_ms
stride_ms
=
self
.
test_loader
.
collate_fn
.
stride_ms
hyps
,
refs
=
[],
[]
hyps
,
refs
=
[],
[]
len_refs
,
num_ins
=
0
,
0
len_refs
,
num_ins
=
0
,
0
...
@@ -495,7 +503,7 @@ class U2STTester(U2STTrainer):
...
@@ -495,7 +503,7 @@ class U2STTester(U2STTrainer):
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
for
i
,
batch
in
enumerate
(
self
.
test_loader
):
for
i
,
batch
in
enumerate
(
self
.
test_loader
):
metrics
=
self
.
compute_translation_metrics
(
metrics
=
self
.
compute_translation_metrics
(
*
batch
,
fout
=
fout
)
*
batch
,
bleu_func
=
bleu_func
,
fout
=
fout
)
hyps
+=
metrics
[
'hyps'
]
hyps
+=
metrics
[
'hyps'
]
refs
+=
metrics
[
'refs'
]
refs
+=
metrics
[
'refs'
]
bleu
=
metrics
[
'bleu'
]
bleu
=
metrics
[
'bleu'
]
...
@@ -504,19 +512,16 @@ class U2STTester(U2STTrainer):
...
@@ -504,19 +512,16 @@ class U2STTester(U2STTrainer):
len_refs
+=
metrics
[
'len_refs'
]
len_refs
+=
metrics
[
'len_refs'
]
num_ins
+=
metrics
[
'num_ins'
]
num_ins
+=
metrics
[
'num_ins'
]
rtf
=
num_time
/
(
num_frames
*
stride_ms
)
rtf
=
num_time
/
(
num_frames
*
stride_ms
)
logger
.
info
(
"RTF: %f, BELU (%d) = %f"
%
logger
.
info
(
"RTF: %f, BELU (%d) = %f"
%
(
rtf
,
num_ins
,
bleu
))
(
rtf
,
num_ins
,
bleu
))
rtf
=
num_time
/
(
num_frames
*
stride_ms
)
rtf
=
num_time
/
(
num_frames
*
stride_ms
)
msg
=
"Test: "
msg
=
"Test: "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"RTF: {}, "
.
format
(
rtf
)
msg
+=
"RTF: {}, "
.
format
(
rtf
)
msg
+=
"Test set [%s]: %s"
%
(
msg
+=
"Test set [%s]: %s"
%
(
len
(
hyps
),
str
(
bleu_func
(
hyps
,
[
refs
])))
len
(
hyps
),
str
(
sacrebleu
.
corpus_bleu
(
hyps
,
[
refs
])))
logger
.
info
(
msg
)
logger
.
info
(
msg
)
bleu_meta_path
=
os
.
path
.
splitext
(
bleu_meta_path
=
os
.
path
.
splitext
(
self
.
args
.
result_file
)[
0
]
+
'.bleu'
self
.
args
.
result_file
)[
0
]
+
'.bleu'
err_type_str
=
"BLEU"
err_type_str
=
"BLEU"
with
open
(
bleu_meta_path
,
'w'
)
as
f
:
with
open
(
bleu_meta_path
,
'w'
)
as
f
:
data
=
json
.
dumps
({
data
=
json
.
dumps
({
...
@@ -527,7 +532,7 @@ class U2STTester(U2STTrainer):
...
@@ -527,7 +532,7 @@ class U2STTester(U2STTrainer):
"rtf"
:
"rtf"
:
rtf
,
rtf
,
err_type_str
:
err_type_str
:
sacrebleu
.
corpus_bleu
(
hyps
,
[
refs
]).
score
,
bleu_func
(
hyps
,
[
refs
]).
score
,
"dataset_hour"
:
(
num_frames
*
stride_ms
)
/
1000.0
/
3600.0
,
"dataset_hour"
:
(
num_frames
*
stride_ms
)
/
1000.0
/
3600.0
,
"process_hour"
:
"process_hour"
:
num_time
/
1000.0
/
3600.0
,
num_time
/
1000.0
/
3600.0
,
...
...
deepspeech/io/collator_st.py
0 → 100644
浏览文件 @
ac0ae57e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
io
from
collections
import
namedtuple
from
typing
import
Optional
from
typing
import
Tuple
import
kaldiio
import
numpy
as
np
from
yacs.config
import
CfgNode
from
deepspeech.frontend.augmentor.augmentation
import
AugmentationPipeline
from
deepspeech.frontend.featurizer.speech_featurizer
import
SpeechFeaturizer
from
deepspeech.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
deepspeech.frontend.normalizer
import
FeatureNormalizer
from
deepspeech.frontend.speech
import
SpeechSegment
from
deepspeech.frontend.utility
import
IGNORE_ID
from
deepspeech.io.utility
import
pad_sequence
from
deepspeech.utils.log
import
Log
__all__
=
[
"SpeechCollator"
,
"KaldiPrePorocessedCollator"
]
logger
=
Log
(
__name__
).
getlog
()
# namedtupe need global for pickle.
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
class
SpeechCollator
():
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
default
=
CfgNode
(
dict
(
augmentation_config
=
""
,
random_seed
=
0
,
mean_std_filepath
=
""
,
unit_type
=
"char"
,
vocab_filepath
=
""
,
spm_model_prefix
=
""
,
specgram_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
stride_ms
=
10.0
,
# ms
window_ms
=
20.0
,
# ms
n_fft
=
None
,
# fft points
max_freq
=
None
,
# None for samplerate/2
target_sample_rate
=
16000
,
# target sample rate
use_dB_normalization
=
True
,
target_dB
=-
20
,
dither
=
1.0
,
# feature dither
keep_transcription_text
=
False
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert
'augmentation_config'
in
config
.
collator
assert
'keep_transcription_text'
in
config
.
collator
assert
'mean_std_filepath'
in
config
.
collator
assert
'vocab_filepath'
in
config
.
collator
assert
'specgram_type'
in
config
.
collator
assert
'n_fft'
in
config
.
collator
assert
config
.
collator
if
isinstance
(
config
.
collator
.
augmentation_config
,
(
str
,
bytes
)):
if
config
.
collator
.
augmentation_config
:
aug_file
=
io
.
open
(
config
.
collator
.
augmentation_config
,
mode
=
'r'
,
encoding
=
'utf8'
)
else
:
aug_file
=
io
.
StringIO
(
initial_value
=
'{}'
,
newline
=
''
)
else
:
aug_file
=
config
.
collator
.
augmentation_config
assert
isinstance
(
aug_file
,
io
.
StringIO
)
speech_collator
=
cls
(
aug_file
=
aug_file
,
random_seed
=
0
,
mean_std_filepath
=
config
.
collator
.
mean_std_filepath
,
unit_type
=
config
.
collator
.
unit_type
,
vocab_filepath
=
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
config
.
collator
.
spm_model_prefix
,
specgram_type
=
config
.
collator
.
specgram_type
,
feat_dim
=
config
.
collator
.
feat_dim
,
delta_delta
=
config
.
collator
.
delta_delta
,
stride_ms
=
config
.
collator
.
stride_ms
,
window_ms
=
config
.
collator
.
window_ms
,
n_fft
=
config
.
collator
.
n_fft
,
max_freq
=
config
.
collator
.
max_freq
,
target_sample_rate
=
config
.
collator
.
target_sample_rate
,
use_dB_normalization
=
config
.
collator
.
use_dB_normalization
,
target_dB
=
config
.
collator
.
target_dB
,
dither
=
config
.
collator
.
dither
,
keep_transcription_text
=
config
.
collator
.
keep_transcription_text
)
return
speech_collator
def
__init__
(
self
,
aug_file
,
mean_std_filepath
,
vocab_filepath
,
spm_model_prefix
,
random_seed
=
0
,
unit_type
=
"char"
,
specgram_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
stride_ms
=
10.0
,
# ms
window_ms
=
20.0
,
# ms
n_fft
=
None
,
# fft points
max_freq
=
None
,
# None for samplerate/2
target_sample_rate
=
16000
,
# target sample rate
use_dB_normalization
=
True
,
target_dB
=-
20
,
dither
=
1.0
,
keep_transcription_text
=
True
):
"""SpeechCollator Collator
Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self
.
_keep_transcription_text
=
keep_transcription_text
self
.
_local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
self
.
_augmentation_pipeline
=
AugmentationPipeline
(
augmentation_config
=
aug_file
.
read
(),
random_seed
=
random_seed
)
self
.
_normalizer
=
FeatureNormalizer
(
mean_std_filepath
)
if
mean_std_filepath
else
None
self
.
_stride_ms
=
stride_ms
self
.
_target_sample_rate
=
target_sample_rate
self
.
_speech_featurizer
=
SpeechFeaturizer
(
unit_type
=
unit_type
,
vocab_filepath
=
vocab_filepath
,
spm_model_prefix
=
spm_model_prefix
,
specgram_type
=
specgram_type
,
feat_dim
=
feat_dim
,
delta_delta
=
delta_delta
,
stride_ms
=
stride_ms
,
window_ms
=
window_ms
,
n_fft
=
n_fft
,
max_freq
=
max_freq
,
target_sample_rate
=
target_sample_rate
,
use_dB_normalization
=
use_dB_normalization
,
target_dB
=
target_dB
,
dither
=
dither
)
def
_parse_tar
(
self
,
file
):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result
=
{}
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
_subfile_from_tar
(
self
,
file
):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
'tar2info'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2info
=
{}
if
'tar2object'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2object
=
{}
if
tarpath
not
in
self
.
_local_data
.
tar2info
:
object
,
infoes
=
self
.
_parse_tar
(
tarpath
)
self
.
_local_data
.
tar2info
[
tarpath
]
=
infoes
self
.
_local_data
.
tar2object
[
tarpath
]
=
object
return
self
.
_local_data
.
tar2object
[
tarpath
].
extractfile
(
self
.
_local_data
.
tar2info
[
tarpath
][
filename
])
def
process_utterance
(
self
,
audio_file
,
translation
):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if
isinstance
(
audio_file
,
str
)
and
audio_file
.
startswith
(
'tar:'
):
speech_segment
=
SpeechSegment
.
from_file
(
self
.
_subfile_from_tar
(
audio_file
),
translation
)
else
:
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
translation
)
# audio augment
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
specgram
,
translation_part
=
self
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
_keep_transcription_text
)
if
self
.
_normalizer
:
specgram
=
self
.
_normalizer
.
apply
(
specgram
)
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
return
specgram
,
translation_part
def
__call__
(
self
,
batch
):
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
text (List[int] or str): shape (U,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios
=
[]
audio_lens
=
[]
texts
=
[]
text_lens
=
[]
utts
=
[]
for
utt
,
audio
,
text
in
batch
:
audio
,
text
=
self
.
process_utterance
(
audio
,
text
)
#utt
utts
.
append
(
utt
)
# audio
audios
.
append
(
audio
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
0
])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens
=
[]
if
self
.
_keep_transcription_text
:
assert
isinstance
(
text
,
str
),
(
type
(
text
),
text
)
tokens
=
[
ord
(
t
)
for
t
in
text
]
else
:
tokens
=
text
# token ids
tokens
=
tokens
if
isinstance
(
tokens
,
np
.
ndarray
)
else
np
.
array
(
tokens
,
dtype
=
np
.
int64
)
texts
.
append
(
tokens
)
text_lens
.
append
(
tokens
.
shape
[
0
])
padded_audios
=
pad_sequence
(
audios
,
padding_value
=
0.0
).
astype
(
np
.
float32
)
#[B, T, D]
audio_lens
=
np
.
array
(
audio_lens
).
astype
(
np
.
int64
)
padded_texts
=
pad_sequence
(
texts
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
text_lens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
utts
,
padded_audios
,
audio_lens
,
padded_texts
,
text_lens
@
property
def
manifest
(
self
):
return
self
.
_manifest
@
property
def
vocab_size
(
self
):
return
self
.
_speech_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
return
self
.
_speech_featurizer
.
vocab_list
@
property
def
vocab_dict
(
self
):
return
self
.
_speech_featurizer
.
vocab_dict
@
property
def
text_feature
(
self
):
return
self
.
_speech_featurizer
.
text_feature
@
property
def
feature_size
(
self
):
return
self
.
_speech_featurizer
.
feature_size
@
property
def
stride_ms
(
self
):
return
self
.
_speech_featurizer
.
stride_ms
class
TripletSpeechCollator
(
SpeechCollator
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if
isinstance
(
audio_file
,
str
)
and
audio_file
.
startswith
(
'tar:'
):
speech_segment
=
SpeechSegment
.
from_file
(
self
.
_subfile_from_tar
(
audio_file
),
translation
)
else
:
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
translation
)
# audio augment
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
specgram
,
translation_part
=
self
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
_keep_transcription_text
)
transcript_part
=
self
.
_speech_featurizer
.
_text_featurizer
.
featurize
(
transcript
)
if
self
.
_normalizer
:
specgram
=
self
.
_normalizer
.
apply
(
specgram
)
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
return
specgram
,
translation_part
,
transcript_part
def
__call__
(
self
,
batch
):
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
text (List[int] or str): shape (U,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios
=
[]
audio_lens
=
[]
translation_text
=
[]
translation_text_lens
=
[]
transcription_text
=
[]
transcription_text_lens
=
[]
utts
=
[]
for
utt
,
audio
,
translation
,
transcription
in
batch
:
audio
,
translation
,
transcription
=
self
.
process_utterance
(
audio
,
translation
,
transcription
)
#utt
utts
.
append
(
utt
)
# audio
audios
.
append
(
audio
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
0
])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens
=
[[],
[]]
for
idx
,
text
in
enumerate
([
translation
,
transcription
]):
if
self
.
_keep_transcription_text
:
assert
isinstance
(
text
,
str
),
(
type
(
text
),
text
)
tokens
[
idx
]
=
[
ord
(
t
)
for
t
in
text
]
else
:
tokens
[
idx
]
=
text
# token ids
tokens
[
idx
]
=
tokens
[
idx
]
if
isinstance
(
tokens
[
idx
],
np
.
ndarray
)
else
np
.
array
(
tokens
[
idx
],
dtype
=
np
.
int64
)
translation_text
.
append
(
tokens
[
0
])
translation_text_lens
.
append
(
tokens
[
0
].
shape
[
0
])
transcription_text
.
append
(
tokens
[
1
])
transcription_text_lens
.
append
(
tokens
[
1
].
shape
[
0
])
padded_audios
=
pad_sequence
(
audios
,
padding_value
=
0.0
).
astype
(
np
.
float32
)
#[B, T, D]
audio_lens
=
np
.
array
(
audio_lens
).
astype
(
np
.
int64
)
padded_translation
=
pad_sequence
(
translation_text
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
translation_lens
=
np
.
array
(
translation_text_lens
).
astype
(
np
.
int64
)
padded_transcription
=
pad_sequence
(
transcription_text
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
transcription_lens
=
np
.
array
(
transcription_text_lens
).
astype
(
np
.
int64
)
return
utts
,
padded_audios
,
audio_lens
,
(
padded_translation
,
padded_transcription
),
(
translation_lens
,
transcription_lens
)
class
KaldiPrePorocessedCollator
(
SpeechCollator
):
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
default
=
CfgNode
(
dict
(
augmentation_config
=
""
,
random_seed
=
0
,
unit_type
=
"char"
,
vocab_filepath
=
""
,
spm_model_prefix
=
""
,
feat_dim
=
0
,
stride_ms
=
10.0
,
keep_transcription_text
=
False
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert
'augmentation_config'
in
config
.
collator
assert
'keep_transcription_text'
in
config
.
collator
assert
'vocab_filepath'
in
config
.
collator
assert
config
.
collator
if
isinstance
(
config
.
collator
.
augmentation_config
,
(
str
,
bytes
)):
if
config
.
collator
.
augmentation_config
:
aug_file
=
io
.
open
(
config
.
collator
.
augmentation_config
,
mode
=
'r'
,
encoding
=
'utf8'
)
else
:
aug_file
=
io
.
StringIO
(
initial_value
=
'{}'
,
newline
=
''
)
else
:
aug_file
=
config
.
collator
.
augmentation_config
assert
isinstance
(
aug_file
,
io
.
StringIO
)
speech_collator
=
cls
(
aug_file
=
aug_file
,
random_seed
=
0
,
unit_type
=
config
.
collator
.
unit_type
,
vocab_filepath
=
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
config
.
collator
.
spm_model_prefix
,
feat_dim
=
config
.
collator
.
feat_dim
,
stride_ms
=
config
.
collator
.
stride_ms
,
keep_transcription_text
=
config
.
collator
.
keep_transcription_text
)
return
speech_collator
def
__init__
(
self
,
aug_file
,
vocab_filepath
,
spm_model_prefix
,
random_seed
=
0
,
unit_type
=
"char"
,
feat_dim
=
0
,
stride_ms
=
10.0
,
keep_transcription_text
=
True
):
"""SpeechCollator Collator
Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self
.
_keep_transcription_text
=
keep_transcription_text
self
.
_feat_dim
=
feat_dim
self
.
_stride_ms
=
stride_ms
self
.
_local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
self
.
_augmentation_pipeline
=
AugmentationPipeline
(
augmentation_config
=
aug_file
.
read
(),
random_seed
=
random_seed
)
self
.
_text_featurizer
=
TextFeaturizer
(
unit_type
,
vocab_filepath
,
spm_model_prefix
)
def
process_utterance
(
self
,
audio_file
,
translation
):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of kaldi processed feature.
:type audio_file: str | file
:param translation: Translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
specgram
=
kaldiio
.
load_mat
(
audio_file
)
specgram
=
specgram
.
transpose
([
1
,
0
])
assert
specgram
.
shape
[
0
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
self
.
_feat_dim
,
specgram
.
shape
[
0
])
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
if
self
.
_keep_transcription_text
:
return
specgram
,
translation
else
:
text_ids
=
self
.
_text_featurizer
.
featurize
(
translation
)
return
specgram
,
text_ids
@
property
def
manifest
(
self
):
return
self
.
_manifest
@
property
def
vocab_size
(
self
):
return
self
.
_text_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
return
self
.
_text_featurizer
.
vocab_list
@
property
def
vocab_dict
(
self
):
return
self
.
_text_featurizer
.
vocab_dict
@
property
def
text_feature
(
self
):
return
self
.
_text_featurizer
@
property
def
feature_size
(
self
):
return
self
.
_feat_dim
@
property
def
stride_ms
(
self
):
return
self
.
_stride_ms
class
TripletKaldiPrePorocessedCollator
(
KaldiPrePorocessedCollator
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of kali processed feature.
:type audio_file: str | file
:param translation: Translation text.
:type translation: str
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of translation and transcription parts,
where translation and transcription parts could be token ids or text.
:rtype: tuple of (2darray, (list, list))
"""
specgram
=
kaldiio
.
load_mat
(
audio_file
)
specgram
=
specgram
.
transpose
([
1
,
0
])
assert
specgram
.
shape
[
0
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
self
.
_feat_dim
,
specgram
.
shape
[
0
])
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
if
self
.
_keep_transcription_text
:
return
specgram
,
translation
,
transcript
else
:
translation_text_ids
=
self
.
_text_featurizer
.
featurize
(
translation
)
transcript_text_ids
=
self
.
_text_featurizer
.
featurize
(
transcript
)
return
specgram
,
translation_text_ids
,
transcript_text_ids
def
__call__
(
self
,
batch
):
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
translation (List[int] or str): shape (U,)
transcription (List[int] or str): shape (V,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
translation_text : (B, Umax)
translation_text_lens: (B)
transcription_text : (B, Vmax)
transcription_text_lens: (B)
"""
audios
=
[]
audio_lens
=
[]
translation_text
=
[]
translation_text_lens
=
[]
transcription_text
=
[]
transcription_text_lens
=
[]
utts
=
[]
for
utt
,
audio
,
translation
,
transcription
in
batch
:
audio
,
translation
,
transcription
=
self
.
process_utterance
(
audio
,
translation
,
transcription
)
#utt
utts
.
append
(
utt
)
# audio
audios
.
append
(
audio
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
0
])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens
=
[[],
[]]
for
idx
,
text
in
enumerate
([
translation
,
transcription
]):
if
self
.
_keep_transcription_text
:
assert
isinstance
(
text
,
str
),
(
type
(
text
),
text
)
tokens
[
idx
]
=
[
ord
(
t
)
for
t
in
text
]
else
:
tokens
[
idx
]
=
text
# token ids
tokens
[
idx
]
=
tokens
[
idx
]
if
isinstance
(
tokens
[
idx
],
np
.
ndarray
)
else
np
.
array
(
tokens
[
idx
],
dtype
=
np
.
int64
)
translation_text
.
append
(
tokens
[
0
])
translation_text_lens
.
append
(
tokens
[
0
].
shape
[
0
])
transcription_text
.
append
(
tokens
[
1
])
transcription_text_lens
.
append
(
tokens
[
1
].
shape
[
0
])
padded_audios
=
pad_sequence
(
audios
,
padding_value
=
0.0
).
astype
(
np
.
float32
)
#[B, T, D]
audio_lens
=
np
.
array
(
audio_lens
).
astype
(
np
.
int64
)
padded_translation
=
pad_sequence
(
translation_text
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
translation_lens
=
np
.
array
(
translation_text_lens
).
astype
(
np
.
int64
)
padded_transcription
=
pad_sequence
(
transcription_text
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
transcription_lens
=
np
.
array
(
transcription_text_lens
).
astype
(
np
.
int64
)
return
utts
,
padded_audios
,
audio_lens
,
(
padded_translation
,
padded_transcription
),
(
translation_lens
,
transcription_lens
)
deepspeech/io/dataset.py
浏览文件 @
ac0ae57e
...
@@ -19,9 +19,7 @@ from yacs.config import CfgNode
...
@@ -19,9 +19,7 @@ from yacs.config import CfgNode
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.log
import
Log
__all__
=
[
__all__
=
[
"ManifestDataset"
,
"TripletManifestDataset"
]
"ManifestDataset"
,
]
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -105,3 +103,16 @@ class ManifestDataset(Dataset):
...
@@ -105,3 +103,16 @@ class ManifestDataset(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
instance
=
self
.
_manifest
[
idx
]
instance
=
self
.
_manifest
[
idx
]
return
instance
[
"utt"
],
instance
[
"feat"
],
instance
[
"text"
]
return
instance
[
"utt"
],
instance
[
"feat"
],
instance
[
"text"
]
class
TripletManifestDataset
(
ManifestDataset
):
"""
For Joint Training of Speech Translation and ASR.
text: translation,
text1: transcript.
"""
def
__getitem__
(
self
,
idx
):
instance
=
self
.
_manifest
[
idx
]
return
instance
[
"utt"
],
instance
[
"feat"
],
instance
[
"text"
],
instance
[
"text1"
]
deepspeech/models/u2_st.py
0 → 100644
浏览文件 @
ac0ae57e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""U2 ASR Model
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
(https://arxiv.org/pdf/2012.05481.pdf)
"""
import
sys
import
time
from
collections
import
defaultdict
from
typing
import
Dict
from
typing
import
List
from
typing
import
Optional
from
typing
import
Tuple
import
paddle
from
paddle
import
jit
from
paddle
import
nn
from
yacs.config
import
CfgNode
from
deepspeech.frontend.utility
import
IGNORE_ID
from
deepspeech.frontend.utility
import
load_cmvn
from
deepspeech.modules.cmvn
import
GlobalCMVN
from
deepspeech.modules.ctc
import
CTCDecoder
from
deepspeech.modules.decoder
import
TransformerDecoder
from
deepspeech.modules.encoder
import
ConformerEncoder
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.modules.loss
import
LabelSmoothingLoss
from
deepspeech.modules.mask
import
make_pad_mask
from
deepspeech.modules.mask
import
mask_finished_preds
from
deepspeech.modules.mask
import
mask_finished_scores
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.utils
import
checkpoint
from
deepspeech.utils
import
layer_tools
from
deepspeech.utils.ctc_utils
import
remove_duplicates_and_blank
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.tensor_utils
import
add_sos_eos
from
deepspeech.utils.tensor_utils
import
pad_sequence
from
deepspeech.utils.tensor_utils
import
th_accuracy
from
deepspeech.utils.utility
import
log_add
__all__
=
[
"U2STModel"
,
"U2STInferModel"
]
logger
=
Log
(
__name__
).
getlog
()
class
U2STBaseModel
(
nn
.
Module
):
"""CTC-Attention hybrid Encoder-Decoder model"""
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
# network architecture
default
=
CfgNode
()
# allow add new item when merge_with_file
default
.
cmvn_file
=
""
default
.
cmvn_file_type
=
"json"
default
.
input_dim
=
0
default
.
output_dim
=
0
# encoder related
default
.
encoder
=
'transformer'
default
.
encoder_conf
=
CfgNode
(
dict
(
output_size
=
256
,
# dimension of attention
attention_heads
=
4
,
linear_units
=
2048
,
# the number of units of position-wise feed forward
num_blocks
=
12
,
# the number of encoder blocks
dropout_rate
=
0.1
,
positional_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.0
,
input_layer
=
'conv2d'
,
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
=
True
,
# use_cnn_module=True,
# cnn_module_kernel=15,
# activation_type='swish',
# pos_enc_layer_type='rel_pos',
# selfattention_layer_type='rel_selfattn',
))
# decoder related
default
.
decoder
=
'transformer'
default
.
decoder_conf
=
CfgNode
(
dict
(
attention_heads
=
4
,
linear_units
=
2048
,
num_blocks
=
6
,
dropout_rate
=
0.1
,
positional_dropout_rate
=
0.1
,
self_attention_dropout_rate
=
0.0
,
src_attention_dropout_rate
=
0.0
,
))
# hybrid CTC/attention
default
.
model_conf
=
CfgNode
(
dict
(
asr_weight
=
0.0
,
ctc_weight
=
0.0
,
lsm_weight
=
0.1
,
# label smoothing option
length_normalized_loss
=
False
,
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
def
__init__
(
self
,
vocab_size
:
int
,
encoder
:
TransformerEncoder
,
st_decoder
:
TransformerDecoder
,
decoder
:
TransformerDecoder
=
None
,
ctc
:
CTCDecoder
=
None
,
ctc_weight
:
float
=
0.0
,
asr_weight
:
float
=
0.0
,
ignore_id
:
int
=
IGNORE_ID
,
lsm_weight
:
float
=
0.0
,
length_normalized_loss
:
bool
=
False
):
assert
0.0
<=
ctc_weight
<=
1.0
,
ctc_weight
super
().
__init__
()
# note that eos is the same as sos (equivalent ID)
self
.
sos
=
vocab_size
-
1
self
.
eos
=
vocab_size
-
1
self
.
vocab_size
=
vocab_size
self
.
ignore_id
=
ignore_id
self
.
ctc_weight
=
ctc_weight
self
.
asr_weight
=
asr_weight
self
.
encoder
=
encoder
self
.
st_decoder
=
st_decoder
self
.
decoder
=
decoder
self
.
ctc
=
ctc
self
.
criterion_att
=
LabelSmoothingLoss
(
size
=
vocab_size
,
padding_idx
=
ignore_id
,
smoothing
=
lsm_weight
,
normalize_length
=
length_normalized_loss
,
)
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
text_lengths
:
paddle
.
Tensor
,
asr_text
:
paddle
.
Tensor
=
None
,
asr_text_lengths
:
paddle
.
Tensor
=
None
,
)
->
Tuple
[
Optional
[
paddle
.
Tensor
],
Optional
[
paddle
.
Tensor
],
Optional
[
paddle
.
Tensor
]]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
Returns:
total_loss, attention_loss, ctc_loss
"""
assert
text_lengths
.
dim
()
==
1
,
text_lengths
.
shape
# Check that batch_size is unified
assert
(
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
==
text
.
shape
[
0
]
==
text_lengths
.
shape
[
0
]),
(
speech
.
shape
,
speech_lengths
.
shape
,
text
.
shape
,
text_lengths
.
shape
)
# 1. Encoder
start
=
time
.
time
()
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
)
encoder_time
=
time
.
time
()
-
start
#logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
cast
(
paddle
.
int64
).
sum
(
1
)
#[B, 1, T] -> [B]
# 2a. ST-decoder branch
start
=
time
.
time
()
loss_st
,
acc_st
=
self
.
_calc_st_loss
(
encoder_out
,
encoder_mask
,
text
,
text_lengths
)
decoder_time
=
time
.
time
()
-
start
loss_asr_att
=
None
loss_asr_ctc
=
None
# 2b. ASR Attention-decoder branch
if
self
.
asr_weight
>
0.
:
if
self
.
ctc_weight
!=
1.0
:
start
=
time
.
time
()
loss_asr_att
,
acc_att
=
self
.
_calc_att_loss
(
encoder_out
,
encoder_mask
,
asr_text
,
asr_text_lengths
)
decoder_time
=
time
.
time
()
-
start
# 2c. CTC branch
if
self
.
ctc_weight
!=
0.0
:
start
=
time
.
time
()
loss_asr_ctc
=
self
.
ctc
(
encoder_out
,
encoder_out_lens
,
asr_text
,
asr_text_lengths
)
ctc_time
=
time
.
time
()
-
start
if
loss_asr_ctc
is
None
:
loss_asr
=
loss_asr_att
elif
loss_asr_att
is
None
:
loss_asr
=
loss_asr_ctc
else
:
loss_asr
=
self
.
ctc_weight
*
loss_asr_ctc
+
(
1
-
self
.
ctc_weight
)
*
loss_asr_att
loss
=
self
.
asr_weight
*
loss_asr
+
(
1
-
self
.
asr_weight
)
*
loss_st
else
:
loss
=
loss_st
return
loss
,
loss_st
,
loss_asr_att
,
loss_asr_ctc
def
_calc_st_loss
(
self
,
encoder_out
:
paddle
.
Tensor
,
encoder_mask
:
paddle
.
Tensor
,
ys_pad
:
paddle
.
Tensor
,
ys_pad_lens
:
paddle
.
Tensor
,
)
->
Tuple
[
paddle
.
Tensor
,
float
]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad
,
ys_out_pad
=
add_sos_eos
(
ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
ys_in_lens
=
ys_pad_lens
+
1
# 1. Forward decoder
decoder_out
,
_
=
self
.
st_decoder
(
encoder_out
,
encoder_mask
,
ys_in_pad
,
ys_in_lens
)
# 2. Compute attention loss
loss_att
=
self
.
criterion_att
(
decoder_out
,
ys_out_pad
)
acc_att
=
th_accuracy
(
decoder_out
.
view
(
-
1
,
self
.
vocab_size
),
ys_out_pad
,
ignore_label
=
self
.
ignore_id
,
)
return
loss_att
,
acc_att
def
_calc_att_loss
(
self
,
encoder_out
:
paddle
.
Tensor
,
encoder_mask
:
paddle
.
Tensor
,
ys_pad
:
paddle
.
Tensor
,
ys_pad_lens
:
paddle
.
Tensor
,
)
->
Tuple
[
paddle
.
Tensor
,
float
]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad
,
ys_out_pad
=
add_sos_eos
(
ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
ys_in_lens
=
ys_pad_lens
+
1
# 1. Forward decoder
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
ys_in_pad
,
ys_in_lens
)
# 2. Compute attention loss
loss_att
=
self
.
criterion_att
(
decoder_out
,
ys_out_pad
)
acc_att
=
th_accuracy
(
decoder_out
.
view
(
-
1
,
self
.
vocab_size
),
ys_out_pad
,
ignore_label
=
self
.
ignore_id
,
)
return
loss_att
,
acc_att
def
_forward_encoder
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Encoder pass.
Args:
speech (paddle.Tensor): [B, Tmax, D]
speech_lengths (paddle.Tensor): [B]
decoding_chunk_size (int, optional): chuck size. Defaults to -1.
num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1.
simulate_streaming (bool, optional): streaming or not. Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
encoder hiddens mask (B, 1, Tmax).
"""
# Let's assume B = batch_size
# 1. Encoder
if
simulate_streaming
and
decoding_chunk_size
>
0
:
encoder_out
,
encoder_mask
=
self
.
encoder
.
forward_chunk_by_chunk
(
speech
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
else
:
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
return
encoder_out
,
encoder_mask
def
translate
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
beam_size
:
int
=
10
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
,
)
->
paddle
.
Tensor
:
""" Apply beam search on attention decoder
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
paddle.Tensor: decoding result, (batch, max_result_len)
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
device
=
speech
.
place
batch_size
=
speech
.
shape
[
0
]
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
maxlen
=
encoder_out
.
size
(
1
)
encoder_dim
=
encoder_out
.
size
(
2
)
running_size
=
batch_size
*
beam_size
encoder_out
=
encoder_out
.
unsqueeze
(
1
).
repeat
(
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
maxlen
,
encoder_dim
)
# (B*N, maxlen, encoder_dim)
encoder_mask
=
encoder_mask
.
unsqueeze
(
1
).
repeat
(
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
1
,
maxlen
)
# (B*N, 1, max_len)
hyps
=
paddle
.
ones
(
[
running_size
,
1
],
dtype
=
paddle
.
long
).
fill_
(
self
.
sos
)
# (B*N, 1)
# log scale score
scores
=
paddle
.
to_tensor
(
[
0.0
]
+
[
-
float
(
'inf'
)]
*
(
beam_size
-
1
),
dtype
=
paddle
.
float
)
scores
=
scores
.
to
(
device
).
repeat
(
batch_size
).
unsqueeze
(
1
).
to
(
device
)
# (B*N, 1)
end_flag
=
paddle
.
zeros_like
(
scores
,
dtype
=
paddle
.
bool
)
# (B*N, 1)
cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
# 2. Decoder forward step by step
for
i
in
range
(
1
,
maxlen
+
1
):
# Stop if all batch and all beam produce eos
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if
end_flag
.
cast
(
paddle
.
int64
).
sum
()
==
running_size
:
break
# 2.1 Forward decoder step
hyps_mask
=
subsequent_mask
(
i
).
unsqueeze
(
0
).
repeat
(
running_size
,
1
,
1
).
to
(
device
)
# (B*N, i, i)
# logp: (B*N, vocab)
logp
,
cache
=
self
.
st_decoder
.
forward_one_step
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_mask
,
cache
)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (B*N, N)
top_k_logp
=
mask_finished_scores
(
top_k_logp
,
end_flag
)
top_k_index
=
mask_finished_preds
(
top_k_index
,
end_flag
,
self
.
eos
)
# 2.3 Seconde beam prune: select topk score with history
scores
=
scores
+
top_k_logp
# (B*N, N), broadcast add
scores
=
scores
.
view
(
batch_size
,
beam_size
*
beam_size
)
# (B, N*N)
scores
,
offset_k_index
=
scores
.
topk
(
k
=
beam_size
)
# (B, N)
scores
=
scores
.
view
(
-
1
,
1
)
# (B*N, 1)
# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index
=
paddle
.
arange
(
batch_size
).
view
(
-
1
,
1
).
repeat
(
1
,
beam_size
)
# (B, N)
base_k_index
=
base_k_index
*
beam_size
*
beam_size
best_k_index
=
base_k_index
.
view
(
-
1
)
+
offset_k_index
.
view
(
-
1
)
# (B*N)
# 2.5 Update best hyps
best_k_pred
=
paddle
.
index_select
(
top_k_index
.
view
(
-
1
),
index
=
best_k_index
,
axis
=
0
)
# (B*N)
best_hyps_index
=
best_k_index
//
beam_size
last_best_k_hyps
=
paddle
.
index_select
(
hyps
,
index
=
best_hyps_index
,
axis
=
0
)
# (B*N, i)
hyps
=
paddle
.
cat
(
(
last_best_k_hyps
,
best_k_pred
.
view
(
-
1
,
1
)),
dim
=
1
)
# (B*N, i+1)
# 2.6 Update end flag
end_flag
=
paddle
.
eq
(
hyps
[:,
-
1
],
self
.
eos
).
view
(
-
1
,
1
)
# 3. Select best of best
scores
=
scores
.
view
(
batch_size
,
beam_size
)
# TODO: length normalization
best_index
=
paddle
.
argmax
(
scores
,
axis
=-
1
).
long
()
# (B)
best_hyps_index
=
best_index
+
paddle
.
arange
(
batch_size
,
dtype
=
paddle
.
long
)
*
beam_size
best_hyps
=
paddle
.
index_select
(
hyps
,
index
=
best_hyps_index
,
axis
=
0
)
best_hyps
=
best_hyps
[:,
1
:]
return
best_hyps
@
jit
.
export
def
subsampling_rate
(
self
)
->
int
:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return
self
.
encoder
.
embed
.
subsampling_rate
@
jit
.
export
def
right_context
(
self
)
->
int
:
""" Export interface for c++ call, return right_context of the model
"""
return
self
.
encoder
.
embed
.
right_context
@
jit
.
export
def
sos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return sos symbol id of the model
"""
return
self
.
sos
@
jit
.
export
def
eos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return eos symbol id of the model
"""
return
self
.
eos
@
jit
.
export
def
forward_encoder_chunk
(
self
,
xs
:
paddle
.
Tensor
,
offset
:
int
,
required_cache_size
:
int
,
subsampling_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
elayers_output_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
conformer_cnn_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
List
[
paddle
.
Tensor
],
List
[
paddle
.
Tensor
]]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (paddle.Tensor): chunk input
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
Returns:
paddle.Tensor: output, it ranges from time 0 to current chunk.
paddle.Tensor: subsampling cache
List[paddle.Tensor]: attention cache
List[paddle.Tensor]: conformer cnn cache
"""
return
self
.
encoder
.
forward_chunk
(
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
elayers_output_cache
,
conformer_cnn_cache
)
@
jit
.
export
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output
Returns:
paddle.Tensor: activation before ctc
"""
return
self
.
ctc
.
log_softmax
(
xs
)
@
jit
.
export
def
forward_attention_decoder
(
self
,
hyps
:
paddle
.
Tensor
,
hyps_lens
:
paddle
.
Tensor
,
encoder_out
:
paddle
.
Tensor
,
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining, (B, T)
hyps_lens (paddle.Tensor): length of each hyp in hyps, (B)
encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D)
Returns:
paddle.Tensor: decoder output, (B, L)
"""
assert
encoder_out
.
size
(
0
)
==
1
num_hyps
=
hyps
.
size
(
0
)
assert
hyps_lens
.
size
(
0
)
==
num_hyps
encoder_out
=
encoder_out
.
repeat
(
num_hyps
,
1
,
1
)
# (B, 1, T)
encoder_mask
=
paddle
.
ones
(
[
num_hyps
,
1
,
encoder_out
.
size
(
1
)],
dtype
=
paddle
.
bool
)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_lens
)
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
dim
=-
1
)
return
decoder_out
@
paddle
.
no_grad
()
def
decode
(
self
,
feats
:
paddle
.
Tensor
,
feats_lengths
:
paddle
.
Tensor
,
text_feature
:
Dict
[
str
,
int
],
decoding_method
:
str
,
lang_model_path
:
str
,
beam_alpha
:
float
,
beam_beta
:
float
,
beam_size
:
int
,
cutoff_prob
:
float
,
cutoff_top_n
:
int
,
num_processes
:
int
,
ctc_weight
:
float
=
0.0
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
):
"""u2 decoding.
Args:
feats (Tenosr): audio features, (B, T, D)
feats_lengths (Tenosr): (B)
text_feature (TextFeaturizer): text feature object.
decoding_method (str): decoding mode, e.g.
'fullsentence',
'simultaneous'
lang_model_path (str): lm path.
beam_alpha (float): lm weight.
beam_beta (float): length penalty.
beam_size (int): beam size for search
cutoff_prob (float): for prune.
cutoff_top_n (int): for prune.
num_processes (int):
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here.
num_decoding_left_chunks (int, optional):
number of left chunks for decoding. Defaults to -1.
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
Raises:
ValueError: when not support decoding_method.
Returns:
List[List[int]]: transcripts.
"""
batch_size
=
feats
.
size
(
0
)
if
decoding_method
==
'fullsentence'
:
hyps
=
self
.
translate
(
feats
,
feats_lengths
,
beam_size
=
beam_size
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
,
simulate_streaming
=
simulate_streaming
)
hyps
=
[
hyp
.
tolist
()
for
hyp
in
hyps
]
else
:
raise
ValueError
(
f
"Not support decoding method:
{
decoding_method
}
"
)
res
=
[
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
return
res
class
U2STModel
(
U2STBaseModel
):
def
__init__
(
self
,
configs
:
dict
):
vocab_size
,
encoder
,
decoder
=
U2STModel
.
_init_from_config
(
configs
)
if
isinstance
(
decoder
,
Tuple
):
st_decoder
,
asr_decoder
,
ctc
=
decoder
super
().
__init__
(
vocab_size
=
vocab_size
,
encoder
=
encoder
,
st_decoder
=
st_decoder
,
decoder
=
asr_decoder
,
ctc
=
ctc
,
**
configs
[
'model_conf'
])
else
:
super
().
__init__
(
vocab_size
=
vocab_size
,
encoder
=
encoder
,
st_decoder
=
decoder
,
**
configs
[
'model_conf'
])
@
classmethod
def
_init_from_config
(
cls
,
configs
:
dict
):
"""init sub module for model.
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'cmvn_file_type'
])
global_cmvn
=
GlobalCMVN
(
paddle
.
to_tensor
(
mean
,
dtype
=
paddle
.
float
),
paddle
.
to_tensor
(
istd
,
dtype
=
paddle
.
float
))
else
:
global_cmvn
=
None
input_dim
=
configs
[
'input_dim'
]
vocab_size
=
configs
[
'output_dim'
]
assert
input_dim
!=
0
,
input_dim
assert
vocab_size
!=
0
,
vocab_size
encoder_type
=
configs
.
get
(
'encoder'
,
'transformer'
)
logger
.
info
(
f
"U2 Encoder type:
{
encoder_type
}
"
)
if
encoder_type
==
'transformer'
:
encoder
=
TransformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
elif
encoder_type
==
'conformer'
:
encoder
=
ConformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
else
:
raise
ValueError
(
f
"not support encoder type:
{
encoder_type
}
"
)
st_decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
asr_weight
=
configs
[
'model_conf'
][
'asr_weight'
]
logger
.
info
(
f
"ASR Joint Training Weight:
{
asr_weight
}
"
)
if
asr_weight
>
0.
:
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
ctc
=
CTCDecoder
(
odim
=
vocab_size
,
enc_n_units
=
encoder
.
output_size
(),
blank_id
=
0
,
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
)
# sum / batch_size
return
vocab_size
,
encoder
,
(
st_decoder
,
decoder
,
ctc
)
else
:
return
vocab_size
,
encoder
,
st_decoder
@
classmethod
def
from_config
(
cls
,
configs
:
dict
):
"""init model.
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
nn.Layer: U2STModel
"""
model
=
cls
(
configs
)
return
model
@
classmethod
def
from_pretrained
(
cls
,
dataloader
,
config
,
checkpoint_path
):
"""Build a DeepSpeech2Model model from a pretrained model.
Args:
dataloader (paddle.io.DataLoader): not used.
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
config
.
defrost
()
config
.
input_dim
=
dataloader
.
collate_fn
.
feature_size
config
.
output_dim
=
dataloader
.
collate_fn
.
vocab_size
config
.
freeze
()
model
=
cls
.
from_config
(
config
)
if
checkpoint_path
:
infos
=
checkpoint
.
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
layer_tools
.
summary
(
model
)
return
model
class
U2STInferModel
(
U2STModel
):
def
__init__
(
self
,
configs
:
dict
):
super
().
__init__
(
configs
)
def
forward
(
self
,
feats
,
feats_lengths
,
decoding_chunk_size
=-
1
,
num_decoding_left_chunks
=-
1
,
simulate_streaming
=
False
):
"""export model function
Args:
feats (Tensor): [B, T, D]
feats_lengths (Tensor): [B]
Returns:
List[List[int]]: best path result
"""
return
self
.
translate
(
feats
,
feats_lengths
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
,
simulate_streaming
=
simulate_streaming
)
deepspeech/utils/bleu_score.py
0 → 100644
浏览文件 @
ac0ae57e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
"""
import
numpy
as
np
import
sacrebleu
__all__
=
[
'bleu'
,
'char_bleu'
]
def
bleu
(
hypothesis
,
reference
):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in word-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference length is zero.
"""
return
sacrebleu
.
corpus_bleu
(
hypothesis
,
reference
)
def
char_bleu
(
hypothesis
,
reference
):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in char-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference number is zero.
"""
hypothesis
=
[
' '
.
join
(
list
(
hyp
.
replace
(
' '
,
''
)))
for
hyp
in
hypothesis
]
reference
=
[[
' '
.
join
(
list
(
ref_i
.
replace
(
' '
,
''
)))
for
ref_i
in
ref
]
for
ref
in
reference
]
return
sacrebleu
.
corpus_bleu
(
hypothesis
,
reference
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录