Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
279348d7
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
279348d7
编写于
6月 08, 2021
作者:
H
Haoxin Ma
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move process utt to collator
上级
8781ab58
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
120 addition
and
85 deletion
+120
-85
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+1
-1
deepspeech/io/collator.py
deepspeech/io/collator.py
+116
-1
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+1
-81
examples/tiny/s0/conf/deepspeech2.yaml
examples/tiny/s0/conf/deepspeech2.yaml
+2
-2
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
279348d7
...
...
@@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer):
sortagrad
=
config
.
data
.
sortagrad
,
shuffle_method
=
config
.
data
.
shuffle_method
)
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
False
)
collate_fn
=
SpeechCollator
(
config
,
keep_transcription_text
=
False
)
self
.
train_loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
batch_sampler
,
...
...
deepspeech/io/collator.py
浏览文件 @
279348d7
...
...
@@ -16,14 +16,22 @@ import numpy as np
from
deepspeech.frontend.utility
import
IGNORE_ID
from
deepspeech.io.utility
import
pad_sequence
from
deepspeech.utils.log
import
Log
from
deepspeech.frontend.augmentor.augmentation
import
AugmentationPipeline
from
deepspeech.frontend.featurizer.speech_featurizer
import
SpeechFeaturizer
from
deepspeech.frontend.normalizer
import
FeatureNormalizer
from
deepspeech.frontend.speech
import
SpeechSegment
import
io
import
time
__all__
=
[
"SpeechCollator"
]
logger
=
Log
(
__name__
).
getlog
()
# namedtupe need global for pickle.
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
class
SpeechCollator
():
def
__init__
(
self
,
keep_transcription_text
=
True
):
def
__init__
(
self
,
config
,
keep_transcription_text
=
True
):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
...
...
@@ -32,6 +40,112 @@ class SpeechCollator():
"""
self
.
_keep_transcription_text
=
keep_transcription_text
if
isinstance
(
config
.
data
.
augmentation_config
,
(
str
,
bytes
)):
if
config
.
data
.
augmentation_config
:
aug_file
=
io
.
open
(
config
.
data
.
augmentation_config
,
mode
=
'r'
,
encoding
=
'utf8'
)
else
:
aug_file
=
io
.
StringIO
(
initial_value
=
'{}'
,
newline
=
''
)
else
:
aug_file
=
config
.
data
.
augmentation_config
assert
isinstance
(
aug_file
,
io
.
StringIO
)
self
.
_local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{}
)
self
.
_augmentation_pipeline
=
AugmentationPipeline
(
augmentation_config
=
aug_file
.
read
(),
random_seed
=
config
.
data
.
random_seed
)
self
.
_normalizer
=
FeatureNormalizer
(
config
.
data
.
mean_std_filepath
)
if
config
.
data
.
mean_std_filepath
else
None
self
.
_stride_ms
=
config
.
data
.
stride_ms
self
.
_target_sample_rate
=
config
.
data
.
target_sample_rate
self
.
_speech_featurizer
=
SpeechFeaturizer
(
unit_type
=
config
.
data
.
unit_type
,
vocab_filepath
=
config
.
data
.
vocab_filepath
,
spm_model_prefix
=
config
.
data
.
spm_model_prefix
,
specgram_type
=
config
.
data
.
specgram_type
,
feat_dim
=
config
.
data
.
feat_dim
,
delta_delta
=
config
.
data
.
delta_delta
,
stride_ms
=
config
.
data
.
stride_ms
,
window_ms
=
config
.
data
.
window_ms
,
n_fft
=
config
.
data
.
n_fft
,
max_freq
=
config
.
data
.
max_freq
,
target_sample_rate
=
config
.
data
.
target_sample_rate
,
use_dB_normalization
=
config
.
data
.
use_dB_normalization
,
target_dB
=
config
.
data
.
target_dB
,
dither
=
config
.
data
.
dither
)
def
_parse_tar
(
self
,
file
):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result
=
{}
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
_subfile_from_tar
(
self
,
file
):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
'tar2info'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2info
=
{}
if
'tar2object'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2object
=
{}
if
tarpath
not
in
self
.
_local_data
.
tar2info
:
object
,
infoes
=
self
.
_parse_tar
(
tarpath
)
self
.
_local_data
.
tar2info
[
tarpath
]
=
infoes
self
.
_local_data
.
tar2object
[
tarpath
]
=
object
return
self
.
_local_data
.
tar2object
[
tarpath
].
extractfile
(
self
.
_local_data
.
tar2info
[
tarpath
][
filename
])
def
process_utterance
(
self
,
audio_file
,
transcript
):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
start_time
=
time
.
time
()
if
isinstance
(
audio_file
,
str
)
and
audio_file
.
startswith
(
'tar:'
):
speech_segment
=
SpeechSegment
.
from_file
(
self
.
_subfile_from_tar
(
audio_file
),
transcript
)
else
:
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
transcript
)
load_wav_time
=
time
.
time
()
-
start_time
#logger.debug(f"load wav time: {load_wav_time}")
# audio augment
start_time
=
time
.
time
()
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
audio_aug_time
=
time
.
time
()
-
start_time
#logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time
=
time
.
time
()
specgram
,
transcript_part
=
self
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
_keep_transcription_text
)
if
self
.
_normalizer
:
specgram
=
self
.
_normalizer
.
apply
(
specgram
)
feature_time
=
time
.
time
()
-
start_time
#logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment
start_time
=
time
.
time
()
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
feature_aug_time
=
time
.
time
()
-
start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return
specgram
,
transcript_part
def
__call__
(
self
,
batch
):
"""batch examples
...
...
@@ -53,6 +167,7 @@ class SpeechCollator():
text_lens
=
[]
utts
=
[]
for
utt
,
audio
,
text
in
batch
:
audio
,
text
=
self
.
process_utterance
(
audio
,
text
)
#utt
utts
.
append
(
utt
)
# audio
...
...
deepspeech/io/dataset.py
浏览文件 @
279348d7
...
...
@@ -34,9 +34,6 @@ __all__ = [
logger
=
Log
(
__name__
).
getlog
()
# namedtupe need global for pickle.
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
class
ManifestDataset
(
Dataset
):
@
classmethod
...
...
@@ -192,10 +189,6 @@ class ManifestDataset(Dataset):
self
.
_stride_ms
=
stride_ms
self
.
_target_sample_rate
=
target_sample_rate
self
.
_normalizer
=
FeatureNormalizer
(
mean_std_filepath
)
if
mean_std_filepath
else
None
self
.
_augmentation_pipeline
=
AugmentationPipeline
(
augmentation_config
=
augmentation_config
,
random_seed
=
random_seed
)
self
.
_speech_featurizer
=
SpeechFeaturizer
(
unit_type
=
unit_type
,
vocab_filepath
=
vocab_filepath
,
...
...
@@ -214,8 +207,6 @@ class ManifestDataset(Dataset):
self
.
_rng
=
np
.
random
.
RandomState
(
random_seed
)
self
.
_keep_transcription_text
=
keep_transcription_text
# for caching tar files info
self
.
_local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
# read manifest
self
.
_manifest
=
read_manifest
(
...
...
@@ -256,74 +247,7 @@ class ManifestDataset(Dataset):
def
stride_ms
(
self
):
return
self
.
_speech_featurizer
.
stride_ms
def
_parse_tar
(
self
,
file
):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result
=
{}
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
_subfile_from_tar
(
self
,
file
):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
'tar2info'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2info
=
{}
if
'tar2object'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2object
=
{}
if
tarpath
not
in
self
.
_local_data
.
tar2info
:
object
,
infoes
=
self
.
_parse_tar
(
tarpath
)
self
.
_local_data
.
tar2info
[
tarpath
]
=
infoes
self
.
_local_data
.
tar2object
[
tarpath
]
=
object
return
self
.
_local_data
.
tar2object
[
tarpath
].
extractfile
(
self
.
_local_data
.
tar2info
[
tarpath
][
filename
])
def
process_utterance
(
self
,
utt
,
audio_file
,
transcript
):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
start_time
=
time
.
time
()
if
isinstance
(
audio_file
,
str
)
and
audio_file
.
startswith
(
'tar:'
):
speech_segment
=
SpeechSegment
.
from_file
(
self
.
_subfile_from_tar
(
audio_file
),
transcript
)
else
:
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
transcript
)
load_wav_time
=
time
.
time
()
-
start_time
#logger.debug(f"load wav time: {load_wav_time}")
# audio augment
start_time
=
time
.
time
()
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
audio_aug_time
=
time
.
time
()
-
start_time
#logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time
=
time
.
time
()
specgram
,
transcript_part
=
self
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
_keep_transcription_text
)
if
self
.
_normalizer
:
specgram
=
self
.
_normalizer
.
apply
(
specgram
)
feature_time
=
time
.
time
()
-
start_time
#logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment
start_time
=
time
.
time
()
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
feature_aug_time
=
time
.
time
()
-
start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return
utt
,
specgram
,
transcript_part
def
_instance_reader_creator
(
self
,
manifest
):
"""
...
...
@@ -336,8 +260,6 @@ class ManifestDataset(Dataset):
def
reader
():
for
instance
in
manifest
:
# inst = self.process_utterance(instance["feat"],
# instance["text"])
inst
=
self
.
process_utterance
(
instance
[
"utt"
],
instance
[
"feat"
],
instance
[
"text"
])
yield
inst
...
...
@@ -349,6 +271,4 @@ class ManifestDataset(Dataset):
def
__getitem__
(
self
,
idx
):
instance
=
self
.
_manifest
[
idx
]
return
self
.
process_utterance
(
instance
[
"utt"
],
instance
[
"feat"
],
instance
[
"text"
])
# return self.process_utterance(instance["feat"], instance["text"])
return
(
instance
[
"utt"
],
instance
[
"feat"
],
instance
[
"text"
])
examples/tiny/s0/conf/deepspeech2.yaml
浏览文件 @
279348d7
...
...
@@ -6,7 +6,7 @@ data:
mean_std_filepath
:
data/mean_std.json
vocab_filepath
:
data/vocab.txt
augmentation_config
:
conf/augmentation.json
batch_size
:
4
batch_size
:
2
min_input_len
:
0.0
max_input_len
:
27.0
min_output_len
:
0.0
...
...
@@ -37,7 +37,7 @@ model:
share_rnn_weights
:
True
training
:
n_epoch
:
2
0
n_epoch
:
1
0
lr
:
1e-5
lr_decay
:
1.0
weight_decay
:
1e-06
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录